mirror of
https://github.com/1f349/lavender.git
synced 2024-12-22 15:44:07 +00:00
Start new auth interfaces
This commit is contained in:
parent
7e5a8b9921
commit
2171cece75
87
auth/auth.go
87
auth/auth.go
@ -1,11 +1,88 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import "github.com/1f349/lavender/database"
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
type LoginProvider interface {
|
type Factor byte
|
||||||
AttemptLogin(username, password string) (database.User, error)
|
|
||||||
|
const (
|
||||||
|
FactorFirst Factor = 1 << iota
|
||||||
|
FactorSecond
|
||||||
|
// FactorAuthorized defines the "authorized" state of a session
|
||||||
|
FactorAuthorized
|
||||||
|
)
|
||||||
|
|
||||||
|
type Provider interface {
|
||||||
|
// Factor defines the factors potentially supported by the provider
|
||||||
|
// Some factors might be unavailable due to user preference
|
||||||
|
Factor() Factor
|
||||||
|
|
||||||
|
// Name defines a string value for the provider, useful for template switching
|
||||||
|
Name() string
|
||||||
|
|
||||||
|
// RenderData stores values to send to the templating function
|
||||||
|
RenderData(ctx context.Context, req *http.Request, user *database.User, data map[string]any) error
|
||||||
|
|
||||||
|
// AttemptLogin processes the login request
|
||||||
|
AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthProvider interface {
|
// ErrRequiresSecondFactor notifies the ServeHTTP function to ask for another factor
|
||||||
AttemptLogin(username string) (database.User, error)
|
var ErrRequiresSecondFactor = errors.New("requires second factor")
|
||||||
|
|
||||||
|
// ErrRequiresPreviousFactor is a generic error for providers which require a previous factor
|
||||||
|
var ErrRequiresPreviousFactor = errors.New("requires previous factor")
|
||||||
|
|
||||||
|
// ErrUserDoesNotSupportFactor is a generic error for providers with are unable to support the user
|
||||||
|
var ErrUserDoesNotSupportFactor = errors.New("user does not support factor")
|
||||||
|
|
||||||
|
type UserSafeError struct {
|
||||||
|
Display string
|
||||||
|
Code int
|
||||||
|
Internal error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e UserSafeError) Error() string {
|
||||||
|
return fmt.Sprintf("%s [%d]: %v", e.Display, e.Code, e.Internal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e UserSafeError) Unwrap() error {
|
||||||
|
return e.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
func BasicUserSafeError(code int, message string) UserSafeError {
|
||||||
|
return UserSafeError{
|
||||||
|
Code: code,
|
||||||
|
Display: message,
|
||||||
|
Internal: errors.New(message),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AdminSafeError(inner error) UserSafeError {
|
||||||
|
return UserSafeError{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
Display: "Internal server error",
|
||||||
|
Internal: inner,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type lookupUserDB interface {
|
||||||
|
GetUser(ctx context.Context, subject string) (database.User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lookupUser(ctx context.Context, db lookupUserDB, subject string, resolvesTwoFactor bool, user *database.User) error {
|
||||||
|
getUser, err := db.GetUser(ctx, subject)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*user = getUser
|
||||||
|
if user.NeedFactor && !resolvesTwoFactor {
|
||||||
|
return ErrRequiresSecondFactor
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1 +1,49 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type basicLoginDB interface {
|
||||||
|
lookupUserDB
|
||||||
|
CheckLogin(ctx context.Context, un, pw string) (database.CheckLoginResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Provider = (*BasicLogin)(nil)
|
||||||
|
|
||||||
|
type BasicLogin struct {
|
||||||
|
DB basicLoginDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BasicLogin) Factor() Factor {
|
||||||
|
return FactorFirst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BasicLogin) Name() string { return "basic" }
|
||||||
|
|
||||||
|
func (b *BasicLogin) RenderData(ctx context.Context, req *http.Request, user *database.User, data map[string]any) error {
|
||||||
|
data["username"] = req.FormValue("username")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BasicLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error {
|
||||||
|
un := req.FormValue("username")
|
||||||
|
pw := req.FormValue("password")
|
||||||
|
if len(pw) < 8 {
|
||||||
|
return BasicUserSafeError(http.StatusBadRequest, "Password too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
login, err := b.DB.CheckLogin(ctx, un, pw)
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return lookupUser(ctx, b.DB, login.Subject, false, user)
|
||||||
|
case errors.Is(err, sql.ErrNoRows):
|
||||||
|
return BasicUserSafeError(http.StatusForbidden, "Username or password is invalid")
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
70
auth/otp.go
Normal file
70
auth/otp.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/1f349/lavender/database"
|
||||||
|
"github.com/xlzd/gotp"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isDigitsSupported(digits int64) bool {
|
||||||
|
return digits >= 6 && digits <= 8
|
||||||
|
}
|
||||||
|
|
||||||
|
type otpLoginDB interface {
|
||||||
|
lookupUserDB
|
||||||
|
CheckLogin(ctx context.Context, un, pw string) (database.CheckLoginResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Provider = (*OtpLogin)(nil)
|
||||||
|
|
||||||
|
type OtpLogin struct {
|
||||||
|
db otpLoginDB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *OtpLogin) Factor() Factor {
|
||||||
|
return FactorSecond
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *OtpLogin) Name() string { return "basic" }
|
||||||
|
|
||||||
|
func (b *OtpLogin) RenderData(_ context.Context, _ *http.Request, user *database.User, data map[string]any) error {
|
||||||
|
if user.Subject == "" {
|
||||||
|
return ErrRequiresPreviousFactor
|
||||||
|
}
|
||||||
|
if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) {
|
||||||
|
return ErrUserDoesNotSupportFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
// no need to provide render data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *OtpLogin) AttemptLogin(ctx context.Context, req *http.Request, user *database.User) error {
|
||||||
|
if user == nil || user.Subject == "" {
|
||||||
|
return ErrRequiresPreviousFactor
|
||||||
|
}
|
||||||
|
if user.OtpSecret == "" || !isDigitsSupported(user.OtpDigits) {
|
||||||
|
return ErrUserDoesNotSupportFactor
|
||||||
|
}
|
||||||
|
|
||||||
|
code := req.FormValue("code")
|
||||||
|
|
||||||
|
totp := gotp.NewTOTP(user.OtpSecret, int(user.OtpDigits), 30, nil)
|
||||||
|
if !verifyTotp(totp, code) {
|
||||||
|
return BasicUserSafeError(http.StatusBadRequest, "invalid OTP code")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyTotp(totp *gotp.TOTP, code string) bool {
|
||||||
|
t := time.Now()
|
||||||
|
if totp.VerifyTime(code, t) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if totp.VerifyTime(code, t.Add(-30*time.Second)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return totp.VerifyTime(code, t.Add(30*time.Second))
|
||||||
|
}
|
@ -35,7 +35,8 @@ CREATE TABLE users
|
|||||||
otp_secret TEXT NOT NULL DEFAULT '',
|
otp_secret TEXT NOT NULL DEFAULT '',
|
||||||
otp_digits INTEGER NOT NULL DEFAULT 0,
|
otp_digits INTEGER NOT NULL DEFAULT 0,
|
||||||
|
|
||||||
to_delete BOOLEAN NOT NULL DEFAULT 0
|
to_delete BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
need_factor BOOLEAN NOT NULL DEFAULT 0
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX users_subject ON users (subject);
|
CREATE INDEX users_subject ON users (subject);
|
||||||
|
@ -58,6 +58,7 @@ type User struct {
|
|||||||
OtpSecret string `json:"otp_secret"`
|
OtpSecret string `json:"otp_secret"`
|
||||||
OtpDigits int64 `json:"otp_digits"`
|
OtpDigits int64 `json:"otp_digits"`
|
||||||
ToDelete bool `json:"to_delete"`
|
ToDelete bool `json:"to_delete"`
|
||||||
|
NeedFactor bool `json:"need_factor"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UsersRole struct {
|
type UsersRole struct {
|
||||||
|
@ -71,7 +71,7 @@ func (q *Queries) AddOAuthUser(ctx context.Context, arg AddOAuthUserParams) (str
|
|||||||
|
|
||||||
type CheckLoginResult struct {
|
type CheckLoginResult struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
HasOtp bool `json:"has_otp"`
|
NeedFactor bool `json:"need_factor"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ func (q *Queries) CheckLogin(ctx context.Context, un, pw string) (CheckLoginResu
|
|||||||
}
|
}
|
||||||
return CheckLoginResult{
|
return CheckLoginResult{
|
||||||
Subject: login.Subject,
|
Subject: login.Subject,
|
||||||
HasOtp: login.HasOtp,
|
NeedFactor: login.NeedFactor,
|
||||||
Email: login.Email,
|
Email: login.Email,
|
||||||
EmailVerified: login.EmailVerified,
|
EmailVerified: login.EmailVerified,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -7,7 +7,7 @@ INSERT INTO users (subject, password, email, email_verified, updated_at, registe
|
|||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
|
|
||||||
-- name: checkLogin :one
|
-- name: checkLogin :one
|
||||||
SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
|
SELECT subject, password, need_factor, email, email_verified
|
||||||
FROM users
|
FROM users
|
||||||
WHERE users.subject = ?
|
WHERE users.subject = ?
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
@ -47,7 +47,7 @@ func (q *Queries) FlagUserAsDeleted(ctx context.Context, subject string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUser = `-- name: GetUser :one
|
const getUser = `-- name: GetUser :one
|
||||||
SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete
|
SELECT id, subject, password, change_password, email, email_verified, updated_at, registered, active, name, picture, website, pronouns, birthdate, zone, locale, login, profile_url, auth_type, auth_namespace, auth_user, access_token, refresh_token, token_expiry, otp_secret, otp_digits, to_delete, need_factor
|
||||||
FROM users
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
@ -84,6 +84,7 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
|
|||||||
&i.OtpSecret,
|
&i.OtpSecret,
|
||||||
&i.OtpDigits,
|
&i.OtpDigits,
|
||||||
&i.ToDelete,
|
&i.ToDelete,
|
||||||
|
&i.NeedFactor,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
@ -216,7 +217,7 @@ func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPassword
|
|||||||
}
|
}
|
||||||
|
|
||||||
const checkLogin = `-- name: checkLogin :one
|
const checkLogin = `-- name: checkLogin :one
|
||||||
SELECT subject, password, CAST(otp_secret != '' AS BOOLEAN) AS has_otp, email, email_verified
|
SELECT subject, password, need_factor, email, email_verified
|
||||||
FROM users
|
FROM users
|
||||||
WHERE users.subject = ?
|
WHERE users.subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
@ -225,7 +226,7 @@ LIMIT 1
|
|||||||
type checkLoginRow struct {
|
type checkLoginRow struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Password password.HashString `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
HasOtp bool `json:"has_otp"`
|
NeedFactor bool `json:"need_factor"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
}
|
}
|
||||||
@ -236,7 +237,7 @@ func (q *Queries) checkLogin(ctx context.Context, subject string) (checkLoginRow
|
|||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Subject,
|
&i.Subject,
|
||||||
&i.Password,
|
&i.Password,
|
||||||
&i.HasOtp,
|
&i.NeedFactor,
|
||||||
&i.Email,
|
&i.Email,
|
||||||
&i.EmailVerified,
|
&i.EmailVerified,
|
||||||
)
|
)
|
||||||
|
@ -16,15 +16,15 @@ type UserHandler func(rw http.ResponseWriter, req *http.Request, params httprout
|
|||||||
|
|
||||||
type UserAuth struct {
|
type UserAuth struct {
|
||||||
Subject string
|
Subject string
|
||||||
NeedOtp bool
|
Factor auth.Factor
|
||||||
UserInfo auth.UserInfoFields
|
UserInfo auth.UserInfoFields
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u UserAuth) IsGuest() bool { return u.Subject == "" }
|
func (u UserAuth) IsGuest() bool { return u.Subject == "" }
|
||||||
|
|
||||||
func (u UserAuth) NextFlowUrl(origin *url.URL) *url.URL {
|
func (u UserAuth) NextFlowUrl(origin *url.URL) *url.URL {
|
||||||
if u.NeedOtp {
|
if u.Factor < auth.FactorAuthorized {
|
||||||
return PrepareRedirectUrl("/login/otp", origin)
|
return PrepareRedirectUrl("/login", origin)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user