mirror of
https://github.com/1f349/tulip.git
synced 2024-11-13 23:31:37 +00:00
Correct refactored database calls
This commit is contained in:
parent
fbd49da2db
commit
37570e2157
@ -7,20 +7,16 @@ import (
|
||||
)
|
||||
|
||||
type ClientStore struct {
|
||||
db *database.DB
|
||||
db *database.Queries
|
||||
}
|
||||
|
||||
var _ oauth2.ClientStore = &ClientStore{}
|
||||
|
||||
func New(db *database.DB) *ClientStore {
|
||||
func New(db *database.Queries) *ClientStore {
|
||||
return &ClientStore{db: db}
|
||||
}
|
||||
|
||||
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
|
||||
tx, err := c.db.BeginCtx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
return tx.GetClientInfo(id)
|
||||
a, err := c.db.GetClientInfo(ctx, id)
|
||||
return &a, err
|
||||
}
|
||||
|
@ -118,7 +118,7 @@ func genHmacKey() []byte {
|
||||
return a
|
||||
}
|
||||
|
||||
func checkDbHasUser(db *database.DB) error {
|
||||
func checkDbHasUser(db *database.Queries) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
@ -126,7 +126,7 @@ func checkDbHasUser(db *database.DB) error {
|
||||
defer tx.Rollback()
|
||||
if err := tx.HasUser(); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, database.RoleAdmin, false)
|
||||
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, types.RoleAdmin, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
|
82
database/clientstore-wrapper.go
Normal file
82
database/clientstore-wrapper.go
Normal file
@ -0,0 +1,82 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"golang.org/x/text/language"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserPatch struct {
|
||||
Name string
|
||||
Picture string
|
||||
Website string
|
||||
Pronouns types.UserPronoun
|
||||
Birthdate sql.NullTime
|
||||
ZoneInfo types.UserZone
|
||||
Locale types.UserLocale
|
||||
}
|
||||
|
||||
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
|
||||
var err error
|
||||
u.Name = v.Get("name")
|
||||
u.Picture = v.Get("picture")
|
||||
u.Website = v.Get("website")
|
||||
if v.Has("reset_pronouns") {
|
||||
u.Pronouns.Pronoun = pronouns.TheyThem
|
||||
} else {
|
||||
u.Pronouns.Pronoun, err = pronouns.FindPronoun(v.Get("pronouns"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
|
||||
u.Birthdate = sql.NullTime{}
|
||||
} else {
|
||||
u.Birthdate = sql.NullTime{Valid: true}
|
||||
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_zoneinfo") {
|
||||
u.ZoneInfo.Location = time.UTC
|
||||
} else {
|
||||
u.ZoneInfo.Location, err = time.LoadLocation(v.Get("zoneinfo"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_locale") {
|
||||
u.Locale.Tag = language.AmericanEnglish
|
||||
} else {
|
||||
u.Locale.Tag, err = language.Parse(v.Get("locale"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var _ oauth2.ClientInfo = &ClientStore{}
|
||||
|
||||
func (c *ClientStore) GetID() string { return c.Subject }
|
||||
func (c *ClientStore) GetSecret() string { return c.Secret }
|
||||
func (c *ClientStore) GetDomain() string { return c.Domain }
|
||||
func (c *ClientStore) IsPublic() bool { return c.Public }
|
||||
func (c *ClientStore) GetUserID() string { return c.Owner }
|
||||
|
||||
// GetName is an extra field for the oauth handler to display the application
|
||||
// name
|
||||
func (c *ClientStore) GetName() string { return c.Name }
|
||||
|
||||
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
||||
// this is for trusted applications to get permissions without asking the user
|
||||
func (c *ClientStore) IsSSO() bool { return c.Sso }
|
||||
|
||||
// IsActive is an extra field for the app manager to get the active state
|
||||
func (c *ClientStore) IsActive() bool { return c.Active }
|
@ -1,145 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"golang.org/x/text/language"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
_, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
_, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
)
|
||||
|
||||
func marshalValueOrNull(null bool, data any) ([]byte, error) {
|
||||
if null {
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
type NullStringScanner struct{ sql.NullString }
|
||||
|
||||
func (s *NullStringScanner) Null() bool { return !s.Valid }
|
||||
func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) }
|
||||
func (s NullStringScanner) MarshalJSON() ([]byte, error) {
|
||||
return marshalValueOrNull(s.Null(), s.NullString.String)
|
||||
}
|
||||
func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error {
|
||||
if string(bytes) == "null" {
|
||||
return s.Scan(nil)
|
||||
}
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Scan(&a)
|
||||
}
|
||||
func (s NullStringScanner) String() string {
|
||||
if s.Null() {
|
||||
return ""
|
||||
}
|
||||
return s.NullString.String
|
||||
}
|
||||
|
||||
type NullDateScanner struct{ sql.NullTime }
|
||||
|
||||
func (t *NullDateScanner) Null() bool { return !t.Valid }
|
||||
func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) }
|
||||
func (t NullDateScanner) MarshalJSON() ([]byte, error) {
|
||||
return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly))
|
||||
}
|
||||
func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error {
|
||||
if string(bytes) == "null" {
|
||||
return t.Scan(nil)
|
||||
}
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.Scan(&a)
|
||||
}
|
||||
func (t NullDateScanner) String() string {
|
||||
if t.Null() {
|
||||
return ""
|
||||
}
|
||||
return t.NullTime.Time.UTC().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
type LocationScanner struct{ *time.Location }
|
||||
|
||||
func (l *LocationScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
loc, err := time.LoadLocation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Location = loc
|
||||
return nil
|
||||
}
|
||||
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
|
||||
func (l *LocationScanner) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return l.Scan(a)
|
||||
}
|
||||
|
||||
type LocaleScanner struct{ language.Tag }
|
||||
|
||||
func (l *LocaleScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
lang, err := language.Parse(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Tag = lang
|
||||
return nil
|
||||
}
|
||||
func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
|
||||
func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return l.Scan(a)
|
||||
}
|
||||
|
||||
type PronounScanner struct{ pronouns.Pronoun }
|
||||
|
||||
func (p *PronounScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
|
||||
}
|
||||
pro, err := pronouns.FindPronoun(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Pronoun = pro
|
||||
return nil
|
||||
}
|
||||
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
|
||||
func (p *PronounScanner) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.Scan(a)
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func encode(data any) string {
|
||||
j, err := json.Marshal(map[string]any{"value": data})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(j)
|
||||
}
|
||||
|
||||
func TestStringScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}}))
|
||||
}
|
||||
|
||||
func TestDateScanner_MarshalJSON(t *testing.T) {
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{}))
|
||||
}
|
||||
|
||||
func TestLocationScanner_MarshalJSON(t *testing.T) {
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location}))
|
||||
assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC}))
|
||||
}
|
||||
|
||||
func TestLocaleScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish}))
|
||||
assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish}))
|
||||
}
|
||||
|
||||
func TestPronounScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem}))
|
||||
assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim}))
|
||||
assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer}))
|
||||
assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts}))
|
||||
assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes}))
|
||||
}
|
@ -1,127 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"golang.org/x/text/language"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Picture NullStringScanner `json:"picture,omitempty"`
|
||||
Website NullStringScanner `json:"website,omitempty"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Pronouns PronounScanner `json:"pronouns,omitempty"`
|
||||
Birthdate NullDateScanner `json:"birthdate,omitempty"`
|
||||
ZoneInfo LocationScanner `json:"zoneinfo,omitempty"`
|
||||
Locale LocaleScanner `json:"locale,omitempty"`
|
||||
Role UserRole `json:"role"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type UserRole int
|
||||
|
||||
const (
|
||||
RoleMember UserRole = iota
|
||||
RoleAdmin
|
||||
RoleToDelete
|
||||
)
|
||||
|
||||
func (r UserRole) String() string {
|
||||
switch r {
|
||||
case RoleMember:
|
||||
return "Member"
|
||||
case RoleAdmin:
|
||||
return "Admin"
|
||||
case RoleToDelete:
|
||||
return "ToDelete"
|
||||
}
|
||||
return fmt.Sprintf("UserRole{ %d }", r)
|
||||
}
|
||||
|
||||
func (r UserRole) IsValid() bool {
|
||||
return r == RoleMember || r == RoleAdmin
|
||||
}
|
||||
|
||||
type UserPatch struct {
|
||||
Name string
|
||||
Picture string
|
||||
Website string
|
||||
Pronouns pronouns.Pronoun
|
||||
Birthdate sql.NullTime
|
||||
ZoneInfo *time.Location
|
||||
Locale language.Tag
|
||||
}
|
||||
|
||||
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
|
||||
var err error
|
||||
u.Name = v.Get("name")
|
||||
u.Picture = v.Get("picture")
|
||||
u.Website = v.Get("website")
|
||||
if v.Has("reset_pronouns") {
|
||||
u.Pronouns = pronouns.TheyThem
|
||||
} else {
|
||||
u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
|
||||
u.Birthdate = sql.NullTime{}
|
||||
} else {
|
||||
u.Birthdate = sql.NullTime{Valid: true}
|
||||
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_zoneinfo") {
|
||||
u.ZoneInfo = time.UTC
|
||||
} else {
|
||||
u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
|
||||
}
|
||||
}
|
||||
if v.Has("reset_locale") {
|
||||
u.Locale = language.AmericanEnglish
|
||||
} else {
|
||||
u.Locale, err = language.Parse(v.Get("locale"))
|
||||
if err != nil {
|
||||
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type ClientInfoDbOutput struct {
|
||||
Sub, Name, Secret, Domain, Owner string
|
||||
Public, SSO, Active bool
|
||||
}
|
||||
|
||||
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
|
||||
|
||||
func (c *ClientInfoDbOutput) GetID() string { return c.Sub }
|
||||
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
|
||||
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
|
||||
func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public }
|
||||
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
|
||||
|
||||
// GetName is an extra field for the oauth handler to display the application
|
||||
// name
|
||||
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
|
||||
|
||||
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
||||
// this is for trusted applications to get permissions without asking the user
|
||||
func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO }
|
||||
|
||||
// IsActive is an extra field for the app manager to get the active state
|
||||
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }
|
@ -1,5 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
@ -7,7 +7,6 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
const getAppList = `-- name: GetAppList :many
|
||||
@ -29,9 +28,9 @@ type GetAppListRow struct {
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Public sql.NullInt64 `json:"public"`
|
||||
Sso sql.NullInt64 `json:"sso"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
|
||||
@ -66,28 +65,21 @@ func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAp
|
||||
}
|
||||
|
||||
const getClientInfo = `-- name: GetClientInfo :one
|
||||
SELECT secret, name, domain, public, sso, active
|
||||
SELECT subject, name, secret, domain, owner, public, sso, active
|
||||
FROM client_store
|
||||
WHERE subject = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetClientInfoRow struct {
|
||||
Secret string `json:"secret"`
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
Public sql.NullInt64 `json:"public"`
|
||||
Sso sql.NullInt64 `json:"sso"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (GetClientInfoRow, error) {
|
||||
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) {
|
||||
row := q.db.QueryRowContext(ctx, getClientInfo, subject)
|
||||
var i GetClientInfoRow
|
||||
var i ClientStore
|
||||
err := row.Scan(
|
||||
&i.Secret,
|
||||
&i.Subject,
|
||||
&i.Name,
|
||||
&i.Secret,
|
||||
&i.Domain,
|
||||
&i.Owner,
|
||||
&i.Public,
|
||||
&i.Sso,
|
||||
&i.Active,
|
||||
@ -106,9 +98,9 @@ type InsertClientAppParams struct {
|
||||
Secret string `json:"secret"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Public sql.NullInt64 `json:"public"`
|
||||
Sso sql.NullInt64 `json:"sso"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
|
||||
@ -139,9 +131,9 @@ WHERE subject = ?
|
||||
type UpdateClientAppParams struct {
|
||||
Name string `json:"name"`
|
||||
Domain string `json:"domain"`
|
||||
Public sql.NullInt64 `json:"public"`
|
||||
Sso sql.NullInt64 `json:"sso"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
Subject string `json:"subject"`
|
||||
Owner string `json:"owner"`
|
||||
}
|
||||
|
@ -7,7 +7,9 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/1f349/tulip/database/types"
|
||||
)
|
||||
|
||||
const getUserList = `-- name: GetUserList :many
|
||||
@ -28,12 +30,12 @@ type GetUserListRow struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Picture interface{} `json:"picture"`
|
||||
Picture string `json:"picture"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified int64 `json:"email_verified"`
|
||||
Role int64 `json:"role"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Role types.UserRole `json:"role"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
||||
@ -77,8 +79,8 @@ WHERE subject = ?
|
||||
`
|
||||
|
||||
type UpdateUserRoleParams struct {
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Role int64 `json:"role"`
|
||||
Active bool `json:"active"`
|
||||
Role types.UserRole `json:"role"`
|
||||
Subject string `json:"subject"`
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,4 @@
|
||||
DROP TABLE users;
|
||||
DROP INDEX username_index;
|
||||
DROP TABLE client_store;
|
||||
DROP TABLE otp;
|
@ -1,39 +1,39 @@
|
||||
CREATE TABLE IF NOT EXISTS users
|
||||
CREATE TABLE users
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
picture TEXT DEFAULT "" NOT NULL,
|
||||
website TEXT DEFAULT "" NOT NULL,
|
||||
picture TEXT DEFAULT '' NOT NULL,
|
||||
website TEXT DEFAULT '' NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
||||
pronouns TEXT DEFAULT "they/them" NOT NULL,
|
||||
email_verified BOOLEAN DEFAULT 0 NOT NULL,
|
||||
pronouns TEXT DEFAULT 'they/them' NOT NULL,
|
||||
birthdate DATE,
|
||||
zoneinfo TEXT DEFAULT "UTC" NOT NULL,
|
||||
locale TEXT DEFAULT "en-US" NOT NULL,
|
||||
zoneinfo TEXT DEFAULT 'UTC' NOT NULL,
|
||||
locale TEXT DEFAULT 'en-US' NOT NULL,
|
||||
role INTEGER DEFAULT 0 NOT NULL,
|
||||
updated_at DATETIME,
|
||||
registered INTEGER DEFAULT 0,
|
||||
active INTEGER DEFAULT 1
|
||||
updated_at DATETIME NOT NULL,
|
||||
registered DATETIME NOT NULL,
|
||||
active BOOLEAN DEFAULT 1 NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS username_index ON users (username);
|
||||
CREATE UNIQUE INDEX username_index ON users (username);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS client_store
|
||||
CREATE TABLE client_store
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
owner TEXT NOT NULL,
|
||||
public INTEGER,
|
||||
sso INTEGER,
|
||||
active INTEGER DEFAULT 1,
|
||||
public BOOLEAN NOT NULL,
|
||||
sso BOOLEAN NOT NULL,
|
||||
active BOOLEAN DEFAULT 1 NOT NULL,
|
||||
FOREIGN KEY (owner) REFERENCES users (subject)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS otp
|
||||
CREATE TABLE otp
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
secret TEXT NOT NULL,
|
||||
|
@ -6,6 +6,10 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/1f349/tulip/password"
|
||||
)
|
||||
|
||||
type ClientStore struct {
|
||||
@ -14,9 +18,9 @@ type ClientStore struct {
|
||||
Secret string `json:"secret"`
|
||||
Domain string `json:"domain"`
|
||||
Owner string `json:"owner"`
|
||||
Public sql.NullInt64 `json:"public"`
|
||||
Sso sql.NullInt64 `json:"sso"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Public bool `json:"public"`
|
||||
Sso bool `json:"sso"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type Otp struct {
|
||||
@ -29,17 +33,17 @@ type User struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Picture interface{} `json:"picture"`
|
||||
Website interface{} `json:"website"`
|
||||
Password password.HashString `json:"password"`
|
||||
Picture string `json:"picture"`
|
||||
Website string `json:"website"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified int64 `json:"email_verified"`
|
||||
Pronouns interface{} `json:"pronouns"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Pronouns types.UserPronoun `json:"pronouns"`
|
||||
Birthdate sql.NullTime `json:"birthdate"`
|
||||
Zoneinfo interface{} `json:"zoneinfo"`
|
||||
Locale interface{} `json:"locale"`
|
||||
Role int64 `json:"role"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Registered sql.NullInt64 `json:"registered"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
Zoneinfo types.UserZone `json:"zoneinfo"`
|
||||
Locale types.UserLocale `json:"locale"`
|
||||
Role types.UserRole `json:"role"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Registered time.Time `json:"registered"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
47
database/password-wrapper.go
Normal file
47
database/password-wrapper.go
Normal file
@ -0,0 +1,47 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/1f349/tulip/password"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AddUserParams struct {
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Role types.UserRole `json:"role"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) {
|
||||
pwHash, err := password.HashPassword(arg.Password)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
a := addUserParams{
|
||||
Subject: uuid.NewString(),
|
||||
Name: arg.Name,
|
||||
Username: arg.Username,
|
||||
Password: pwHash,
|
||||
Email: arg.Email,
|
||||
EmailVerified: arg.EmailVerified,
|
||||
Role: arg.Role,
|
||||
UpdatedAt: arg.UpdatedAt,
|
||||
Active: arg.Active,
|
||||
}
|
||||
return a.Subject, q.addUser(ctx, a)
|
||||
}
|
||||
|
||||
type CheckLoginRow struct {
|
||||
Subject string `json:"subject"`
|
||||
Password password.HashString `json:"password"`
|
||||
HasTwoFactor bool `json:"hasTwoFactor"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
-- name: GetClientInfo :one
|
||||
SELECT secret, name, domain, public, sso, active
|
||||
SELECT *
|
||||
FROM client_store
|
||||
WHERE subject = ?
|
||||
LIMIT 1;
|
||||
|
@ -13,22 +13,21 @@ WHERE username = ?
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetUser :one
|
||||
SELECT name,
|
||||
username,
|
||||
picture,
|
||||
website,
|
||||
email,
|
||||
email_verified,
|
||||
pronouns,
|
||||
birthdate,
|
||||
zoneinfo,
|
||||
locale,
|
||||
updated_at,
|
||||
active
|
||||
SELECT *
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1;
|
||||
|
||||
-- name: GetUserRole :one
|
||||
SELECT role
|
||||
FROM users
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: GetUserDisplayName :one
|
||||
SELECT name
|
||||
FROM users
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: getUserPassword :one
|
||||
SELECT password
|
||||
FROM users
|
||||
@ -68,3 +67,6 @@ WHERE otp.subject = ?;
|
||||
SELECT secret, digits
|
||||
FROM otp
|
||||
WHERE subject = ?;
|
||||
|
||||
-- name: HasTwoFactor :one
|
||||
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN);
|
||||
|
302
database/tx.go
302
database/tx.go
@ -1,302 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/1f349/tulip/password"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
func updatedAt() string {
|
||||
return time.Now().UTC().Format(time.DateTime)
|
||||
}
|
||||
|
||||
type Tx struct{ tx *sql.Tx }
|
||||
|
||||
func (t *Tx) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *Tx) Rollback() {
|
||||
_ = t.tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *Tx) HasUser() error {
|
||||
var exists bool
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
|
||||
err := row.Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) InsertUser(name, un, pw, email string, verifyEmail bool, role UserRole, active bool) (uuid.UUID, error) {
|
||||
pwHash, err := password.HashPassword(pw)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, err
|
||||
}
|
||||
u := uuid.New()
|
||||
_, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, email_verified, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u, name, un, pwHash, email, verifyEmail, role, updatedAt(), active)
|
||||
return u, err
|
||||
}
|
||||
|
||||
func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) {
|
||||
var u User
|
||||
var pwHash password.HashString
|
||||
var hasOtp, hasVerify bool
|
||||
row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un)
|
||||
err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify)
|
||||
if err != nil {
|
||||
return nil, false, false, err
|
||||
}
|
||||
err = password.CheckPasswordHash(pwHash, pw)
|
||||
return &u, hasOtp, hasVerify, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserDisplayName(sub string) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&u.Name)
|
||||
u.Sub = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserRole(sub string) (UserRole, error) {
|
||||
var r UserRole
|
||||
row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&r)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUser(sub string) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub)
|
||||
err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
|
||||
u.Sub = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserEmail(sub string) (string, error) {
|
||||
var email string
|
||||
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub)
|
||||
err := row.Scan(&email)
|
||||
return email, err
|
||||
}
|
||||
|
||||
func (t *Tx) ChangeUserPassword(sub, pwOld, pwNew string) error {
|
||||
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pwHash password.HashString
|
||||
if q.Next() {
|
||||
err = q.Scan(&pwHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("invalid user")
|
||||
}
|
||||
if err := q.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = password.CheckPasswordHash(pwHash, pwOld)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pwNewHash, err := password.HashPassword(pwNew)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := exec.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("row wasn't updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) ModifyUser(sub string, v *UserPatch) error {
|
||||
exec, err := t.tx.Exec(
|
||||
`UPDATE users
|
||||
SET name = ?,
|
||||
picture = ?,
|
||||
website = ?,
|
||||
pronouns = ?,
|
||||
birthdate = ?,
|
||||
zoneinfo = ?,
|
||||
locale = ?,
|
||||
updated_at = ?
|
||||
WHERE subject = ?`,
|
||||
v.Name,
|
||||
v.Picture,
|
||||
v.Website,
|
||||
v.Pronouns.String(),
|
||||
v.Birthdate,
|
||||
v.ZoneInfo.String(),
|
||||
v.Locale.String(),
|
||||
updatedAt(),
|
||||
sub,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := exec.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("row wasn't updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) SetTwoFactor(sub string, secret string, digits int) error {
|
||||
if secret == "" && digits == 0 {
|
||||
_, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub)
|
||||
return err
|
||||
}
|
||||
_, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub, secret, digits)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) GetTwoFactor(sub string) (string, int, error) {
|
||||
var secret string
|
||||
var digits int
|
||||
row := t.tx.QueryRow(`SELECT secret, digits FROM otp WHERE subject = ?`, sub)
|
||||
err := row.Scan(&secret, &digits)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return secret, digits, nil
|
||||
}
|
||||
|
||||
func (t *Tx) HasTwoFactor(sub string) (bool, error) {
|
||||
var hasOtp bool
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub)
|
||||
err := row.Scan(&hasOtp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return hasOtp, row.Err()
|
||||
}
|
||||
|
||||
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
|
||||
var u ClientInfoDbOutput
|
||||
row := t.tx.QueryRow(`SELECT secret, name, domain, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Public, &u.SSO, &u.Active)
|
||||
u.Owner = sub
|
||||
if !u.Active {
|
||||
return nil, fmt.Errorf("client is not active")
|
||||
}
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) {
|
||||
var u []ClientInfoDbOutput
|
||||
row, err := t.tx.Query(`SELECT subject, name, domain, owner, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer row.Close()
|
||||
for row.Next() {
|
||||
var a ClientInfoDbOutput
|
||||
err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.Public, &a.SSO, &a.Active)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u = append(u, a)
|
||||
}
|
||||
return u, row.Err()
|
||||
}
|
||||
|
||||
func (t *Tx) InsertClientApp(name, domain string, public, sso, active bool, owner string) error {
|
||||
u := uuid.New()
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, public, sso, active)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateClientApp(subject, owner string, name, domain string, public, sso, active bool) error {
|
||||
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, public, sso, active, subject, owner)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) ResetClientAppSecret(subject, owner string) (string, error) {
|
||||
secret, err := password.GenerateApiSecret(70)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject, owner)
|
||||
return secret, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserList(offset int) ([]User, error) {
|
||||
var u []User
|
||||
row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for row.Next() {
|
||||
var a User
|
||||
err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u = append(u, a)
|
||||
}
|
||||
return u, row.Err()
|
||||
}
|
||||
|
||||
func (t *Tx) UpdateUser(subject string, role UserRole, active bool) error {
|
||||
_, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) VerifyUserEmail(sub string) error {
|
||||
_, err := t.tx.Exec(`UPDATE users SET email_verified = 1 WHERE subject = ?`, sub)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) UserResetPassword(sub string, pw string) error {
|
||||
hashPassword, err := password.HashPassword(pw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := exec.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("row wasn't updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
|
||||
err = row.Scan(&exists)
|
||||
return
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/1f349/tulip/password"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTx_ChangeUserPassword(t *testing.T) {
|
||||
u := uuid.New()
|
||||
pw, err := password.HashPassword("test")
|
||||
assert.NoError(t, err)
|
||||
d, err := Open("file::memory:")
|
||||
assert.NoError(t, err)
|
||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
|
||||
assert.NoError(t, err)
|
||||
tx, err := d.Begin()
|
||||
assert.NoError(t, err)
|
||||
err = tx.ChangeUserPassword(u.String(), "test", "new")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tx.Commit())
|
||||
query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, query.Next())
|
||||
var oldPw password.HashString
|
||||
assert.NoError(t, query.Scan(&oldPw))
|
||||
assert.NoError(t, password.CheckPasswordHash(oldPw, "new"))
|
||||
assert.NoError(t, query.Err())
|
||||
assert.NoError(t, query.Close())
|
||||
}
|
||||
|
||||
func TestTx_ModifyUser(t *testing.T) {
|
||||
u := uuid.New()
|
||||
pw, err := password.HashPassword("test")
|
||||
assert.NoError(t, err)
|
||||
d, err := Open("file::memory:")
|
||||
assert.NoError(t, err)
|
||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
|
||||
assert.NoError(t, err)
|
||||
tx, err := d.Begin()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tx.ModifyUser(u.String(), &UserPatch{
|
||||
Name: "example",
|
||||
Pronouns: pronouns.TheyThem,
|
||||
ZoneInfo: time.UTC,
|
||||
Locale: language.AmericanEnglish,
|
||||
}))
|
||||
}
|
38
database/types/userlocale.go
Normal file
38
database/types/userlocale.go
Normal file
@ -0,0 +1,38 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"golang.org/x/text/language"
|
||||
)
|
||||
|
||||
var (
|
||||
_ sql.Scanner = &UserLocale{}
|
||||
_ json.Marshaler = &UserLocale{}
|
||||
_ json.Unmarshaler = &UserLocale{}
|
||||
)
|
||||
|
||||
type UserLocale struct{ language.Tag }
|
||||
|
||||
func (l *UserLocale) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
lang, err := language.Parse(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Tag = lang
|
||||
return nil
|
||||
}
|
||||
func (l UserLocale) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
|
||||
func (l *UserLocale) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return l.Scan(a)
|
||||
}
|
12
database/types/userlocale_test.go
Normal file
12
database/types/userlocale_test.go
Normal file
@ -0,0 +1,12 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserLocale_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "\"en-US\"", encode(UserLocale{language.AmericanEnglish}))
|
||||
assert.Equal(t, "\"en-GB\"", encode(UserLocale{language.BritishEnglish}))
|
||||
}
|
38
database/types/userpronoun.go
Normal file
38
database/types/userpronoun.go
Normal file
@ -0,0 +1,38 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
)
|
||||
|
||||
var (
|
||||
_ sql.Scanner = &UserPronoun{}
|
||||
_ json.Marshaler = &UserPronoun{}
|
||||
_ json.Unmarshaler = &UserPronoun{}
|
||||
)
|
||||
|
||||
type UserPronoun struct{ pronouns.Pronoun }
|
||||
|
||||
func (p *UserPronoun) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
|
||||
}
|
||||
pro, err := pronouns.FindPronoun(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Pronoun = pro
|
||||
return nil
|
||||
}
|
||||
func (p UserPronoun) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
|
||||
func (p *UserPronoun) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.Scan(a)
|
||||
}
|
15
database/types/userpronoun_test.go
Normal file
15
database/types/userpronoun_test.go
Normal file
@ -0,0 +1,15 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserPronoun_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "\"they/them\"", encode(UserPronoun{pronouns.TheyThem}))
|
||||
assert.Equal(t, "\"he/him\"", encode(UserPronoun{pronouns.HeHim}))
|
||||
assert.Equal(t, "\"she/her\"", encode(UserPronoun{pronouns.SheHer}))
|
||||
assert.Equal(t, "\"it/its\"", encode(UserPronoun{pronouns.ItIts}))
|
||||
assert.Equal(t, "\"one/one's\"", encode(UserPronoun{pronouns.OneOnes}))
|
||||
}
|
27
database/types/userrole.go
Normal file
27
database/types/userrole.go
Normal file
@ -0,0 +1,27 @@
|
||||
package types
|
||||
|
||||
import "fmt"
|
||||
|
||||
type UserRole int64
|
||||
|
||||
const (
|
||||
RoleMember UserRole = iota
|
||||
RoleAdmin
|
||||
RoleToDelete
|
||||
)
|
||||
|
||||
func (r UserRole) String() string {
|
||||
switch r {
|
||||
case RoleMember:
|
||||
return "Member"
|
||||
case RoleAdmin:
|
||||
return "Admin"
|
||||
case RoleToDelete:
|
||||
return "ToDelete"
|
||||
}
|
||||
return fmt.Sprintf("UserRole{ %d }", r)
|
||||
}
|
||||
|
||||
func (r UserRole) IsValid() bool {
|
||||
return r == RoleMember || r == RoleAdmin
|
||||
}
|
45
database/types/userzone.go
Normal file
45
database/types/userzone.go
Normal file
@ -0,0 +1,45 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_ sql.Scanner = &UserZone{}
|
||||
_ driver.Valuer = &UserZone{}
|
||||
_ json.Marshaler = &UserZone{}
|
||||
_ json.Unmarshaler = &UserZone{}
|
||||
)
|
||||
|
||||
type UserZone struct{ *time.Location }
|
||||
|
||||
func (l *UserZone) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
loc, err := time.LoadLocation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Location = loc
|
||||
return nil
|
||||
}
|
||||
func (l UserZone) Value() (driver.Value, error) {
|
||||
return l.Location.String(), nil
|
||||
}
|
||||
func (l UserZone) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(l.Location.String())
|
||||
}
|
||||
func (l *UserZone) UnmarshalJSON(bytes []byte) error {
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return l.Scan(a)
|
||||
}
|
14
database/types/userzone_test.go
Normal file
14
database/types/userzone_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserZone_MarshalJSON(t *testing.T) {
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "\"Europe/London\"", encode(UserZone{location}))
|
||||
assert.Equal(t, "\"UTC\"", encode(UserZone{time.UTC}))
|
||||
}
|
11
database/types/utils_test.go
Normal file
11
database/types/utils_test.go
Normal file
@ -0,0 +1,11 @@
|
||||
package types
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
func encode(data any) string {
|
||||
j, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(j)
|
||||
}
|
@ -8,6 +8,10 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/1f349/tulip/password"
|
||||
)
|
||||
|
||||
const deleteTwoFactor = `-- name: DeleteTwoFactor :exec
|
||||
@ -40,44 +44,20 @@ func (q *Queries) GetTwoFactor(ctx context.Context, subject string) (GetTwoFacto
|
||||
}
|
||||
|
||||
const getUser = `-- name: GetUser :one
|
||||
SELECT name,
|
||||
username,
|
||||
picture,
|
||||
website,
|
||||
email,
|
||||
email_verified,
|
||||
pronouns,
|
||||
birthdate,
|
||||
zoneinfo,
|
||||
locale,
|
||||
updated_at,
|
||||
active
|
||||
SELECT subject, name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, registered, active
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
LIMIT 1
|
||||
`
|
||||
|
||||
type GetUserRow struct {
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Picture interface{} `json:"picture"`
|
||||
Website interface{} `json:"website"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified int64 `json:"email_verified"`
|
||||
Pronouns interface{} `json:"pronouns"`
|
||||
Birthdate sql.NullTime `json:"birthdate"`
|
||||
Zoneinfo interface{} `json:"zoneinfo"`
|
||||
Locale interface{} `json:"locale"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, error) {
|
||||
func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUser, subject)
|
||||
var i GetUserRow
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.Subject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Password,
|
||||
&i.Picture,
|
||||
&i.Website,
|
||||
&i.Email,
|
||||
@ -86,12 +66,51 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, erro
|
||||
&i.Birthdate,
|
||||
&i.Zoneinfo,
|
||||
&i.Locale,
|
||||
&i.Role,
|
||||
&i.UpdatedAt,
|
||||
&i.Registered,
|
||||
&i.Active,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getUserDisplayName = `-- name: GetUserDisplayName :one
|
||||
SELECT name
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserDisplayName(ctx context.Context, subject string) (string, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserDisplayName, subject)
|
||||
var name string
|
||||
err := row.Scan(&name)
|
||||
return name, err
|
||||
}
|
||||
|
||||
const getUserRole = `-- name: GetUserRole :one
|
||||
SELECT role
|
||||
FROM users
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetUserRole(ctx context.Context, subject string) (types.UserRole, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserRole, subject)
|
||||
var role types.UserRole
|
||||
err := row.Scan(&role)
|
||||
return role, err
|
||||
}
|
||||
|
||||
const hasTwoFactor = `-- name: HasTwoFactor :one
|
||||
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN)
|
||||
`
|
||||
|
||||
func (q *Queries) HasTwoFactor(ctx context.Context, subject string) (bool, error) {
|
||||
row := q.db.QueryRowContext(ctx, hasTwoFactor, subject)
|
||||
var column_1 bool
|
||||
err := row.Scan(&column_1)
|
||||
return column_1, err
|
||||
}
|
||||
|
||||
const hasUser = `-- name: HasUser :one
|
||||
SELECT cast(count(subject) AS BOOLEAN) AS hasUser
|
||||
FROM users
|
||||
@ -119,13 +138,13 @@ WHERE subject = ?
|
||||
|
||||
type ModifyUserParams struct {
|
||||
Name string `json:"name"`
|
||||
Picture interface{} `json:"picture"`
|
||||
Website interface{} `json:"website"`
|
||||
Pronouns interface{} `json:"pronouns"`
|
||||
Picture string `json:"picture"`
|
||||
Website string `json:"website"`
|
||||
Pronouns types.UserPronoun `json:"pronouns"`
|
||||
Birthdate sql.NullTime `json:"birthdate"`
|
||||
Zoneinfo interface{} `json:"zoneinfo"`
|
||||
Locale interface{} `json:"locale"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Zoneinfo types.UserZone `json:"zoneinfo"`
|
||||
Locale types.UserLocale `json:"locale"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Subject string `json:"subject"`
|
||||
}
|
||||
|
||||
@ -174,12 +193,12 @@ type addUserParams struct {
|
||||
Subject string `json:"subject"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Password password.HashString `json:"password"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified int64 `json:"email_verified"`
|
||||
Role int64 `json:"role"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Active sql.NullInt64 `json:"active"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Role types.UserRole `json:"role"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
||||
@ -206,10 +225,10 @@ WHERE subject = ?
|
||||
`
|
||||
|
||||
type changeUserPasswordParams struct {
|
||||
Password string `json:"password"`
|
||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
||||
Password password.HashString `json:"password"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Subject string `json:"subject"`
|
||||
Password_2 string `json:"password_2"`
|
||||
Password_2 password.HashString `json:"password_2"`
|
||||
}
|
||||
|
||||
func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) {
|
||||
@ -234,10 +253,10 @@ LIMIT 1
|
||||
|
||||
type checkLoginRow struct {
|
||||
Subject string `json:"subject"`
|
||||
Password string `json:"password"`
|
||||
Password password.HashString `json:"password"`
|
||||
Column3 int64 `json:"column_3"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified int64 `json:"email_verified"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
}
|
||||
|
||||
func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) {
|
||||
@ -259,9 +278,9 @@ FROM users
|
||||
WHERE subject = ?
|
||||
`
|
||||
|
||||
func (q *Queries) getUserPassword(ctx context.Context, subject string) (string, error) {
|
||||
func (q *Queries) getUserPassword(ctx context.Context, subject string) (password.HashString, error) {
|
||||
row := q.db.QueryRowContext(ctx, getUserPassword, subject)
|
||||
var password string
|
||||
var password password.HashString
|
||||
err := row.Scan(&password)
|
||||
return password, err
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/mjwt/auth"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -30,14 +31,14 @@ func (u UserAuth) IsGuest() bool {
|
||||
|
||||
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
|
||||
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
var role database.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
role, err = tx.GetUserRole(auth.ID)
|
||||
var role types.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
role, err = tx.GetUserRole(req.Context(), auth.ID)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
}
|
||||
if role != database.RoleAdmin {
|
||||
if role != types.RoleAdmin {
|
||||
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
16
server/db.go
16
server/db.go
@ -9,25 +9,13 @@ import (
|
||||
// DbTx wraps a database transaction with http error messages and a simple action
|
||||
// function. If the action function returns an error the transaction will be
|
||||
// rolled back. If there is no error then the transaction is committed.
|
||||
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
|
||||
tx, err := h.db.Begin()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
err = action(tx)
|
||||
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool {
|
||||
err := action(h.db)
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
log.Println("Database action error:", err)
|
||||
return true
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
log.Println("Database commit error:", err)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -11,12 +11,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (h *HttpServer) EditGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
var user *database.User
|
||||
func (h *HttpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||
var user database.User
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
var err error
|
||||
user, err = tx.GetUser(auth.ID)
|
||||
user, err = tx.GetUser(req.Context(), auth.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read user data: %w", err)
|
||||
}
|
||||
@ -64,8 +64,19 @@ func (h *HttpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httpr
|
||||
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
|
||||
return
|
||||
}
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if err := tx.ModifyUser(auth.ID, &patch); err != nil {
|
||||
m := database.ModifyUserParams{
|
||||
Name: patch.Name,
|
||||
Picture: patch.Picture,
|
||||
Website: patch.Website,
|
||||
Pronouns: patch.Pronouns,
|
||||
Birthdate: patch.Birthdate,
|
||||
Zoneinfo: patch.ZoneInfo,
|
||||
Locale: patch.Locale,
|
||||
UpdatedAt: time.Now(),
|
||||
Subject: auth.ID,
|
||||
}
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
if _, err := tx.ModifyUser(req.Context(), m); err != nil {
|
||||
return fmt.Errorf("failed to modify user info: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/1f349/tulip/pages"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
@ -29,18 +30,19 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
|
||||
return
|
||||
}
|
||||
|
||||
var userWithName *database.User
|
||||
var userWithName string
|
||||
var userRole types.UserRole
|
||||
var hasTwoFactor bool
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
userWithName, err = tx.GetUserDisplayName(auth.ID)
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
userWithName, err = tx.GetUserDisplayName(req.Context(), auth.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user display name: %w", err)
|
||||
}
|
||||
hasTwoFactor, err = tx.HasTwoFactor(auth.ID)
|
||||
hasTwoFactor, err = tx.HasTwoFactor(req.Context(), auth.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user two factor state: %w", err)
|
||||
}
|
||||
userWithName.Role, err = tx.GetUserRole(auth.ID)
|
||||
userRole, err = tx.GetUserRole(req.Context(), auth.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user role: %w", err)
|
||||
}
|
||||
@ -54,6 +56,6 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
|
||||
"User": userWithName,
|
||||
"Nonce": lNonce,
|
||||
"OtpEnabled": hasTwoFactor,
|
||||
"IsAdmin": userWithName.Role,
|
||||
"IsAdmin": userRole == types.RoleAdmin,
|
||||
})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/1f349/mjwt"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
@ -9,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) {
|
||||
func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) {
|
||||
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
|
||||
scope := ti.GetScope()
|
||||
if containsScope(scope, "openid") {
|
||||
@ -30,18 +31,13 @@ type IdTokenClaims struct{}
|
||||
func (a IdTokenClaims) Valid() error { return nil }
|
||||
func (a IdTokenClaims) Type() string { return "access-token" }
|
||||
|
||||
func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) {
|
||||
tx, err := us.Begin()
|
||||
func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) {
|
||||
user, err := us.GetUser(context.Background(), ti.GetUserID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
user, err := tx.GetUser(ti.GetUserID())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
|
||||
token, err = key.GenerateJwt(user.Subject, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ func (h *HttpServer) LoginPost(rw http.ResponseWriter, req *http.Request, _ http
|
||||
var loginMismatch byte
|
||||
var hasOtp bool
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
@ -176,7 +176,7 @@ func (h *HttpServer) LoginResetPasswordPost(rw http.ResponseWriter, req *http.Re
|
||||
}
|
||||
|
||||
var emailExists bool
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
emailExists, err = tx.UserEmailExists(email)
|
||||
return err
|
||||
}) {
|
||||
|
@ -18,7 +18,7 @@ func (h *HttpServer) MailVerify(rw http.ResponseWriter, _ *http.Request, params
|
||||
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.VerifyUserEmail(userSub)
|
||||
}) {
|
||||
return
|
||||
@ -75,7 +75,7 @@ func (h *HttpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request,
|
||||
h.mailLinkCache.Delete(k)
|
||||
|
||||
// reset password database call
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.UserResetPassword(userSub, pw)
|
||||
}) {
|
||||
return
|
||||
@ -94,12 +94,12 @@ func (h *HttpServer) MailDelete(rw http.ResponseWriter, _ *http.Request, params
|
||||
return
|
||||
}
|
||||
var userInfo *database.User
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
userInfo, err = tx.GetUser(userSub)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return tx.UpdateUser(userSub, database.RoleToDelete, false)
|
||||
return tx.UpdateUser(userSub, types.RoleToDelete, false)
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
@ -22,14 +22,14 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
||||
}
|
||||
}
|
||||
|
||||
var role database.UserRole
|
||||
var appList []database.ClientInfoDbOutput
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
var role types.UserRole
|
||||
var appList []database.ClientStore
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
role, err = tx.GetUserRole(auth.ID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
appList, err = tx.GetAppList(auth.ID, role == database.RoleAdmin, offset)
|
||||
appList, err = tx.GetAppList(auth.ID, role == types.RoleAdmin, offset)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
@ -39,7 +39,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
||||
"ServiceName": h.conf.ServiceName,
|
||||
"Apps": appList,
|
||||
"Offset": offset,
|
||||
"IsAdmin": role == database.RoleAdmin,
|
||||
"IsAdmin": role == types.RoleAdmin,
|
||||
"NewAppName": q.Get("NewAppName"),
|
||||
"NewAppSecret": q.Get("NewAppSecret"),
|
||||
}
|
||||
@ -76,14 +76,14 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
||||
active := req.Form.Has("active")
|
||||
|
||||
if sso {
|
||||
var role database.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
var role types.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
role, err = tx.GetUserRole(auth.ID)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
}
|
||||
if role != database.RoleAdmin {
|
||||
if role != types.RoleAdmin {
|
||||
http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -91,13 +91,13 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
||||
|
||||
switch action {
|
||||
case "create":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.InsertClientApp(name, domain, public, sso, active, auth.ID)
|
||||
}) {
|
||||
return
|
||||
}
|
||||
case "edit":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active)
|
||||
}) {
|
||||
return
|
||||
@ -105,7 +105,7 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
||||
case "secret":
|
||||
var info oauth2.ClientInfo
|
||||
var secret string
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
sub := req.Form.Get("subject")
|
||||
info, err = tx.GetClientInfo(sub)
|
||||
if err != nil {
|
||||
|
@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/1f349/tulip/database/types"
|
||||
"github.com/1f349/tulip/pages"
|
||||
"github.com/emersion/go-message/mail"
|
||||
"github.com/google/uuid"
|
||||
@ -27,9 +28,9 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
|
||||
}
|
||||
}
|
||||
|
||||
var role database.UserRole
|
||||
var role types.UserRole
|
||||
var userList []database.User
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
role, err = tx.GetUserRole(auth.ID)
|
||||
if err != nil {
|
||||
return
|
||||
@ -39,7 +40,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
|
||||
}) {
|
||||
return
|
||||
}
|
||||
if role != database.RoleAdmin {
|
||||
if role != types.RoleAdmin {
|
||||
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@ -76,14 +77,14 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
var role database.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
var role types.UserRole
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
role, err = tx.GetUserRole(auth.ID)
|
||||
return
|
||||
}) {
|
||||
return
|
||||
}
|
||||
if role != database.RoleAdmin {
|
||||
if role != types.RoleAdmin {
|
||||
http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -116,7 +117,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
addrDomain := address.Address[n+1:]
|
||||
|
||||
var userSub uuid.UUID
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active)
|
||||
return err
|
||||
}) {
|
||||
@ -136,7 +137,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
case "edit":
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
sub := req.Form.Get("subject")
|
||||
return tx.UpdateUser(sub, newRole, active)
|
||||
}) {
|
||||
@ -151,12 +152,12 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
||||
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func parseRoleValue(role string) (database.UserRole, error) {
|
||||
func parseRoleValue(role string) (types.UserRole, error) {
|
||||
switch role {
|
||||
case "member":
|
||||
return database.RoleMember, nil
|
||||
return types.RoleMember, nil
|
||||
case "admin":
|
||||
return database.RoleAdmin, nil
|
||||
return types.RoleAdmin, nil
|
||||
}
|
||||
return 0, errors.New("invalid role value")
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
|
||||
|
||||
var user *database.User
|
||||
var hasOtp bool
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
user, err = tx.GetUserDisplayName(auth.ID)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -47,7 +47,7 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code strin
|
||||
var hasOtp bool
|
||||
var secret string
|
||||
var digits int
|
||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
hasOtp, err = tx.HasTwoFactor(sub)
|
||||
if err != nil {
|
||||
return
|
||||
@ -86,7 +86,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
||||
return
|
||||
}
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.SetTwoFactor(auth.ID, "", 0)
|
||||
}) {
|
||||
return
|
||||
@ -118,7 +118,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
||||
if secret == "" {
|
||||
// get user email
|
||||
var email string
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
var err error
|
||||
email, err = tx.GetUserEmail(auth.ID)
|
||||
return err
|
||||
@ -167,7 +167,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
||||
return
|
||||
}
|
||||
|
||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
||||
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||
return tx.SetTwoFactor(auth.ID, secret, digits)
|
||||
}) {
|
||||
return
|
||||
|
@ -31,7 +31,7 @@ type HttpServer struct {
|
||||
r *httprouter.Router
|
||||
oauthSrv *server.Server
|
||||
oauthMgr *manage.Manager
|
||||
db *database.DB
|
||||
db *database.Queries
|
||||
conf Conf
|
||||
signingKey mjwt.Signer
|
||||
|
||||
@ -50,7 +50,7 @@ type mailLinkKey struct {
|
||||
data string
|
||||
}
|
||||
|
||||
func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server {
|
||||
func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server {
|
||||
r := httprouter.New()
|
||||
|
||||
// remove last slash from baseUrl
|
||||
@ -191,10 +191,10 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
|
||||
return
|
||||
}
|
||||
|
||||
var userData *database.User
|
||||
var userData database.GetUserRow
|
||||
|
||||
if hs.DbTx(rw, func(tx *database.Tx) (err error) {
|
||||
userData, err = tx.GetUser(userId)
|
||||
if hs.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||
userData, err = tx.GetUser(req.Context(), userId)
|
||||
return err
|
||||
}) {
|
||||
return
|
||||
|
14
sqlc.yaml
14
sqlc.yaml
@ -9,7 +9,13 @@ sql:
|
||||
out: "database"
|
||||
emit_json_tags: true
|
||||
overrides:
|
||||
- column: "routes.flags"
|
||||
go_type: "github.com/1f349/violet/target.Flags"
|
||||
- column: "redirects.flags"
|
||||
go_type: "github.com/1f349/violet/target.Flags"
|
||||
- column: "users.password"
|
||||
go_type: "github.com/1f349/tulip/password.HashString"
|
||||
- column: "users.role"
|
||||
go_type: "github.com/1f349/tulip/database/types.UserRole"
|
||||
- column: "users.pronouns"
|
||||
go_type: "github.com/1f349/tulip/database/types.UserPronoun"
|
||||
- column: "users.zoneinfo"
|
||||
go_type: "github.com/1f349/tulip/database/types.UserZone"
|
||||
- column: "users.locale"
|
||||
go_type: "github.com/1f349/tulip/database/types.UserLocale"
|
||||
|
Loading…
Reference in New Issue
Block a user