diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 2dd21618..7481ac5b 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -81,7 +81,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -96,7 +96,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -171,6 +171,14 @@ func (s *devicesStatements) InsertDevice( }, nil } +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + return s.InsertDevice(ctx, txn, id, localpart, serverName, accessToken, displayName, ipAddr, userAgent) +} + // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, @@ -271,7 +279,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { @@ -303,7 +311,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f549dcef..bf94f14d 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -588,16 +588,42 @@ func (d *Database) CreateDevice( deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - var err error - // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { - return err - } + _, ok := d.Writer.(*sqlutil.ExclusiveWriter) + if ok { // we're using most likely using SQLite, so do things a little different + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) - return err - }) + devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + // No devices yet, only create a new one + if len(devices) == 0 { + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + } + sessionID := devices[0].SessionID + 1 + + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + // Create a new device with the session ID incremented + dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID) + return err + }) + } else { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var err error + // Revoke existing tokens for this device + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { + return err + } + + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) + return err + }) + } } else { // We generate device IDs in a loop in case its already taken. // We cap this at going round 5 times to ensure we don't spin forever @@ -618,7 +644,7 @@ func (d *Database) CreateDevice( } } } - return + return dev, returnErr } // generateDeviceID creates a new device id. Returns an error if failed to generate diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index c5db34bd..449e4549 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -65,7 +65,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" @@ -80,7 +80,7 @@ const deleteDevicesSQL = "" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" @@ -162,6 +162,27 @@ func (s *devicesStatements) InsertDevice( }, nil } +func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, + sessionID int64, +) (*api.Device, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + return nil, err + } + return &api.Device{ + ID: id, + UserID: userutil.MakeUserID(localpart, serverName), + AccessToken: accessToken, + SessionID: sessionID, + LastSeenTS: createdTimeMS, + LastSeenIP: ipAddr, + UserAgent: userAgent, + }, nil +} + func (s *devicesStatements) DeleteDevice( ctx context.Context, txn *sql.Tx, id string, localpart string, serverName gomatrixserverlib.ServerName, @@ -271,7 +292,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( var lastseents sql.NullInt64 var id, displayname, ip, useragent sql.NullString for rows.Next() { - err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) + err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID) if err != nil { return devices, err } @@ -317,7 +338,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil { return nil, err } if displayName.Valid { diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index e14776cf..9221e571 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -44,6 +44,7 @@ type AccountsTable interface { type DevicesTable interface { InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error) DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error