mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-25 21:21:35 +00:00
Refactor user API storage (#2202)
* Refactor User API database * Fix migration bugs
This commit is contained in:
parent
9bd5e414c9
commit
9f4a39e8e0
@ -21,6 +21,7 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
@ -56,19 +57,20 @@ type accountDataStatements struct {
|
|||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||||
_, err = db.Exec(accountDataSchema)
|
s := &accountDataStatements{}
|
||||||
|
_, err := db.Exec(accountDataSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
||||||
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
||||||
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||||
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountData(
|
func (s *accountDataStatements) SelectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (
|
) (
|
||||||
/* global */ map[string]json.RawMessage,
|
/* global */ map[string]json.RawMessage,
|
||||||
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
|
|||||||
return global, rooms, rows.Err()
|
return global, rooms, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataByType(
|
func (s *accountDataStatements) SelectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -78,14 +79,15 @@ type accountsStatements struct {
|
|||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||||
_, err := db.Exec(accountsSchema)
|
s := &accountsStatements{
|
||||||
return err
|
serverName: serverName,
|
||||||
}
|
}
|
||||||
|
_, err := db.Exec(accountsSchema)
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
if err != nil {
|
||||||
s.serverName = server
|
return nil, err
|
||||||
return sqlutil.StatementList{
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertAccountStmt, insertAccountSQL},
|
{&s.insertAccountStmt, insertAccountSQL},
|
||||||
{&s.updatePasswordStmt, updatePasswordSQL},
|
{&s.updatePasswordStmt, updatePasswordSQL},
|
||||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||||
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) updatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart, passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) deactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,53 +112,32 @@ type devicesStatements struct {
|
|||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||||
|
s := &devicesStatements{
|
||||||
|
serverName: serverName,
|
||||||
|
}
|
||||||
_, err := db.Exec(devicesSchema)
|
_, err := db.Exec(devicesSchema)
|
||||||
return err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||||
if err = s.execSchema(db); err != nil {
|
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||||
return
|
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||||
}
|
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||||
return
|
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||||
}
|
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
{&s.deleteDevicesStmt, deleteDevicesSQL},
|
||||||
return
|
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||||
}
|
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
}.Prepare(db)
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deleteDevice removes a single device by id and user localpart.
|
// deleteDevice removes a single device by id and user localpart.
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
|
|||||||
|
|
||||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||||
// Returns an error if the execution failed.
|
// Returns an error if the execution failed.
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||||
@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
|
|||||||
|
|
||||||
// deleteDevicesByLocalpart removes all devices for the
|
// deleteDevicesByLocalpart removes all devices for the
|
||||||
// given user localpart.
|
// given user localpart.
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) SelectDeviceByToken(
|
||||||
ctx context.Context, accessToken string,
|
ctx context.Context, accessToken string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||||||
|
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) selectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||||||
return &dev, err
|
return &dev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const keyBackupTableSchema = `
|
const keyBackupTableSchema = `
|
||||||
@ -72,12 +73,13 @@ type keyBackupStatements struct {
|
|||||||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
||||||
_, err = db.Exec(keyBackupTableSchema)
|
s := &keyBackupStatements{}
|
||||||
|
_, err := db.Exec(keyBackupTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||||
{&s.countKeysStmt, countKeysSQL},
|
{&s.countKeysStmt, countKeysSQL},
|
||||||
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s keyBackupStatements) countKeys(
|
func (s keyBackupStatements) CountKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (count int64, err error) {
|
) (count int64, err error) {
|
||||||
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) insertBackupKey(
|
func (s *keyBackupStatements) InsertBackupKey(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
||||||
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) updateBackupKey(
|
func (s *keyBackupStatements) UpdateBackupKey(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
||||||
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeys(
|
func (s *keyBackupStatements) SelectKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||||
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
|
|||||||
return unpackKeys(ctx, rows)
|
return unpackKeys(ctx, rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
func (s *keyBackupStatements) SelectKeysByRoomID(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||||
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
|
|||||||
return unpackKeys(ctx, rows)
|
return unpackKeys(ctx, rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const keyBackupVersionTableSchema = `
|
const keyBackupVersionTableSchema = `
|
||||||
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
|
|||||||
updateKeyBackupETagStmt *sql.Stmt
|
updateKeyBackupETagStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
|
||||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
s := &keyBackupVersionStatements{}
|
||||||
|
_, err := db.Exec(keyBackupVersionTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
||||||
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
||||||
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
||||||
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) insertKeyBackup(
|
func (s *keyBackupVersionStatements) InsertKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
||||||
) (version string, err error) {
|
) (version string, err error) {
|
||||||
var versionInt int64
|
var versionInt int64
|
||||||
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
|
|||||||
return strconv.FormatInt(versionInt, 10), err
|
return strconv.FormatInt(versionInt, 10), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
||||||
) error {
|
) error {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) deleteKeyBackup(
|
func (s *keyBackupVersionStatements) DeleteKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
|
|||||||
return ra == 1, nil
|
return ra == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) selectKeyBackup(
|
func (s *keyBackupVersionStatements) SelectKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||||
var versionInt int64
|
var versionInt int64
|
||||||
|
@ -21,18 +21,11 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type loginTokenStatements struct {
|
const loginTokenSchema = `
|
||||||
insertStmt *sql.Stmt
|
|
||||||
deleteStmt *sql.Stmt
|
|
||||||
selectByTokenStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
// execSchema ensures tables and indices exist.
|
|
||||||
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||||
-- The random value of the token issued to a user
|
-- The random value of the token issued to a user
|
||||||
token TEXT NOT NULL PRIMARY KEY,
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
@ -45,24 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||||||
|
|
||||||
-- This index allows efficient garbage collection of expired tokens.
|
-- This index allows efficient garbage collection of expired tokens.
|
||||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||||
`)
|
`
|
||||||
return err
|
|
||||||
|
const insertLoginTokenSQL = "" +
|
||||||
|
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
|
const deleteLoginTokenSQL = "" +
|
||||||
|
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||||
|
|
||||||
|
const selectLoginTokenSQL = "" +
|
||||||
|
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||||
|
|
||||||
|
type loginTokenStatements struct {
|
||||||
|
insertStmt *sql.Stmt
|
||||||
|
deleteStmt *sql.Stmt
|
||||||
|
selectStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare runs statement preparation.
|
func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
s := &loginTokenStatements{}
|
||||||
if err := s.execSchema(db); err != nil {
|
_, err := db.Exec(loginTokenSchema)
|
||||||
return err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
{&s.insertStmt, insertLoginTokenSQL},
|
||||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
{&s.deleteStmt, deleteLoginTokenSQL},
|
||||||
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
|
{&s.selectStmt, selectLoginTokenSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert adds an already generated token to the database.
|
// insert adds an already generated token to the database.
|
||||||
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
||||||
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
||||||
return err
|
return err
|
||||||
@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
|
|||||||
//
|
//
|
||||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||||
// The login_tokens_expiration_idx index should make that efficient.
|
// The login_tokens_expiration_idx index should make that efficient.
|
||||||
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
|
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
||||||
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||||
var data api.LoginTokenData
|
var data api.LoginTokenData
|
||||||
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
const selectTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
|
|
||||||
type tokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
insertTokenStmt *sql.Stmt
|
insertTokenStmt *sql.Stmt
|
||||||
selectTokenStmt *sql.Stmt
|
selectTokenStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
|
||||||
_, err = db.Exec(openIDTokenSchema)
|
s := &openIDTokenStatements{
|
||||||
if err != nil {
|
serverName: serverName,
|
||||||
return
|
|
||||||
}
|
}
|
||||||
s.serverName = server
|
_, err := db.Exec(openIDTokenSchema)
|
||||||
return sqlutil.StatementList{
|
if err != nil {
|
||||||
{&s.insertTokenStmt, insertTokenSQL},
|
return nil, err
|
||||||
{&s.selectTokenStmt, selectTokenSQL},
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.insertTokenStmt, insertOpenIDTokenSQL},
|
||||||
|
{&s.selectTokenStmt, selectOpenIDTokenSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertToken inserts a new OpenID Connect token to the DB.
|
// insertToken inserts a new OpenID Connect token to the DB.
|
||||||
// Returns new token, otherwise returns error if the token already exists.
|
// Returns new token, otherwise returns error if the token already exists.
|
||||||
func (s *tokenStatements) insertToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string,
|
||||||
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
|
|||||||
|
|
||||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||||
// Returns the existing token's attributes, or err if no token is found
|
// Returns the existing token's attributes, or err if no token is found
|
||||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
@ -59,12 +60,13 @@ type profilesStatements struct {
|
|||||||
selectProfilesBySearchStmt *sql.Stmt
|
selectProfilesBySearchStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
|
||||||
_, err = db.Exec(profilesSchema)
|
s := &profilesStatements{}
|
||||||
|
_, err := db.Exec(profilesSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertProfileStmt, insertProfileSQL},
|
{&s.insertProfileStmt, insertProfileSQL},
|
||||||
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
||||||
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
||||||
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||||||
return &profile, nil
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfilesBySearch(
|
func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
ctx context.Context, searchString string, limit int,
|
ctx context.Context, searchString string, limit int,
|
||||||
) ([]authtypes.Profile, error) {
|
) ([]authtypes.Profile, error) {
|
||||||
var profiles []authtypes.Profile
|
var profiles []authtypes.Profile
|
||||||
|
@ -15,76 +15,33 @@
|
|||||||
package postgres
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||||
|
|
||||||
// Import the postgres database driver.
|
// Import the postgres database driver.
|
||||||
_ "github.com/lib/pq"
|
_ "github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Database represents an account database
|
|
||||||
type Database struct {
|
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
sqlutil.PartitionOffsetStatements
|
|
||||||
accounts accountsStatements
|
|
||||||
profiles profilesStatements
|
|
||||||
accountDatas accountDataStatements
|
|
||||||
threepids threepidStatements
|
|
||||||
openIDTokens tokenStatements
|
|
||||||
keyBackupVersions keyBackupVersionStatements
|
|
||||||
devices devicesStatements
|
|
||||||
loginTokens loginTokenStatements
|
|
||||||
loginTokenLifetime time.Duration
|
|
||||||
keyBackups keyBackupStatements
|
|
||||||
serverName gomatrixserverlib.ServerName
|
|
||||||
bcryptCost int
|
|
||||||
openIDTokenLifetimeMS int64
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// The length of generated device IDs
|
|
||||||
deviceIDByteLength = 6
|
|
||||||
loginTokenByteLength = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &Database{
|
|
||||||
serverName: serverName,
|
|
||||||
db: db,
|
|
||||||
writer: sqlutil.NewDummyWriter(),
|
|
||||||
loginTokenLifetime: loginTokenLifetime,
|
|
||||||
bcryptCost: bcryptCost,
|
|
||||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
m := sqlutil.NewMigrations()
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
if _, err = db.Exec(accountsSchema); err != nil {
|
||||||
if err = d.accounts.execSchema(db); err != nil {
|
// do this so that the migration can and we don't fail on
|
||||||
|
// preparing statements for columns that don't exist yet
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
@ -92,638 +49,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
|
accountDataTable, err := NewPostgresAccountDataTable(db)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.profiles.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accountDatas.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.devices.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.loginTokens.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByPassword(
|
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
accountsTable, err := NewPostgresAccountsTable(db, serverName)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
|
||||||
func (d *Database) GetProfileByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (*authtypes.Profile, error) {
|
|
||||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetAvatarURL(
|
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
|
||||||
) error {
|
|
||||||
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetDisplayName(
|
|
||||||
ctx context.Context, localpart string, displayName string,
|
|
||||||
) error {
|
|
||||||
return d.profiles.setDisplayName(ctx, localpart, displayName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
|
||||||
func (d *Database) SetPassword(
|
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
|
||||||
) error {
|
|
||||||
hash, err := d.hashPassword(plaintextPassword)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
|
||||||
}
|
}
|
||||||
return d.accounts.updatePassword(ctx, localpart, hash)
|
devicesTable, err := NewPostgresDevicesTable(db, serverName)
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
|
||||||
// account already exists, it will return nil, sqlutil.ErrUserExists.
|
|
||||||
func (d *Database) CreateAccount(
|
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
|
||||||
) (acc *api.Account, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
// For guest accounts, we create a new numeric local part
|
|
||||||
if accountType == api.AccountTypeGuest {
|
|
||||||
var numLocalpart int64
|
|
||||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
|
||||||
}
|
}
|
||||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
keyBackupTable, err := NewPostgresKeyBackupTable(db)
|
||||||
plaintextPassword = ""
|
|
||||||
appserviceID = ""
|
|
||||||
}
|
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) createAccount(
|
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
var account *api.Account
|
|
||||||
var err error
|
|
||||||
// Generate a password hash if this is not a password-less user
|
|
||||||
hash := ""
|
|
||||||
if plaintextPassword != "" {
|
|
||||||
hash, err = d.hashPassword(plaintextPassword)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
|
||||||
}
|
}
|
||||||
}
|
keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
|
||||||
if sqlutil.IsUniqueConstraintViolationErr(err) {
|
|
||||||
return nil, sqlutil.ErrUserExists
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
|
||||||
"global": {
|
|
||||||
"content": [],
|
|
||||||
"override": [],
|
|
||||||
"room": [],
|
|
||||||
"sender": [],
|
|
||||||
"underride": []
|
|
||||||
}
|
|
||||||
}`)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccountData saves new account data for a given user and a given room.
|
|
||||||
// If the account data is not specific to a room, the room ID should be an empty string
|
|
||||||
// If an account data already exists for a given set (user, room, data type), it will
|
|
||||||
// update the corresponding row with the new content
|
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
|
||||||
func (d *Database) SaveAccountData(
|
|
||||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
|
||||||
// If no account data could be found, returns an empty arrays
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|
||||||
global map[string]json.RawMessage,
|
|
||||||
rooms map[string]map[string]json.RawMessage,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountDataByType returns account data matching a given
|
|
||||||
// localpart, room ID and type.
|
|
||||||
// If no account data could be found, returns nil
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountDataByType(
|
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
|
||||||
) (data json.RawMessage, err error) {
|
|
||||||
return d.accountDatas.selectAccountDataByType(
|
|
||||||
ctx, localpart, roomID, dataType,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
|
||||||
ctx context.Context,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
|
|
||||||
return string(hashBytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
|
||||||
// a third-party identifier which is already associated to a local user.
|
|
||||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
|
||||||
|
|
||||||
// SaveThreePIDAssociation saves the association between a third party identifier
|
|
||||||
// and a local Matrix user (identified by the user's ID's local part).
|
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
|
||||||
) (err error) {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
user, err := d.threepids.selectLocalpartForThreePID(
|
|
||||||
ctx, txn, threepid, medium,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
|
||||||
}
|
}
|
||||||
|
loginTokenTable, err := NewPostgresLoginTokenTable(db)
|
||||||
if len(user) > 0 {
|
|
||||||
return Err3PIDInUse
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association exists involving this third-party identifier, returns nothing.
|
|
||||||
// If there was a problem talking to the database, returns an error.
|
|
||||||
func (d *Database) RemoveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (err error) {
|
|
||||||
return d.threepids.deleteThreePID(ctx, threepid, medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association involves the given third-party idenfitier, returns an empty
|
|
||||||
// string.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (localpart string, err error) {
|
|
||||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
|
||||||
// a given local user.
|
|
||||||
// If no association is known for this user, returns an empty slice.
|
|
||||||
// Returns an error if there was an issue talking to the database.
|
|
||||||
func (d *Database) GetThreePIDsForLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
|
||||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
|
||||||
// in the database.
|
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
|
||||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
|
||||||
// This function assumes the request is authenticated or the account data is used only internally.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchProfiles returns all profiles where the provided localpart or display name
|
|
||||||
// match any part of the profiles in the database.
|
|
||||||
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
|
||||||
) ([]authtypes.Profile, error) {
|
|
||||||
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
|
|
||||||
func (d *Database) CreateOpenIDToken(
|
|
||||||
ctx context.Context,
|
|
||||||
token, localpart string,
|
|
||||||
) (int64, error) {
|
|
||||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
|
||||||
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
|
||||||
})
|
|
||||||
return expiresAtMS, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
|
||||||
func (d *Database) GetOpenIDTokenAttributes(
|
|
||||||
ctx context.Context,
|
|
||||||
token string,
|
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
|
||||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) CreateKeyBackup(
|
|
||||||
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
|
||||||
) (version string, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) UpdateKeyBackupAuthData(
|
|
||||||
ctx context.Context, userID, version string, authData json.RawMessage,
|
|
||||||
) (err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) DeleteKeyBackup(
|
|
||||||
ctx context.Context, userID, version string,
|
|
||||||
) (exists bool, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetKeyBackup(
|
|
||||||
ctx context.Context, userID, version string,
|
|
||||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetBackupKeys(
|
|
||||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
|
||||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if filterSessionID != "" {
|
|
||||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if filterRoomID != "" {
|
|
||||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) CountBackupKeys(
|
|
||||||
ctx context.Context, version, userID string,
|
|
||||||
) (count int64, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
openIDTable, err := NewPostgresOpenIDTable(db, serverName)
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:nakedret
|
|
||||||
func (d *Database) UpsertBackupKeys(
|
|
||||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
|
||||||
) (count int64, etag string, err error) {
|
|
||||||
// wrap the following logic in a txn to ensure we atomically upload keys
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
|
||||||
}
|
}
|
||||||
if deleted {
|
profilesTable, err := NewPostgresProfilesTable(db)
|
||||||
return fmt.Errorf("backup was deleted")
|
|
||||||
}
|
|
||||||
// pull out all keys for this (user_id, version)
|
|
||||||
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
|
||||||
}
|
}
|
||||||
|
threePIDTable, err := NewPostgresThreePIDTable(db)
|
||||||
changed := false
|
|
||||||
// loop over all the new keys (which should be smaller than the set of backed up keys)
|
|
||||||
for _, newKey := range uploads {
|
|
||||||
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
|
|
||||||
existingRoom := existingKeys[newKey.RoomID]
|
|
||||||
if existingRoom != nil {
|
|
||||||
existingSession, ok := existingRoom[newKey.SessionID]
|
|
||||||
if ok {
|
|
||||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
|
||||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
|
||||||
changed = true
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
|
return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
|
||||||
}
|
}
|
||||||
}
|
return &shared.Database{
|
||||||
// if we shouldn't replace the key we do nothing with it
|
AccountDatas: accountDataTable,
|
||||||
continue
|
Accounts: accountsTable,
|
||||||
}
|
Devices: devicesTable,
|
||||||
}
|
KeyBackups: keyBackupTable,
|
||||||
// if we're here, either the room or session are new, either way, we insert
|
KeyBackupVersions: keyBackupVersionTable,
|
||||||
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
|
LoginTokens: loginTokenTable,
|
||||||
changed = true
|
OpenIDTokens: openIDTable,
|
||||||
if err != nil {
|
Profiles: profilesTable,
|
||||||
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
|
ThreePIDs: threePIDTable,
|
||||||
}
|
ServerName: serverName,
|
||||||
}
|
DB: db,
|
||||||
|
Writer: sqlutil.NewDummyWriter(),
|
||||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
LoginTokenLifetime: loginTokenLifetime,
|
||||||
if err != nil {
|
BcryptCost: bcryptCost,
|
||||||
return err
|
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}, nil
|
||||||
if changed {
|
|
||||||
// update the etag
|
|
||||||
var newETag string
|
|
||||||
if oldETag == "" {
|
|
||||||
newETag = "1"
|
|
||||||
} else {
|
|
||||||
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse old etag: %s", err)
|
|
||||||
}
|
|
||||||
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
|
||||||
}
|
|
||||||
etag = newETag
|
|
||||||
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
|
|
||||||
} else {
|
|
||||||
etag = oldETag
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByAccessToken(
|
|
||||||
ctx context.Context, token string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByToken(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByID(
|
|
||||||
ctx context.Context, localpart, deviceID string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
func (d *Database) CreateDevice(
|
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
|
||||||
displayName *string, ipAddr, userAgent string,
|
|
||||||
) (dev *api.Device, returnErr error) {
|
|
||||||
if deviceID != nil {
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
// Revoke existing tokens for this device
|
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// We generate device IDs in a loop in case its already taken.
|
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
|
||||||
var newDeviceID string
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
newDeviceID, returnErr = generateDeviceID()
|
|
||||||
if returnErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
|
||||||
// random bytes.
|
|
||||||
func generateDeviceID() (string, error) {
|
|
||||||
b := make([]byte, deviceIDByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// url-safe no padding
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDevice updates the given device with the display name.
|
|
||||||
// Returns SQL error if there are problems and nil on success.
|
|
||||||
func (d *Database) UpdateDevice(
|
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
|
||||||
// matching with the given device IDs and user ID localpart.
|
|
||||||
// If the devices don't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevices(
|
|
||||||
ctx context.Context, localpart string, devices []string,
|
|
||||||
) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
|
||||||
// database matching the given user ID localpart.
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveAllDevices(
|
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
|
||||||
) (devices []api.Device, err error) {
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
|
||||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
|
||||||
tok, err := generateLoginToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
meta := &api.LoginTokenMetadata{
|
|
||||||
Token: tok,
|
|
||||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateLoginToken() (string, error) {
|
|
||||||
b := make([]byte, loginTokenByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
|
||||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
|
||||||
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
|
||||||
// May return sql.ErrNoRows.
|
|
||||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
|
||||||
return d.loginTokens.selectByToken(ctx, token)
|
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
@ -58,12 +59,13 @@ type threepidStatements struct {
|
|||||||
deleteThreePIDStmt *sql.Stmt
|
deleteThreePIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
_, err = db.Exec(threepidSchema)
|
s := &threepidStatements{}
|
||||||
|
_, err := db.Exec(threepidSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
||||||
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
||||||
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
||||||
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||||
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) insertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) deleteThreePID(
|
func (s *threepidStatements) DeleteThreePID(
|
||||||
ctx context.Context, threepid string, medium string) (err error) {
|
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||||
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
|
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||||
|
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
672
userapi/storage/shared/storage.go
Normal file
672
userapi/storage/shared/storage.go
Normal file
@ -0,0 +1,672 @@
|
|||||||
|
// Copyright 2017 Vector Creations Ltd
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Database represents an account database
|
||||||
|
type Database struct {
|
||||||
|
DB *sql.DB
|
||||||
|
Writer sqlutil.Writer
|
||||||
|
Accounts tables.AccountsTable
|
||||||
|
Profiles tables.ProfileTable
|
||||||
|
AccountDatas tables.AccountDataTable
|
||||||
|
ThreePIDs tables.ThreePIDTable
|
||||||
|
OpenIDTokens tables.OpenIDTable
|
||||||
|
KeyBackups tables.KeyBackupTable
|
||||||
|
KeyBackupVersions tables.KeyBackupVersionTable
|
||||||
|
Devices tables.DevicesTable
|
||||||
|
LoginTokens tables.LoginTokenTable
|
||||||
|
LoginTokenLifetime time.Duration
|
||||||
|
ServerName gomatrixserverlib.ServerName
|
||||||
|
BcryptCost int
|
||||||
|
OpenIDTokenLifetimeMS int64
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// The length of generated device IDs
|
||||||
|
deviceIDByteLength = 6
|
||||||
|
loginTokenByteLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByPassword(
|
||||||
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) (*api.Account, error) {
|
||||||
|
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||||
|
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||||
|
func (d *Database) GetProfileByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (*authtypes.Profile, error) {
|
||||||
|
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetAvatarURL(
|
||||||
|
ctx context.Context, localpart string, avatarURL string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisplayName updates the display name of the profile associated with the given
|
||||||
|
// localpart. Returns an error if something went wrong with the SQL query
|
||||||
|
func (d *Database) SetDisplayName(
|
||||||
|
ctx context.Context, localpart string, displayName string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPassword sets the account password to the given hash.
|
||||||
|
func (d *Database) SetPassword(
|
||||||
|
ctx context.Context, localpart, plaintextPassword string,
|
||||||
|
) error {
|
||||||
|
hash, err := d.hashPassword(plaintextPassword)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
||||||
|
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||||
|
// account already exists, it will return nil, ErrUserExists.
|
||||||
|
func (d *Database) CreateAccount(
|
||||||
|
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||||
|
) (acc *api.Account, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
// For guest accounts, we create a new numeric local part
|
||||||
|
if accountType == api.AccountTypeGuest {
|
||||||
|
var numLocalpart int64
|
||||||
|
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||||
|
plaintextPassword = ""
|
||||||
|
appserviceID = ""
|
||||||
|
}
|
||||||
|
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING! This function assumes that the relevant mutexes have already
|
||||||
|
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||||
|
func (d *Database) createAccount(
|
||||||
|
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||||
|
) (*api.Account, error) {
|
||||||
|
var err error
|
||||||
|
var account *api.Account
|
||||||
|
// Generate a password hash if this is not a password-less user
|
||||||
|
hash := ""
|
||||||
|
if plaintextPassword != "" {
|
||||||
|
hash, err = d.hashPassword(plaintextPassword)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||||
|
return nil, sqlutil.ErrUserExists
|
||||||
|
}
|
||||||
|
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
||||||
|
"global": {
|
||||||
|
"content": [],
|
||||||
|
"override": [],
|
||||||
|
"room": [],
|
||||||
|
"sender": [],
|
||||||
|
"underride": []
|
||||||
|
}
|
||||||
|
}`)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAccountData saves new account data for a given user and a given room.
|
||||||
|
// If the account data is not specific to a room, the room ID should be an empty string
|
||||||
|
// If an account data already exists for a given set (user, room, data type), it will
|
||||||
|
// update the corresponding row with the new content
|
||||||
|
// Returns a SQL error if there was an issue with the insertion/update
|
||||||
|
func (d *Database) SaveAccountData(
|
||||||
|
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountData returns account data related to a given localpart
|
||||||
|
// If no account data could be found, returns an empty arrays
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||||
|
global map[string]json.RawMessage,
|
||||||
|
rooms map[string]map[string]json.RawMessage,
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountDataByType returns account data matching a given
|
||||||
|
// localpart, room ID and type.
|
||||||
|
// If no account data could be found, returns nil
|
||||||
|
// Returns an error if there was an issue with the retrieval
|
||||||
|
func (d *Database) GetAccountDataByType(
|
||||||
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
|
) (data json.RawMessage, err error) {
|
||||||
|
return d.AccountDatas.SelectAccountDataByType(
|
||||||
|
ctx, localpart, roomID, dataType,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||||
|
func (d *Database) GetNewNumericLocalpart(
|
||||||
|
ctx context.Context,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||||
|
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost)
|
||||||
|
return string(hashBytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Err3PIDInUse is the error returned when trying to save an association involving
|
||||||
|
// a third-party identifier which is already associated to a local user.
|
||||||
|
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
||||||
|
|
||||||
|
// SaveThreePIDAssociation saves the association between a third party identifier
|
||||||
|
// and a local Matrix user (identified by the user's ID's local part).
|
||||||
|
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) SaveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid, localpart, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||||
|
ctx, txn, threepid, medium,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user) > 0 {
|
||||||
|
return Err3PIDInUse
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveThreePIDAssociation removes the association involving a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association exists involving this third-party identifier, returns nothing.
|
||||||
|
// If there was a problem talking to the database, returns an error.
|
||||||
|
func (d *Database) RemoveThreePIDAssociation(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (err error) {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
||||||
|
// identifier.
|
||||||
|
// If no association involves the given third-party idenfitier, returns an empty
|
||||||
|
// string.
|
||||||
|
// Returns an error if there was a problem talking to the database.
|
||||||
|
func (d *Database) GetLocalpartForThreePID(
|
||||||
|
ctx context.Context, threepid string, medium string,
|
||||||
|
) (localpart string, err error) {
|
||||||
|
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
||||||
|
// a given local user.
|
||||||
|
// If no association is known for this user, returns an empty slice.
|
||||||
|
// Returns an error if there was an issue talking to the database.
|
||||||
|
func (d *Database) GetThreePIDsForLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
|
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckAccountAvailability checks if the username/localpart is already present
|
||||||
|
// in the database.
|
||||||
|
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||||
|
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||||
|
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||||
|
// This function assumes the request is authenticated or the account data is used only internally.
|
||||||
|
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||||
|
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||||
|
) (*api.Account, error) {
|
||||||
|
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchProfiles returns all profiles where the provided localpart or display name
|
||||||
|
// match any part of the profiles in the database.
|
||||||
|
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
||||||
|
) ([]authtypes.Profile, error) {
|
||||||
|
return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||||
|
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||||
|
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||||
|
func (d *Database) CreateOpenIDToken(
|
||||||
|
ctx context.Context,
|
||||||
|
token, localpart string,
|
||||||
|
) (int64, error) {
|
||||||
|
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||||
|
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||||
|
})
|
||||||
|
return expiresAtMS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
||||||
|
func (d *Database) GetOpenIDTokenAttributes(
|
||||||
|
ctx context.Context,
|
||||||
|
token string,
|
||||||
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CreateKeyBackup(
|
||||||
|
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
||||||
|
) (version string, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) UpdateKeyBackupAuthData(
|
||||||
|
ctx context.Context, userID, version string, authData json.RawMessage,
|
||||||
|
) (err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) DeleteKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (exists bool, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetKeyBackup(
|
||||||
|
ctx context.Context, userID, version string,
|
||||||
|
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetBackupKeys(
|
||||||
|
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
||||||
|
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
if filterSessionID != "" {
|
||||||
|
result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if filterRoomID != "" {
|
||||||
|
result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) CountBackupKeys(
|
||||||
|
ctx context.Context, version, userID string,
|
||||||
|
) (count int64, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nakedret
|
||||||
|
func (d *Database) UpsertBackupKeys(
|
||||||
|
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
||||||
|
) (count int64, etag string, err error) {
|
||||||
|
// wrap the following logic in a txn to ensure we atomically upload keys
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
_, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if deleted {
|
||||||
|
return fmt.Errorf("backup was deleted")
|
||||||
|
}
|
||||||
|
// pull out all keys for this (user_id, version)
|
||||||
|
existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
// loop over all the new keys (which should be smaller than the set of backed up keys)
|
||||||
|
for _, newKey := range uploads {
|
||||||
|
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
|
||||||
|
existingRoom := existingKeys[newKey.RoomID]
|
||||||
|
if existingRoom != nil {
|
||||||
|
existingSession, ok := existingRoom[newKey.SessionID]
|
||||||
|
if ok {
|
||||||
|
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
||||||
|
err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey)
|
||||||
|
changed = true
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we shouldn't replace the key we do nothing with it
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if we're here, either the room or session are new, either way, we insert
|
||||||
|
err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey)
|
||||||
|
changed = true
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
// update the etag
|
||||||
|
var newETag string
|
||||||
|
if oldETag == "" {
|
||||||
|
newETag = "1"
|
||||||
|
} else {
|
||||||
|
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse old etag: %s", err)
|
||||||
|
}
|
||||||
|
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
||||||
|
}
|
||||||
|
etag = newETag
|
||||||
|
return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag)
|
||||||
|
} else {
|
||||||
|
etag = oldETag
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByAccessToken returns the device matching the given access token.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByAccessToken(
|
||||||
|
ctx context.Context, token string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.Devices.SelectDeviceByToken(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceByID returns the device matching the given ID.
|
||||||
|
// Returns sql.ErrNoRows if no matching device was found.
|
||||||
|
func (d *Database) GetDeviceByID(
|
||||||
|
ctx context.Context, localpart, deviceID string,
|
||||||
|
) (*api.Device, error) {
|
||||||
|
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||||
|
func (d *Database) GetDevicesByLocalpart(
|
||||||
|
ctx context.Context, localpart string,
|
||||||
|
) ([]api.Device, error) {
|
||||||
|
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
|
return d.Devices.SelectDevicesByID(ctx, deviceIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateDevice makes a new device associated with the given user ID localpart.
|
||||||
|
// If there is already a device with the same device ID for this user, that access token will be revoked
|
||||||
|
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
||||||
|
// an error will be returned.
|
||||||
|
// If no device ID is given one is generated.
|
||||||
|
// Returns the device on success.
|
||||||
|
func (d *Database) CreateDevice(
|
||||||
|
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||||
|
displayName *string, ipAddr, userAgent string,
|
||||||
|
) (dev *api.Device, returnErr error) {
|
||||||
|
if deviceID != nil {
|
||||||
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
// Revoke existing tokens for this device
|
||||||
|
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// We generate device IDs in a loop in case its already taken.
|
||||||
|
// We cap this at going round 5 times to ensure we don't spin forever
|
||||||
|
var newDeviceID string
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
newDeviceID, returnErr = generateDeviceID()
|
||||||
|
if returnErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
var err error
|
||||||
|
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if returnErr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
||||||
|
// random bytes.
|
||||||
|
func generateDeviceID() (string, error) {
|
||||||
|
b := make([]byte, deviceIDByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// url-safe no padding
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDevice updates the given device with the display name.
|
||||||
|
// Returns SQL error if there are problems and nil on success.
|
||||||
|
func (d *Database) UpdateDevice(
|
||||||
|
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevice revokes a device by deleting the entry in the database
|
||||||
|
// matching with the given device ID and user ID localpart.
|
||||||
|
// If the device doesn't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevice(
|
||||||
|
ctx context.Context, deviceID, localpart string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
||||||
|
// matching with the given device IDs and user ID localpart.
|
||||||
|
// If the devices don't exist, it will not return an error
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveDevices(
|
||||||
|
ctx context.Context, localpart string, devices []string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllDevices revokes devices by deleting the entry in the
|
||||||
|
// database matching the given user ID localpart.
|
||||||
|
// If something went wrong during the deletion, it will return the SQL error.
|
||||||
|
func (d *Database) RemoveAllDevices(
|
||||||
|
ctx context.Context, localpart, exceptDeviceID string,
|
||||||
|
) (devices []api.Device, err error) {
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
||||||
|
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||||
|
// determined by the loginTokenLifetime given to the Database constructor.
|
||||||
|
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
||||||
|
tok, err := generateLoginToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
meta := &api.LoginTokenMetadata{
|
||||||
|
Token: tok,
|
||||||
|
Expiration: time.Now().Add(d.LoginTokenLifetime),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return meta, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateLoginToken() (string, error) {
|
||||||
|
b := make([]byte, loginTokenByteLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||||
|
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.LoginTokens.DeleteLoginToken(ctx, txn, token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginTokenDataByToken returns the data associated with the given token.
|
||||||
|
// May return sql.ErrNoRows.
|
||||||
|
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||||
|
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||||
|
}
|
@ -20,6 +20,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const accountDataSchema = `
|
const accountDataSchema = `
|
||||||
@ -56,27 +57,29 @@ type accountDataStatements struct {
|
|||||||
selectAccountDataByTypeStmt *sql.Stmt
|
selectAccountDataByTypeStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
|
func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
||||||
s.db = db
|
s := &accountDataStatements{
|
||||||
_, err = db.Exec(accountDataSchema)
|
db: db,
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
_, err := db.Exec(accountDataSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
{&s.insertAccountDataStmt, insertAccountDataSQL},
|
||||||
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
{&s.selectAccountDataStmt, selectAccountDataSQL},
|
||||||
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) insertAccountData(
|
func (s *accountDataStatements) InsertAccountData(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountData(
|
func (s *accountDataStatements) SelectAccountData(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (
|
) (
|
||||||
/* global */ map[string]json.RawMessage,
|
/* global */ map[string]json.RawMessage,
|
||||||
@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
|
|||||||
return global, rooms, nil
|
return global, rooms, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountDataStatements) selectAccountDataByType(
|
func (s *accountDataStatements) SelectAccountDataByType(
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
ctx context.Context, localpart, roomID, dataType string,
|
||||||
) (data json.RawMessage, err error) {
|
) (data json.RawMessage, err error) {
|
||||||
var bytes []byte
|
var bytes []byte
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -77,15 +78,16 @@ type accountsStatements struct {
|
|||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) execSchema(db *sql.DB) error {
|
func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
|
||||||
_, err := db.Exec(accountsSchema)
|
s := &accountsStatements{
|
||||||
return err
|
db: db,
|
||||||
|
serverName: serverName,
|
||||||
}
|
}
|
||||||
|
_, err := db.Exec(accountsSchema)
|
||||||
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
if err != nil {
|
||||||
s.db = db
|
return nil, err
|
||||||
s.serverName = server
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertAccountStmt, insertAccountSQL},
|
{&s.insertAccountStmt, insertAccountSQL},
|
||||||
{&s.updatePasswordStmt, updatePasswordSQL},
|
{&s.updatePasswordStmt, updatePasswordSQL},
|
||||||
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
{&s.deactivateAccountStmt, deactivateAccountSQL},
|
||||||
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
|
|||||||
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
|
||||||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||||
// on success.
|
// on success.
|
||||||
func (s *accountsStatements) insertAccount(
|
func (s *accountsStatements) InsertAccount(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) updatePassword(
|
func (s *accountsStatements) UpdatePassword(
|
||||||
ctx context.Context, localpart, passwordHash string,
|
ctx context.Context, localpart, passwordHash string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) deactivateAccount(
|
func (s *accountsStatements) DeactivateAccount(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectPasswordHash(
|
func (s *accountsStatements) SelectPasswordHash(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (hash string, err error) {
|
) (hash string, err error) {
|
||||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectAccountByLocalpart(
|
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*api.Account, error) {
|
) (*api.Account, error) {
|
||||||
var appserviceIDPtr sql.NullString
|
var appserviceIDPtr sql.NullString
|
||||||
@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
|
|||||||
return &acc, nil
|
return &acc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *accountsStatements) selectNewNumericLocalpart(
|
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
) (id int64, err error) {
|
) (id int64, err error) {
|
||||||
stmt := s.selectNewNumericLocalpartStmt
|
stmt := s.selectNewNumericLocalpartStmt
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
|
|||||||
|
|
||||||
type devicesStatements struct {
|
type devicesStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
writer sqlutil.Writer
|
|
||||||
insertDeviceStmt *sql.Stmt
|
insertDeviceStmt *sql.Stmt
|
||||||
selectDevicesCountStmt *sql.Stmt
|
selectDevicesCountStmt *sql.Stmt
|
||||||
selectDeviceByTokenStmt *sql.Stmt
|
selectDeviceByTokenStmt *sql.Stmt
|
||||||
@ -98,55 +98,33 @@ type devicesStatements struct {
|
|||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||||
|
s := &devicesStatements{
|
||||||
|
db: db,
|
||||||
|
serverName: serverName,
|
||||||
|
}
|
||||||
_, err := db.Exec(devicesSchema)
|
_, err := db.Exec(devicesSchema)
|
||||||
return err
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||||
s.db = db
|
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
|
||||||
s.writer = writer
|
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||||
if err = s.execSchema(db); err != nil {
|
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||||
return
|
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||||
}
|
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||||
return
|
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||||
}
|
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||||
return
|
}.Prepare(db)
|
||||||
}
|
|
||||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.serverName = server
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||||
// Returns an error if the user already has a device with the given device ID.
|
// Returns an error if the user already has a device with the given device ID.
|
||||||
// Returns the device on success.
|
// Returns the device on success.
|
||||||
func (s *devicesStatements) insertDevice(
|
func (s *devicesStatements) InsertDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||||
displayName *string, ipAddr, userAgent string,
|
displayName *string, ipAddr, userAgent string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice(
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevice(
|
func (s *devicesStatements) DeleteDevice(
|
||||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||||
@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevices(
|
func (s *devicesStatements) DeleteDevices(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||||
) error {
|
) error {
|
||||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||||
@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||||
@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceName(
|
func (s *devicesStatements) UpdateDeviceName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||||
@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDeviceByToken(
|
func (s *devicesStatements) SelectDeviceByToken(
|
||||||
ctx context.Context, accessToken string,
|
ctx context.Context, accessToken string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||||||
|
|
||||||
// selectDeviceByID retrieves a device from the database with the given user
|
// selectDeviceByID retrieves a device from the database with the given user
|
||||||
// localpart and deviceID
|
// localpart and deviceID
|
||||||
func (s *devicesStatements) selectDeviceByID(
|
func (s *devicesStatements) SelectDeviceByID(
|
||||||
ctx context.Context, localpart, deviceID string,
|
ctx context.Context, localpart, deviceID string,
|
||||||
) (*api.Device, error) {
|
) (*api.Device, error) {
|
||||||
var dev api.Device
|
var dev api.Device
|
||||||
@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||||||
return &dev, err
|
return &dev, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||||
) ([]api.Device, error) {
|
) ([]api.Device, error) {
|
||||||
devices := []api.Device{}
|
devices := []api.Device{}
|
||||||
@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||||||
return devices, nil
|
return devices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||||
for i := range deviceIDs {
|
for i := range deviceIDs {
|
||||||
@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||||||
return devices, rows.Err()
|
return devices, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const keyBackupTableSchema = `
|
const keyBackupTableSchema = `
|
||||||
@ -72,12 +73,13 @@ type keyBackupStatements struct {
|
|||||||
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
||||||
_, err = db.Exec(keyBackupTableSchema)
|
s := &keyBackupStatements{}
|
||||||
|
_, err := db.Exec(keyBackupTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||||
{&s.countKeysStmt, countKeysSQL},
|
{&s.countKeysStmt, countKeysSQL},
|
||||||
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s keyBackupStatements) countKeys(
|
func (s keyBackupStatements) CountKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (count int64, err error) {
|
) (count int64, err error) {
|
||||||
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) insertBackupKey(
|
func (s *keyBackupStatements) InsertBackupKey(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
|
||||||
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) updateBackupKey(
|
func (s *keyBackupStatements) UpdateBackupKey(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
|
||||||
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeys(
|
func (s *keyBackupStatements) SelectKeys(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
|
||||||
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
|
|||||||
return unpackKeys(ctx, rows)
|
return unpackKeys(ctx, rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeysByRoomID(
|
func (s *keyBackupStatements) SelectKeysByRoomID(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
|
||||||
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
|
|||||||
return unpackKeys(ctx, rows)
|
return unpackKeys(ctx, rows)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
|
func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
|
||||||
) (map[string]map[string]api.KeyBackupSession, error) {
|
) (map[string]map[string]api.KeyBackupSession, error) {
|
||||||
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const keyBackupVersionTableSchema = `
|
const keyBackupVersionTableSchema = `
|
||||||
@ -67,12 +68,13 @@ type keyBackupVersionStatements struct {
|
|||||||
updateKeyBackupETagStmt *sql.Stmt
|
updateKeyBackupETagStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
|
||||||
_, err = db.Exec(keyBackupVersionTableSchema)
|
s := &keyBackupVersionStatements{}
|
||||||
|
_, err := db.Exec(keyBackupVersionTableSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
{&s.insertKeyBackupStmt, insertKeyBackupSQL},
|
||||||
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
|
||||||
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
|
||||||
@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) insertKeyBackup(
|
func (s *keyBackupVersionStatements) InsertKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
|
||||||
) (version string, err error) {
|
) (version string, err error) {
|
||||||
var versionInt int64
|
var versionInt int64
|
||||||
@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
|
|||||||
return strconv.FormatInt(versionInt, 10), err
|
return strconv.FormatInt(versionInt, 10), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
|
||||||
) error {
|
) error {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
ctx context.Context, txn *sql.Tx, userID, version, etag string,
|
||||||
) error {
|
) error {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) deleteKeyBackup(
|
func (s *keyBackupVersionStatements) DeleteKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
versionInt, err := strconv.ParseInt(version, 10, 64)
|
versionInt, err := strconv.ParseInt(version, 10, 64)
|
||||||
@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
|
|||||||
return ra == 1, nil
|
return ra == 1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *keyBackupVersionStatements) selectKeyBackup(
|
func (s *keyBackupVersionStatements) SelectKeyBackup(
|
||||||
ctx context.Context, txn *sql.Tx, userID, version string,
|
ctx context.Context, txn *sql.Tx, userID, version string,
|
||||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
||||||
var versionInt int64
|
var versionInt int64
|
||||||
|
@ -21,18 +21,17 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type loginTokenStatements struct {
|
type loginTokenStatements struct {
|
||||||
insertStmt *sql.Stmt
|
insertStmt *sql.Stmt
|
||||||
deleteStmt *sql.Stmt
|
deleteStmt *sql.Stmt
|
||||||
selectByTokenStmt *sql.Stmt
|
selectStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
// execSchema ensures tables and indices exist.
|
const loginTokenSchema = `
|
||||||
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
|
|
||||||
_, err := db.Exec(`
|
|
||||||
CREATE TABLE IF NOT EXISTS login_tokens (
|
CREATE TABLE IF NOT EXISTS login_tokens (
|
||||||
-- The random value of the token issued to a user
|
-- The random value of the token issued to a user
|
||||||
token TEXT NOT NULL PRIMARY KEY,
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
@ -45,24 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens (
|
|||||||
|
|
||||||
-- This index allows efficient garbage collection of expired tokens.
|
-- This index allows efficient garbage collection of expired tokens.
|
||||||
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
|
||||||
`)
|
`
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare runs statement preparation.
|
const insertLoginTokenSQL = "" +
|
||||||
func (s *loginTokenStatements) prepare(db *sql.DB) error {
|
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
|
||||||
if err := s.execSchema(db); err != nil {
|
|
||||||
return err
|
const deleteLoginTokenSQL = "" +
|
||||||
|
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
|
||||||
|
|
||||||
|
const selectLoginTokenSQL = "" +
|
||||||
|
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
|
||||||
|
|
||||||
|
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
|
||||||
|
s := &loginTokenStatements{}
|
||||||
|
_, err := db.Exec(loginTokenSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
|
{&s.insertStmt, insertLoginTokenSQL},
|
||||||
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
|
{&s.deleteStmt, deleteLoginTokenSQL},
|
||||||
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"},
|
{&s.selectStmt, selectLoginTokenSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert adds an already generated token to the database.
|
// insert adds an already generated token to the database.
|
||||||
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertStmt)
|
||||||
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
|
||||||
return err
|
return err
|
||||||
@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
|
|||||||
//
|
//
|
||||||
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
|
||||||
// The login_tokens_expiration_idx index should make that efficient.
|
// The login_tokens_expiration_idx index should make that efficient.
|
||||||
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error {
|
func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteStmt)
|
||||||
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
|
||||||
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
||||||
var data api.LoginTokenData
|
var data api.LoginTokenData
|
||||||
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertTokenSQL = "" +
|
const insertOpenIDTokenSQL = "" +
|
||||||
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||||
|
|
||||||
const selectTokenSQL = "" +
|
const selectOpenIDTokenSQL = "" +
|
||||||
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
|
||||||
|
|
||||||
type tokenStatements struct {
|
type openIDTokenStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertTokenStmt *sql.Stmt
|
insertTokenStmt *sql.Stmt
|
||||||
selectTokenStmt *sql.Stmt
|
selectTokenStmt *sql.Stmt
|
||||||
serverName gomatrixserverlib.ServerName
|
serverName gomatrixserverlib.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
|
||||||
s.db = db
|
s := &openIDTokenStatements{
|
||||||
_, err = db.Exec(openIDTokenSchema)
|
db: db,
|
||||||
if err != nil {
|
serverName: serverName,
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
s.serverName = server
|
_, err := db.Exec(openIDTokenSchema)
|
||||||
return sqlutil.StatementList{
|
if err != nil {
|
||||||
{&s.insertTokenStmt, insertTokenSQL},
|
return nil, err
|
||||||
{&s.selectTokenStmt, selectTokenSQL},
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
|
{&s.insertTokenStmt, insertOpenIDTokenSQL},
|
||||||
|
{&s.selectTokenStmt, selectOpenIDTokenSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertToken inserts a new OpenID Connect token to the DB.
|
// insertToken inserts a new OpenID Connect token to the DB.
|
||||||
// Returns new token, otherwise returns error if the token already exists.
|
// Returns new token, otherwise returns error if the token already exists.
|
||||||
func (s *tokenStatements) insertToken(
|
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
txn *sql.Tx,
|
txn *sql.Tx,
|
||||||
token, localpart string,
|
token, localpart string,
|
||||||
@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken(
|
|||||||
|
|
||||||
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
|
||||||
// Returns the existing token's attributes, or err if no token is found
|
// Returns the existing token's attributes, or err if no token is found
|
||||||
func (s *tokenStatements) selectOpenIDTokenAtrributes(
|
func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
token string,
|
token string,
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
) (*api.OpenIDTokenAttributes, error) {
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
)
|
)
|
||||||
|
|
||||||
const profilesSchema = `
|
const profilesSchema = `
|
||||||
@ -60,13 +61,15 @@ type profilesStatements struct {
|
|||||||
selectProfilesBySearchStmt *sql.Stmt
|
selectProfilesBySearchStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
|
||||||
s.db = db
|
s := &profilesStatements{
|
||||||
_, err = db.Exec(profilesSchema)
|
db: db,
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
_, err := db.Exec(profilesSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
{&s.insertProfileStmt, insertProfileSQL},
|
{&s.insertProfileStmt, insertProfileSQL},
|
||||||
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
|
||||||
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
{&s.setAvatarURLStmt, setAvatarURLSQL},
|
||||||
@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) insertProfile(
|
func (s *profilesStatements) InsertProfile(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string,
|
ctx context.Context, txn *sql.Tx, localpart string,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfileByLocalpart(
|
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (*authtypes.Profile, error) {
|
) (*authtypes.Profile, error) {
|
||||||
var profile authtypes.Profile
|
var profile authtypes.Profile
|
||||||
@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart(
|
|||||||
return &profile, nil
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setAvatarURL(
|
func (s *profilesStatements) SetAvatarURL(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||||
@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) setDisplayName(
|
func (s *profilesStatements) SetDisplayName(
|
||||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||||
@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *profilesStatements) selectProfilesBySearch(
|
func (s *profilesStatements) SelectProfilesBySearch(
|
||||||
ctx context.Context, searchString string, limit int,
|
ctx context.Context, searchString string, limit int,
|
||||||
) ([]authtypes.Profile, error) {
|
) ([]authtypes.Profile, error) {
|
||||||
var profiles []authtypes.Profile
|
var profiles []authtypes.Profile
|
||||||
|
@ -15,80 +15,34 @@
|
|||||||
package sqlite3
|
package sqlite3
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/userapi/api"
|
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||||
)
|
|
||||||
|
|
||||||
// Database represents an account database
|
// Import the postgres database driver.
|
||||||
type Database struct {
|
_ "github.com/lib/pq"
|
||||||
db *sql.DB
|
|
||||||
writer sqlutil.Writer
|
|
||||||
|
|
||||||
sqlutil.PartitionOffsetStatements
|
|
||||||
accounts accountsStatements
|
|
||||||
profiles profilesStatements
|
|
||||||
accountDatas accountDataStatements
|
|
||||||
threepids threepidStatements
|
|
||||||
openIDTokens tokenStatements
|
|
||||||
keyBackupVersions keyBackupVersionStatements
|
|
||||||
keyBackups keyBackupStatements
|
|
||||||
devices devicesStatements
|
|
||||||
loginTokens loginTokenStatements
|
|
||||||
loginTokenLifetime time.Duration
|
|
||||||
serverName gomatrixserverlib.ServerName
|
|
||||||
bcryptCost int
|
|
||||||
openIDTokenLifetimeMS int64
|
|
||||||
|
|
||||||
accountsMu sync.Mutex
|
|
||||||
profilesMu sync.Mutex
|
|
||||||
accountDatasMu sync.Mutex
|
|
||||||
threepidsMu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// The length of generated device IDs
|
|
||||||
deviceIDByteLength = 6
|
|
||||||
loginTokenByteLength = 32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewDatabase creates a new accounts and profiles database
|
// NewDatabase creates a new accounts and profiles database
|
||||||
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
|
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
|
||||||
db, err := sqlutil.Open(dbProperties)
|
db, err := sqlutil.Open(dbProperties)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
d := &Database{
|
|
||||||
serverName: serverName,
|
|
||||||
db: db,
|
|
||||||
writer: sqlutil.NewExclusiveWriter(),
|
|
||||||
loginTokenLifetime: loginTokenLifetime,
|
|
||||||
bcryptCost: bcryptCost,
|
|
||||||
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create tables before executing migrations so we don't fail if the table is missing,
|
m := sqlutil.NewMigrations()
|
||||||
// and THEN prepare statements so we don't fail due to referencing new columns
|
if _, err = db.Exec(accountsSchema); err != nil {
|
||||||
if err = d.accounts.execSchema(db); err != nil {
|
// do this so that the migration can and we don't fail on
|
||||||
|
// preparing statements for columns that don't exist yet
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
m := sqlutil.NewMigrations()
|
|
||||||
deltas.LoadIsActive(m)
|
deltas.LoadIsActive(m)
|
||||||
//deltas.LoadLastSeenTSIP(m)
|
//deltas.LoadLastSeenTSIP(m)
|
||||||
deltas.LoadAddAccountType(m)
|
deltas.LoadAddAccountType(m)
|
||||||
@ -96,666 +50,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
partitions := sqlutil.PartitionOffsetStatements{}
|
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||||
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accounts.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.profiles.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accountDatas.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.threepids.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.openIDTokens.prepare(db, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.keyBackupVersions.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.keyBackups.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.devices.prepare(db, d.writer, serverName); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.loginTokens.prepare(db); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByPassword(
|
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
|
||||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
|
||||||
func (d *Database) GetProfileByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (*authtypes.Profile, error) {
|
|
||||||
return d.profiles.selectProfileByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetAvatarURL(
|
|
||||||
ctx context.Context, localpart string, avatarURL string,
|
|
||||||
) error {
|
|
||||||
d.profilesMu.Lock()
|
|
||||||
defer d.profilesMu.Unlock()
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDisplayName updates the display name of the profile associated with the given
|
|
||||||
// localpart. Returns an error if something went wrong with the SQL query
|
|
||||||
func (d *Database) SetDisplayName(
|
|
||||||
ctx context.Context, localpart string, displayName string,
|
|
||||||
) error {
|
|
||||||
d.profilesMu.Lock()
|
|
||||||
defer d.profilesMu.Unlock()
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPassword sets the account password to the given hash.
|
|
||||||
func (d *Database) SetPassword(
|
|
||||||
ctx context.Context, localpart, plaintextPassword string,
|
|
||||||
) error {
|
|
||||||
hash, err := d.hashPassword(plaintextPassword)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||||
}
|
}
|
||||||
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
|
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||||
return d.accounts.updatePassword(ctx, localpart, hash)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
|
|
||||||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
|
||||||
// account already exists, it will return nil, ErrUserExists.
|
|
||||||
func (d *Database) CreateAccount(
|
|
||||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
|
||||||
) (acc *api.Account, err error) {
|
|
||||||
// Create one account at a time else we can get 'database is locked'.
|
|
||||||
d.profilesMu.Lock()
|
|
||||||
d.accountDatasMu.Lock()
|
|
||||||
d.accountsMu.Lock()
|
|
||||||
defer d.profilesMu.Unlock()
|
|
||||||
defer d.accountDatasMu.Unlock()
|
|
||||||
defer d.accountsMu.Unlock()
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
// For guest accounts, we create a new numeric local part
|
|
||||||
if accountType == api.AccountTypeGuest {
|
|
||||||
var numLocalpart int64
|
|
||||||
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||||
}
|
}
|
||||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
keyBackupTable, err := NewSQLiteKeyBackupTable(db)
|
||||||
plaintextPassword = ""
|
|
||||||
appserviceID = ""
|
|
||||||
}
|
|
||||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// WARNING! This function assumes that the relevant mutexes have already
|
|
||||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
|
||||||
func (d *Database) createAccount(
|
|
||||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
var err error
|
|
||||||
var account *api.Account
|
|
||||||
// Generate a password hash if this is not a password-less user
|
|
||||||
hash := ""
|
|
||||||
if plaintextPassword != "" {
|
|
||||||
hash, err = d.hashPassword(plaintextPassword)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err)
|
||||||
}
|
}
|
||||||
}
|
keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db)
|
||||||
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
|
||||||
return nil, sqlutil.ErrUserExists
|
|
||||||
}
|
|
||||||
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
|
|
||||||
"global": {
|
|
||||||
"content": [],
|
|
||||||
"override": [],
|
|
||||||
"room": [],
|
|
||||||
"sender": [],
|
|
||||||
"underride": []
|
|
||||||
}
|
|
||||||
}`)); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return account, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveAccountData saves new account data for a given user and a given room.
|
|
||||||
// If the account data is not specific to a room, the room ID should be an empty string
|
|
||||||
// If an account data already exists for a given set (user, room, data type), it will
|
|
||||||
// update the corresponding row with the new content
|
|
||||||
// Returns a SQL error if there was an issue with the insertion/update
|
|
||||||
func (d *Database) SaveAccountData(
|
|
||||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
|
||||||
) error {
|
|
||||||
d.accountDatasMu.Lock()
|
|
||||||
defer d.accountDatasMu.Unlock()
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountData returns account data related to a given localpart
|
|
||||||
// If no account data could be found, returns an empty arrays
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|
||||||
global map[string]json.RawMessage,
|
|
||||||
rooms map[string]map[string]json.RawMessage,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
return d.accountDatas.selectAccountData(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountDataByType returns account data matching a given
|
|
||||||
// localpart, room ID and type.
|
|
||||||
// If no account data could be found, returns nil
|
|
||||||
// Returns an error if there was an issue with the retrieval
|
|
||||||
func (d *Database) GetAccountDataByType(
|
|
||||||
ctx context.Context, localpart, roomID, dataType string,
|
|
||||||
) (data json.RawMessage, err error) {
|
|
||||||
return d.accountDatas.selectAccountDataByType(
|
|
||||||
ctx, localpart, roomID, dataType,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
|
||||||
func (d *Database) GetNewNumericLocalpart(
|
|
||||||
ctx context.Context,
|
|
||||||
) (int64, error) {
|
|
||||||
return d.accounts.selectNewNumericLocalpart(ctx, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
|
||||||
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
|
|
||||||
return string(hashBytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Err3PIDInUse is the error returned when trying to save an association involving
|
|
||||||
// a third-party identifier which is already associated to a local user.
|
|
||||||
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
|
||||||
|
|
||||||
// SaveThreePIDAssociation saves the association between a third party identifier
|
|
||||||
// and a local Matrix user (identified by the user's ID's local part).
|
|
||||||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) SaveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid, localpart, medium string,
|
|
||||||
) (err error) {
|
|
||||||
d.threepidsMu.Lock()
|
|
||||||
defer d.threepidsMu.Unlock()
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
user, err := d.threepids.selectLocalpartForThreePID(
|
|
||||||
ctx, txn, threepid, medium,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err)
|
||||||
}
|
}
|
||||||
|
loginTokenTable, err := NewSQLiteLoginTokenTable(db)
|
||||||
if len(user) > 0 {
|
|
||||||
return Err3PIDInUse
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveThreePIDAssociation removes the association involving a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association exists involving this third-party identifier, returns nothing.
|
|
||||||
// If there was a problem talking to the database, returns an error.
|
|
||||||
func (d *Database) RemoveThreePIDAssociation(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (err error) {
|
|
||||||
d.threepidsMu.Lock()
|
|
||||||
defer d.threepidsMu.Unlock()
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
|
|
||||||
// identifier.
|
|
||||||
// If no association involves the given third-party idenfitier, returns an empty
|
|
||||||
// string.
|
|
||||||
// Returns an error if there was a problem talking to the database.
|
|
||||||
func (d *Database) GetLocalpartForThreePID(
|
|
||||||
ctx context.Context, threepid string, medium string,
|
|
||||||
) (localpart string, err error) {
|
|
||||||
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
|
|
||||||
// a given local user.
|
|
||||||
// If no association is known for this user, returns an empty slice.
|
|
||||||
// Returns an error if there was an issue talking to the database.
|
|
||||||
func (d *Database) GetThreePIDsForLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
|
||||||
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present
|
|
||||||
// in the database.
|
|
||||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
|
||||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
|
||||||
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
|
||||||
// This function assumes the request is authenticated or the account data is used only internally.
|
|
||||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
|
||||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
|
||||||
) (*api.Account, error) {
|
|
||||||
return d.accounts.selectAccountByLocalpart(ctx, localpart)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchProfiles returns all profiles where the provided localpart or display name
|
|
||||||
// match any part of the profiles in the database.
|
|
||||||
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
|
|
||||||
) ([]authtypes.Profile, error) {
|
|
||||||
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
|
||||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
|
||||||
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.accounts.deactivateAccount(ctx, localpart)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
|
||||||
func (d *Database) CreateOpenIDToken(
|
|
||||||
ctx context.Context,
|
|
||||||
token, localpart string,
|
|
||||||
) (int64, error) {
|
|
||||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
|
|
||||||
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
|
|
||||||
})
|
|
||||||
return expiresAtMS, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
|
|
||||||
func (d *Database) GetOpenIDTokenAttributes(
|
|
||||||
ctx context.Context,
|
|
||||||
token string,
|
|
||||||
) (*api.OpenIDTokenAttributes, error) {
|
|
||||||
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) CreateKeyBackup(
|
|
||||||
ctx context.Context, userID, algorithm string, authData json.RawMessage,
|
|
||||||
) (version string, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) UpdateKeyBackupAuthData(
|
|
||||||
ctx context.Context, userID, version string, authData json.RawMessage,
|
|
||||||
) (err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) DeleteKeyBackup(
|
|
||||||
ctx context.Context, userID, version string,
|
|
||||||
) (exists bool, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetKeyBackup(
|
|
||||||
ctx context.Context, userID, version string,
|
|
||||||
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetBackupKeys(
|
|
||||||
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
|
|
||||||
) (result map[string]map[string]api.KeyBackupSession, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
if filterSessionID != "" {
|
|
||||||
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if filterRoomID != "" {
|
|
||||||
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) CountBackupKeys(
|
|
||||||
ctx context.Context, version, userID string,
|
|
||||||
) (count int64, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
openIDTable, err := NewSQLiteOpenIDTable(db, serverName)
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:nakedret
|
|
||||||
func (d *Database) UpsertBackupKeys(
|
|
||||||
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
|
|
||||||
) (count int64, etag string, err error) {
|
|
||||||
// wrap the following logic in a txn to ensure we atomically upload keys
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
|
||||||
}
|
}
|
||||||
if deleted {
|
profilesTable, err := NewSQLiteProfilesTable(db)
|
||||||
return fmt.Errorf("backup was deleted")
|
|
||||||
}
|
|
||||||
// pull out all keys for this (user_id, version)
|
|
||||||
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
|
||||||
}
|
}
|
||||||
|
threePIDTable, err := NewSQLiteThreePIDTable(db)
|
||||||
changed := false
|
|
||||||
// loop over all the new keys (which should be smaller than the set of backed up keys)
|
|
||||||
for _, newKey := range uploads {
|
|
||||||
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
|
|
||||||
existingRoom := existingKeys[newKey.RoomID]
|
|
||||||
if existingRoom != nil {
|
|
||||||
existingSession, ok := existingRoom[newKey.SessionID]
|
|
||||||
if ok {
|
|
||||||
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
|
|
||||||
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
|
|
||||||
changed = true
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err)
|
return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
|
||||||
}
|
}
|
||||||
}
|
return &shared.Database{
|
||||||
// if we shouldn't replace the key we do nothing with it
|
AccountDatas: accountDataTable,
|
||||||
continue
|
Accounts: accountsTable,
|
||||||
}
|
Devices: devicesTable,
|
||||||
}
|
KeyBackups: keyBackupTable,
|
||||||
// if we're here, either the room or session are new, either way, we insert
|
KeyBackupVersions: keyBackupVersionTable,
|
||||||
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
|
LoginTokens: loginTokenTable,
|
||||||
changed = true
|
OpenIDTokens: openIDTable,
|
||||||
if err != nil {
|
Profiles: profilesTable,
|
||||||
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err)
|
ThreePIDs: threePIDTable,
|
||||||
}
|
ServerName: serverName,
|
||||||
}
|
DB: db,
|
||||||
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
|
LoginTokenLifetime: loginTokenLifetime,
|
||||||
if err != nil {
|
BcryptCost: bcryptCost,
|
||||||
return err
|
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||||
}
|
}, nil
|
||||||
if changed {
|
|
||||||
// update the etag
|
|
||||||
var newETag string
|
|
||||||
if oldETag == "" {
|
|
||||||
newETag = "1"
|
|
||||||
} else {
|
|
||||||
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse old etag: %s", err)
|
|
||||||
}
|
|
||||||
newETag = strconv.FormatInt(oldETagInt+1, 10)
|
|
||||||
}
|
|
||||||
etag = newETag
|
|
||||||
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
|
|
||||||
} else {
|
|
||||||
etag = oldETag
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByAccessToken returns the device matching the given access token.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByAccessToken(
|
|
||||||
ctx context.Context, token string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByToken(ctx, token)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDeviceByID returns the device matching the given ID.
|
|
||||||
// Returns sql.ErrNoRows if no matching device was found.
|
|
||||||
func (d *Database) GetDeviceByID(
|
|
||||||
ctx context.Context, localpart, deviceID string,
|
|
||||||
) (*api.Device, error) {
|
|
||||||
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
|
||||||
func (d *Database) GetDevicesByLocalpart(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
|
||||||
return d.devices.selectDevicesByID(ctx, deviceIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDevice makes a new device associated with the given user ID localpart.
|
|
||||||
// If there is already a device with the same device ID for this user, that access token will be revoked
|
|
||||||
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
|
|
||||||
// an error will be returned.
|
|
||||||
// If no device ID is given one is generated.
|
|
||||||
// Returns the device on success.
|
|
||||||
func (d *Database) CreateDevice(
|
|
||||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
|
||||||
displayName *string, ipAddr, userAgent string,
|
|
||||||
) (dev *api.Device, returnErr error) {
|
|
||||||
if deviceID != nil {
|
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
// Revoke existing tokens for this device
|
|
||||||
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// We generate device IDs in a loop in case its already taken.
|
|
||||||
// We cap this at going round 5 times to ensure we don't spin forever
|
|
||||||
var newDeviceID string
|
|
||||||
for i := 1; i <= 5; i++ {
|
|
||||||
newDeviceID, returnErr = generateDeviceID()
|
|
||||||
if returnErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
var err error
|
|
||||||
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if returnErr == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDeviceID creates a new device id. Returns an error if failed to generate
|
|
||||||
// random bytes.
|
|
||||||
func generateDeviceID() (string, error) {
|
|
||||||
b := make([]byte, deviceIDByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
// url-safe no padding
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDevice updates the given device with the display name.
|
|
||||||
// Returns SQL error if there are problems and nil on success.
|
|
||||||
func (d *Database) UpdateDevice(
|
|
||||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevice revokes a device by deleting the entry in the database
|
|
||||||
// matching with the given device ID and user ID localpart.
|
|
||||||
// If the device doesn't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevice(
|
|
||||||
ctx context.Context, deviceID, localpart string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDevices revokes one or more devices by deleting the entry in the database
|
|
||||||
// matching with the given device IDs and user ID localpart.
|
|
||||||
// If the devices don't exist, it will not return an error
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveDevices(
|
|
||||||
ctx context.Context, localpart string, devices []string,
|
|
||||||
) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllDevices revokes devices by deleting the entry in the
|
|
||||||
// database matching the given user ID localpart.
|
|
||||||
// If something went wrong during the deletion, it will return the SQL error.
|
|
||||||
func (d *Database) RemoveAllDevices(
|
|
||||||
ctx context.Context, localpart, exceptDeviceID string,
|
|
||||||
) (devices []api.Device, err error) {
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
|
|
||||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
|
||||||
// determined by the loginTokenLifetime given to the Database constructor.
|
|
||||||
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
|
|
||||||
tok, err := generateLoginToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
meta := &api.LoginTokenMetadata{
|
|
||||||
Token: tok,
|
|
||||||
Expiration: time.Now().Add(d.loginTokenLifetime),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.insert(ctx, txn, meta, data)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateLoginToken() (string, error) {
|
|
||||||
b := make([]byte, loginTokenByteLength)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
|
||||||
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
|
|
||||||
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
|
|
||||||
return d.loginTokens.deleteByToken(ctx, txn, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLoginTokenDataByToken returns the data associated with the given token.
|
|
||||||
// May return sql.ErrNoRows.
|
|
||||||
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
|
|
||||||
return d.loginTokens.selectByToken(ctx, token)
|
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
)
|
)
|
||||||
@ -60,13 +61,15 @@ type threepidStatements struct {
|
|||||||
deleteThreePIDStmt *sql.Stmt
|
deleteThreePIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
||||||
s.db = db
|
s := &threepidStatements{
|
||||||
_, err = db.Exec(threepidSchema)
|
db: db,
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
return sqlutil.StatementList{
|
_, err := db.Exec(threepidSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s, sqlutil.StatementList{
|
||||||
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
|
||||||
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
|
||||||
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
{&s.insertThreePIDStmt, insertThreePIDSQL},
|
||||||
@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
|
|||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectLocalpartForThreePID(
|
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||||
) (localpart string, err error) {
|
) (localpart string, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||||
@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) selectThreePIDsForLocalpart(
|
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||||
ctx context.Context, localpart string,
|
ctx context.Context, localpart string,
|
||||||
) (threepids []authtypes.ThreePID, err error) {
|
) (threepids []authtypes.ThreePID, err error) {
|
||||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||||
@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
|
|||||||
return threepids, rows.Err()
|
return threepids, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) insertThreePID(
|
func (s *threepidStatements) InsertThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||||
@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID(
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *threepidStatements) deleteThreePID(
|
func (s *threepidStatements) DeleteThreePID(
|
||||||
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
|
||||||
_, err = stmt.ExecContext(ctx, threepid, medium)
|
_, err = stmt.ExecContext(ctx, threepid, medium)
|
||||||
|
95
userapi/storage/tables/interface.go
Normal file
95
userapi/storage/tables/interface.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package tables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||||
|
"github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AccountDataTable interface {
|
||||||
|
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
|
||||||
|
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
|
||||||
|
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountsTable interface {
|
||||||
|
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
|
||||||
|
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
|
||||||
|
DeactivateAccount(ctx context.Context, localpart string) (err error)
|
||||||
|
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
|
||||||
|
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
|
||||||
|
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DevicesTable interface {
|
||||||
|
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
|
||||||
|
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
|
||||||
|
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
|
||||||
|
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
|
||||||
|
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
|
||||||
|
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
|
||||||
|
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
|
||||||
|
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
|
||||||
|
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
|
||||||
|
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupTable interface {
|
||||||
|
CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error)
|
||||||
|
InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
|
||||||
|
UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
|
||||||
|
SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error)
|
||||||
|
SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error)
|
||||||
|
SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type KeyBackupVersionTable interface {
|
||||||
|
InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error)
|
||||||
|
UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error
|
||||||
|
UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error
|
||||||
|
DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error)
|
||||||
|
SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type LoginTokenTable interface {
|
||||||
|
InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error
|
||||||
|
DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error
|
||||||
|
SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenIDTable interface {
|
||||||
|
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
|
||||||
|
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProfileTable interface {
|
||||||
|
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
|
||||||
|
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
|
||||||
|
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
|
||||||
|
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
|
||||||
|
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreePIDTable interface {
|
||||||
|
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
|
||||||
|
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
|
||||||
|
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
|
||||||
|
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user