mirror of
https://github.com/1f349/tulip.git
synced 2024-12-22 08:14:13 +00:00
Correct refactored database calls
This commit is contained in:
parent
fbd49da2db
commit
37570e2157
@ -7,20 +7,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ClientStore struct {
|
type ClientStore struct {
|
||||||
db *database.DB
|
db *database.Queries
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ oauth2.ClientStore = &ClientStore{}
|
var _ oauth2.ClientStore = &ClientStore{}
|
||||||
|
|
||||||
func New(db *database.DB) *ClientStore {
|
func New(db *database.Queries) *ClientStore {
|
||||||
return &ClientStore{db: db}
|
return &ClientStore{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
|
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
|
||||||
tx, err := c.db.BeginCtx(ctx)
|
a, err := c.db.GetClientInfo(ctx, id)
|
||||||
if err != nil {
|
return &a, err
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
return tx.GetClientInfo(id)
|
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ func genHmacKey() []byte {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkDbHasUser(db *database.DB) error {
|
func checkDbHasUser(db *database.Queries) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start transaction: %w", err)
|
return fmt.Errorf("failed to start transaction: %w", err)
|
||||||
@ -126,7 +126,7 @@ func checkDbHasUser(db *database.DB) error {
|
|||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if err := tx.HasUser(); err != nil {
|
if err := tx.HasUser(); err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, database.RoleAdmin, false)
|
_, err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost", false, types.RoleAdmin, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to add user: %w", err)
|
return fmt.Errorf("failed to add user: %w", err)
|
||||||
}
|
}
|
||||||
|
82
database/clientstore-wrapper.go
Normal file
82
database/clientstore-wrapper.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
|
"github.com/MrMelon54/pronouns"
|
||||||
|
"github.com/go-oauth2/oauth2/v4"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserPatch struct {
|
||||||
|
Name string
|
||||||
|
Picture string
|
||||||
|
Website string
|
||||||
|
Pronouns types.UserPronoun
|
||||||
|
Birthdate sql.NullTime
|
||||||
|
ZoneInfo types.UserZone
|
||||||
|
Locale types.UserLocale
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
|
||||||
|
var err error
|
||||||
|
u.Name = v.Get("name")
|
||||||
|
u.Picture = v.Get("picture")
|
||||||
|
u.Website = v.Get("website")
|
||||||
|
if v.Has("reset_pronouns") {
|
||||||
|
u.Pronouns.Pronoun = pronouns.TheyThem
|
||||||
|
} else {
|
||||||
|
u.Pronouns.Pronoun, err = pronouns.FindPronoun(v.Get("pronouns"))
|
||||||
|
if err != nil {
|
||||||
|
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
|
||||||
|
u.Birthdate = sql.NullTime{}
|
||||||
|
} else {
|
||||||
|
u.Birthdate = sql.NullTime{Valid: true}
|
||||||
|
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
|
||||||
|
if err != nil {
|
||||||
|
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v.Has("reset_zoneinfo") {
|
||||||
|
u.ZoneInfo.Location = time.UTC
|
||||||
|
} else {
|
||||||
|
u.ZoneInfo.Location, err = time.LoadLocation(v.Get("zoneinfo"))
|
||||||
|
if err != nil {
|
||||||
|
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v.Has("reset_locale") {
|
||||||
|
u.Locale.Tag = language.AmericanEnglish
|
||||||
|
} else {
|
||||||
|
u.Locale.Tag, err = language.Parse(v.Get("locale"))
|
||||||
|
if err != nil {
|
||||||
|
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ oauth2.ClientInfo = &ClientStore{}
|
||||||
|
|
||||||
|
func (c *ClientStore) GetID() string { return c.Subject }
|
||||||
|
func (c *ClientStore) GetSecret() string { return c.Secret }
|
||||||
|
func (c *ClientStore) GetDomain() string { return c.Domain }
|
||||||
|
func (c *ClientStore) IsPublic() bool { return c.Public }
|
||||||
|
func (c *ClientStore) GetUserID() string { return c.Owner }
|
||||||
|
|
||||||
|
// GetName is an extra field for the oauth handler to display the application
|
||||||
|
// name
|
||||||
|
func (c *ClientStore) GetName() string { return c.Name }
|
||||||
|
|
||||||
|
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
||||||
|
// this is for trusted applications to get permissions without asking the user
|
||||||
|
func (c *ClientStore) IsSSO() bool { return c.Sso }
|
||||||
|
|
||||||
|
// IsActive is an extra field for the app manager to get the active state
|
||||||
|
func (c *ClientStore) IsActive() bool { return c.Active }
|
@ -1,145 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/MrMelon54/pronouns"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
_, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
|
||||||
_, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
|
||||||
_, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func marshalValueOrNull(null bool, data any) ([]byte, error) {
|
|
||||||
if null {
|
|
||||||
return json.Marshal(nil)
|
|
||||||
}
|
|
||||||
return json.Marshal(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
type NullStringScanner struct{ sql.NullString }
|
|
||||||
|
|
||||||
func (s *NullStringScanner) Null() bool { return !s.Valid }
|
|
||||||
func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) }
|
|
||||||
func (s NullStringScanner) MarshalJSON() ([]byte, error) {
|
|
||||||
return marshalValueOrNull(s.Null(), s.NullString.String)
|
|
||||||
}
|
|
||||||
func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error {
|
|
||||||
if string(bytes) == "null" {
|
|
||||||
return s.Scan(nil)
|
|
||||||
}
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(bytes, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.Scan(&a)
|
|
||||||
}
|
|
||||||
func (s NullStringScanner) String() string {
|
|
||||||
if s.Null() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return s.NullString.String
|
|
||||||
}
|
|
||||||
|
|
||||||
type NullDateScanner struct{ sql.NullTime }
|
|
||||||
|
|
||||||
func (t *NullDateScanner) Null() bool { return !t.Valid }
|
|
||||||
func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) }
|
|
||||||
func (t NullDateScanner) MarshalJSON() ([]byte, error) {
|
|
||||||
return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly))
|
|
||||||
}
|
|
||||||
func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error {
|
|
||||||
if string(bytes) == "null" {
|
|
||||||
return t.Scan(nil)
|
|
||||||
}
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(bytes, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return t.Scan(&a)
|
|
||||||
}
|
|
||||||
func (t NullDateScanner) String() string {
|
|
||||||
if t.Null() {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return t.NullTime.Time.UTC().Format(time.DateOnly)
|
|
||||||
}
|
|
||||||
|
|
||||||
type LocationScanner struct{ *time.Location }
|
|
||||||
|
|
||||||
func (l *LocationScanner) Scan(src any) error {
|
|
||||||
s, ok := src.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
|
||||||
}
|
|
||||||
loc, err := time.LoadLocation(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
l.Location = loc
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
|
|
||||||
func (l *LocationScanner) UnmarshalJSON(bytes []byte) error {
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(bytes, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return l.Scan(a)
|
|
||||||
}
|
|
||||||
|
|
||||||
type LocaleScanner struct{ language.Tag }
|
|
||||||
|
|
||||||
func (l *LocaleScanner) Scan(src any) error {
|
|
||||||
s, ok := src.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
|
||||||
}
|
|
||||||
lang, err := language.Parse(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
l.Tag = lang
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
|
|
||||||
func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error {
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(bytes, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return l.Scan(a)
|
|
||||||
}
|
|
||||||
|
|
||||||
type PronounScanner struct{ pronouns.Pronoun }
|
|
||||||
|
|
||||||
func (p *PronounScanner) Scan(src any) error {
|
|
||||||
s, ok := src.(string)
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
|
|
||||||
}
|
|
||||||
pro, err := pronouns.FindPronoun(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.Pronoun = pro
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
|
|
||||||
func (p *PronounScanner) UnmarshalJSON(bytes []byte) error {
|
|
||||||
var a string
|
|
||||||
err := json.Unmarshal(bytes, &a)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.Scan(a)
|
|
||||||
}
|
|
@ -1,52 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"github.com/MrMelon54/pronouns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func encode(data any) string {
|
|
||||||
j, err := json.Marshal(map[string]any{"value": data})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return string(j)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStringScanner_MarshalJSON(t *testing.T) {
|
|
||||||
assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}}))
|
|
||||||
assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDateScanner_MarshalJSON(t *testing.T) {
|
|
||||||
location, err := time.LoadLocation("Europe/London")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}}))
|
|
||||||
assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}}))
|
|
||||||
assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLocationScanner_MarshalJSON(t *testing.T) {
|
|
||||||
location, err := time.LoadLocation("Europe/London")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location}))
|
|
||||||
assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLocaleScanner_MarshalJSON(t *testing.T) {
|
|
||||||
assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish}))
|
|
||||||
assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPronounScanner_MarshalJSON(t *testing.T) {
|
|
||||||
assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem}))
|
|
||||||
assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim}))
|
|
||||||
assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer}))
|
|
||||||
assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts}))
|
|
||||||
assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes}))
|
|
||||||
}
|
|
@ -1,127 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"github.com/MrMelon54/pronouns"
|
|
||||||
"github.com/go-oauth2/oauth2/v4"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
"net/url"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Sub string `json:"sub"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Picture NullStringScanner `json:"picture,omitempty"`
|
|
||||||
Website NullStringScanner `json:"website,omitempty"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
EmailVerified bool `json:"email_verified"`
|
|
||||||
Pronouns PronounScanner `json:"pronouns,omitempty"`
|
|
||||||
Birthdate NullDateScanner `json:"birthdate,omitempty"`
|
|
||||||
ZoneInfo LocationScanner `json:"zoneinfo,omitempty"`
|
|
||||||
Locale LocaleScanner `json:"locale,omitempty"`
|
|
||||||
Role UserRole `json:"role"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
Active bool `json:"active"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserRole int
|
|
||||||
|
|
||||||
const (
|
|
||||||
RoleMember UserRole = iota
|
|
||||||
RoleAdmin
|
|
||||||
RoleToDelete
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r UserRole) String() string {
|
|
||||||
switch r {
|
|
||||||
case RoleMember:
|
|
||||||
return "Member"
|
|
||||||
case RoleAdmin:
|
|
||||||
return "Admin"
|
|
||||||
case RoleToDelete:
|
|
||||||
return "ToDelete"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("UserRole{ %d }", r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r UserRole) IsValid() bool {
|
|
||||||
return r == RoleMember || r == RoleAdmin
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserPatch struct {
|
|
||||||
Name string
|
|
||||||
Picture string
|
|
||||||
Website string
|
|
||||||
Pronouns pronouns.Pronoun
|
|
||||||
Birthdate sql.NullTime
|
|
||||||
ZoneInfo *time.Location
|
|
||||||
Locale language.Tag
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) {
|
|
||||||
var err error
|
|
||||||
u.Name = v.Get("name")
|
|
||||||
u.Picture = v.Get("picture")
|
|
||||||
u.Website = v.Get("website")
|
|
||||||
if v.Has("reset_pronouns") {
|
|
||||||
u.Pronouns = pronouns.TheyThem
|
|
||||||
} else {
|
|
||||||
u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns"))
|
|
||||||
if err != nil {
|
|
||||||
safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v.Has("reset_birthdate") || v.Get("birthdate") == "" {
|
|
||||||
u.Birthdate = sql.NullTime{}
|
|
||||||
} else {
|
|
||||||
u.Birthdate = sql.NullTime{Valid: true}
|
|
||||||
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
|
|
||||||
if err != nil {
|
|
||||||
safeErrs = append(safeErrs, fmt.Errorf("invalid time selected"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v.Has("reset_zoneinfo") {
|
|
||||||
u.ZoneInfo = time.UTC
|
|
||||||
} else {
|
|
||||||
u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo"))
|
|
||||||
if err != nil {
|
|
||||||
safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if v.Has("reset_locale") {
|
|
||||||
u.Locale = language.AmericanEnglish
|
|
||||||
} else {
|
|
||||||
u.Locale, err = language.Parse(v.Get("locale"))
|
|
||||||
if err != nil {
|
|
||||||
safeErrs = append(safeErrs, fmt.Errorf("invalid language selected"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClientInfoDbOutput struct {
|
|
||||||
Sub, Name, Secret, Domain, Owner string
|
|
||||||
Public, SSO, Active bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ oauth2.ClientInfo = &ClientInfoDbOutput{}
|
|
||||||
|
|
||||||
func (c *ClientInfoDbOutput) GetID() string { return c.Sub }
|
|
||||||
func (c *ClientInfoDbOutput) GetSecret() string { return c.Secret }
|
|
||||||
func (c *ClientInfoDbOutput) GetDomain() string { return c.Domain }
|
|
||||||
func (c *ClientInfoDbOutput) IsPublic() bool { return c.Public }
|
|
||||||
func (c *ClientInfoDbOutput) GetUserID() string { return c.Owner }
|
|
||||||
|
|
||||||
// GetName is an extra field for the oauth handler to display the application
|
|
||||||
// name
|
|
||||||
func (c *ClientInfoDbOutput) GetName() string { return c.Name }
|
|
||||||
|
|
||||||
// IsSSO is an extra field for the oauth handler to skip the user input stage
|
|
||||||
// this is for trusted applications to get permissions without asking the user
|
|
||||||
func (c *ClientInfoDbOutput) IsSSO() bool { return c.SSO }
|
|
||||||
|
|
||||||
// IsActive is an extra field for the app manager to get the active state
|
|
||||||
func (c *ClientInfoDbOutput) IsActive() bool { return c.Active }
|
|
@ -1,5 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
@ -7,7 +7,6 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const getAppList = `-- name: GetAppList :many
|
const getAppList = `-- name: GetAppList :many
|
||||||
@ -25,13 +24,13 @@ type GetAppListParams struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GetAppListRow struct {
|
type GetAppListRow struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
Public sql.NullInt64 `json:"public"`
|
Public bool `json:"public"`
|
||||||
Sso sql.NullInt64 `json:"sso"`
|
Sso bool `json:"sso"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
|
func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAppListRow, error) {
|
||||||
@ -66,28 +65,21 @@ func (q *Queries) GetAppList(ctx context.Context, arg GetAppListParams) ([]GetAp
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getClientInfo = `-- name: GetClientInfo :one
|
const getClientInfo = `-- name: GetClientInfo :one
|
||||||
SELECT secret, name, domain, public, sso, active
|
SELECT subject, name, secret, domain, owner, public, sso, active
|
||||||
FROM client_store
|
FROM client_store
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetClientInfoRow struct {
|
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (ClientStore, error) {
|
||||||
Secret string `json:"secret"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Domain string `json:"domain"`
|
|
||||||
Public sql.NullInt64 `json:"public"`
|
|
||||||
Sso sql.NullInt64 `json:"sso"`
|
|
||||||
Active sql.NullInt64 `json:"active"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetClientInfo(ctx context.Context, subject string) (GetClientInfoRow, error) {
|
|
||||||
row := q.db.QueryRowContext(ctx, getClientInfo, subject)
|
row := q.db.QueryRowContext(ctx, getClientInfo, subject)
|
||||||
var i GetClientInfoRow
|
var i ClientStore
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
&i.Secret,
|
&i.Subject,
|
||||||
&i.Name,
|
&i.Name,
|
||||||
|
&i.Secret,
|
||||||
&i.Domain,
|
&i.Domain,
|
||||||
|
&i.Owner,
|
||||||
&i.Public,
|
&i.Public,
|
||||||
&i.Sso,
|
&i.Sso,
|
||||||
&i.Active,
|
&i.Active,
|
||||||
@ -101,14 +93,14 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|||||||
`
|
`
|
||||||
|
|
||||||
type InsertClientAppParams struct {
|
type InsertClientAppParams struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
Public sql.NullInt64 `json:"public"`
|
Public bool `json:"public"`
|
||||||
Sso sql.NullInt64 `json:"sso"`
|
Sso bool `json:"sso"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
|
func (q *Queries) InsertClientApp(ctx context.Context, arg InsertClientAppParams) error {
|
||||||
@ -137,13 +129,13 @@ WHERE subject = ?
|
|||||||
`
|
`
|
||||||
|
|
||||||
type UpdateClientAppParams struct {
|
type UpdateClientAppParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
Public sql.NullInt64 `json:"public"`
|
Public bool `json:"public"`
|
||||||
Sso sql.NullInt64 `json:"sso"`
|
Sso bool `json:"sso"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error {
|
func (q *Queries) UpdateClientApp(ctx context.Context, arg UpdateClientAppParams) error {
|
||||||
|
@ -7,7 +7,9 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const getUserList = `-- name: GetUserList :many
|
const getUserList = `-- name: GetUserList :many
|
||||||
@ -25,15 +27,15 @@ LIMIT 25 OFFSET ?
|
|||||||
`
|
`
|
||||||
|
|
||||||
type GetUserListRow struct {
|
type GetUserListRow struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Picture interface{} `json:"picture"`
|
Picture string `json:"picture"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified int64 `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
Role int64 `json:"role"`
|
Role types.UserRole `json:"role"`
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
func (q *Queries) GetUserList(ctx context.Context, offset int64) ([]GetUserListRow, error) {
|
||||||
@ -77,9 +79,9 @@ WHERE subject = ?
|
|||||||
`
|
`
|
||||||
|
|
||||||
type UpdateUserRoleParams struct {
|
type UpdateUserRoleParams struct {
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
Role int64 `json:"role"`
|
Role types.UserRole `json:"role"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error {
|
func (q *Queries) UpdateUserRole(ctx context.Context, arg UpdateUserRoleParams) error {
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
DROP TABLE users;
|
||||||
|
DROP INDEX username_index;
|
||||||
|
DROP TABLE client_store;
|
||||||
|
DROP TABLE otp;
|
@ -1,39 +1,39 @@
|
|||||||
CREATE TABLE IF NOT EXISTS users
|
CREATE TABLE users
|
||||||
(
|
(
|
||||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
username TEXT UNIQUE NOT NULL,
|
username TEXT UNIQUE NOT NULL,
|
||||||
password TEXT NOT NULL,
|
password TEXT NOT NULL,
|
||||||
picture TEXT DEFAULT "" NOT NULL,
|
picture TEXT DEFAULT '' NOT NULL,
|
||||||
website TEXT DEFAULT "" NOT NULL,
|
website TEXT DEFAULT '' NOT NULL,
|
||||||
email TEXT NOT NULL,
|
email TEXT NOT NULL,
|
||||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
email_verified BOOLEAN DEFAULT 0 NOT NULL,
|
||||||
pronouns TEXT DEFAULT "they/them" NOT NULL,
|
pronouns TEXT DEFAULT 'they/them' NOT NULL,
|
||||||
birthdate DATE,
|
birthdate DATE,
|
||||||
zoneinfo TEXT DEFAULT "UTC" NOT NULL,
|
zoneinfo TEXT DEFAULT 'UTC' NOT NULL,
|
||||||
locale TEXT DEFAULT "en-US" NOT NULL,
|
locale TEXT DEFAULT 'en-US' NOT NULL,
|
||||||
role INTEGER DEFAULT 0 NOT NULL,
|
role INTEGER DEFAULT 0 NOT NULL,
|
||||||
updated_at DATETIME,
|
updated_at DATETIME NOT NULL,
|
||||||
registered INTEGER DEFAULT 0,
|
registered DATETIME NOT NULL,
|
||||||
active INTEGER DEFAULT 1
|
active BOOLEAN DEFAULT 1 NOT NULL
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS username_index ON users (username);
|
CREATE UNIQUE INDEX username_index ON users (username);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS client_store
|
CREATE TABLE client_store
|
||||||
(
|
(
|
||||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
secret TEXT UNIQUE NOT NULL,
|
secret TEXT UNIQUE NOT NULL,
|
||||||
domain TEXT NOT NULL,
|
domain TEXT NOT NULL,
|
||||||
owner TEXT NOT NULL,
|
owner TEXT NOT NULL,
|
||||||
public INTEGER,
|
public BOOLEAN NOT NULL,
|
||||||
sso INTEGER,
|
sso BOOLEAN NOT NULL,
|
||||||
active INTEGER DEFAULT 1,
|
active BOOLEAN DEFAULT 1 NOT NULL,
|
||||||
FOREIGN KEY (owner) REFERENCES users (subject)
|
FOREIGN KEY (owner) REFERENCES users (subject)
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS otp
|
CREATE TABLE otp
|
||||||
(
|
(
|
||||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||||
secret TEXT NOT NULL,
|
secret TEXT NOT NULL,
|
||||||
|
@ -6,17 +6,21 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
|
"github.com/1f349/tulip/password"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ClientStore struct {
|
type ClientStore struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
Owner string `json:"owner"`
|
Owner string `json:"owner"`
|
||||||
Public sql.NullInt64 `json:"public"`
|
Public bool `json:"public"`
|
||||||
Sso sql.NullInt64 `json:"sso"`
|
Sso bool `json:"sso"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Otp struct {
|
type Otp struct {
|
||||||
@ -26,20 +30,20 @@ type Otp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
Picture interface{} `json:"picture"`
|
Picture string `json:"picture"`
|
||||||
Website interface{} `json:"website"`
|
Website string `json:"website"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified int64 `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
Pronouns interface{} `json:"pronouns"`
|
Pronouns types.UserPronoun `json:"pronouns"`
|
||||||
Birthdate sql.NullTime `json:"birthdate"`
|
Birthdate sql.NullTime `json:"birthdate"`
|
||||||
Zoneinfo interface{} `json:"zoneinfo"`
|
Zoneinfo types.UserZone `json:"zoneinfo"`
|
||||||
Locale interface{} `json:"locale"`
|
Locale types.UserLocale `json:"locale"`
|
||||||
Role int64 `json:"role"`
|
Role types.UserRole `json:"role"`
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Registered sql.NullInt64 `json:"registered"`
|
Registered time.Time `json:"registered"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
47
database/password-wrapper.go
Normal file
47
database/password-wrapper.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
|
"github.com/1f349/tulip/password"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AddUserParams struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Role types.UserRole `json:"role"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
Active bool `json:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Queries) AddUser(ctx context.Context, arg AddUserParams) (string, error) {
|
||||||
|
pwHash, err := password.HashPassword(arg.Password)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
a := addUserParams{
|
||||||
|
Subject: uuid.NewString(),
|
||||||
|
Name: arg.Name,
|
||||||
|
Username: arg.Username,
|
||||||
|
Password: pwHash,
|
||||||
|
Email: arg.Email,
|
||||||
|
EmailVerified: arg.EmailVerified,
|
||||||
|
Role: arg.Role,
|
||||||
|
UpdatedAt: arg.UpdatedAt,
|
||||||
|
Active: arg.Active,
|
||||||
|
}
|
||||||
|
return a.Subject, q.addUser(ctx, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
type CheckLoginRow struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
Password password.HashString `json:"password"`
|
||||||
|
HasTwoFactor bool `json:"hasTwoFactor"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
}
|
@ -1,5 +1,5 @@
|
|||||||
-- name: GetClientInfo :one
|
-- name: GetClientInfo :one
|
||||||
SELECT secret, name, domain, public, sso, active
|
SELECT *
|
||||||
FROM client_store
|
FROM client_store
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
@ -13,22 +13,21 @@ WHERE username = ?
|
|||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
|
||||||
-- name: GetUser :one
|
-- name: GetUser :one
|
||||||
SELECT name,
|
SELECT *
|
||||||
username,
|
|
||||||
picture,
|
|
||||||
website,
|
|
||||||
email,
|
|
||||||
email_verified,
|
|
||||||
pronouns,
|
|
||||||
birthdate,
|
|
||||||
zoneinfo,
|
|
||||||
locale,
|
|
||||||
updated_at,
|
|
||||||
active
|
|
||||||
FROM users
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1;
|
LIMIT 1;
|
||||||
|
|
||||||
|
-- name: GetUserRole :one
|
||||||
|
SELECT role
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: GetUserDisplayName :one
|
||||||
|
SELECT name
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?;
|
||||||
|
|
||||||
-- name: getUserPassword :one
|
-- name: getUserPassword :one
|
||||||
SELECT password
|
SELECT password
|
||||||
FROM users
|
FROM users
|
||||||
@ -68,3 +67,6 @@ WHERE otp.subject = ?;
|
|||||||
SELECT secret, digits
|
SELECT secret, digits
|
||||||
FROM otp
|
FROM otp
|
||||||
WHERE subject = ?;
|
WHERE subject = ?;
|
||||||
|
|
||||||
|
-- name: HasTwoFactor :one
|
||||||
|
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN);
|
||||||
|
302
database/tx.go
302
database/tx.go
@ -1,302 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"github.com/1f349/tulip/password"
|
|
||||||
"github.com/go-oauth2/oauth2/v4"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func updatedAt() string {
|
|
||||||
return time.Now().UTC().Format(time.DateTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tx struct{ tx *sql.Tx }
|
|
||||||
|
|
||||||
func (t *Tx) Commit() error {
|
|
||||||
return t.tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) Rollback() {
|
|
||||||
_ = t.tx.Rollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) HasUser() error {
|
|
||||||
var exists bool
|
|
||||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
|
|
||||||
err := row.Scan(&exists)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
return sql.ErrNoRows
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) InsertUser(name, un, pw, email string, verifyEmail bool, role UserRole, active bool) (uuid.UUID, error) {
|
|
||||||
pwHash, err := password.HashPassword(pw)
|
|
||||||
if err != nil {
|
|
||||||
return uuid.UUID{}, err
|
|
||||||
}
|
|
||||||
u := uuid.New()
|
|
||||||
_, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email, email_verified, role, updated_at, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, u, name, un, pwHash, email, verifyEmail, role, updatedAt(), active)
|
|
||||||
return u, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) CheckLogin(un, pw string) (*User, bool, bool, error) {
|
|
||||||
var u User
|
|
||||||
var pwHash password.HashString
|
|
||||||
var hasOtp, hasVerify bool
|
|
||||||
row := t.tx.QueryRow(`SELECT subject, password, EXISTS(SELECT 1 FROM otp WHERE otp.subject = users.subject), email, email_verified FROM users WHERE username = ?`, un)
|
|
||||||
err := row.Scan(&u.Sub, &pwHash, &hasOtp, &u.Email, &hasVerify)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, false, err
|
|
||||||
}
|
|
||||||
err = password.CheckPasswordHash(pwHash, pw)
|
|
||||||
return &u, hasOtp, hasVerify, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetUserDisplayName(sub string) (*User, error) {
|
|
||||||
var u User
|
|
||||||
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub)
|
|
||||||
err := row.Scan(&u.Name)
|
|
||||||
u.Sub = sub
|
|
||||||
return &u, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetUserRole(sub string) (UserRole, error) {
|
|
||||||
var r UserRole
|
|
||||||
row := t.tx.QueryRow(`SELECT role FROM users WHERE subject = ? LIMIT 1`, sub)
|
|
||||||
err := row.Scan(&r)
|
|
||||||
return r, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetUser(sub string) (*User, error) {
|
|
||||||
var u User
|
|
||||||
row := t.tx.QueryRow(`SELECT name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ?`, sub)
|
|
||||||
err := row.Scan(&u.Name, &u.Username, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
|
|
||||||
u.Sub = sub
|
|
||||||
return &u, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetUserEmail(sub string) (string, error) {
|
|
||||||
var email string
|
|
||||||
row := t.tx.QueryRow(`SELECT email FROM users WHERE subject = ?`, sub)
|
|
||||||
err := row.Scan(&email)
|
|
||||||
return email, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) ChangeUserPassword(sub, pwOld, pwNew string) error {
|
|
||||||
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var pwHash password.HashString
|
|
||||||
if q.Next() {
|
|
||||||
err = q.Scan(&pwHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("invalid user")
|
|
||||||
}
|
|
||||||
if err := q.Err(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := q.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = password.CheckPasswordHash(pwHash, pwOld)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pwNewHash, err := password.HashPassword(pwNew)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, updatedAt(), sub, pwHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
affected, err := exec.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected != 1 {
|
|
||||||
return fmt.Errorf("row wasn't updated")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) ModifyUser(sub string, v *UserPatch) error {
|
|
||||||
exec, err := t.tx.Exec(
|
|
||||||
`UPDATE users
|
|
||||||
SET name = ?,
|
|
||||||
picture = ?,
|
|
||||||
website = ?,
|
|
||||||
pronouns = ?,
|
|
||||||
birthdate = ?,
|
|
||||||
zoneinfo = ?,
|
|
||||||
locale = ?,
|
|
||||||
updated_at = ?
|
|
||||||
WHERE subject = ?`,
|
|
||||||
v.Name,
|
|
||||||
v.Picture,
|
|
||||||
v.Website,
|
|
||||||
v.Pronouns.String(),
|
|
||||||
v.Birthdate,
|
|
||||||
v.ZoneInfo.String(),
|
|
||||||
v.Locale.String(),
|
|
||||||
updatedAt(),
|
|
||||||
sub,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
affected, err := exec.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected != 1 {
|
|
||||||
return fmt.Errorf("row wasn't updated")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) SetTwoFactor(sub string, secret string, digits int) error {
|
|
||||||
if secret == "" && digits == 0 {
|
|
||||||
_, err := t.tx.Exec(`DELETE FROM otp WHERE otp.subject = ?`, sub)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err := t.tx.Exec(`INSERT INTO otp(subject, secret, digits) VALUES (?, ?, ?) ON CONFLICT(subject) DO UPDATE SET secret = excluded.secret, digits = excluded.digits`, sub, secret, digits)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetTwoFactor(sub string) (string, int, error) {
|
|
||||||
var secret string
|
|
||||||
var digits int
|
|
||||||
row := t.tx.QueryRow(`SELECT secret, digits FROM otp WHERE subject = ?`, sub)
|
|
||||||
err := row.Scan(&secret, &digits)
|
|
||||||
if err != nil {
|
|
||||||
return "", 0, err
|
|
||||||
}
|
|
||||||
return secret, digits, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) HasTwoFactor(sub string) (bool, error) {
|
|
||||||
var hasOtp bool
|
|
||||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM otp WHERE otp.subject = ?)`, sub)
|
|
||||||
err := row.Scan(&hasOtp)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return hasOtp, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
|
|
||||||
var u ClientInfoDbOutput
|
|
||||||
row := t.tx.QueryRow(`SELECT secret, name, domain, public, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
|
|
||||||
err := row.Scan(&u.Secret, &u.Name, &u.Domain, &u.Public, &u.SSO, &u.Active)
|
|
||||||
u.Owner = sub
|
|
||||||
if !u.Active {
|
|
||||||
return nil, fmt.Errorf("client is not active")
|
|
||||||
}
|
|
||||||
return &u, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetAppList(owner string, admin bool, offset int) ([]ClientInfoDbOutput, error) {
|
|
||||||
var u []ClientInfoDbOutput
|
|
||||||
row, err := t.tx.Query(`SELECT subject, name, domain, owner, public, sso, active FROM client_store WHERE owner = ? OR ? = 1 LIMIT 25 OFFSET ?`, owner, admin, offset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer row.Close()
|
|
||||||
for row.Next() {
|
|
||||||
var a ClientInfoDbOutput
|
|
||||||
err := row.Scan(&a.Sub, &a.Name, &a.Domain, &a.Owner, &a.Public, &a.SSO, &a.Active)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u = append(u, a)
|
|
||||||
}
|
|
||||||
return u, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) InsertClientApp(name, domain string, public, sso, active bool, owner string) error {
|
|
||||||
u := uuid.New()
|
|
||||||
secret, err := password.GenerateApiSecret(70)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = t.tx.Exec(`INSERT INTO client_store (subject, name, secret, domain, owner, public, sso, active) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, u.String(), name, secret, domain, owner, public, sso, active)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) UpdateClientApp(subject, owner string, name, domain string, public, sso, active bool) error {
|
|
||||||
_, err := t.tx.Exec(`UPDATE client_store SET name = ?, domain = ?, public = ?, sso = ?, active = ? WHERE subject = ? AND owner = ?`, name, domain, public, sso, active, subject, owner)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) ResetClientAppSecret(subject, owner string) (string, error) {
|
|
||||||
secret, err := password.GenerateApiSecret(70)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
_, err = t.tx.Exec(`UPDATE client_store SET secret = ? WHERE subject = ? AND owner = ?`, secret, subject, owner)
|
|
||||||
return secret, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) GetUserList(offset int) ([]User, error) {
|
|
||||||
var u []User
|
|
||||||
row, err := t.tx.Query(`SELECT subject, name, username, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, active FROM users LIMIT 25 OFFSET ?`, offset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
for row.Next() {
|
|
||||||
var a User
|
|
||||||
err := row.Scan(&a.Sub, &a.Name, &a.Username, &a.Picture, &a.Website, &a.Email, &a.EmailVerified, &a.Pronouns, &a.Birthdate, &a.ZoneInfo, &a.Locale, &a.Role, &a.UpdatedAt, &a.Active)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
u = append(u, a)
|
|
||||||
}
|
|
||||||
return u, row.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) UpdateUser(subject string, role UserRole, active bool) error {
|
|
||||||
_, err := t.tx.Exec(`UPDATE users SET active = ?, role = ? WHERE subject = ?`, active, role, subject)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) VerifyUserEmail(sub string) error {
|
|
||||||
_, err := t.tx.Exec(`UPDATE users SET email_verified = 1 WHERE subject = ?`, sub)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) UserResetPassword(sub string, pw string) error {
|
|
||||||
hashPassword, err := password.HashPassword(pw)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ?`, hashPassword, updatedAt(), sub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
affected, err := exec.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if affected != 1 {
|
|
||||||
return fmt.Errorf("row wasn't updated")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Tx) UserEmailExists(email string) (exists bool, err error) {
|
|
||||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users WHERE email = ? and email_verified = 1)`, email)
|
|
||||||
err = row.Scan(&exists)
|
|
||||||
return
|
|
||||||
}
|
|
@ -1,52 +0,0 @@
|
|||||||
package database
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/1f349/tulip/password"
|
|
||||||
"github.com/MrMelon54/pronouns"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestTx_ChangeUserPassword(t *testing.T) {
|
|
||||||
u := uuid.New()
|
|
||||||
pw, err := password.HashPassword("test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
d, err := Open("file::memory:")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
tx, err := d.Begin()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = tx.ChangeUserPassword(u.String(), "test", "new")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, tx.Commit())
|
|
||||||
query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.True(t, query.Next())
|
|
||||||
var oldPw password.HashString
|
|
||||||
assert.NoError(t, query.Scan(&oldPw))
|
|
||||||
assert.NoError(t, password.CheckPasswordHash(oldPw, "new"))
|
|
||||||
assert.NoError(t, query.Err())
|
|
||||||
assert.NoError(t, query.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTx_ModifyUser(t *testing.T) {
|
|
||||||
u := uuid.New()
|
|
||||||
pw, err := password.HashPassword("test")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
d, err := Open("file::memory:")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost", updatedAt())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
tx, err := d.Begin()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NoError(t, tx.ModifyUser(u.String(), &UserPatch{
|
|
||||||
Name: "example",
|
|
||||||
Pronouns: pronouns.TheyThem,
|
|
||||||
ZoneInfo: time.UTC,
|
|
||||||
Locale: language.AmericanEnglish,
|
|
||||||
}))
|
|
||||||
}
|
|
38
database/types/userlocale.go
Normal file
38
database/types/userlocale.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sql.Scanner = &UserLocale{}
|
||||||
|
_ json.Marshaler = &UserLocale{}
|
||||||
|
_ json.Unmarshaler = &UserLocale{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserLocale struct{ language.Tag }
|
||||||
|
|
||||||
|
func (l *UserLocale) Scan(src any) error {
|
||||||
|
s, ok := src.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||||
|
}
|
||||||
|
lang, err := language.Parse(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
l.Tag = lang
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (l UserLocale) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
|
||||||
|
func (l *UserLocale) UnmarshalJSON(bytes []byte) error {
|
||||||
|
var a string
|
||||||
|
err := json.Unmarshal(bytes, &a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return l.Scan(a)
|
||||||
|
}
|
12
database/types/userlocale_test.go
Normal file
12
database/types/userlocale_test.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserLocale_MarshalJSON(t *testing.T) {
|
||||||
|
assert.Equal(t, "\"en-US\"", encode(UserLocale{language.AmericanEnglish}))
|
||||||
|
assert.Equal(t, "\"en-GB\"", encode(UserLocale{language.BritishEnglish}))
|
||||||
|
}
|
38
database/types/userpronoun.go
Normal file
38
database/types/userpronoun.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/MrMelon54/pronouns"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sql.Scanner = &UserPronoun{}
|
||||||
|
_ json.Marshaler = &UserPronoun{}
|
||||||
|
_ json.Unmarshaler = &UserPronoun{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserPronoun struct{ pronouns.Pronoun }
|
||||||
|
|
||||||
|
func (p *UserPronoun) Scan(src any) error {
|
||||||
|
s, ok := src.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
|
||||||
|
}
|
||||||
|
pro, err := pronouns.FindPronoun(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.Pronoun = pro
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (p UserPronoun) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
|
||||||
|
func (p *UserPronoun) UnmarshalJSON(bytes []byte) error {
|
||||||
|
var a string
|
||||||
|
err := json.Unmarshal(bytes, &a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.Scan(a)
|
||||||
|
}
|
15
database/types/userpronoun_test.go
Normal file
15
database/types/userpronoun_test.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/MrMelon54/pronouns"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserPronoun_MarshalJSON(t *testing.T) {
|
||||||
|
assert.Equal(t, "\"they/them\"", encode(UserPronoun{pronouns.TheyThem}))
|
||||||
|
assert.Equal(t, "\"he/him\"", encode(UserPronoun{pronouns.HeHim}))
|
||||||
|
assert.Equal(t, "\"she/her\"", encode(UserPronoun{pronouns.SheHer}))
|
||||||
|
assert.Equal(t, "\"it/its\"", encode(UserPronoun{pronouns.ItIts}))
|
||||||
|
assert.Equal(t, "\"one/one's\"", encode(UserPronoun{pronouns.OneOnes}))
|
||||||
|
}
|
27
database/types/userrole.go
Normal file
27
database/types/userrole.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type UserRole int64
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleMember UserRole = iota
|
||||||
|
RoleAdmin
|
||||||
|
RoleToDelete
|
||||||
|
)
|
||||||
|
|
||||||
|
func (r UserRole) String() string {
|
||||||
|
switch r {
|
||||||
|
case RoleMember:
|
||||||
|
return "Member"
|
||||||
|
case RoleAdmin:
|
||||||
|
return "Admin"
|
||||||
|
case RoleToDelete:
|
||||||
|
return "ToDelete"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("UserRole{ %d }", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r UserRole) IsValid() bool {
|
||||||
|
return r == RoleMember || r == RoleAdmin
|
||||||
|
}
|
45
database/types/userzone.go
Normal file
45
database/types/userzone.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ sql.Scanner = &UserZone{}
|
||||||
|
_ driver.Valuer = &UserZone{}
|
||||||
|
_ json.Marshaler = &UserZone{}
|
||||||
|
_ json.Unmarshaler = &UserZone{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserZone struct{ *time.Location }
|
||||||
|
|
||||||
|
func (l *UserZone) Scan(src any) error {
|
||||||
|
s, ok := src.(string)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||||
|
}
|
||||||
|
loc, err := time.LoadLocation(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
l.Location = loc
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (l UserZone) Value() (driver.Value, error) {
|
||||||
|
return l.Location.String(), nil
|
||||||
|
}
|
||||||
|
func (l UserZone) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(l.Location.String())
|
||||||
|
}
|
||||||
|
func (l *UserZone) UnmarshalJSON(bytes []byte) error {
|
||||||
|
var a string
|
||||||
|
err := json.Unmarshal(bytes, &a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return l.Scan(a)
|
||||||
|
}
|
14
database/types/userzone_test.go
Normal file
14
database/types/userzone_test.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserZone_MarshalJSON(t *testing.T) {
|
||||||
|
location, err := time.LoadLocation("Europe/London")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "\"Europe/London\"", encode(UserZone{location}))
|
||||||
|
assert.Equal(t, "\"UTC\"", encode(UserZone{time.UTC}))
|
||||||
|
}
|
11
database/types/utils_test.go
Normal file
11
database/types/utils_test.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
func encode(data any) string {
|
||||||
|
j, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(j)
|
||||||
|
}
|
@ -8,6 +8,10 @@ package database
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
|
"github.com/1f349/tulip/password"
|
||||||
)
|
)
|
||||||
|
|
||||||
const deleteTwoFactor = `-- name: DeleteTwoFactor :exec
|
const deleteTwoFactor = `-- name: DeleteTwoFactor :exec
|
||||||
@ -40,44 +44,20 @@ func (q *Queries) GetTwoFactor(ctx context.Context, subject string) (GetTwoFacto
|
|||||||
}
|
}
|
||||||
|
|
||||||
const getUser = `-- name: GetUser :one
|
const getUser = `-- name: GetUser :one
|
||||||
SELECT name,
|
SELECT subject, name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, role, updated_at, registered, active
|
||||||
username,
|
|
||||||
picture,
|
|
||||||
website,
|
|
||||||
email,
|
|
||||||
email_verified,
|
|
||||||
pronouns,
|
|
||||||
birthdate,
|
|
||||||
zoneinfo,
|
|
||||||
locale,
|
|
||||||
updated_at,
|
|
||||||
active
|
|
||||||
FROM users
|
FROM users
|
||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetUserRow struct {
|
func (q *Queries) GetUser(ctx context.Context, subject string) (User, error) {
|
||||||
Name string `json:"name"`
|
|
||||||
Username string `json:"username"`
|
|
||||||
Picture interface{} `json:"picture"`
|
|
||||||
Website interface{} `json:"website"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
EmailVerified int64 `json:"email_verified"`
|
|
||||||
Pronouns interface{} `json:"pronouns"`
|
|
||||||
Birthdate sql.NullTime `json:"birthdate"`
|
|
||||||
Zoneinfo interface{} `json:"zoneinfo"`
|
|
||||||
Locale interface{} `json:"locale"`
|
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
|
||||||
Active sql.NullInt64 `json:"active"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, error) {
|
|
||||||
row := q.db.QueryRowContext(ctx, getUser, subject)
|
row := q.db.QueryRowContext(ctx, getUser, subject)
|
||||||
var i GetUserRow
|
var i User
|
||||||
err := row.Scan(
|
err := row.Scan(
|
||||||
|
&i.Subject,
|
||||||
&i.Name,
|
&i.Name,
|
||||||
&i.Username,
|
&i.Username,
|
||||||
|
&i.Password,
|
||||||
&i.Picture,
|
&i.Picture,
|
||||||
&i.Website,
|
&i.Website,
|
||||||
&i.Email,
|
&i.Email,
|
||||||
@ -86,12 +66,51 @@ func (q *Queries) GetUser(ctx context.Context, subject string) (GetUserRow, erro
|
|||||||
&i.Birthdate,
|
&i.Birthdate,
|
||||||
&i.Zoneinfo,
|
&i.Zoneinfo,
|
||||||
&i.Locale,
|
&i.Locale,
|
||||||
|
&i.Role,
|
||||||
&i.UpdatedAt,
|
&i.UpdatedAt,
|
||||||
|
&i.Registered,
|
||||||
&i.Active,
|
&i.Active,
|
||||||
)
|
)
|
||||||
return i, err
|
return i, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const getUserDisplayName = `-- name: GetUserDisplayName :one
|
||||||
|
SELECT name
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetUserDisplayName(ctx context.Context, subject string) (string, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getUserDisplayName, subject)
|
||||||
|
var name string
|
||||||
|
err := row.Scan(&name)
|
||||||
|
return name, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const getUserRole = `-- name: GetUserRole :one
|
||||||
|
SELECT role
|
||||||
|
FROM users
|
||||||
|
WHERE subject = ?
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) GetUserRole(ctx context.Context, subject string) (types.UserRole, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, getUserRole, subject)
|
||||||
|
var role types.UserRole
|
||||||
|
err := row.Scan(&role)
|
||||||
|
return role, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const hasTwoFactor = `-- name: HasTwoFactor :one
|
||||||
|
SELECT cast(EXISTS(SELECT 1 FROM otp WHERE subject = ?) AS BOOLEAN)
|
||||||
|
`
|
||||||
|
|
||||||
|
func (q *Queries) HasTwoFactor(ctx context.Context, subject string) (bool, error) {
|
||||||
|
row := q.db.QueryRowContext(ctx, hasTwoFactor, subject)
|
||||||
|
var column_1 bool
|
||||||
|
err := row.Scan(&column_1)
|
||||||
|
return column_1, err
|
||||||
|
}
|
||||||
|
|
||||||
const hasUser = `-- name: HasUser :one
|
const hasUser = `-- name: HasUser :one
|
||||||
SELECT cast(count(subject) AS BOOLEAN) AS hasUser
|
SELECT cast(count(subject) AS BOOLEAN) AS hasUser
|
||||||
FROM users
|
FROM users
|
||||||
@ -118,15 +137,15 @@ WHERE subject = ?
|
|||||||
`
|
`
|
||||||
|
|
||||||
type ModifyUserParams struct {
|
type ModifyUserParams struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Picture interface{} `json:"picture"`
|
Picture string `json:"picture"`
|
||||||
Website interface{} `json:"website"`
|
Website string `json:"website"`
|
||||||
Pronouns interface{} `json:"pronouns"`
|
Pronouns types.UserPronoun `json:"pronouns"`
|
||||||
Birthdate sql.NullTime `json:"birthdate"`
|
Birthdate sql.NullTime `json:"birthdate"`
|
||||||
Zoneinfo interface{} `json:"zoneinfo"`
|
Zoneinfo types.UserZone `json:"zoneinfo"`
|
||||||
Locale interface{} `json:"locale"`
|
Locale types.UserLocale `json:"locale"`
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) {
|
func (q *Queries) ModifyUser(ctx context.Context, arg ModifyUserParams) (int64, error) {
|
||||||
@ -171,15 +190,15 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|||||||
`
|
`
|
||||||
|
|
||||||
type addUserParams struct {
|
type addUserParams struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified int64 `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
Role int64 `json:"role"`
|
Role types.UserRole `json:"role"`
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Active sql.NullInt64 `json:"active"`
|
Active bool `json:"active"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
func (q *Queries) addUser(ctx context.Context, arg addUserParams) error {
|
||||||
@ -206,10 +225,10 @@ WHERE subject = ?
|
|||||||
`
|
`
|
||||||
|
|
||||||
type changeUserPasswordParams struct {
|
type changeUserPasswordParams struct {
|
||||||
Password string `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
UpdatedAt sql.NullTime `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Password_2 string `json:"password_2"`
|
Password_2 password.HashString `json:"password_2"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) {
|
func (q *Queries) changeUserPassword(ctx context.Context, arg changeUserPasswordParams) (int64, error) {
|
||||||
@ -233,11 +252,11 @@ LIMIT 1
|
|||||||
`
|
`
|
||||||
|
|
||||||
type checkLoginRow struct {
|
type checkLoginRow struct {
|
||||||
Subject string `json:"subject"`
|
Subject string `json:"subject"`
|
||||||
Password string `json:"password"`
|
Password password.HashString `json:"password"`
|
||||||
Column3 int64 `json:"column_3"`
|
Column3 int64 `json:"column_3"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
EmailVerified int64 `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) {
|
func (q *Queries) checkLogin(ctx context.Context, username string) (checkLoginRow, error) {
|
||||||
@ -259,9 +278,9 @@ FROM users
|
|||||||
WHERE subject = ?
|
WHERE subject = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
func (q *Queries) getUserPassword(ctx context.Context, subject string) (string, error) {
|
func (q *Queries) getUserPassword(ctx context.Context, subject string) (password.HashString, error) {
|
||||||
row := q.db.QueryRowContext(ctx, getUserPassword, subject)
|
row := q.db.QueryRowContext(ctx, getUserPassword, subject)
|
||||||
var password string
|
var password password.HashString
|
||||||
err := row.Scan(&password)
|
err := row.Scan(&password)
|
||||||
return password, err
|
return password, err
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/mjwt/auth"
|
"github.com/1f349/mjwt/auth"
|
||||||
"github.com/1f349/tulip/database"
|
"github.com/1f349/tulip/database"
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -30,14 +31,14 @@ func (u UserAuth) IsGuest() bool {
|
|||||||
|
|
||||||
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
|
func (h *HttpServer) RequireAdminAuthentication(next UserHandler) httprouter.Handle {
|
||||||
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
return h.RequireAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||||
var role database.UserRole
|
var role types.UserRole
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
role, err = tx.GetUserRole(auth.ID)
|
role, err = tx.GetUserRole(req.Context(), auth.ID)
|
||||||
return
|
return
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if role != database.RoleAdmin {
|
if role != types.RoleAdmin {
|
||||||
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
16
server/db.go
16
server/db.go
@ -9,25 +9,13 @@ import (
|
|||||||
// DbTx wraps a database transaction with http error messages and a simple action
|
// DbTx wraps a database transaction with http error messages and a simple action
|
||||||
// function. If the action function returns an error the transaction will be
|
// function. If the action function returns an error the transaction will be
|
||||||
// rolled back. If there is no error then the transaction is committed.
|
// rolled back. If there is no error then the transaction is committed.
|
||||||
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
|
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(db *database.Queries) error) bool {
|
||||||
tx, err := h.db.Begin()
|
err := action(h.db)
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
err = action(tx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||||
log.Println("Database action error:", err)
|
log.Println("Database action error:", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
|
||||||
log.Println("Database commit error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -11,12 +11,12 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *HttpServer) EditGet(rw http.ResponseWriter, _ *http.Request, _ httprouter.Params, auth UserAuth) {
|
func (h *HttpServer) EditGet(rw http.ResponseWriter, req *http.Request, _ httprouter.Params, auth UserAuth) {
|
||||||
var user *database.User
|
var user database.User
|
||||||
|
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
var err error
|
var err error
|
||||||
user, err = tx.GetUser(auth.ID)
|
user, err = tx.GetUser(req.Context(), auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to read user data: %w", err)
|
return fmt.Errorf("failed to read user data: %w", err)
|
||||||
}
|
}
|
||||||
@ -64,8 +64,19 @@ func (h *HttpServer) EditPost(rw http.ResponseWriter, req *http.Request, _ httpr
|
|||||||
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
|
_, _ = fmt.Fprintln(rw, "</body>\n</html>")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
m := database.ModifyUserParams{
|
||||||
if err := tx.ModifyUser(auth.ID, &patch); err != nil {
|
Name: patch.Name,
|
||||||
|
Picture: patch.Picture,
|
||||||
|
Website: patch.Website,
|
||||||
|
Pronouns: patch.Pronouns,
|
||||||
|
Birthdate: patch.Birthdate,
|
||||||
|
Zoneinfo: patch.ZoneInfo,
|
||||||
|
Locale: patch.Locale,
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
Subject: auth.ID,
|
||||||
|
}
|
||||||
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
|
if _, err := tx.ModifyUser(req.Context(), m); err != nil {
|
||||||
return fmt.Errorf("failed to modify user info: %w", err)
|
return fmt.Errorf("failed to modify user info: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/1f349/tulip/database"
|
"github.com/1f349/tulip/database"
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
"github.com/1f349/tulip/pages"
|
"github.com/1f349/tulip/pages"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
@ -29,18 +30,19 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var userWithName *database.User
|
var userWithName string
|
||||||
|
var userRole types.UserRole
|
||||||
var hasTwoFactor bool
|
var hasTwoFactor bool
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
userWithName, err = tx.GetUserDisplayName(auth.ID)
|
userWithName, err = tx.GetUserDisplayName(req.Context(), auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user display name: %w", err)
|
return fmt.Errorf("failed to get user display name: %w", err)
|
||||||
}
|
}
|
||||||
hasTwoFactor, err = tx.HasTwoFactor(auth.ID)
|
hasTwoFactor, err = tx.HasTwoFactor(req.Context(), auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user two factor state: %w", err)
|
return fmt.Errorf("failed to get user two factor state: %w", err)
|
||||||
}
|
}
|
||||||
userWithName.Role, err = tx.GetUserRole(auth.ID)
|
userRole, err = tx.GetUserRole(req.Context(), auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get user role: %w", err)
|
return fmt.Errorf("failed to get user role: %w", err)
|
||||||
}
|
}
|
||||||
@ -54,6 +56,6 @@ func (h *HttpServer) Home(rw http.ResponseWriter, req *http.Request, _ httproute
|
|||||||
"User": userWithName,
|
"User": userWithName,
|
||||||
"Nonce": lNonce,
|
"Nonce": lNonce,
|
||||||
"OtpEnabled": hasTwoFactor,
|
"OtpEnabled": hasTwoFactor,
|
||||||
"IsAdmin": userWithName.Role,
|
"IsAdmin": userRole == types.RoleAdmin,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"github.com/1f349/mjwt"
|
"github.com/1f349/mjwt"
|
||||||
"github.com/1f349/tulip/database"
|
"github.com/1f349/tulip/database"
|
||||||
"github.com/go-oauth2/oauth2/v4"
|
"github.com/go-oauth2/oauth2/v4"
|
||||||
@ -9,7 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func addIdTokenSupport(srv *server.Server, db *database.DB, key mjwt.Signer) {
|
func addIdTokenSupport(srv *server.Server, db *database.Queries, key mjwt.Signer) {
|
||||||
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
|
srv.SetExtensionFieldsHandler(func(ti oauth2.TokenInfo) (fieldsValue map[string]interface{}) {
|
||||||
scope := ti.GetScope()
|
scope := ti.GetScope()
|
||||||
if containsScope(scope, "openid") {
|
if containsScope(scope, "openid") {
|
||||||
@ -30,18 +31,13 @@ type IdTokenClaims struct{}
|
|||||||
func (a IdTokenClaims) Valid() error { return nil }
|
func (a IdTokenClaims) Valid() error { return nil }
|
||||||
func (a IdTokenClaims) Type() string { return "access-token" }
|
func (a IdTokenClaims) Type() string { return "access-token" }
|
||||||
|
|
||||||
func generateIDToken(ti oauth2.TokenInfo, us *database.DB, key mjwt.Signer) (token string, err error) {
|
func generateIDToken(ti oauth2.TokenInfo, us *database.Queries, key mjwt.Signer) (token string, err error) {
|
||||||
tx, err := us.Begin()
|
user, err := us.GetUser(context.Background(), ti.GetUserID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
user, err := tx.GetUser(ti.GetUserID())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
tx.Rollback()
|
|
||||||
|
|
||||||
token, err = key.GenerateJwt(user.Sub, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
|
token, err = key.GenerateJwt(user.Subject, "", jwt.ClaimStrings{ti.GetClientID()}, ti.GetAccessExpiresIn(), IdTokenClaims{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ func (h *HttpServer) LoginPost(rw http.ResponseWriter, req *http.Request, _ http
|
|||||||
var loginMismatch byte
|
var loginMismatch byte
|
||||||
var hasOtp bool
|
var hasOtp bool
|
||||||
|
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw)
|
loginUser, hasOtpRaw, hasVerifiedEmail, err := tx.CheckLogin(un, pw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
if errors.Is(err, sql.ErrNoRows) || errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||||
@ -176,7 +176,7 @@ func (h *HttpServer) LoginResetPasswordPost(rw http.ResponseWriter, req *http.Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
var emailExists bool
|
var emailExists bool
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
emailExists, err = tx.UserEmailExists(email)
|
emailExists, err = tx.UserEmailExists(email)
|
||||||
return err
|
return err
|
||||||
}) {
|
}) {
|
||||||
|
@ -18,7 +18,7 @@ func (h *HttpServer) MailVerify(rw http.ResponseWriter, _ *http.Request, params
|
|||||||
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
|
http.Error(rw, "Invalid email verification code", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.VerifyUserEmail(userSub)
|
return tx.VerifyUserEmail(userSub)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -75,7 +75,7 @@ func (h *HttpServer) MailPasswordPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
h.mailLinkCache.Delete(k)
|
h.mailLinkCache.Delete(k)
|
||||||
|
|
||||||
// reset password database call
|
// reset password database call
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.UserResetPassword(userSub, pw)
|
return tx.UserResetPassword(userSub, pw)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -94,12 +94,12 @@ func (h *HttpServer) MailDelete(rw http.ResponseWriter, _ *http.Request, params
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var userInfo *database.User
|
var userInfo *database.User
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
userInfo, err = tx.GetUser(userSub)
|
userInfo, err = tx.GetUser(userSub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return tx.UpdateUser(userSub, database.RoleToDelete, false)
|
return tx.UpdateUser(userSub, types.RoleToDelete, false)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -22,14 +22,14 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var role database.UserRole
|
var role types.UserRole
|
||||||
var appList []database.ClientInfoDbOutput
|
var appList []database.ClientStore
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
role, err = tx.GetUserRole(auth.ID)
|
role, err = tx.GetUserRole(auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
appList, err = tx.GetAppList(auth.ID, role == database.RoleAdmin, offset)
|
appList, err = tx.GetAppList(auth.ID, role == types.RoleAdmin, offset)
|
||||||
return
|
return
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -39,7 +39,7 @@ func (h *HttpServer) ManageAppsGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
"ServiceName": h.conf.ServiceName,
|
"ServiceName": h.conf.ServiceName,
|
||||||
"Apps": appList,
|
"Apps": appList,
|
||||||
"Offset": offset,
|
"Offset": offset,
|
||||||
"IsAdmin": role == database.RoleAdmin,
|
"IsAdmin": role == types.RoleAdmin,
|
||||||
"NewAppName": q.Get("NewAppName"),
|
"NewAppName": q.Get("NewAppName"),
|
||||||
"NewAppSecret": q.Get("NewAppSecret"),
|
"NewAppSecret": q.Get("NewAppSecret"),
|
||||||
}
|
}
|
||||||
@ -76,14 +76,14 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
active := req.Form.Has("active")
|
active := req.Form.Has("active")
|
||||||
|
|
||||||
if sso {
|
if sso {
|
||||||
var role database.UserRole
|
var role types.UserRole
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
role, err = tx.GetUserRole(auth.ID)
|
role, err = tx.GetUserRole(auth.ID)
|
||||||
return
|
return
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if role != database.RoleAdmin {
|
if role != types.RoleAdmin {
|
||||||
http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest)
|
http.Error(rw, "400 Bad Request: Only admin users can create SSO client applications", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -91,13 +91,13 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case "create":
|
case "create":
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.InsertClientApp(name, domain, public, sso, active, auth.ID)
|
return tx.InsertClientApp(name, domain, public, sso, active, auth.ID)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "edit":
|
case "edit":
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active)
|
return tx.UpdateClientApp(req.Form.Get("subject"), auth.ID, name, domain, public, sso, active)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -105,7 +105,7 @@ func (h *HttpServer) ManageAppsPost(rw http.ResponseWriter, req *http.Request, _
|
|||||||
case "secret":
|
case "secret":
|
||||||
var info oauth2.ClientInfo
|
var info oauth2.ClientInfo
|
||||||
var secret string
|
var secret string
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
sub := req.Form.Get("subject")
|
sub := req.Form.Get("subject")
|
||||||
info, err = tx.GetClientInfo(sub)
|
info, err = tx.GetClientInfo(sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/1f349/tulip/database"
|
"github.com/1f349/tulip/database"
|
||||||
|
"github.com/1f349/tulip/database/types"
|
||||||
"github.com/1f349/tulip/pages"
|
"github.com/1f349/tulip/pages"
|
||||||
"github.com/emersion/go-message/mail"
|
"github.com/emersion/go-message/mail"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@ -27,9 +28,9 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var role database.UserRole
|
var role types.UserRole
|
||||||
var userList []database.User
|
var userList []database.User
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
role, err = tx.GetUserRole(auth.ID)
|
role, err = tx.GetUserRole(auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -39,7 +40,7 @@ func (h *HttpServer) ManageUsersGet(rw http.ResponseWriter, req *http.Request, _
|
|||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if role != database.RoleAdmin {
|
if role != types.RoleAdmin {
|
||||||
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -76,14 +77,14 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var role database.UserRole
|
var role types.UserRole
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
role, err = tx.GetUserRole(auth.ID)
|
role, err = tx.GetUserRole(auth.ID)
|
||||||
return
|
return
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if role != database.RoleAdmin {
|
if role != types.RoleAdmin {
|
||||||
http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest)
|
http.Error(rw, "400 Bad Request: Only admin users can manage users", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -116,7 +117,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
addrDomain := address.Address[n+1:]
|
addrDomain := address.Address[n+1:]
|
||||||
|
|
||||||
var userSub uuid.UUID
|
var userSub uuid.UUID
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active)
|
userSub, err = tx.InsertUser(name, username, "", email, addrDomain == h.conf.Namespace, newRole, active)
|
||||||
return err
|
return err
|
||||||
}) {
|
}) {
|
||||||
@ -136,7 +137,7 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "edit":
|
case "edit":
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
sub := req.Form.Get("subject")
|
sub := req.Form.Get("subject")
|
||||||
return tx.UpdateUser(sub, newRole, active)
|
return tx.UpdateUser(sub, newRole, active)
|
||||||
}) {
|
}) {
|
||||||
@ -151,12 +152,12 @@ func (h *HttpServer) ManageUsersPost(rw http.ResponseWriter, req *http.Request,
|
|||||||
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
|
http.Redirect(rw, req, redirectUrl.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRoleValue(role string) (database.UserRole, error) {
|
func parseRoleValue(role string) (types.UserRole, error) {
|
||||||
switch role {
|
switch role {
|
||||||
case "member":
|
case "member":
|
||||||
return database.RoleMember, nil
|
return types.RoleMember, nil
|
||||||
case "admin":
|
case "admin":
|
||||||
return database.RoleAdmin, nil
|
return types.RoleAdmin, nil
|
||||||
}
|
}
|
||||||
return 0, errors.New("invalid role value")
|
return 0, errors.New("invalid role value")
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request
|
|||||||
|
|
||||||
var user *database.User
|
var user *database.User
|
||||||
var hasOtp bool
|
var hasOtp bool
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
user, err = tx.GetUserDisplayName(auth.ID)
|
user, err = tx.GetUserDisplayName(auth.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
|
@ -47,7 +47,7 @@ func (h *HttpServer) fetchAndValidateOtp(rw http.ResponseWriter, sub, code strin
|
|||||||
var hasOtp bool
|
var hasOtp bool
|
||||||
var secret string
|
var secret string
|
||||||
var digits int
|
var digits int
|
||||||
if h.DbTx(rw, func(tx *database.Tx) (err error) {
|
if h.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
hasOtp, err = tx.HasTwoFactor(sub)
|
hasOtp, err = tx.HasTwoFactor(sub)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
@ -86,7 +86,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.SetTwoFactor(auth.ID, "", 0)
|
return tx.SetTwoFactor(auth.ID, "", 0)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
@ -118,7 +118,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
|||||||
if secret == "" {
|
if secret == "" {
|
||||||
// get user email
|
// get user email
|
||||||
var email string
|
var email string
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
var err error
|
var err error
|
||||||
email, err = tx.GetUserEmail(auth.ID)
|
email, err = tx.GetUserEmail(auth.ID)
|
||||||
return err
|
return err
|
||||||
@ -167,7 +167,7 @@ func (h *HttpServer) EditOtpPost(rw http.ResponseWriter, req *http.Request, _ ht
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.DbTx(rw, func(tx *database.Tx) error {
|
if h.DbTx(rw, func(tx *database.Queries) error {
|
||||||
return tx.SetTwoFactor(auth.ID, secret, digits)
|
return tx.SetTwoFactor(auth.ID, secret, digits)
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
|
@ -31,7 +31,7 @@ type HttpServer struct {
|
|||||||
r *httprouter.Router
|
r *httprouter.Router
|
||||||
oauthSrv *server.Server
|
oauthSrv *server.Server
|
||||||
oauthMgr *manage.Manager
|
oauthMgr *manage.Manager
|
||||||
db *database.DB
|
db *database.Queries
|
||||||
conf Conf
|
conf Conf
|
||||||
signingKey mjwt.Signer
|
signingKey mjwt.Signer
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ type mailLinkKey struct {
|
|||||||
data string
|
data string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Server {
|
func NewHttpServer(conf Conf, db *database.Queries, signingKey mjwt.Signer) *http.Server {
|
||||||
r := httprouter.New()
|
r := httprouter.New()
|
||||||
|
|
||||||
// remove last slash from baseUrl
|
// remove last slash from baseUrl
|
||||||
@ -191,10 +191,10 @@ func NewHttpServer(conf Conf, db *database.DB, signingKey mjwt.Signer) *http.Ser
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var userData *database.User
|
var userData database.GetUserRow
|
||||||
|
|
||||||
if hs.DbTx(rw, func(tx *database.Tx) (err error) {
|
if hs.DbTx(rw, func(tx *database.Queries) (err error) {
|
||||||
userData, err = tx.GetUser(userId)
|
userData, err = tx.GetUser(req.Context(), userId)
|
||||||
return err
|
return err
|
||||||
}) {
|
}) {
|
||||||
return
|
return
|
||||||
|
14
sqlc.yaml
14
sqlc.yaml
@ -9,7 +9,13 @@ sql:
|
|||||||
out: "database"
|
out: "database"
|
||||||
emit_json_tags: true
|
emit_json_tags: true
|
||||||
overrides:
|
overrides:
|
||||||
- column: "routes.flags"
|
- column: "users.password"
|
||||||
go_type: "github.com/1f349/violet/target.Flags"
|
go_type: "github.com/1f349/tulip/password.HashString"
|
||||||
- column: "redirects.flags"
|
- column: "users.role"
|
||||||
go_type: "github.com/1f349/violet/target.Flags"
|
go_type: "github.com/1f349/tulip/database/types.UserRole"
|
||||||
|
- column: "users.pronouns"
|
||||||
|
go_type: "github.com/1f349/tulip/database/types.UserPronoun"
|
||||||
|
- column: "users.zoneinfo"
|
||||||
|
go_type: "github.com/1f349/tulip/database/types.UserZone"
|
||||||
|
- column: "users.locale"
|
||||||
|
go_type: "github.com/1f349/tulip/database/types.UserLocale"
|
||||||
|
Loading…
Reference in New Issue
Block a user