diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 6caf7dcd..82ecf674 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -102,8 +102,7 @@ type userInteractiveFlow struct { // the user already has a valid access token, but we want to double-check // that it isn't stolen by re-authenticating them. type UserInteractive struct { - Completed []string - Flows []userInteractiveFlow + Flows []userInteractiveFlow // Map of login type to implementation Types map[string]Type // Map of session ID to completed login types, will need to be extended in future @@ -116,7 +115,6 @@ func NewUserInteractive(userAccountAPI api.UserLoginAPI, cfg *config.ClientAPI) Config: cfg, } return &UserInteractive{ - Completed: []string{}, Flows: []userInteractiveFlow{ { Stages: []string{typePassword.Name()}, @@ -140,7 +138,6 @@ func (u *UserInteractive) IsSingleStageFlow(authType string) bool { func (u *UserInteractive) AddCompletedStage(sessionID, authType string) { // TODO: Handle multi-stage flows - u.Completed = append(u.Completed, authType) delete(u.Sessions, sessionID) } @@ -157,7 +154,7 @@ func (u *UserInteractive) Challenge(sessionID string) *util.JSONResponse { return &util.JSONResponse{ Code: 401, JSON: Challenge{ - Completed: u.Completed, + Completed: u.Sessions[sessionID], Flows: u.Flows, Session: sessionID, Params: make(map[string]interface{}), diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 262e4810..001b1a6d 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -187,3 +187,38 @@ func TestUserInteractivePasswordBadLogin(t *testing.T) { } } } + +func TestUserInteractive_AddCompletedStage(t *testing.T) { + tests := []struct { + name string + sessionID string + }{ + { + name: "first user", + sessionID: util.RandomString(8), + }, + { + name: "second user", + sessionID: util.RandomString(8), + }, + { + name: "third user", + sessionID: util.RandomString(8), + }, + } + u := setup() + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, resp := u.Verify(ctx, []byte("{}"), nil) + challenge, ok := resp.JSON.(Challenge) + if !ok { + t.Fatalf("expected a Challenge, got %T", resp.JSON) + } + if len(challenge.Completed) > 0 { + t.Fatalf("expected 0 completed stages, got %d", len(challenge.Completed)) + } + u.AddCompletedStage(tt.sessionID, "") + }) + } +}