From 78032b3f4c0ca2d272180afb09142fce82313109 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 22 Jul 2019 15:05:38 +0100 Subject: [PATCH] Correctly create new device when device_id is passed to /login (#753) Fixes https://github.com/matrix-org/dendrite/issues/401 Currently when passing a `device_id` parameter to `/login`, which is [supposed](https://matrix.org/docs/spec/client_server/unstable#post-matrix-client-r0-login) to return a device with that ID set, it instead just generates a random `device_id` and hands that back to you. The code was already there to do this correctly, it looks like it had just been broken during some change. Hopefully sytest will prevent this from becoming broken again. --- .../auth/storage/devices/devices_table.go | 2 ++ clientapi/auth/storage/devices/storage.go | 2 +- clientapi/routing/login.go | 25 ++++++++--------- clientapi/routing/register.go | 28 ++++++++++++------- testfile | 2 ++ 5 files changed, 34 insertions(+), 25 deletions(-) diff --git a/clientapi/auth/storage/devices/devices_table.go b/clientapi/auth/storage/devices/devices_table.go index 96d6521d..60aa563a 100644 --- a/clientapi/auth/storage/devices/devices_table.go +++ b/clientapi/auth/storage/devices/devices_table.go @@ -169,6 +169,8 @@ func (s *devicesStatements) selectDeviceByToken( return &dev, err } +// selectDeviceByID retrieves a device from the database with the given user +// localpart and deviceID func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*authtypes.Device, error) { diff --git a/clientapi/auth/storage/devices/storage.go b/clientapi/auth/storage/devices/storage.go index 7032fe7b..82c8e97a 100644 --- a/clientapi/auth/storage/devices/storage.go +++ b/clientapi/auth/storage/devices/storage.go @@ -84,7 +84,7 @@ func (d *Database) CreateDevice( if deviceID != nil { returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { var err error - // Revoke existing token for this device + // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { return err } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 2e2d409f..02d95815 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -18,7 +18,6 @@ import ( "net/http" "context" - "database/sql" "github.com/matrix-org/dendrite/clientapi/auth" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" @@ -42,10 +41,12 @@ type flow struct { } type passwordRequest struct { - User string `json:"user"` - Password string `json:"password"` + User string `json:"user"` + Password string `json:"password"` + // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") + // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` - DeviceID string `json:"device_id"` + DeviceID *string `json:"device_id"` } type loginResponse struct { @@ -110,7 +111,7 @@ func Login( return httputil.LogThenError(req, err) } - dev, err := getDevice(req.Context(), r, deviceDB, acc, localpart, token) + dev, err := getDevice(req.Context(), r, deviceDB, acc, token) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -134,20 +135,16 @@ func Login( } } -// check if device exists else create one +// getDevice returns a new or existing device func getDevice( ctx context.Context, r passwordRequest, deviceDB *devices.Database, acc *authtypes.Account, - localpart, token string, + token string, ) (dev *authtypes.Device, err error) { - dev, err = deviceDB.GetDeviceByID(ctx, localpart, r.DeviceID) - if err == sql.ErrNoRows { - // device doesn't exist, create one - dev, err = deviceDB.CreateDevice( - ctx, acc.Localpart, nil, token, r.InitialDisplayName, - ) - } + dev, err = deviceDB.CreateDevice( + ctx, acc.Localpart, r.DeviceID, token, r.InitialDisplayName, + ) return } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index fa15f4fc..c5a3d301 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -121,7 +121,10 @@ type registerRequest struct { // user-interactive auth params Auth authDict `json:"auth"` + // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") + // Thus a pointer is needed to differentiate between the two InitialDisplayName *string `json:"initial_device_display_name"` + DeviceID *string `json:"device_id"` // Prevent this user from logging in InhibitLogin common.WeakBoolean `json:"inhibit_login"` @@ -626,7 +629,7 @@ func handleApplicationServiceRegistration( // application service registration is entirely separate. return completeRegistration( req.Context(), accountDB, deviceDB, r.Username, "", appserviceID, - r.InhibitLogin, r.InitialDisplayName, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -646,7 +649,7 @@ func checkAndCompleteFlow( // This flow was completed, registration can continue return completeRegistration( req.Context(), accountDB, deviceDB, r.Username, r.Password, "", - r.InhibitLogin, r.InitialDisplayName, + r.InhibitLogin, r.InitialDisplayName, r.DeviceID, ) } @@ -697,10 +700,10 @@ func LegacyRegister( return util.MessageResponse(http.StatusForbidden, "HMAC incorrect") } - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) case authtypes.LoginTypeDummy: // there is nothing to do - return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil) + return completeRegistration(req.Context(), accountDB, deviceDB, r.Username, r.Password, "", false, nil, nil) default: return util.JSONResponse{ Code: http.StatusNotImplemented, @@ -738,13 +741,19 @@ func parseAndValidateLegacyLogin(req *http.Request, r *legacyRegisterRequest) *u return nil } +// completeRegistration runs some rudimentary checks against the submitted +// input, then if successful creates an account and a newly associated device +// We pass in each individual part of the request here instead of just passing a +// registerRequest, as this function serves requests encoded as both +// registerRequests and legacyRegisterRequests, which share some attributes but +// not all func completeRegistration( ctx context.Context, accountDB *accounts.Database, deviceDB *devices.Database, username, password, appserviceID string, inhibitLogin common.WeakBoolean, - displayName *string, + displayName, deviceID *string, ) util.JSONResponse { if username == "" { return util.JSONResponse{ @@ -773,6 +782,9 @@ func completeRegistration( } } + // Increment prometheus counter for created users + amtRegUsers.Inc() + // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { @@ -793,8 +805,7 @@ func completeRegistration( } } - // TODO: Use the device ID in the request. - dev, err := deviceDB.CreateDevice(ctx, username, nil, token, displayName) + dev, err := deviceDB.CreateDevice(ctx, username, deviceID, token, displayName) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -802,9 +813,6 @@ func completeRegistration( } } - // Increment prometheus counter for created users - amtRegUsers.Inc() - return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ diff --git a/testfile b/testfile index 8a8225a8..1b2ad9e6 100644 --- a/testfile +++ b/testfile @@ -149,3 +149,5 @@ Typing events appear in incremental sync Typing events appear in gapped sync Inbound federation of state requires event_id as a mandatory paramater Inbound federation of state_ids requires event_id as a mandatory paramater +POST /register returns the same device_id as that in the request +POST /login returns the same device_id as that in the request