2024-10-06 21:30:39 +01:00
|
|
|
package auth
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2024-10-25 15:08:56 +01:00
|
|
|
"errors"
|
2024-10-06 21:30:39 +01:00
|
|
|
"github.com/1f349/lavender/database"
|
|
|
|
"github.com/xlzd/gotp"
|
|
|
|
"net/http"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
func isDigitsSupported(digits int64) bool {
|
|
|
|
return digits >= 6 && digits <= 8
|
|
|
|
}
|
|
|
|
|
|
|
|
type otpLoginDB interface {
|
2024-10-25 15:08:56 +01:00
|
|
|
GetOtp(ctx context.Context, subject string) (database.GetOtpRow, error)
|
2024-10-06 21:30:39 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
var _ Provider = (*OtpLogin)(nil)
|
|
|
|
|
|
|
|
type OtpLogin struct {
|
2024-10-25 15:08:56 +01:00
|
|
|
DB otpLoginDB
|
2024-10-06 21:30:39 +01:00
|
|
|
}
|
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
func (o *OtpLogin) Factor() Factor { return FactorSecond }
|
2024-10-06 21:30:39 +01:00
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
func (o *OtpLogin) Name() string { return "basic" }
|
2024-10-06 21:30:39 +01:00
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
func (o *OtpLogin) RenderData(_ context.Context, _ *http.Request, user *database.User, data map[string]any) error {
|
|
|
|
if user == nil || user.Subject == "" {
|
2024-10-06 21:30:39 +01:00
|
|
|
return ErrRequiresPreviousFactor
|
|
|
|
}
|
|
|
|
if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) {
|
|
|
|
return ErrUserDoesNotSupportFactor
|
|
|
|
}
|
|
|
|
|
|
|
|
// no need to provide render data
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
func (o *OtpLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error {
|
2024-10-06 21:30:39 +01:00
|
|
|
if user == nil || user.Subject == "" {
|
|
|
|
return ErrRequiresPreviousFactor
|
|
|
|
}
|
|
|
|
if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) {
|
|
|
|
return ErrUserDoesNotSupportFactor
|
|
|
|
}
|
|
|
|
|
|
|
|
code := req.FormValue("code")
|
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
if !validateTotp(user.OtpSecret, int(user.OtpDigits), code) {
|
2024-10-06 21:30:39 +01:00
|
|
|
return BasicUserSafeError(http.StatusBadRequest, "invalid OTP code")
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-10-25 15:08:56 +01:00
|
|
|
var ErrInvalidOtpCode = errors.New("invalid OTP code")
|
|
|
|
|
|
|
|
func (o *OtpLogin) VerifyOtpCode(ctx context.Context, subject, code string) error {
|
|
|
|
otp, err := o.DB.GetOtp(ctx, subject)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if !validateTotp(otp.OtpSecret, int(otp.OtpDigits), code) {
|
|
|
|
return ErrInvalidOtpCode
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func validateTotp(secret string, digits int, code string) bool {
|
|
|
|
totp := gotp.NewTOTP(secret, int(digits), 30, nil)
|
|
|
|
return verifyTotp(totp, code)
|
|
|
|
}
|
|
|
|
|
2024-10-06 21:30:39 +01:00
|
|
|
func verifyTotp(totp *gotp.TOTP, code string) bool {
|
|
|
|
t := time.Now()
|
|
|
|
if totp.VerifyTime(code, t) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
if totp.VerifyTime(code, t.Add(-30*time.Second)) {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
return totp.VerifyTime(code, t.Add(30*time.Second))
|
|
|
|
}
|