package storage_test import ( "context" "encoding/json" "fmt" "testing" "time" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" ) const loginTokenLifetime = time.Minute var ( openIDLifetimeMS = time.Minute.Milliseconds() ctx = context.Background() ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { t.Fatalf("NewUserAPIDatabase returned %s", err) } return db, func() { close() baseclose() } } // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) events := room.Events() contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) assert.NoError(t, err, "unable to save account data") contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) assert.NoError(t, err, "unable to save account data") accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") assert.NoError(t, err, "unable to get account data by type") assert.Equal(t, contentRoom, accountData) globalData, roomData, err := db.GetAccountData(ctx, localpart) assert.NoError(t, err) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) }) } // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) available, err = db.CheckAccountAvailability(ctx, "unusedname") assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart first, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") second, err := db.GetNewNumericLocalpart(ctx) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice err = db.SetPassword(ctx, aliceLocalpart, "newPassword") assert.NoError(t, err, "failed to update password") accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account err = db.DeactivateAccount(ctx, aliceLocalpart) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") assert.Error(t, err, "expected an error, got none") _, err = db.GetAccountByLocalpart(ctx, "unusename") assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart // if there's already a user without a localpart in the database _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) assert.NoError(t, err) // test getting a numeric localpart, with an existing user without a localpart _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) assert.NoError(t, err) // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) assert.NoError(t, err) // Now try to create a new guest user _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) assert.NoError(t, err) }) } func Test_Devices(t *testing.T) { alice := test.NewUser(t) localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken) assert.NoError(t, err, "unable to get device by access token") assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields // create a device without existing device ID accessToken = util.RandomString(16) deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices devices, err := db.GetDevicesByLocalpart(ctx, localpart) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) for _, dev := range devices { deviceIDs = append(deviceIDs, dev.ID) } devices2, err := db.GetDevicesByID(ctx, deviceIDs) assert.NoError(t, err, "unable to get devices by id") assert.ElementsMatch(t, devices, devices2) // Update device newName := "new display name" err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP) assert.Greater(t, gotDevice.LastSeenTS, updatedAfterTimestamp) // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") devices, err = db.GetDevicesByLocalpart(ctx, localpart) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) err = db.RemoveDevices(ctx, localpart, deviceIDs) assert.NoError(t, err, "unable to remove devices") devices, err = db.GetDevicesByLocalpart(ctx, localpart) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) deleted, err := db.RemoveAllDevices(ctx, localpart, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) }) } func Test_KeyBackup(t *testing.T) { alice := test.NewUser(t) room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData) assert.NoError(t, err, "unable to create key backup") // get key backup by version gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion) assert.NoError(t, err, "unable to get key backup") assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") // get any key backup gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "") assert.NoError(t, err, "unable to get key backup") assert.Equal(t, wantVersion, gotVersion, "backup version mismatch") assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch") assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch") err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data")) assert.NoError(t, err, "unable to update key backup auth data") uploads := []api.InternalKeyBackupSession{ { KeyBackupSession: api.KeyBackupSession{ IsVerified: true, SessionData: wantAuthData, }, RoomID: room.ID, SessionID: "1", }, { KeyBackupSession: api.KeyBackupSession{}, RoomID: room.ID, SessionID: "2", }, } count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads) assert.NoError(t, err, "unable to upsert backup keys") assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") // do it again to update a key uploads[1].IsVerified = true count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:]) assert.NoError(t, err, "unable to upsert backup keys") assert.Equal(t, int64(len(uploads)), count, "unexpected backup count") // get backup keys by session id gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1") assert.NoError(t, err, "unable to get backup keys") assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) // get backup keys by room id gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "") assert.NoError(t, err, "unable to get backup keys") assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"]) gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID) assert.NoError(t, err, "unable to get backup keys count") assert.Equal(t, count, gotCount, "unexpected backup count") // finally delete a key exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion) assert.NoError(t, err, "unable to delete key backup") assert.True(t, exists) // this key should not exist exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3") assert.NoError(t, err, "unable to delete key backup") assert.False(t, exists) }) } func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() // create a new token wantLoginToken := &api.LoginTokenData{UserID: alice.ID} gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken) assert.NoError(t, err, "unable to create login token") assert.NotNil(t, gotMetadata) assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime)) // get the new token gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) assert.NoError(t, err, "unable to get login token") assert.NotNil(t, gotLoginToken) assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token") // remove the login token again err = db.RemoveLoginToken(ctx, gotMetadata.Token) assert.NoError(t, err, "unable to remove login token") // check if the token was actually deleted _, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token) assert.Error(t, err, "expected an error, but got none") }) } func Test_OpenID(t *testing.T) { alice := test.NewUser(t) token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS expires, err := db.CreateOpenIDToken(ctx, token, alice.ID) assert.NoError(t, err, "unable to create OpenID token") assert.Equal(t, expiresAtMS, expires) attributes, err := db.GetOpenIDTokenAttributes(ctx, token) assert.NoError(t, err, "unable to get OpenID token attributes") assert.Equal(t, alice.ID, attributes.UserID) assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS) }) } func Test_Profile(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() // create account, which also creates a profile _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get profile by localpart") wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" wantProfile.AvatarURL = "mxc://aliceAvatar" err = db.SetDisplayName(ctx, aliceLocalpart, "Alice") assert.NoError(t, err, "unable to set displayname") err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") // verify profile gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get profile by localpart") assert.Equal(t, wantProfile, gotProfile) // search profiles searchRes, err := db.SearchProfiles(ctx, "Alice", 2) assert.NoError(t, err, "unable to search profiles") assert.Equal(t, 1, len(searchRes)) assert.Equal(t, *wantProfile, searchRes[0]) }) } func Test_Pusher(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() appID := util.RandomString(8) var pushKeys []string var gotPushers []api.Pusher for i := 0; i < 2; i++ { pushKey := util.RandomString(8) wantPusher := api.Pusher{ PushKey: pushKey, Kind: api.HTTPKind, AppID: appID, AppDisplayName: util.RandomString(8), DeviceDisplayName: util.RandomString(8), ProfileTag: util.RandomString(8), Language: util.RandomString(2), } err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) assert.NoError(t, err, "unable to upsert pusher") // check it was actually persisted gotPushers, err = db.GetPushers(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, wantPusher, gotPushers[i]) pushKeys = append(pushKeys, pushKey) } // remove single pusher err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) assert.NoError(t, err, "unable to remove pusher") gotPushers, err := db.GetPushers(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 1, len(gotPushers)) // remove last pusher err = db.RemovePushers(ctx, appID, pushKeys[1]) assert.NoError(t, err, "unable to remove pusher") gotPushers, err = db.GetPushers(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 0, len(gotPushers)) }) } func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ Address: threePID, Medium: medium, }, threepids[0]) // remove threepid association err = db.RemoveThreePIDAssociation(ctx, threePID, medium) assert.NoError(t, err, "unexpected error") // verify it was deleted threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) } func Test_Notification(t *testing.T) { alice := test.NewUser(t) aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { eventID := util.RandomString(16) roomID := room.ID ts := time.Now() if i > 5 { roomID = room2.ID // create some old notifications to test DeleteOldNotifications ts = ts.AddDate(0, -2, 0) } notification := &api.Notification{ Actions: []*pushrules.Action{ {}, }, Event: gomatrixserverlib.ClientEvent{ Content: gomatrixserverlib.RawJSON("{}"), }, Read: false, RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } // get notifications count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") assert.Equal(t, int64(10), count) notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count) assert.Equal(t, 10, len(notifs)) // ... for a specific room total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(4), total) // mark notification as read affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(2), total) // delete old notifications err = db.DeleteOldNotifications(ctx) assert.NoError(t, err) // this should now return 0 notifications total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(0), total) }) }