diff --git a/pages/index.go.html b/pages/index.go.html index f5551b8..737a71e 100644 --- a/pages/index.go.html +++ b/pages/index.go.html @@ -35,7 +35,7 @@
-
+ diff --git a/server/otp.go b/server/otp.go index 4856022..1d51203 100644 --- a/server/otp.go +++ b/server/otp.go @@ -75,66 +75,9 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub uuid.UUID, return false } -func (h *HttpServer) EditOtpGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { - var digits int - switch req.URL.Query().Get("digits") { - case "6": - digits = 6 - case "7": - digits = 7 - case "8": - digits = 8 - default: - http.Error(rw, "400 Bad Request: Invalid number of digits for OTP code", http.StatusBadRequest) - return - } - - // get user email - var email string - if h.DbTx(rw, func(tx *database.Tx) error { - var err error - email, err = tx.GetUserEmail(auth.Data.ID) - return err - }) { - return - } - - secret := gotp.RandomSecret(64) - if secret == "" { - http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError) - return - } - totp := gotp.NewTOTP(secret, digits, 30, nil) - otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer) - code, err := qrcode.New(otpUri, qrcode.Medium) - if err != nil { - http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError) - return - } - qrImg := code.Image(60 * 4) - qrBounds := qrImg.Bounds() - qrWidth := qrBounds.Dx() - - qrBuf := new(bytes.Buffer) - if png.Encode(qrBuf, qrImg) != nil { - http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError) - return - } - - // render page - pages.RenderPageTemplate(rw, "edit-otp", map[string]any{ - "ServiceName": h.conf.ServiceName, - "OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())), - "QrWidth": qrWidth, - "OtpUrl": otpUri, - "OtpSecret": secret, - "OtpDigits": digits, - }) -} - func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) { var digits int - switch req.FormValue("digits") { + switch req.URL.Query().Get("digits") { case "6": digits = 6 case "7": @@ -152,6 +95,51 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht return } + if secret == "" { + // get user email + var email string + if h.DbTx(rw, func(tx *database.Tx) error { + var err error + email, err = tx.GetUserEmail(auth.Data.ID) + return err + }) { + return + } + + secret = gotp.RandomSecret(64) + if secret == "" { + http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError) + return + } + totp := gotp.NewTOTP(secret, digits, 30, nil) + otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer) + code, err := qrcode.New(otpUri, qrcode.Medium) + if err != nil { + http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError) + return + } + qrImg := code.Image(60 * 4) + qrBounds := qrImg.Bounds() + qrWidth := qrBounds.Dx() + + qrBuf := new(bytes.Buffer) + if png.Encode(qrBuf, qrImg) != nil { + http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError) + return + } + + // render page + pages.RenderPageTemplate(rw, "edit-otp", map[string]any{ + "ServiceName": h.conf.ServiceName, + "OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())), + "QrWidth": qrWidth, + "OtpUrl": otpUri, + "OtpSecret": secret, + "OtpDigits": digits, + }) + return + } + totp := gotp.NewTOTP(secret, digits, 30, nil) if !verifyTotp(totp, req.FormValue("code")) { @@ -159,6 +147,12 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht return } + if h.DbTx(rw, func(tx *database.Tx) error { + return tx.SetTwoFactor(auth.Data.ID, secret, digits) + }) { + return + } + http.Redirect(rw, req, "/", http.StatusFound) } diff --git a/server/server.go b/server/server.go index 0185e28..69705d3 100644 --- a/server/server.go +++ b/server/server.go @@ -168,7 +168,6 @@ func NewHttpServer(conf Conf, db *database.DB, privKey []byte) *http.Server { // edit profile pages r.GET("/edit", RequireAuthentication(hs.EditGet)) r.POST("/edit", RequireAuthentication(hs.EditPost)) - r.GET("/edit/otp", RequireAuthentication(hs.EditOtpGet)) r.POST("/edit/otp", RequireAuthentication(hs.EditOtpPost)) // management pages