76 lines
1.8 KiB
Go
76 lines
1.8 KiB
Go
|
package login_checker
|
||
|
|
||
|
import (
|
||
|
"code.mrmelon54.com/melon/summer/pkg/api"
|
||
|
"code.mrmelon54.com/melon/summer/pkg/claims/auth"
|
||
|
"code.mrmelon54.com/melon/summer/pkg/tables/user"
|
||
|
"code.mrmelon54.com/melon/summer/pkg/utils"
|
||
|
"errors"
|
||
|
"github.com/mrmelon54/mjwt"
|
||
|
"net/http"
|
||
|
"net/url"
|
||
|
"xorm.io/xorm"
|
||
|
)
|
||
|
|
||
|
type LoginChecker struct {
|
||
|
db *xorm.Engine
|
||
|
signer mjwt.Provider
|
||
|
loginSite url.URL
|
||
|
mfaLoginSite url.URL
|
||
|
}
|
||
|
|
||
|
func NewLoginChecker(db *xorm.Engine, signer mjwt.Provider) *LoginChecker {
|
||
|
return &LoginChecker{
|
||
|
db: db,
|
||
|
signer: signer,
|
||
|
loginSite: url.URL{Path: "/login"},
|
||
|
mfaLoginSite: url.URL{Path: "/mfa"},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (lc *LoginChecker) GetUserID(req *http.Request) (uint64, error) {
|
||
|
token := utils.GetBearerToken(req)
|
||
|
if token == "" {
|
||
|
return 0, errors.New("access token missing")
|
||
|
}
|
||
|
|
||
|
// Verify as access token
|
||
|
_, b, err := mjwt.ExtractClaims[auth.AccessTokenClaims](lc.signer, token)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
var u user.User
|
||
|
get, err := lc.db.Where("id = ?", b.Claims.UserId).Get(&u)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
if !utils.SBool(u.EmailVerified) {
|
||
|
return 0, errors.New("requires email verification")
|
||
|
}
|
||
|
if !utils.SBool(u.MfaActive) {
|
||
|
return 0, errors.New("requires MFA")
|
||
|
}
|
||
|
if get {
|
||
|
return u.Id, nil
|
||
|
}
|
||
|
return 0, errors.New("failed login check")
|
||
|
}
|
||
|
|
||
|
func (lc *LoginChecker) CheckRequired(rw http.ResponseWriter, req *http.Request, cbGood func(uint64)) {
|
||
|
userId, err := lc.GetUserID(req)
|
||
|
if api.GenericErrorMsg[LoginChecker](rw, err, http.StatusUnauthorized, "Token Not Valid") {
|
||
|
return
|
||
|
}
|
||
|
cbGood(userId)
|
||
|
}
|
||
|
|
||
|
func (lc *LoginChecker) CheckOptional(rw http.ResponseWriter, req *http.Request, cbGood func(uint64), cbBad func(uint64)) {
|
||
|
userId, err := lc.GetUserID(req)
|
||
|
if err != nil {
|
||
|
cbBad(0)
|
||
|
return
|
||
|
}
|
||
|
cbGood(userId)
|
||
|
}
|