From a66a3b830c53c223cb939bd010d3769f02f6ccfb Mon Sep 17 00:00:00 2001 From: Kegsay Date: Wed, 17 Jun 2020 11:22:26 +0100 Subject: [PATCH] Make userapi control account creation entirely (#1139) This makes a chokepoint with which we can finally fix 'database is locked' errors on sqlite during account creation --- appservice/appservice.go | 1 + clientapi/auth/auth.go | 3 +- clientapi/auth/authtypes/account.go | 31 ------ clientapi/auth/storage/accounts/interface.go | 9 +- .../accounts/postgres/accounts_table.go | 10 +- .../auth/storage/accounts/postgres/storage.go | 11 ++- .../accounts/sqlite3/accounts_table.go | 10 +- .../auth/storage/accounts/sqlite3/storage.go | 11 ++- clientapi/routing/login.go | 5 +- clientapi/routing/register.go | 96 +++++++++++-------- clientapi/routing/routing.go | 4 +- userapi/api/api.go | 31 ++++-- userapi/internal/api.go | 22 ++++- 13 files changed, 131 insertions(+), 113 deletions(-) delete mode 100644 clientapi/auth/authtypes/account.go diff --git a/appservice/appservice.go b/appservice/appservice.go index 84a6a9b1..72869041 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -110,6 +110,7 @@ func generateAppServiceAccount( ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeUser, Localpart: as.SenderLocalpart, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index b8e40853..b4c39ae3 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -23,7 +23,6 @@ import ( "net/http" "strings" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -42,7 +41,7 @@ type DeviceDatabase interface { // AccountDatabase represents an account database. type AccountDatabase interface { // Look up the account matching the given localpart. - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/account.go b/clientapi/auth/authtypes/account.go deleted file mode 100644 index fd3c15a8..00000000 --- a/clientapi/auth/authtypes/account.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2017 Vector Creations Ltd -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package authtypes - -import ( - "github.com/matrix-org/gomatrixserverlib" -) - -// Account represents a Matrix account on this home server. -type Account struct { - UserID string - Localpart string - ServerName gomatrixserverlib.ServerName - Profile *Profile - AppServiceID string - // TODO: Other flags like IsAdmin, IsGuest - // TODO: Devices - // TODO: Associations (e.g. with application services) -} diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index 3391ccbf..13e3e289 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -20,20 +20,21 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" ) type Database interface { internal.PartitionStorer - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*authtypes.Account, error) + GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) - CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) + CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*api.Account, error) + CreateGuestAccount(ctx context.Context) (*api.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetRoomIDsByLocalPart(ctx context.Context, localpart string) ([]string, error) @@ -53,7 +54,7 @@ type Database interface { GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error) PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*authtypes.Account, error) + GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) } // Err3PIDInUse is the error returned when trying to save an association involving diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go index 85c1938a..931ffb73 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -92,7 +92,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := txn.Stmt(s.insertAccountStmt) @@ -106,7 +106,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -123,9 +123,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index fcb592ae..2b88cb70 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" @@ -84,7 +85,7 @@ func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serve // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -121,7 +122,7 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var numLocalpart int64 numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) @@ -140,7 +141,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // account already exists, it will return nil, sqlutil.ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { +) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err @@ -150,7 +151,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user @@ -427,6 +428,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go index fd6a09cd..768f536d 100644 --- a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go @@ -19,8 +19,8 @@ import ( "database/sql" "time" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" @@ -90,7 +90,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // on success. func (s *accountsStatements) insertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt @@ -104,7 +104,7 @@ func (s *accountsStatements) insertAccount( return nil, err } - return &authtypes.Account{ + return &api.Account{ Localpart: localpart, UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, @@ -121,9 +121,9 @@ func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) selectAccountByLocalpart( ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var appserviceIDPtr sql.NullString - var acc authtypes.Account + var acc api.Account stmt := s.selectAccountByLocalpartStmt err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 44245a99..4dd755a7 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" // Import the sqlite3 database driver. @@ -89,7 +90,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( ctx context.Context, localpart, plaintextPassword string, -) (*authtypes.Account, error) { +) (*api.Account, error) { hash, err := d.accounts.selectPasswordHash(ctx, localpart) if err != nil { return nil, err @@ -126,7 +127,7 @@ func (d *Database) SetDisplayName( // CreateGuestAccount makes a new guest account and creates an empty profile // for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { // We need to lock so we sequentially create numeric localparts. If we don't, two calls to // this function will cause the same number to be selected and one will fail with 'database is locked' @@ -152,7 +153,7 @@ func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Accou // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, -) (acc *authtypes.Account, err error) { +) (acc *api.Account, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) return err @@ -162,7 +163,7 @@ func (d *Database) CreateAccount( func (d *Database) createAccount( ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, -) (*authtypes.Account, error) { +) (*api.Account, error) { var err error // Generate a password hash if this is not a password-less user hash := "" @@ -438,6 +439,6 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, -) (*authtypes.Account, error) { +) (*api.Account, error) { return d.accounts.selectAccountByLocalpart(ctx, localpart) } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 2eb480ef..25231a3a 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -20,7 +20,6 @@ import ( "context" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" @@ -81,7 +80,7 @@ func Login( } } else if req.Method == http.MethodPost { var r passwordRequest - var acc *authtypes.Account + var acc *api.Account resErr := httputil.UnmarshalJSONRequest(req, &r) if resErr != nil { return *resErr @@ -156,7 +155,7 @@ func getDevice( ctx context.Context, r passwordRequest, deviceDB devices.Database, - acc *authtypes.Account, + acc *api.Account, token string, ) (dev *api.Device, err error) { dev, err = deviceDB.CreateDevice( diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 8988dbd0..fddf9253 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -34,15 +34,14 @@ import ( "github.com/matrix-org/dendrite/internal/config" "github.com/matrix-org/dendrite/internal/eventutil" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/storage/accounts" - "github.com/matrix-org/dendrite/clientapi/auth/storage/devices" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" + userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/tokens" "github.com/matrix-org/util" @@ -441,8 +440,8 @@ func validateApplicationService( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register func Register( req *http.Request, + userAPI userapi.UserInternalAPI, accountDB accounts.Database, - deviceDB devices.Database, cfg *config.Dendrite, ) util.JSONResponse { var r registerRequest @@ -451,7 +450,7 @@ func Register( return *resErr } if req.URL.Query().Get("kind") == "guest" { - return handleGuestRegistration(req, r, cfg, accountDB, deviceDB) + return handleGuestRegistration(req, r, cfg, userAPI) } // Retrieve or generate the sessionID @@ -507,17 +506,19 @@ func Register( "session_id": r.Auth.Session, }).Info("Processing registration request") - return handleRegistrationFlow(req, r, sessionID, cfg, accountDB, deviceDB) + return handleRegistrationFlow(req, r, sessionID, cfg, userAPI) } func handleGuestRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { - acc, err := accountDB.CreateGuestAccount(req.Context()) + var res userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ + AccountType: userapi.AccountTypeGuest, + }, &res) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -526,8 +527,8 @@ func handleGuestRegistration( } token, err := tokens.GenerateLoginToken(tokens.TokenOptions{ ServerPrivateKey: cfg.Matrix.PrivateKey.Seed(), - ServerName: string(acc.ServerName), - UserID: acc.UserID, + ServerName: string(res.Account.ServerName), + UserID: res.Account.UserID, }) if err != nil { @@ -537,7 +538,12 @@ func handleGuestRegistration( } } //we don't allow guests to specify their own device_id - dev, err := deviceDB.CreateDevice(req.Context(), acc.Localpart, nil, token, r.InitialDisplayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ + Localpart: res.Account.Localpart, + DeviceDisplayName: r.InitialDisplayName, + AccessToken: token, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -547,10 +553,10 @@ func handleGuestRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: res.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } @@ -563,8 +569,7 @@ func handleRegistrationFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // TODO: Shared secret registration (create new user scripts) // TODO: Enable registration config flag @@ -615,7 +620,7 @@ func handleRegistrationFlow( // by whether the request contains an access token. if err == nil { return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) } @@ -626,7 +631,7 @@ func handleRegistrationFlow( // don't need a condition on that call since the registration is clearly // stated as being AS-related. return handleApplicationServiceRegistration( - accessToken, err, req, r, cfg, accountDB, deviceDB, + accessToken, err, req, r, cfg, userAPI, ) case authtypes.LoginTypeDummy: @@ -645,7 +650,7 @@ func handleRegistrationFlow( // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), - req, r, sessionID, cfg, accountDB, deviceDB) + req, r, sessionID, cfg, userAPI) } // handleApplicationServiceRegistration handles the registration of an @@ -662,8 +667,7 @@ func handleApplicationServiceRegistration( req *http.Request, r registerRequest, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { // Check if we previously had issues extracting the access token from the // request. @@ -687,7 +691,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, + req.Context(), userAPI, r.Username, "", appserviceID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -701,13 +705,12 @@ func checkAndCompleteFlow( r registerRequest, sessionID string, cfg *config.Dendrite, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, ) util.JSONResponse { if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), accountDB, deviceDB, r.Username, r.Password, "", + req.Context(), userAPI, r.Username, r.Password, "", r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -724,8 +727,7 @@ func checkAndCompleteFlow( // LegacyRegister process register requests from the legacy v1 API func LegacyRegister( req *http.Request, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, cfg *config.Dendrite, ) util.JSONResponse { var r legacyRegisterRequest @@ -760,10 +762,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) + return completeRegistration(req.Context(), userAPI, r.Username, r.Password, "", false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -809,8 +811,7 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u // not all func completeRegistration( ctx context.Context, - accountDB accounts.Database, - deviceDB devices.Database, + userAPI userapi.UserInternalAPI, username, password, appserviceID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, @@ -829,9 +830,16 @@ func completeRegistration( } } - acc, err := accountDB.CreateAccount(ctx, username, password, appserviceID) + var accRes userapi.PerformAccountCreationResponse + err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ + AppServiceID: appserviceID, + Localpart: username, + Password: password, + AccountType: userapi.AccountTypeUser, + OnConflict: userapi.ConflictAbort, + }, &accRes) if err != nil { - if errors.Is(err, sqlutil.ErrUserExists) { // user already exists + if _, ok := err.(*userapi.ErrorConflict); ok { // user already exists return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.UserInUse("Desired user ID is already taken."), @@ -852,8 +860,8 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, acc.ServerName), - HomeServer: acc.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), + HomeServer: accRes.Account.ServerName, }, } } @@ -866,7 +874,13 @@ func completeRegistration( } } - dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName) + var devRes userapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + Localpart: username, + AccessToken: token, + DeviceDisplayName: displayName, + DeviceID: deviceID, + }, &devRes) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -877,10 +891,10 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: dev.UserID, - AccessToken: dev.AccessToken, - HomeServer: acc.ServerName, - DeviceID: dev.ID, + UserID: devRes.Device.UserID, + AccessToken: devRes.Device.AccessToken, + HomeServer: accRes.Account.ServerName, + DeviceID: devRes.Device.ID, }, } } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 80d9ab66..5e8a606a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -203,11 +203,11 @@ func Setup( ).Methods(http.MethodPut, http.MethodOptions) r0mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return Register(req, accountDB, deviceDB, cfg) + return Register(req, userAPI, accountDB, cfg) })).Methods(http.MethodPost, http.MethodOptions) v1mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { - return LegacyRegister(req, accountDB, deviceDB, cfg) + return LegacyRegister(req, userAPI, cfg) })).Methods(http.MethodPost, http.MethodOptions) r0mux.Handle("/register/available", httputil.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse { diff --git a/userapi/api/api.go b/userapi/api/api.go index 34c74bb3..c953a5ba 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -89,16 +89,18 @@ type QueryProfileResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - Localpart string - AppServiceID string - Password string + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + + AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. + Password string // optional: if missing then this account will be a passwordless account OnConflict Conflict } // PerformAccountCreationResponse is the response for PerformAccountCreation type PerformAccountCreationResponse struct { AccountCreated bool - UserID string + Account *Account } // PerformDeviceCreationRequest is the request for PerformDeviceCreation @@ -115,8 +117,7 @@ type PerformDeviceCreationRequest struct { // PerformDeviceCreationResponse is the response for PerformDeviceCreation type PerformDeviceCreationResponse struct { DeviceCreated bool - AccessToken string - DeviceID string + Device *Device } // Device represents a client's device (mobile, web, etc) @@ -134,6 +135,16 @@ type Device struct { DisplayName string } +// Account represents a Matrix account on this home server. +type Account struct { + UserID string + Localpart string + ServerName gomatrixserverlib.ServerName + AppServiceID string + // TODO: Other flags like IsAdmin, IsGuest + // TODO: Associations (e.g. with application services) +} + // ErrorForbidden is an error indicating that the supplied access token is forbidden type ErrorForbidden struct { Message string @@ -155,9 +166,17 @@ func (e *ErrorConflict) Error() string { // Conflict is an enum representing what to do when encountering conflicting when creating profiles/devices type Conflict int +// AccountType is an enum representing the kind of account +type AccountType int + const ( // ConflictUpdate will update matching records returning no error ConflictUpdate Conflict = 1 // ConflictAbort will reject the request with ErrorConflict ConflictAbort Conflict = 2 + + // AccountTypeUser indicates this is a user account + AccountTypeUser AccountType = 1 + // AccountTypeGuest indicates this is a guest account + AccountTypeGuest AccountType = 2 ) diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 1b34dc7b..3a413166 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -39,6 +39,15 @@ type UserInternalAPI struct { } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + if req.AccountType == api.AccountTypeGuest { + acc, err := a.AccountDB.CreateGuestAccount(ctx) + if err != nil { + return err + } + res.AccountCreated = true + res.Account = acc + return nil + } acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -51,12 +60,18 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } } } + // account already exists res.AccountCreated = false - res.UserID = fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName) + res.Account = &api.Account{ + AppServiceID: req.AppServiceID, + Localpart: req.Localpart, + ServerName: a.ServerName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + } return nil } res.AccountCreated = true - res.UserID = acc.UserID + res.Account = acc return nil } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { @@ -65,8 +80,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe return err } res.DeviceCreated = true - res.AccessToken = dev.AccessToken - res.DeviceID = dev.ID + res.Device = dev return nil }