diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index 839ca9e5..abfe830f 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -162,7 +162,7 @@ func AuthFallback( } // Success. Add recaptcha as a completed login flow - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) serveSuccess() return nil diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go index 7ecab9d4..4426b7fd 100644 --- a/clientapi/routing/key_crosssigning.go +++ b/clientapi/routing/key_crosssigning.go @@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys( if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) uploadReq.UserID = device.UserID keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes) diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 49951019..acac60fa 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -74,7 +74,7 @@ func Password( if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil { return *authErr } - AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. if resErr = validatePassword(r.NewPassword); resErr != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index d00d9886..10cfa432 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -72,14 +72,19 @@ func init() { // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. type sessionsDict struct { - sync.Mutex + sync.RWMutex sessions map[string][]authtypes.LoginType + params map[string]registerRequest + timer map[string]*time.Timer } -// GetCompletedStages returns the completed stages for a session. -func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType { - d.Lock() - defer d.Unlock() +// defaultTimeout is the timeout used to clean up sessions +const defaultTimeOut = time.Minute * 5 + +// getCompletedStages returns the completed stages for a session. +func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType { + d.RLock() + defer d.RUnlock() if completedStages, ok := d.sessions[sessionID]; ok { return completedStages @@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp return make([]authtypes.LoginType, 0) } -func newSessionsDict() *sessionsDict { - return &sessionsDict{ - sessions: make(map[string][]authtypes.LoginType), +// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest +func (d *sessionsDict) addParams(sessionID string, r registerRequest) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + d.params[sessionID] = r +} + +func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) { + d.RLock() + defer d.RUnlock() + r, ok := d.params[sessionID] + return r, ok +} + +// deleteSession cleans up a given session, either because the registration completed +// successfully, or because a given timeout (default: 5min) was reached. +func (d *sessionsDict) deleteSession(sessionID string) { + d.Lock() + defer d.Unlock() + delete(d.params, sessionID) + delete(d.sessions, sessionID) + // stop the timer, e.g. because the registration was completed + if t, ok := d.timer[sessionID]; ok { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + delete(d.timer, sessionID) } } -// AddCompletedSessionStage records that a session has completed an auth stage. -func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) { - sessions.Lock() - defer sessions.Unlock() +func newSessionsDict() *sessionsDict { + return &sessionsDict{ + sessions: make(map[string][]authtypes.LoginType), + params: make(map[string]registerRequest), + timer: make(map[string]*time.Timer), + } +} - for _, completedStage := range sessions.sessions[sessionID] { +func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) { + d.Lock() + defer d.Unlock() + t, ok := d.timer[sessionID] + if ok { + if !t.Stop() { + <-t.C + } + t.Reset(duration) + return + } + d.timer[sessionID] = time.AfterFunc(duration, func() { + d.deleteSession(sessionID) + }) +} + +// addCompletedSessionStage records that a session has completed an auth stage +// also starts a timer to delete the session once done. +func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) { + d.startTimer(defaultTimeOut, sessionID) + d.Lock() + defer d.Unlock() + for _, completedStage := range d.sessions[sessionID] { if completedStage == stage { return } } - sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage) + d.sessions[sessionID] = append(sessions.sessions[sessionID], stage) } var ( - // TODO: Remove old sessions. Need to do so on a session-specific timeout. - // sessions stores the completed flow stages for all sessions. Referenced using their sessionID. sessions = newSessionsDict() validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) ) @@ -167,7 +223,7 @@ func newUserInteractiveResponse( params map[string]interface{}, ) userInteractiveResponse { return userInteractiveResponse{ - fs, sessions.GetCompletedStages(sessionID), params, sessionID, + fs, sessions.getCompletedStages(sessionID), params, sessionID, } } @@ -645,12 +701,12 @@ func handleRegistrationFlow( } // Add Recaptcha to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) case authtypes.LoginTypeDummy: // there is nothing to do // Add Dummy to the list of completed registration stages - AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy) case "": // An empty auth type means that we want to fetch the available @@ -666,7 +722,7 @@ func handleRegistrationFlow( // Check if the user's registration flow has been completed successfully // A response with current registration flow and remaining available methods // will be returned if a flow has not been successfully completed yet - return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID), + return checkAndCompleteFlow(sessions.getCompletedStages(sessionID), req, r, sessionID, cfg, userAPI) } @@ -708,7 +764,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) } @@ -727,11 +783,11 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), + req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) } - + sessions.addParams(sessionID, r) // There are still more stages to complete. // Return the flows and those that have been completed. return util.JSONResponse{ @@ -750,11 +806,25 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.UserInternalAPI, - username, password, appserviceID, ipAddr, userAgent string, + username, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { + var registrationOK bool + defer func() { + if registrationOK { + sessions.deleteSession(sessionID) + } + }() + + if data, ok := sessions.getParams(sessionID); ok { + username = data.Username + password = data.Password + deviceID = data.DeviceID + displayName = data.InitialDisplayName + inhibitLogin = data.InhibitLogin + } if username == "" { return util.JSONResponse{ Code: http.StatusBadRequest, @@ -795,6 +865,7 @@ func completeRegistration( // Check whether inhibit_login option is set. If so, don't create an access // token or a device for this user if inhibitLogin { + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -828,6 +899,7 @@ func completeRegistration( } } + registrationOK = true return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ @@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 1f615dc2..c6b7e61c 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -17,6 +17,7 @@ package routing import ( "regexp" "testing" + "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/setup/config" @@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) { func TestEmptyCompletedFlows(t *testing.T) { fakeEmptySessions := newSessionsDict() fakeSessionID := "aRandomSessionIDWhichDoesNotExist" - ret := fakeEmptySessions.GetCompletedStages(fakeSessionID) + ret := fakeEmptySessions.getCompletedStages(fakeSessionID) // check for [] if ret == nil || len(ret) != 0 { @@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) { t.Errorf("user_id should not have been valid: @_something_else:localhost") } } + +func TestSessionCleanUp(t *testing.T) { + s := newSessionsDict() + + t.Run("session is cleaned up after a while", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld" + // manually added, as s.addParams() would start the timer with the default timeout + s.params[dummySession] = registerRequest{Username: "Testing"} + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session is deleted, once the registration completed", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld2" + s.startTimer(time.Minute, dummySession) + s.deleteSession(dummySession) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) + + t.Run("session timer is restarted after second call", func(t *testing.T) { + t.Parallel() + dummySession := "helloWorld3" + // the following will start a timer with the default timeout of 5min + s.addParams(dummySession, registerRequest{Username: "Testing"}) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha) + s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy) + s.getCompletedStages(dummySession) + // reset the timer with a lower timeout + s.startTimer(time.Millisecond, dummySession) + time.Sleep(time.Millisecond * 2) + if data, ok := s.getParams(dummySession); ok { + t.Errorf("expected session to be deleted: %+v", data) + } + }) +} \ No newline at end of file diff --git a/sytest-blacklist b/sytest-blacklist index 16abce8d..e8617dcd 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -24,6 +24,7 @@ Local device key changes get to remote servers with correct prev_id # Flakey Local device key changes appear in /keys/changes +/context/ with lazy_load_members filter works # we don't support groups Remove group category diff --git a/sytest-whitelist b/sytest-whitelist index d3144572..12522cfb 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -601,3 +601,7 @@ Can query remote device keys using POST after notification Device deletion propagates over federation Get left notifs in sync and /keys/changes when other user leaves Remote banned user is kicked and may not rejoin until unbanned +registration remembers parameters +registration accepts non-ascii passwords +registration with inhibit_login inhibits login +