2023-09-09 01:38:10 +01:00
|
|
|
package server
|
|
|
|
|
|
|
|
import (
|
2023-10-16 15:18:34 +01:00
|
|
|
"bytes"
|
2023-09-09 01:38:10 +01:00
|
|
|
"encoding/base64"
|
|
|
|
"github.com/1f349/tulip/database"
|
|
|
|
"github.com/1f349/tulip/pages"
|
2023-09-15 13:06:31 +01:00
|
|
|
"github.com/google/uuid"
|
2023-09-09 01:38:10 +01:00
|
|
|
"github.com/julienschmidt/httprouter"
|
2023-10-16 16:47:18 +01:00
|
|
|
"github.com/skip2/go-qrcode"
|
|
|
|
"github.com/xlzd/gotp"
|
2023-09-09 01:38:10 +01:00
|
|
|
"html/template"
|
2023-10-16 15:18:34 +01:00
|
|
|
"image/png"
|
2023-09-09 01:38:10 +01:00
|
|
|
"net/http"
|
2023-10-16 16:47:18 +01:00
|
|
|
"time"
|
2023-09-09 01:38:10 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
func (h *HttpServer) LoginOtpGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
|
|
|
if !auth.Data.NeedOtp {
|
|
|
|
h.SafeRedirect(rw, req)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
pages.RenderPageTemplate(rw, "login-otp", map[string]any{
|
2023-10-10 18:06:43 +01:00
|
|
|
"ServiceName": h.conf.ServiceName,
|
2023-09-15 13:06:31 +01:00
|
|
|
"Redirect": req.URL.Query().Get("redirect"),
|
2023-09-09 01:38:10 +01:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
|
|
|
func (h *HttpServer) LoginOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
|
|
|
if !auth.Data.NeedOtp {
|
|
|
|
http.Redirect(rw, req, "/", http.StatusFound)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
otpInput := req.FormValue("code")
|
2023-09-15 13:06:31 +01:00
|
|
|
if h.fetchAndValidateOtp(rw, auth.Data.ID, otpInput) {
|
2023-09-09 01:38:10 +01:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
auth.Data.NeedOtp = false
|
|
|
|
if auth.SaveSessionData() != nil {
|
|
|
|
http.Error(rw, "500 Internal Server Error: Failed to safe session", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
h.SafeRedirect(rw, req)
|
|
|
|
}
|
|
|
|
|
2023-09-15 13:06:31 +01:00
|
|
|
func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub uuid.UUID, code string) bool {
|
|
|
|
var hasOtp bool
|
2023-10-16 16:47:18 +01:00
|
|
|
var secret string
|
|
|
|
var digits int
|
2023-09-15 13:06:31 +01:00
|
|
|
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
|
|
|
hasOtp, err = tx.HasTwoFactor(sub)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if hasOtp {
|
2023-10-16 16:47:18 +01:00
|
|
|
secret, digits, err = tx.GetTwoFactor(sub)
|
2023-09-15 13:06:31 +01:00
|
|
|
}
|
|
|
|
return
|
|
|
|
}) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
if hasOtp {
|
2023-10-16 16:47:18 +01:00
|
|
|
totp := gotp.NewTOTP(secret, digits, 30, nil)
|
|
|
|
if !verifyTotp(totp, code) {
|
2023-09-15 13:06:31 +01:00
|
|
|
http.Error(rw, "400 Bad Request: Invalid OTP code", http.StatusBadRequest)
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
2023-10-16 16:47:18 +01:00
|
|
|
var digits int
|
2023-09-09 01:38:10 +01:00
|
|
|
switch req.URL.Query().Get("digits") {
|
|
|
|
case "6":
|
|
|
|
digits = 6
|
|
|
|
case "7":
|
|
|
|
digits = 7
|
|
|
|
case "8":
|
|
|
|
digits = 8
|
|
|
|
default:
|
|
|
|
http.Error(rw, "400 Bad Request: Invalid number of digits for OTP code", http.StatusBadRequest)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
secret := req.FormValue("secret")
|
|
|
|
if !gotp.IsSecretValid(secret) {
|
|
|
|
http.Error(rw, "400 Bad Request: Invalid secret", http.StatusBadRequest)
|
2023-10-16 16:47:18 +01:00
|
|
|
return
|
2023-09-09 01:38:10 +01:00
|
|
|
}
|
|
|
|
|
2023-10-16 16:47:18 +01:00
|
|
|
if secret == "" {
|
2023-10-16 18:20:31 +01:00
|
|
|
// get user email
|
|
|
|
var email string
|
|
|
|
if h.DbTx(rw, func(tx *database.Tx) error {
|
|
|
|
var err error
|
|
|
|
email, err = tx.GetUserEmail(auth.Data.ID)
|
|
|
|
return err
|
|
|
|
}) {
|
|
|
|
return
|
|
|
|
}
|
2023-09-09 01:38:10 +01:00
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
secret = gotp.RandomSecret(64)
|
|
|
|
if secret == "" {
|
|
|
|
http.Error(rw, "500 Internal Server Error: failed to generate OTP secret", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
totp := gotp.NewTOTP(secret, digits, 30, nil)
|
|
|
|
otpUri := totp.ProvisioningUri(email, h.conf.OtpIssuer)
|
|
|
|
code, err := qrcode.New(otpUri, qrcode.Medium)
|
|
|
|
if err != nil {
|
|
|
|
http.Error(rw, "500 Internal Server Error: failed to generate QR code", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
|
|
|
qrImg := code.Image(60 * 4)
|
|
|
|
qrBounds := qrImg.Bounds()
|
|
|
|
qrWidth := qrBounds.Dx()
|
2023-09-09 01:38:10 +01:00
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
qrBuf := new(bytes.Buffer)
|
|
|
|
if png.Encode(qrBuf, qrImg) != nil {
|
|
|
|
http.Error(rw, "500 Internal Server Error: failed to generate PNG image of QR code", http.StatusInternalServerError)
|
|
|
|
return
|
|
|
|
}
|
2023-10-16 16:47:18 +01:00
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
// render page
|
|
|
|
pages.RenderPageTemplate(rw, "edit-otp", map[string]any{
|
|
|
|
"ServiceName": h.conf.ServiceName,
|
|
|
|
"OtpQr": template.URL("data:qrImg/png;base64," + base64.StdEncoding.EncodeToString(qrBuf.Bytes())),
|
|
|
|
"QrWidth": qrWidth,
|
|
|
|
"OtpUrl": otpUri,
|
|
|
|
"OtpSecret": secret,
|
|
|
|
"OtpDigits": digits,
|
|
|
|
})
|
2023-09-09 01:38:10 +01:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-10-16 16:47:18 +01:00
|
|
|
totp := gotp.NewTOTP(secret, digits, 30, nil)
|
|
|
|
|
|
|
|
if !verifyTotp(totp, req.FormValue("code")) {
|
|
|
|
http.Error(rw, "400 Bad Request: invalid OTP code", http.StatusBadRequest)
|
2023-09-09 01:38:10 +01:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-10-16 18:20:31 +01:00
|
|
|
if h.DbTx(rw, func(tx *database.Tx) error {
|
|
|
|
return tx.SetTwoFactor(auth.Data.ID, secret, digits)
|
|
|
|
}) {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2023-09-15 13:06:31 +01:00
|
|
|
http.Redirect(rw, req, "/", http.StatusFound)
|
2023-09-09 01:38:10 +01:00
|
|
|
}
|
2023-10-16 16:47:18 +01:00
|
|
|
|
|
|
|
func verifyTotp(totp *gotp.TOTP, code string) bool {
|
|
|
|
t := time.Now()
|
|
|
|
if totp.VerifyTime(code, t) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if totp.VerifyTime(code, t.Add(-30*time.Second)) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
return totp.VerifyTime(code, t.Add(30*time.Second))
|
|
|
|
}
|