76 lines
1.8 KiB
Go
76 lines
1.8 KiB
Go
package login_checker
|
|
|
|
import (
|
|
"code.mrmelon54.com/melon/summer-utils/api"
|
|
"code.mrmelon54.com/melon/summer-utils/claims/auth"
|
|
"code.mrmelon54.com/melon/summer-utils/tables/user"
|
|
"code.mrmelon54.com/melon/summer-utils/utils"
|
|
"errors"
|
|
"github.com/mrmelon54/mjwt"
|
|
"net/http"
|
|
"net/url"
|
|
"xorm.io/xorm"
|
|
)
|
|
|
|
type LoginChecker struct {
|
|
db *xorm.Engine
|
|
signer mjwt.Provider
|
|
loginSite url.URL
|
|
mfaLoginSite url.URL
|
|
}
|
|
|
|
func NewLoginChecker(db *xorm.Engine, signer mjwt.Provider) *LoginChecker {
|
|
return &LoginChecker{
|
|
db: db,
|
|
signer: signer,
|
|
loginSite: url.URL{Path: "/login"},
|
|
mfaLoginSite: url.URL{Path: "/mfa"},
|
|
}
|
|
}
|
|
|
|
func (lc *LoginChecker) GetUserID(req *http.Request) (uint64, error) {
|
|
token := utils.GetBearerToken(req)
|
|
if token == "" {
|
|
return 0, errors.New("access token missing")
|
|
}
|
|
|
|
// Verify as access token
|
|
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](lc.signer, token)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
var u user.User
|
|
get, err := lc.db.Where("id = ?", b.Claims.UserId).Get(&u)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if !utils.SBool(u.EmailVerified) {
|
|
return 0, errors.New("requires email verification")
|
|
}
|
|
if !utils.SBool(u.MfaActive) {
|
|
return 0, errors.New("requires MFA")
|
|
}
|
|
if get {
|
|
return u.Id, nil
|
|
}
|
|
return 0, errors.New("failed login check")
|
|
}
|
|
|
|
func (lc *LoginChecker) CheckRequired(rw http.ResponseWriter, req *http.Request, cbGood func(uint64)) {
|
|
userId, err := lc.GetUserID(req)
|
|
if api.GenericErrorMsg[LoginChecker](rw, err, http.StatusUnauthorized, "Token Not Valid") {
|
|
return
|
|
}
|
|
cbGood(userId)
|
|
}
|
|
|
|
func (lc *LoginChecker) CheckOptional(rw http.ResponseWriter, req *http.Request, cbGood func(uint64), cbBad func(uint64)) {
|
|
userId, err := lc.GetUserID(req)
|
|
if err != nil {
|
|
cbBad(0)
|
|
return
|
|
}
|
|
cbGood(userId)
|
|
}
|