diff --git a/client-store/client-store.go b/client-store/client-store.go new file mode 100644 index 0000000..a57e841 --- /dev/null +++ b/client-store/client-store.go @@ -0,0 +1,26 @@ +package client_store + +import ( + "context" + "github.com/1f349/tulip/database" + "github.com/go-oauth2/oauth2/v4" +) + +type ClientStore struct { + db *database.DB +} + +var _ oauth2.ClientStore = &ClientStore{} + +func New(db *database.DB) *ClientStore { + return &ClientStore{db: db} +} + +func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) { + tx, err := c.db.BeginCtx(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback() + return tx.GetClientInfo(id) +} diff --git a/cmd/tulip/conf.go b/cmd/tulip/conf.go new file mode 100644 index 0000000..e2e7572 --- /dev/null +++ b/cmd/tulip/conf.go @@ -0,0 +1,6 @@ +package main + +type startUpConfig struct { + Listen string `json:"listen"` + Domain string `json:"domain"` +} diff --git a/cmd/tulip/main.go b/cmd/tulip/main.go new file mode 100644 index 0000000..f0b4be5 --- /dev/null +++ b/cmd/tulip/main.go @@ -0,0 +1,19 @@ +package main + +import ( + "context" + "flag" + "github.com/google/subcommands" + "os" +) + +func main() { + subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(subcommands.FlagsCommand(), "") + subcommands.Register(subcommands.CommandsCommand(), "") + subcommands.Register(&serveCmd{}, "") + + flag.Parse() + ctx := context.Background() + os.Exit(int(subcommands.Execute(ctx))) +} diff --git a/cmd/tulip/serve.go b/cmd/tulip/serve.go new file mode 100644 index 0000000..8fcf189 --- /dev/null +++ b/cmd/tulip/serve.go @@ -0,0 +1,131 @@ +package main + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/json" + "errors" + "flag" + "fmt" + clientStore "github.com/1f349/tulip/client-store" + "github.com/1f349/tulip/database" + "github.com/1f349/tulip/server" + "github.com/1f349/violet/utils" + "github.com/MrMelon54/exit-reload" + "github.com/google/subcommands" + _ "github.com/mattn/go-sqlite3" + "log" + "os" + "path/filepath" +) + +type serveCmd struct{ configPath string } + +func (s *serveCmd) Name() string { return "serve" } + +func (s *serveCmd) Synopsis() string { return "Serve user authentication service" } + +func (s *serveCmd) SetFlags(f *flag.FlagSet) { + f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file") +} + +func (s *serveCmd) Usage() string { + return `serve [-conf ] + Serve user authentication service using information from the config file +` +} + +func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + log.Println("[Tulip] Starting...") + + if s.configPath == "" { + log.Println("[Tulip] Error: config flag is missing") + return subcommands.ExitUsageError + } + + openConf, err := os.Open(s.configPath) + if err != nil { + if os.IsNotExist(err) { + log.Println("[Tulip] Error: missing config file") + } else { + log.Println("[Tulip] Error: open config file: ", err) + } + return subcommands.ExitFailure + } + + var config startUpConfig + err = json.NewDecoder(openConf).Decode(&config) + if err != nil { + log.Println("[Tulip] Error: invalid config file: ", err) + return subcommands.ExitFailure + } + + configPathAbs, err := filepath.Abs(s.configPath) + if err != nil { + log.Fatal("[Tulip] Failed to get absolute config path") + } + wd := filepath.Dir(configPathAbs) + normalLoad(config, wd) + return subcommands.ExitSuccess +} + +func normalLoad(startUp startUpConfig, wd string) { + key := genHmacKey() + + db, err := database.Open(filepath.Join(wd, "tulip.db.sqlite")) + if err != nil { + log.Fatal("[Tulip] Failed to open database:", err) + } + + log.Println("[Tulip] Checking database contains at least one user") + if err := checkDbHasUser(db); err != nil { + log.Fatal("[Tulip] Failed check:", err) + } + + cs := clientStore.New(db) + + srv := server.NewHttpServer(startUp.Listen, startUp.Domain, db, key, cs) + log.Printf("[Tulip] Starting HTTP server on '%s'\n", srv.Addr) + go utils.RunBackgroundHttp("HTTP", srv) + + exit_reload.ExitReload("Tulip", func() {}, func() { + // stop http server + srv.Close() + }) +} + +func genHmacKey() []byte { + a := make([]byte, 32) + n, err := rand.Reader.Read(a) + if err != nil { + log.Fatal("[Tulip] Failed to generate HMAC key") + } + if n != 32 { + log.Fatal("[Tulip] Failed to generate HMAC key") + } + return a +} + +func checkDbHasUser(db *database.DB) error { + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + if err := tx.HasUser(); err != nil { + if errors.Is(err, sql.ErrNoRows) { + err := tx.InsertUser("admin", "admin", "admin@localhost") + if err != nil { + return fmt.Errorf("failed to add user: %w", err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + // continue normal operation now + return nil + } else { + return fmt.Errorf("failed to check if table has a user: %w", err) + } + } + return nil +} diff --git a/database/db-scanner.go b/database/db-scanner.go new file mode 100644 index 0000000..176eab5 --- /dev/null +++ b/database/db-scanner.go @@ -0,0 +1,111 @@ +package database + +import ( + "database/sql" + "encoding/json" + "fmt" + "github.com/MrMelon54/pronouns" + "golang.org/x/text/language" + "time" +) + +var ( + _, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} + _, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} + _, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{} +) + +func marshalValueOrNull(null bool, data any) ([]byte, error) { + if null { + return json.Marshal(nil) + } + return json.Marshal(data) +} + +type NullStringScanner struct{ sql.NullString } + +func (s *NullStringScanner) Null() bool { return !s.Valid } +func (s *NullStringScanner) Scan(src any) error { return s.Scan(src) } +func (s NullStringScanner) MarshalJSON() ([]byte, error) { + return marshalValueOrNull(s.Null(), s.String) +} +func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error { + if string(bytes) == "null" { + return s.Scan(nil) + } + var a string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + return s.Scan(&a) +} + +type NullDateScanner struct{ sql.NullTime } + +func (t *NullDateScanner) Null() bool { return !t.Valid } +func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) } +func (t NullDateScanner) MarshalJSON() ([]byte, error) { + return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly)) +} +func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error { + if string(bytes) == "null" { + return t.Scan(nil) + } + var a string + err := json.Unmarshal(bytes, &a) + if err != nil { + return err + } + return t.Scan(&a) +} + +type LocationScanner struct{ *time.Location } + +func (l *LocationScanner) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) + } + loc, err := time.LoadLocation(s) + if err != nil { + return err + } + l.Location = loc + return nil +} +func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) } + +type LocaleScanner struct{ language.Tag } + +func (l *LocaleScanner) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l) + } + lang, err := language.Parse(s) + if err != nil { + return err + } + l.Tag = lang + return nil +} +func (l LocaleScanner) MarshalJSON() ([]byte, error) { + return json.Marshal(l.Tag.String()) +} + +type PronounScanner struct{ pronouns.Pronoun } + +func (p *PronounScanner) Scan(src any) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p) + } + pro, err := pronouns.FindPronoun(s) + if err != nil { + return err + } + p.Pronoun = pro + return nil +} +func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) } diff --git a/database/db-scanner_test.go b/database/db-scanner_test.go new file mode 100644 index 0000000..c7217f3 --- /dev/null +++ b/database/db-scanner_test.go @@ -0,0 +1,52 @@ +package database + +import ( + "database/sql" + "encoding/json" + "github.com/MrMelon54/pronouns" + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" + "testing" + "time" +) + +func encode(data any) string { + j, err := json.Marshal(map[string]any{"value": data}) + if err != nil { + panic(err) + } + return string(j) +} + +func TestStringScanner_MarshalJSON(t *testing.T) { + assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}})) + assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}})) +} + +func TestDateScanner_MarshalJSON(t *testing.T) { + location, err := time.LoadLocation("Europe/London") + assert.NoError(t, err) + assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}})) + assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}})) + assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{})) +} + +func TestLocationScanner_MarshalJSON(t *testing.T) { + location, err := time.LoadLocation("Europe/London") + assert.NoError(t, err) + assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location})) + assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC})) +} + +func TestLocaleScanner_MarshalJSON(t *testing.T) { + assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish})) + assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish})) +} + +func TestPronounScanner_MarshalJSON(t *testing.T) { + assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem})) + assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim})) + assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer})) + assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts})) + assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes})) +} diff --git a/database/db-types.go b/database/db-types.go new file mode 100644 index 0000000..bd183cb --- /dev/null +++ b/database/db-types.go @@ -0,0 +1,105 @@ +package database + +import ( + "encoding/json" + "github.com/MrMelon54/pronouns" + "github.com/google/uuid" + "golang.org/x/text/language" + "net/url" + "time" +) + +type User struct { + Sub uuid.UUID `json:"sub"` + Name string `json:"name,omitempty"` + Username string `json:"username"` + Password string `json:"password"` + Picture NullStringScanner `json:"picture,omitempty"` + Website NullStringScanner `json:"website,omitempty"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Pronouns PronounScanner `json:"pronouns,omitempty"` + Birthdate NullDateScanner `json:"birthdate,omitempty"` + ZoneInfo LocationScanner `json:"zoneinfo,omitempty"` + Locale LocaleScanner `json:"locale,omitempty"` + UpdatedAt time.Time `json:"updated_at"` + Active bool `json:"active"` +} + +type UserPatch struct { + Name NullStringScanner `json:"name"` + Picture NullStringScanner `json:"picture"` + Website NullStringScanner `json:"website"` + Pronouns PronounScanner `json:"pronouns"` + Birthdate NullDateScanner `json:"birthdate"` + ZoneInfo *time.Location `json:"zoneinfo"` + Locale *language.Tag `json:"locale"` +} + +func (u *UserPatch) UnmarshalJSON(bytes []byte) error { + var m struct { + Name string `json:"name"` + Picture string `json:"picture"` + Website string `json:"website"` + Pronouns string `json:"pronouns"` + Birthdate string `json:"birthdate"` + ZoneInfo string `json:"zoneinfo"` + Locale string `json:"locale"` + } + err := json.Unmarshal(bytes, &m) + if err != nil { + return err + } + u.Name = m.Name + + // only parse the picture address if included + if m.Picture != "" { + u.Picture, err = url.Parse(m.Picture) + if err != nil { + return err + } + } + + // only parse the website address if included + if m.Website != "" { + u.Website, err = url.Parse(m.Website) + if err != nil { + return err + } + } + + // only parse the pronouns if included + if m.Pronouns != "" { + u.Pronouns, err = pronouns.FindPronoun(m.Pronouns) + if err != nil { + return err + } + } + + // only parse the birthdate if included + if m.Birthdate != "" { + u.Birthdate, err = time.Parse(time.DateOnly, m.Birthdate) + if err != nil { + return err + } + } + + // only parse the zoneinfo if included + if m.ZoneInfo != "" { + u.ZoneInfo, err = time.LoadLocation(m.ZoneInfo) + if err != nil { + return err + } + } + + if m.Locale != "" { + locale, err := language.Parse(m.Locale) + if err != nil { + return err + } + u.Locale = &locale + } + return nil +} + +var _ json.Unmarshaler = &UserPatch{} diff --git a/database/db-types_test.go b/database/db-types_test.go new file mode 100644 index 0000000..b129f00 --- /dev/null +++ b/database/db-types_test.go @@ -0,0 +1,77 @@ +package database + +import ( + "encoding/json" + "github.com/MrMelon54/pronouns" + "github.com/stretchr/testify/assert" + "maps" + "testing" + "time" +) + +func TestUserPatch_UnmarshalJSON(t *testing.T) { + const a = `{ + "name": "Test", + "picture": "https://example.com/logo.png", + "website": "https://example.com", + "gender": "robot", + "pronouns": "they/them", + "birthdate": "3070-01-01", + "zoneinfo": "Europe/London", + "locale": "en-GB" +}` + var p UserPatch + assert.NoError(t, json.Unmarshal([]byte(a), &p)) + assert.Equal(t, "Test", p.Name) + assert.Equal(t, "https://example.com/logo.png", p.Picture.String()) + assert.Equal(t, "https://example.com", p.Website.String()) + assert.Equal(t, pronouns.TheyThem, p.Pronouns) + assert.Equal(t, time.Date(3070, time.January, 1, 0, 0, 0, 0, time.UTC), p.Birthdate) + location, err := time.LoadLocation("Europe/London") + assert.NoError(t, err) + assert.Equal(t, location, p.ZoneInfo) + assert.Equal(t, "en-GB", p.Locale.String()) +} + +func TestUserPatch_UnmarshalJSON2(t *testing.T) { + var userModifyChecks = map[string]struct{ valid, invalid []string }{ + "picture": {valid: []string{"https://example.com/icon.png"}, invalid: []string{"%/icon.png"}}, + "website": {valid: []string{"https://example.com"}, invalid: []string{"%/example.com"}}, + "pronouns": {valid: []string{"he/him", "she/her"}, invalid: []string{"a/a"}}, + "birthdate": {valid: []string{"2023-08-07", "2023-01-01"}, invalid: []string{"2023-00-00", "hello"}}, + "zoneinfo": { + valid: []string{"Europe/London", "Europe/Berlin", "America/Los_Angeles", "America/Edmonton", "America/Montreal"}, + invalid: []string{"Europe/York", "Canada/Edmonton", "hello"}, + }, + "locale": {valid: []string{"en-GB", "en-US", "zh-CN"}, invalid: []string{"en-YY"}}, + } + m := map[string]string{ + "name": "Test", + "picture": "https://example.com/logo.png", + "website": "https://example.com", + "gender": "robot", + "pronouns": "they/them", + "birthdate": "3070-01-01", + "zoneinfo": "Europe/London", + "locale": "en-GB", + } + for k, v := range userModifyChecks { + t.Run(k, func(t *testing.T) { + m2 := maps.Clone(m) + for _, i := range v.valid { + m2[k] = i + marshal, err := json.Marshal(m2) + assert.NoError(t, err) + var m3 UserPatch + assert.NoError(t, json.Unmarshal(marshal, &m3)) + } + for _, i := range v.invalid { + m2[k] = i + marshal, err := json.Marshal(m2) + assert.NoError(t, err) + var m3 UserPatch + assert.Error(t, json.Unmarshal(marshal, &m3)) + } + }) + } +} diff --git a/database/db.go b/database/db.go new file mode 100644 index 0000000..daf2c62 --- /dev/null +++ b/database/db.go @@ -0,0 +1,37 @@ +package database + +import ( + "context" + "database/sql" + _ "embed" +) + +//go:embed init.sql +var initSql string + +type DB struct{ db *sql.DB } + +func Open(p string) (*DB, error) { + db, err := sql.Open("sqlite3", p) + if err != nil { + return nil, err + } + _, err = db.Exec(initSql) + return &DB{db: db}, err +} + +func (d *DB) Begin() (*Tx, error) { + begin, err := d.db.Begin() + if err != nil { + return nil, err + } + return &Tx{begin}, err +} + +func (d *DB) BeginCtx(ctx context.Context) (*Tx, error) { + begin, err := d.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &Tx{begin}, err +} diff --git a/database/db_test.go b/database/db_test.go new file mode 100644 index 0000000..5d73aab --- /dev/null +++ b/database/db_test.go @@ -0,0 +1,5 @@ +package database + +import ( + _ "github.com/mattn/go-sqlite3" +) diff --git a/database/init.sql b/database/init.sql new file mode 100644 index 0000000..b133162 --- /dev/null +++ b/database/init.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS users +( + subject TEXT PRIMARY KEY UNIQUE NOT NULL, + name TEXT NOT NULL, + username TEXT UNIQUE NOT NULL, + password TEXT NOT NULL, + picture TEXT, + website TEXT, + email TEXT NOT NULL, + email_verified INTEGER DEFAULT 0 NOT NULL, + pronouns TEXT DEFAULT "they/them" NOT NULL, + birthdate DATE, + zoneinfo TEXT DEFAULT "" NOT NULL, + locale TEXT DEFAULT "en-US" NOT NULL, + updated_at DATETIME, + active INTEGER DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS client_store +( + subject TEXT PRIMARY KEY UNIQUE NOT NULL, + name TEXT UNIQUE NOT NULL, + secret TEXT UNIQUE NOT NULL, + domain TEXT NOT NULL, + sso INTEGER, + active INTEGER DEFAULT 1 +); diff --git a/database/tx.go b/database/tx.go new file mode 100644 index 0000000..a451f6c --- /dev/null +++ b/database/tx.go @@ -0,0 +1,177 @@ +package database + +import ( + "database/sql" + "fmt" + "github.com/1f349/tulip/password" + "github.com/go-oauth2/oauth2/v4" + "github.com/google/uuid" + "time" +) + +type Tx struct{ tx *sql.Tx } + +func (t *Tx) Commit() error { + return t.tx.Commit() +} + +func (t *Tx) Rollback() { + _ = t.tx.Rollback() +} + +func (t *Tx) HasUser() error { + var exists bool + row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`) + err := row.Scan(&exists) + if err != nil { + return err + } + if !exists { + return sql.ErrNoRows + } + return nil +} + +func (t *Tx) InsertUser(un, pw, email string) error { + pwHash, err := password.HashPassword(pw) + if err != nil { + return err + } + _, err = t.tx.Exec(`INSERT INTO users (subject, username, password, email) VALUES (?, ?, ?, ?)`, uuid.NewString(), un, pwHash, email) + return err +} + +func (t *Tx) CheckLogin(un, pw string) (*User, error) { + var u User + row := t.tx.QueryRow(`SELECT subject, password FROM users WHERE username = ? LIMIT 1`, un) + err := row.Scan(&u.Sub, &u.Password) + if err != nil { + return nil, err + } + err = password.CheckPasswordHash(u.Password, pw) + return &u, err +} + +func (t *Tx) GetUserDisplayName(sub uuid.UUID) (*User, error) { + var u User + row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub.String()) + err := row.Scan(&u.Name) + u.Sub = sub + return &u, err +} + +func (t *Tx) GetUser(sub uuid.UUID) (*User, error) { + var u User + row := t.tx.QueryRow(`SELECT name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ? LIMIT 1`, sub.String()) + err := row.Scan(&u.Name, &u.Username, &u.Password, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active) + u.Sub = sub + return &u, err +} + +func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error { + q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub) + if err != nil { + return err + } + var pwHash string + if q.Next() { + err = q.Scan(&pwHash) + if err != nil { + return err + } + } else { + return fmt.Errorf("invalid user") + } + if err := q.Err(); err != nil { + return err + } + if err := q.Close(); err != nil { + return err + } + err = password.CheckPasswordHash(pwHash, pwOld) + if err != nil { + return err + } + pwNewHash, err := password.HashPassword(pwNew) + if err != nil { + return err + } + exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, time.Now().Format(time.DateTime), sub, pwHash) + if err != nil { + return err + } + affected, err := exec.RowsAffected() + if err != nil { + return err + } + if affected != 1 { + return fmt.Errorf("row wasn't updated") + } + return nil +} + +func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error { + exec, err := t.tx.Exec( + `UPDATE users +SET name = ifnull(?, name), + picture = ifnull(?, picture), + website = ifnull(?, website), + pronouns = ifnull(?, pronouns), + birthdate = ifnull(?, birthdate), + zoneinfo = ifnull(?, zoneinfo), + locale = ifnull(?, locale), + updated_at = ? +WHERE subject = ?`, + v.Name, + stringify(v.Picture), + stringify(v.Website), + v.Pronouns.String(), + sql.NullTime{Time: v.Birthdate, Valid: !v.Birthdate.IsZero()}, + v.ZoneInfo.String(), + v.Locale.String(), + time.Now().Format(time.DateTime), + sub, + ) + if err != nil { + return err + } + affected, err := exec.RowsAffected() + if err != nil { + return err + } + if affected != 1 { + return fmt.Errorf("row wasn't updated") + } + return nil +} + +func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) { + var u clientInfoDbOutput + row := t.tx.QueryRow(`SELECT secret, domain, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub) + err := row.Scan(&u.secret, &u.domain, &u.sso) + u.sub = sub + return &u, err +} + +type clientInfoDbOutput struct { + sub, secret, domain string + sso bool +} + +func (c *clientInfoDbOutput) GetID() string { return c.sub } +func (c *clientInfoDbOutput) GetSecret() string { return c.secret } +func (c *clientInfoDbOutput) GetDomain() string { return c.domain } +func (c *clientInfoDbOutput) IsPublic() bool { return false } +func (c *clientInfoDbOutput) GetUserID() string { return "" } +func (c *clientInfoDbOutput) IsSSO() bool { return c.sso } + +func stringify(stringer fmt.Stringer) sql.NullString { + if stringer == nil { + return sql.NullString{} + } + return emptyToNull(stringer.String()) +} + +func emptyToNull(a string) sql.NullString { + return sql.NullString{String: a, Valid: a != ""} +} diff --git a/database/tx_test.go b/database/tx_test.go new file mode 100644 index 0000000..ee972e9 --- /dev/null +++ b/database/tx_test.go @@ -0,0 +1,55 @@ +package database + +import ( + "github.com/1f349/tulip/password" + "github.com/MrMelon54/pronouns" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" + "testing" + "time" +) + +func TestTx_ChangeUserPassword(t *testing.T) { + u := uuid.New() + pw, err := password.HashPassword("test") + assert.NoError(t, err) + d, err := Open("file::memory:") + assert.NoError(t, err) + _, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email) VALUES (?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost") + assert.NoError(t, err) + tx, err := d.Begin() + assert.NoError(t, err) + err = tx.ChangeUserPassword(u, "test", "new") + assert.NoError(t, err) + assert.NoError(t, tx.Commit()) + query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test") + assert.NoError(t, err) + assert.True(t, query.Next()) + var oldPw string + assert.NoError(t, query.Scan(&oldPw)) + assert.NoError(t, password.CheckPasswordHash(oldPw, "new")) + assert.NoError(t, query.Err()) + assert.NoError(t, query.Close()) +} + +func TestTx_ModifyUser(t *testing.T) { + u := uuid.New() + pw, err := password.HashPassword("test") + assert.NoError(t, err) + d, err := Open("file::memory:") + assert.NoError(t, err) + _, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email) VALUES (?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost") + assert.NoError(t, err) + tx, err := d.Begin() + assert.NoError(t, err) + assert.NoError(t, tx.ModifyUser(u, &UserPatch{ + Name: "example", + Picture: nil, + Website: nil, + Pronouns: pronouns.Pronoun{}, + Birthdate: time.Time{}, + ZoneInfo: nil, + Locale: &language.Tag{}, + })) +} diff --git a/openid/config.go b/openid/config.go new file mode 100644 index 0000000..77f5e37 --- /dev/null +++ b/openid/config.go @@ -0,0 +1,25 @@ +package openid + +type Config struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserInfoEndpoint string `json:"userinfo_endpoint"` + ResponseTypesSupported []string `json:"response_types_supported"` + ScopesSupported []string `json:"scopes_supported"` + ClaimsSupported []string `json:"claims_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` +} + +func GenConfig(domain string, scopes, claims []string) Config { + return Config{ + Issuer: "https://" + domain, + AuthorizationEndpoint: "https://" + domain + "/authorize", + TokenEndpoint: "https://" + domain + "/token", + UserInfoEndpoint: "https://" + domain + "/userinfo", + ResponseTypesSupported: []string{"code"}, + ScopesSupported: scopes, + ClaimsSupported: claims, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + } +} diff --git a/openid/config_test.go b/openid/config_test.go new file mode 100644 index 0000000..125504d --- /dev/null +++ b/openid/config_test.go @@ -0,0 +1,19 @@ +package openid + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGenConfig(t *testing.T) { + assert.Equal(t, Config{ + Issuer: "https://example.com", + AuthorizationEndpoint: "https://example.com/authorize", + TokenEndpoint: "https://example.com/token", + UserInfoEndpoint: "https://example.com/userinfo", + ResponseTypesSupported: []string{"code"}, + ScopesSupported: []string{"openid", "email"}, + ClaimsSupported: []string{"name", "email", "preferred_username"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + }, GenConfig("example.com", []string{"openid", "email"}, []string{"name", "email", "preferred_username"})) +} diff --git a/pages/authorize.go.html b/pages/authorize.go.html new file mode 100644 index 0000000..7cc617b --- /dev/null +++ b/pages/authorize.go.html @@ -0,0 +1,32 @@ + + + + 1f349 ID + + +
+

1f349 ID

+
+
+
+
The application {{.AppName}} wants to access your account ({{.User.Name}}). It requests the following permissions:
+
+
    + {{range .WantsList}} +
  • {{.Label}}
  • + {{end}} +
+
+
+ + + + + + + +
+
+
+ + diff --git a/pages/index-guest.go.html b/pages/index-guest.go.html new file mode 100644 index 0000000..a31fcbf --- /dev/null +++ b/pages/index-guest.go.html @@ -0,0 +1,17 @@ + + + + 1f349 ID + + +
+

1f349 ID

+
+
+
Not logged in
+
+ +
+
+ + diff --git a/pages/index.go.html b/pages/index.go.html new file mode 100644 index 0000000..96cae19 --- /dev/null +++ b/pages/index.go.html @@ -0,0 +1,19 @@ + + + + 1f349 ID + + +
+

1f349 ID

+
+
+
Logged in as: {{.User.Name}} ({{.User.ID}})
+
+
+ +
+
+
+ + diff --git a/pages/login.go.html b/pages/login.go.html new file mode 100644 index 0000000..bb9e06d --- /dev/null +++ b/pages/login.go.html @@ -0,0 +1,24 @@ + + + + 1f349 ID + + +
+

1f349 ID

+
+
+
+
+ + +
+
+ + +
+ +
+
+ + diff --git a/pages/pages.go b/pages/pages.go new file mode 100644 index 0000000..f6dd481 --- /dev/null +++ b/pages/pages.go @@ -0,0 +1,24 @@ +package pages + +import ( + "embed" + _ "embed" + "html/template" + "io" +) + +var ( + //go:embed * + embeddedTemplates embed.FS + + pageTemplate *template.Template +) + +func LoadPageTemplates() (err error) { + pageTemplate, err = template.New("pages").ParseFS(embeddedTemplates, "*.go.html") + return +} + +func RenderPageTemplate(wr io.Writer, name string, data any) error { + return pageTemplate.ExecuteTemplate(wr, name+".go.html", data) +} diff --git a/password/password.go b/password/password.go new file mode 100644 index 0000000..36111be --- /dev/null +++ b/password/password.go @@ -0,0 +1,12 @@ +package password + +import "golang.org/x/crypto/bcrypt" + +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14) + return string(bytes), err +} + +func CheckPasswordHash(hash, password string) error { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) +} diff --git a/server/auth.go b/server/auth.go new file mode 100644 index 0000000..13c31b1 --- /dev/null +++ b/server/auth.go @@ -0,0 +1,72 @@ +package server + +import ( + "fmt" + "github.com/go-session/session" + "github.com/google/uuid" + "github.com/julienschmidt/httprouter" + "net/http" +) + +type UserHandler func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) + +type UserAuth struct { + ID uuid.UUID + Session session.Store +} + +func (u UserAuth) IsGuest() bool { + return u.ID == uuid.Nil +} + +func (h *HttpServer) RequireAuthentication(error string, code int, next UserHandler) httprouter.Handle { + return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + if auth.IsGuest() { + http.Error(rw, error, code) + return + } + next(rw, req, params, auth) + }) +} + +func (h *HttpServer) RequireAuthenticationRedirect(redirect string, code int, next UserHandler) httprouter.Handle { + return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + if auth.IsGuest() { + http.Redirect(rw, req, redirect, code) + return + } + next(rw, req, params, auth) + }) +} + +func (h *HttpServer) OptionalAuthentication(next UserHandler) httprouter.Handle { + return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + auth, err := h.internalAuthenticationHandler(rw, req) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + next(rw, req, params, auth) + } +} + +func (h *HttpServer) internalAuthenticationHandler(rw http.ResponseWriter, req *http.Request) (UserAuth, error) { + ss, err := session.Start(req.Context(), rw, req) + if err != nil { + return UserAuth{}, fmt.Errorf("failed to start session") + } + + userIdRaw, ok := ss.Get("user") + if !ok { + return UserAuth{Session: ss}, nil + } + userId, ok := userIdRaw.(uuid.UUID) + if !ok { + ss.Delete("user") + err := ss.Save() + if err != nil { + return UserAuth{Session: ss}, fmt.Errorf("failed to reset invalid session data") + } + } + return UserAuth{ID: userId, Session: ss}, nil +} diff --git a/server/auth_test.go b/server/auth_test.go new file mode 100644 index 0000000..a4a06bb --- /dev/null +++ b/server/auth_test.go @@ -0,0 +1,11 @@ +package server + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUserAuth_IsGuest(t *testing.T) { + var u UserAuth + assert.True(t, u.IsGuest()) +} diff --git a/server/db.go b/server/db.go new file mode 100644 index 0000000..f409d80 --- /dev/null +++ b/server/db.go @@ -0,0 +1,30 @@ +package server + +import ( + "github.com/1f349/tulip/database" + "log" + "net/http" +) + +func (h *HttpServer) dbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool { + tx, err := h.db.Begin() + if err != nil { + http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError) + return true + } + defer tx.Rollback() + + err = action(tx) + if err != nil { + http.Error(rw, "Database error", http.StatusInternalServerError) + log.Println("Database action error:", err) + return true + } + err = tx.Commit() + if err != nil { + http.Error(rw, "Database error", http.StatusInternalServerError) + log.Println("Database commit error:", err) + } + + return false +} diff --git a/server/oauth.go b/server/oauth.go new file mode 100644 index 0000000..892ce2d --- /dev/null +++ b/server/oauth.go @@ -0,0 +1,156 @@ +package server + +import ( + "fmt" + "github.com/go-session/session" + "github.com/julienschmidt/httprouter" + "net/http" + "net/url" +) + +func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) { + ss, err := session.Start(req.Context(), rw, req) + if err != nil { + http.Error(rw, "Failed to load session", http.StatusInternalServerError) + return + } + + userID, err := h.oauthSrv.UserAuthorizationHandler(rw, req) + if err != nil { + http.Error(rw, "Failed to check user", http.StatusInternalServerError) + return + } else if userID == "" { + return + } + + // function is only called with GET or POST method + isPost := req.Method == http.MethodPost + + var form url.Values + if isPost { + err = req.ParseForm() + if err != nil { + http.Error(rw, "Failed to parse form", http.StatusInternalServerError) + return + } + form = req.PostForm + } else { + form = req.URL.Query() + } + + clientID := form.Get("client_id") + client, err := h.oauthMgr.GetClient(req.Context(), clientID) + if err != nil { + http.Error(rw, "Invalid client", http.StatusBadRequest) + return + } + + redirectUri := form.Get("redirect_uri") + if redirectUri != client.GetDomain() { + http.Error(rw, "Incorrect redirect URI", http.StatusBadRequest) + return + } + + if form.Has("cancel") { + uCancel, err := url.Parse(client.GetDomain()) + if err != nil { + http.Error(rw, "Invalid redirect URI", http.StatusBadRequest) + return + } + q := uCancel.Query() + q.Set("error", "access_denied") + uCancel.RawQuery = q.Encode() + + http.Redirect(rw, req, uCancel.String(), http.StatusFound) + return + } + + var isSSO bool + if clientIsSSO, ok := client.(interface{ IsSSO() bool }); ok { + isSSO = clientIsSSO.IsSSO() + } + + switch { + case isSSO && isPost: + http.Error(rw, "400 Bad Request", http.StatusBadRequest) + return + case !isSSO && !isPost: + f := func(key string) string { return form.Get(key) } + rw.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(rw, ` + + +Authorize + +
+ + + + + + + +
Scope: %s
+
+
+
+`, clientID, redirectUri, f("scope"), f("state"), f("nonce"), f("response_type"), f("response_mode"), f("scope")) + return + default: + break + } + + // continue flow + oauthDataRaw, ok := ss.Get("OAuthData") + if ok { + ss.Delete("OAuthData") + if ss.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + oauthData, ok := oauthDataRaw.(url.Values) + if !ok { + http.Error(rw, "Failed to load session", http.StatusInternalServerError) + return + } + req.URL.RawQuery = oauthData.Encode() + } + + if err := h.oauthSrv.HandleAuthorizeRequest(rw, req); err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + } +} + +func (h *HttpServer) oauthUserAuthorization(rw http.ResponseWriter, req *http.Request) (string, error) { + err := req.ParseForm() + if err != nil { + return "", err + } + + auth, err := h.internalAuthenticationHandler(rw, req) + if err != nil { + return "", err + } + + if auth.IsGuest() { + // handle redirecting to oauth + var q url.Values + switch req.Method { + case http.MethodPost: + q = req.PostForm + case http.MethodGet: + q = req.URL.Query() + default: + http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed) + return "", err + } + auth.Session.Set("OAuthData", q) + if auth.Session.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return "", err + } + http.Redirect(rw, req, "/login?redirect=oauth", http.StatusFound) + return "", nil + } + return auth.ID.String(), nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..2b967fc --- /dev/null +++ b/server/server.go @@ -0,0 +1,291 @@ +package server + +import ( + "crypto/subtle" + "database/sql" + _ "embed" + "encoding/json" + errors2 "errors" + "fmt" + "github.com/1f349/tulip/database" + "github.com/1f349/tulip/openid" + "github.com/1f349/tulip/pages" + "github.com/go-oauth2/oauth2/v4" + "github.com/go-oauth2/oauth2/v4/errors" + "github.com/go-oauth2/oauth2/v4/generates" + "github.com/go-oauth2/oauth2/v4/manage" + "github.com/go-oauth2/oauth2/v4/server" + "github.com/go-oauth2/oauth2/v4/store" + "github.com/google/uuid" + "github.com/julienschmidt/httprouter" + "golang.org/x/crypto/bcrypt" + "log" + "net/http" + "net/url" + "time" +) + +var errMissingRequiredScope = errors.New("missing required scope") + +type HttpServer struct { + r *httprouter.Router + oauthSrv *server.Server + oauthMgr *manage.Manager + db *database.DB + domain string + privKey []byte +} + +func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clientStore oauth2.ClientStore) *http.Server { + r := httprouter.New() + + openIdConf := openid.GenConfig(domain, []string{"openid", "email"}, []string{"sub", "name", "preferred_username", "profile", "picture", "website", "email", "email_verified", "gender", "birthdate", "zoneinfo", "locale", "updated_at"}) + openIdBytes, err := json.Marshal(openIdConf) + if err != nil { + log.Fatalln("Failed to generate OpenID configuration:", err) + } + + if err := pages.LoadPageTemplates(); err != nil { + log.Fatalln("Failed to load page templates:", err) + } + + oauthManager := manage.NewDefaultManager() + oauthSrv := server.NewServer(server.NewConfig(), oauthManager) + hs := &HttpServer{ + r: httprouter.New(), + oauthSrv: oauthSrv, + oauthMgr: oauthManager, + db: db, + domain: domain, + privKey: privKey, + } + + oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg) + oauthManager.MustTokenStorage(store.NewMemoryTokenStore()) + oauthManager.MapAccessGenerate(generates.NewAccessGenerate()) + oauthManager.MapClientStorage(clientStore) + + oauthSrv.SetResponseErrorHandler(func(re *errors.Response) { + log.Printf("Response error: %#v\n", re) + }) + oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) { + cId, cSecret, err := server.ClientBasicHandler(req) + if cId == "" && cSecret == "" { + cId, cSecret, err = server.ClientFormHandler(req) + } + if err != nil { + return "", "", err + } + return cId, cSecret, nil + }) + oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization) + oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (scope string, err error) { + var form url.Values + if req.Method == http.MethodPost { + form = req.PostForm + } else { + form = req.URL.Query() + } + a := form.Get("scope") + if a != "openid" { + return "", errMissingRequiredScope + } + return "openid", nil + }) + + newUserUuid := uuid.New() + fmt.Println("New User Uuid:", newUserUuid.String()) + + r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + rw.WriteHeader(http.StatusOK) + _, _ = rw.Write(openIdBytes) + }) + r.GET("/", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + rw.Header().Set("Content-Type", "text/html") + rw.WriteHeader(http.StatusOK) + if auth.IsGuest() { + _ = pages.RenderPageTemplate(rw, "index-guest", nil) + return + } + + lNonce := uuid.NewString() + auth.Session.Set("action-nonce", lNonce) + if auth.Session.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + + hs.dbTx(rw, func(tx *database.Tx) error { + userWithName, err := tx.GetUserDisplayName(auth.ID) + if err != nil { + return fmt.Errorf("failed to get user display name: %w", err) + } + _ = pages.RenderPageTemplate(rw, "index", map[string]any{ + "Auth": auth, + "User": userWithName, + "Nonce": lNonce, + }) + return nil + }) + })) + r.POST("/logout", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + lNonce, ok := auth.Session.Get("action-nonce") + if !ok { + http.Error(rw, "Missing nonce", http.StatusInternalServerError) + return + } + if subtle.ConstantTimeCompare([]byte(lNonce.(string)), []byte(req.PostFormValue("nonce"))) == 1 { + auth.Session.Delete("user") + if auth.Session.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + http.Redirect(rw, req, "/", http.StatusFound) + return + } + http.Error(rw, "Logout failed", http.StatusInternalServerError) + })) + r.GET("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + if !auth.IsGuest() { + http.Redirect(rw, req, "/", http.StatusFound) + return + } + rw.Header().Set("Content-Type", "text/html") + rw.WriteHeader(http.StatusOK) + _ = pages.RenderPageTemplate(rw, "login", nil) + })) + r.POST("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + un := req.FormValue("username") + pw := req.FormValue("password") + var userSub uuid.UUID + if hs.dbTx(rw, func(tx *database.Tx) error { + loginUser, err := tx.CheckLogin(un, pw) + if err != nil { + if errors2.Is(err, sql.ErrNoRows) || errors2.Is(err, bcrypt.ErrMismatchedHashAndPassword) { + http.Redirect(rw, req, "/login?mismatch=1", http.StatusFound) + return nil + } + http.Error(rw, "Internal server error", http.StatusInternalServerError) + return err + } + userSub = loginUser.Sub + return nil + }) { + return + } + + // only continues if the above tx succeeds + auth.Session.Set("user", userSub) + if auth.Session.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + + switch req.URL.Query().Get("redirect") { + case "oauth": + oauthDataRaw, ok := auth.Session.Get("OAuthData") + if !ok { + http.Error(rw, "Failed to load session", http.StatusInternalServerError) + return + } + oauthData, ok := oauthDataRaw.(url.Values) + if !ok { + http.Error(rw, "Failed to load session", http.StatusInternalServerError) + return + } + authUrl := url.URL{Path: "/authorize", RawQuery: oauthData.Encode()} + http.Redirect(rw, req, authUrl.String(), http.StatusFound) + default: + http.Redirect(rw, req, "/", http.StatusFound) + } + })) + r.GET("/authorize", hs.authorizeEndpoint) + r.POST("/authorize", hs.authorizeEndpoint) + r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + if err := oauthSrv.HandleTokenRequest(rw, req); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } + }) + r.GET("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + begin, err := db.Begin() + if err != nil { + return + } + user, err := begin.GetUser(auth.ID) + if err != nil { + http.Error(rw, "Failed to read user data", http.StatusInternalServerError) + return + } + + lNonce := uuid.NewString() + auth.Session.Set("action-nonce", lNonce) + if auth.Session.Save() != nil { + http.Error(rw, "Failed to save session", http.StatusInternalServerError) + return + } + _ = pages.RenderPageTemplate(rw, "edit", map[string]any{ + "User": user, + "Nonce": lNonce, + }) + })) + r.POST("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { + if req.ParseForm() != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + // TODO: parse user patch from form + req.Form.Get("") + var patch database.UserPatch + decoder := json.NewDecoder(req.Body) + decoder.DisallowUnknownFields() + err := decoder.Decode(&patch) + if err != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + begin, err := db.Begin() + if err != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + if begin.ModifyUser(auth.ID, &patch) != nil { + http.Error(rw, "Failed to modify user info", http.StatusInternalServerError) + return + } + http.Redirect(rw, req, "/", http.StatusFound) + })) + r.GET("/userinfo", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { + token, err := oauthSrv.ValidationBearerToken(req) + if err != nil { + http.Error(rw, "403 Forbidden", http.StatusForbidden) + return + } + fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", token.GetUserID(), token.GetClientID(), token.GetScope()) + _ = json.NewEncoder(rw).Encode(map[string]any{ + "sub": token.GetUserID(), + "aud": token.GetClientID(), + "name": "Melon", + "preferred_username": "melon", + "profile": "https://" + domain + "/user/melon", + "picture": "https://" + domain + "/picture/melon.svg", + "website": "https://mrmelon54.com", + "email": "melon@mrmelon54.com", + "email_verified": true, + "gender": "male", + "birthdate": time.Now().Format(time.DateOnly), + "zoneinfo": "Europe/London", + "locale": "en-GB", + "updated_at": time.Now().Unix(), + }) + }) + + return &http.Server{ + Addr: listen, + Handler: r, + ReadTimeout: time.Minute, + ReadHeaderTimeout: time.Minute, + WriteTimeout: time.Minute, + IdleTimeout: time.Minute, + MaxHeaderBytes: 2500, + } +}