diff --git a/database/tx.go b/database/tx.go index 507ac00..d0023ed 100644 --- a/database/tx.go +++ b/database/tx.go @@ -167,6 +167,10 @@ WHERE subject = ?`, } func (t *Tx) SetTwoFactor(sub uuid.UUID, secret string, digits int) error { + if secret == "" && digits == 0 { + _, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub.String()) + return err + } _, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub.String(), secret, digits) return err } diff --git a/pages/remove-otp.go.html b/pages/remove-otp.go.html new file mode 100644 index 0000000..8f37d83 --- /dev/null +++ b/pages/remove-otp.go.html @@ -0,0 +1,21 @@ + + + + {{.ServiceName}} + + +
+

{{.ServiceName}}

+
+
+
+ +
+ + +
+ +
+
+ + diff --git a/server/I-am-just-testing.go b/server/I-am-just-testing.go new file mode 100644 index 0000000..372f82f --- /dev/null +++ b/server/I-am-just-testing.go @@ -0,0 +1,49 @@ +package server + +import ( + "context" + "errors" + "fmt" + "github.com/go-oauth2/oauth2/v4" +) + +var _ oauth2.TokenStore = &TestingStruct{} + +type TestingStruct struct { +} + +func (t TestingStruct) Create(ctx context.Context, info oauth2.TokenInfo) error { + fmt.Println(info.GetAccessExpiresIn()) + fmt.Println(info.GetRefreshExpiresIn()) + return errors.New("error") +} + +func (t TestingStruct) RemoveByCode(ctx context.Context, code string) error { + //TODO implement me + panic("implement me") +} + +func (t TestingStruct) RemoveByAccess(ctx context.Context, access string) error { + //TODO implement me + panic("implement me") +} + +func (t TestingStruct) RemoveByRefresh(ctx context.Context, refresh string) error { + //TODO implement me + panic("implement me") +} + +func (t TestingStruct) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) { + //TODO implement me + panic("implement me") +} + +func (t TestingStruct) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) { + //TODO implement me + panic("implement me") +} + +func (t TestingStruct) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) { + //TODO implement me + panic("implement me") +} diff --git a/server/otp.go b/server/otp.go index d0a5515..f9e4f2c 100644 --- a/server/otp.go +++ b/server/otp.go @@ -76,6 +76,30 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub uuid.UUID, } func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { + if req.Method == http.MethodPost && req.FormValue("remove") == "1" { + if !req.Form.Has("code") { + // render page + pages.RenderPageTemplate(rw, "remove-otp", map[string]any{ + "ServiceName": h.conf.ServiceName, + }) + return + } + + otpInput := req.Form.Get("code") + if h.fetchAndValidateOtp(rw, auth.Data.ID, otpInput) { + return + } + + if h.DbTx(rw, func(tx *database.Tx) error { + return tx.SetTwoFactor(auth.Data.ID, "", 0) + }) { + return + } + + http.Redirect(rw, req, "/", http.StatusFound) + return + } + var digits int switch req.FormValue("digits") { case "6":