mirror of
https://github.com/1f349/tulip.git
synced 2024-11-12 14:51:32 +00:00
Lots of code changes
This commit is contained in:
parent
4b1e60b9bf
commit
703f3d17cd
26
client-store/client-store.go
Normal file
26
client-store/client-store.go
Normal file
@ -0,0 +1,26 @@
|
||||
package client_store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
)
|
||||
|
||||
type ClientStore struct {
|
||||
db *database.DB
|
||||
}
|
||||
|
||||
var _ oauth2.ClientStore = &ClientStore{}
|
||||
|
||||
func New(db *database.DB) *ClientStore {
|
||||
return &ClientStore{db: db}
|
||||
}
|
||||
|
||||
func (c *ClientStore) GetByID(ctx context.Context, id string) (oauth2.ClientInfo, error) {
|
||||
tx, err := c.db.BeginCtx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
return tx.GetClientInfo(id)
|
||||
}
|
6
cmd/tulip/conf.go
Normal file
6
cmd/tulip/conf.go
Normal file
@ -0,0 +1,6 @@
|
||||
package main
|
||||
|
||||
type startUpConfig struct {
|
||||
Listen string `json:"listen"`
|
||||
Domain string `json:"domain"`
|
||||
}
|
19
cmd/tulip/main.go
Normal file
19
cmd/tulip/main.go
Normal file
@ -0,0 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"github.com/google/subcommands"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
subcommands.Register(subcommands.HelpCommand(), "")
|
||||
subcommands.Register(subcommands.FlagsCommand(), "")
|
||||
subcommands.Register(subcommands.CommandsCommand(), "")
|
||||
subcommands.Register(&serveCmd{}, "")
|
||||
|
||||
flag.Parse()
|
||||
ctx := context.Background()
|
||||
os.Exit(int(subcommands.Execute(ctx)))
|
||||
}
|
131
cmd/tulip/serve.go
Normal file
131
cmd/tulip/serve.go
Normal file
@ -0,0 +1,131 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
clientStore "github.com/1f349/tulip/client-store"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/1f349/tulip/server"
|
||||
"github.com/1f349/violet/utils"
|
||||
"github.com/MrMelon54/exit-reload"
|
||||
"github.com/google/subcommands"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type serveCmd struct{ configPath string }
|
||||
|
||||
func (s *serveCmd) Name() string { return "serve" }
|
||||
|
||||
func (s *serveCmd) Synopsis() string { return "Serve user authentication service" }
|
||||
|
||||
func (s *serveCmd) SetFlags(f *flag.FlagSet) {
|
||||
f.StringVar(&s.configPath, "conf", "", "/path/to/config.json : path to the config file")
|
||||
}
|
||||
|
||||
func (s *serveCmd) Usage() string {
|
||||
return `serve [-conf <config file>]
|
||||
Serve user authentication service using information from the config file
|
||||
`
|
||||
}
|
||||
|
||||
func (s *serveCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
|
||||
log.Println("[Tulip] Starting...")
|
||||
|
||||
if s.configPath == "" {
|
||||
log.Println("[Tulip] Error: config flag is missing")
|
||||
return subcommands.ExitUsageError
|
||||
}
|
||||
|
||||
openConf, err := os.Open(s.configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Println("[Tulip] Error: missing config file")
|
||||
} else {
|
||||
log.Println("[Tulip] Error: open config file: ", err)
|
||||
}
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
var config startUpConfig
|
||||
err = json.NewDecoder(openConf).Decode(&config)
|
||||
if err != nil {
|
||||
log.Println("[Tulip] Error: invalid config file: ", err)
|
||||
return subcommands.ExitFailure
|
||||
}
|
||||
|
||||
configPathAbs, err := filepath.Abs(s.configPath)
|
||||
if err != nil {
|
||||
log.Fatal("[Tulip] Failed to get absolute config path")
|
||||
}
|
||||
wd := filepath.Dir(configPathAbs)
|
||||
normalLoad(config, wd)
|
||||
return subcommands.ExitSuccess
|
||||
}
|
||||
|
||||
func normalLoad(startUp startUpConfig, wd string) {
|
||||
key := genHmacKey()
|
||||
|
||||
db, err := database.Open(filepath.Join(wd, "tulip.db.sqlite"))
|
||||
if err != nil {
|
||||
log.Fatal("[Tulip] Failed to open database:", err)
|
||||
}
|
||||
|
||||
log.Println("[Tulip] Checking database contains at least one user")
|
||||
if err := checkDbHasUser(db); err != nil {
|
||||
log.Fatal("[Tulip] Failed check:", err)
|
||||
}
|
||||
|
||||
cs := clientStore.New(db)
|
||||
|
||||
srv := server.NewHttpServer(startUp.Listen, startUp.Domain, db, key, cs)
|
||||
log.Printf("[Tulip] Starting HTTP server on '%s'\n", srv.Addr)
|
||||
go utils.RunBackgroundHttp("HTTP", srv)
|
||||
|
||||
exit_reload.ExitReload("Tulip", func() {}, func() {
|
||||
// stop http server
|
||||
srv.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func genHmacKey() []byte {
|
||||
a := make([]byte, 32)
|
||||
n, err := rand.Reader.Read(a)
|
||||
if err != nil {
|
||||
log.Fatal("[Tulip] Failed to generate HMAC key")
|
||||
}
|
||||
if n != 32 {
|
||||
log.Fatal("[Tulip] Failed to generate HMAC key")
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func checkDbHasUser(db *database.DB) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start transaction: %w", err)
|
||||
}
|
||||
if err := tx.HasUser(); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err := tx.InsertUser("admin", "admin", "admin@localhost")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user: %w", err)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
// continue normal operation now
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("failed to check if table has a user: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
111
database/db-scanner.go
Normal file
111
database/db-scanner.go
Normal file
@ -0,0 +1,111 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"golang.org/x/text/language"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_, _, _, _, _ sql.Scanner = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
_, _, _, _, _ json.Marshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
_, _, _, _, _ json.Unmarshaler = &NullStringScanner{}, &NullDateScanner{}, &LocationScanner{}, &LocaleScanner{}, &PronounScanner{}
|
||||
)
|
||||
|
||||
func marshalValueOrNull(null bool, data any) ([]byte, error) {
|
||||
if null {
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
type NullStringScanner struct{ sql.NullString }
|
||||
|
||||
func (s *NullStringScanner) Null() bool { return !s.Valid }
|
||||
func (s *NullStringScanner) Scan(src any) error { return s.Scan(src) }
|
||||
func (s NullStringScanner) MarshalJSON() ([]byte, error) {
|
||||
return marshalValueOrNull(s.Null(), s.String)
|
||||
}
|
||||
func (s *NullStringScanner) UnmarshalJSON(bytes []byte) error {
|
||||
if string(bytes) == "null" {
|
||||
return s.Scan(nil)
|
||||
}
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Scan(&a)
|
||||
}
|
||||
|
||||
type NullDateScanner struct{ sql.NullTime }
|
||||
|
||||
func (t *NullDateScanner) Null() bool { return !t.Valid }
|
||||
func (t *NullDateScanner) Scan(src any) error { return t.NullTime.Scan(src) }
|
||||
func (t NullDateScanner) MarshalJSON() ([]byte, error) {
|
||||
return marshalValueOrNull(t.Null(), t.Time.UTC().Format(time.DateOnly))
|
||||
}
|
||||
func (t *NullDateScanner) UnmarshalJSON(bytes []byte) error {
|
||||
if string(bytes) == "null" {
|
||||
return t.Scan(nil)
|
||||
}
|
||||
var a string
|
||||
err := json.Unmarshal(bytes, &a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return t.Scan(&a)
|
||||
}
|
||||
|
||||
type LocationScanner struct{ *time.Location }
|
||||
|
||||
func (l *LocationScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
loc, err := time.LoadLocation(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Location = loc
|
||||
return nil
|
||||
}
|
||||
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
|
||||
|
||||
type LocaleScanner struct{ language.Tag }
|
||||
|
||||
func (l *LocaleScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, l)
|
||||
}
|
||||
lang, err := language.Parse(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.Tag = lang
|
||||
return nil
|
||||
}
|
||||
func (l LocaleScanner) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(l.Tag.String())
|
||||
}
|
||||
|
||||
type PronounScanner struct{ pronouns.Pronoun }
|
||||
|
||||
func (p *PronounScanner) Scan(src any) error {
|
||||
s, ok := src.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, p)
|
||||
}
|
||||
pro, err := pronouns.FindPronoun(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Pronoun = pro
|
||||
return nil
|
||||
}
|
||||
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
|
52
database/db-scanner_test.go
Normal file
52
database/db-scanner_test.go
Normal file
@ -0,0 +1,52 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func encode(data any) string {
|
||||
j, err := json.Marshal(map[string]any{"value": data})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(j)
|
||||
}
|
||||
|
||||
func TestStringScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"Hello world\"}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":null}", encode(NullStringScanner{sql.NullString{String: "Hello world", Valid: false}}))
|
||||
}
|
||||
|
||||
func TestDateScanner_MarshalJSON(t *testing.T) {
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "{\"value\":\"2006-01-02\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.January, 2, 0, 0, 0, 0, time.UTC), Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":\"2006-08-01\"}", encode(NullDateScanner{sql.NullTime{Time: time.Date(2006, time.August, 2, 0, 0, 0, 0, location), Valid: true}}))
|
||||
assert.Equal(t, "{\"value\":null}", encode(NullDateScanner{}))
|
||||
}
|
||||
|
||||
func TestLocationScanner_MarshalJSON(t *testing.T) {
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "{\"value\":\"Europe/London\"}", encode(LocationScanner{location}))
|
||||
assert.Equal(t, "{\"value\":\"UTC\"}", encode(LocationScanner{time.UTC}))
|
||||
}
|
||||
|
||||
func TestLocaleScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"en-US\"}", encode(LocaleScanner{language.AmericanEnglish}))
|
||||
assert.Equal(t, "{\"value\":\"en-GB\"}", encode(LocaleScanner{language.BritishEnglish}))
|
||||
}
|
||||
|
||||
func TestPronounScanner_MarshalJSON(t *testing.T) {
|
||||
assert.Equal(t, "{\"value\":\"they/them\"}", encode(PronounScanner{pronouns.TheyThem}))
|
||||
assert.Equal(t, "{\"value\":\"he/him\"}", encode(PronounScanner{pronouns.HeHim}))
|
||||
assert.Equal(t, "{\"value\":\"she/her\"}", encode(PronounScanner{pronouns.SheHer}))
|
||||
assert.Equal(t, "{\"value\":\"it/its\"}", encode(PronounScanner{pronouns.ItIts}))
|
||||
assert.Equal(t, "{\"value\":\"one/one's\"}", encode(PronounScanner{pronouns.OneOnes}))
|
||||
}
|
105
database/db-types.go
Normal file
105
database/db-types.go
Normal file
@ -0,0 +1,105 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/text/language"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Sub uuid.UUID `json:"sub"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Picture NullStringScanner `json:"picture,omitempty"`
|
||||
Website NullStringScanner `json:"website,omitempty"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Pronouns PronounScanner `json:"pronouns,omitempty"`
|
||||
Birthdate NullDateScanner `json:"birthdate,omitempty"`
|
||||
ZoneInfo LocationScanner `json:"zoneinfo,omitempty"`
|
||||
Locale LocaleScanner `json:"locale,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
type UserPatch struct {
|
||||
Name NullStringScanner `json:"name"`
|
||||
Picture NullStringScanner `json:"picture"`
|
||||
Website NullStringScanner `json:"website"`
|
||||
Pronouns PronounScanner `json:"pronouns"`
|
||||
Birthdate NullDateScanner `json:"birthdate"`
|
||||
ZoneInfo *time.Location `json:"zoneinfo"`
|
||||
Locale *language.Tag `json:"locale"`
|
||||
}
|
||||
|
||||
func (u *UserPatch) UnmarshalJSON(bytes []byte) error {
|
||||
var m struct {
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
Website string `json:"website"`
|
||||
Pronouns string `json:"pronouns"`
|
||||
Birthdate string `json:"birthdate"`
|
||||
ZoneInfo string `json:"zoneinfo"`
|
||||
Locale string `json:"locale"`
|
||||
}
|
||||
err := json.Unmarshal(bytes, &m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Name = m.Name
|
||||
|
||||
// only parse the picture address if included
|
||||
if m.Picture != "" {
|
||||
u.Picture, err = url.Parse(m.Picture)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// only parse the website address if included
|
||||
if m.Website != "" {
|
||||
u.Website, err = url.Parse(m.Website)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// only parse the pronouns if included
|
||||
if m.Pronouns != "" {
|
||||
u.Pronouns, err = pronouns.FindPronoun(m.Pronouns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// only parse the birthdate if included
|
||||
if m.Birthdate != "" {
|
||||
u.Birthdate, err = time.Parse(time.DateOnly, m.Birthdate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// only parse the zoneinfo if included
|
||||
if m.ZoneInfo != "" {
|
||||
u.ZoneInfo, err = time.LoadLocation(m.ZoneInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if m.Locale != "" {
|
||||
locale, err := language.Parse(m.Locale)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.Locale = &locale
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ json.Unmarshaler = &UserPatch{}
|
77
database/db-types_test.go
Normal file
77
database/db-types_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"maps"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserPatch_UnmarshalJSON(t *testing.T) {
|
||||
const a = `{
|
||||
"name": "Test",
|
||||
"picture": "https://example.com/logo.png",
|
||||
"website": "https://example.com",
|
||||
"gender": "robot",
|
||||
"pronouns": "they/them",
|
||||
"birthdate": "3070-01-01",
|
||||
"zoneinfo": "Europe/London",
|
||||
"locale": "en-GB"
|
||||
}`
|
||||
var p UserPatch
|
||||
assert.NoError(t, json.Unmarshal([]byte(a), &p))
|
||||
assert.Equal(t, "Test", p.Name)
|
||||
assert.Equal(t, "https://example.com/logo.png", p.Picture.String())
|
||||
assert.Equal(t, "https://example.com", p.Website.String())
|
||||
assert.Equal(t, pronouns.TheyThem, p.Pronouns)
|
||||
assert.Equal(t, time.Date(3070, time.January, 1, 0, 0, 0, 0, time.UTC), p.Birthdate)
|
||||
location, err := time.LoadLocation("Europe/London")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, location, p.ZoneInfo)
|
||||
assert.Equal(t, "en-GB", p.Locale.String())
|
||||
}
|
||||
|
||||
func TestUserPatch_UnmarshalJSON2(t *testing.T) {
|
||||
var userModifyChecks = map[string]struct{ valid, invalid []string }{
|
||||
"picture": {valid: []string{"https://example.com/icon.png"}, invalid: []string{"%/icon.png"}},
|
||||
"website": {valid: []string{"https://example.com"}, invalid: []string{"%/example.com"}},
|
||||
"pronouns": {valid: []string{"he/him", "she/her"}, invalid: []string{"a/a"}},
|
||||
"birthdate": {valid: []string{"2023-08-07", "2023-01-01"}, invalid: []string{"2023-00-00", "hello"}},
|
||||
"zoneinfo": {
|
||||
valid: []string{"Europe/London", "Europe/Berlin", "America/Los_Angeles", "America/Edmonton", "America/Montreal"},
|
||||
invalid: []string{"Europe/York", "Canada/Edmonton", "hello"},
|
||||
},
|
||||
"locale": {valid: []string{"en-GB", "en-US", "zh-CN"}, invalid: []string{"en-YY"}},
|
||||
}
|
||||
m := map[string]string{
|
||||
"name": "Test",
|
||||
"picture": "https://example.com/logo.png",
|
||||
"website": "https://example.com",
|
||||
"gender": "robot",
|
||||
"pronouns": "they/them",
|
||||
"birthdate": "3070-01-01",
|
||||
"zoneinfo": "Europe/London",
|
||||
"locale": "en-GB",
|
||||
}
|
||||
for k, v := range userModifyChecks {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
m2 := maps.Clone(m)
|
||||
for _, i := range v.valid {
|
||||
m2[k] = i
|
||||
marshal, err := json.Marshal(m2)
|
||||
assert.NoError(t, err)
|
||||
var m3 UserPatch
|
||||
assert.NoError(t, json.Unmarshal(marshal, &m3))
|
||||
}
|
||||
for _, i := range v.invalid {
|
||||
m2[k] = i
|
||||
marshal, err := json.Marshal(m2)
|
||||
assert.NoError(t, err)
|
||||
var m3 UserPatch
|
||||
assert.Error(t, json.Unmarshal(marshal, &m3))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
37
database/db.go
Normal file
37
database/db.go
Normal file
@ -0,0 +1,37 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
//go:embed init.sql
|
||||
var initSql string
|
||||
|
||||
type DB struct{ db *sql.DB }
|
||||
|
||||
func Open(p string) (*DB, error) {
|
||||
db, err := sql.Open("sqlite3", p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = db.Exec(initSql)
|
||||
return &DB{db: db}, err
|
||||
}
|
||||
|
||||
func (d *DB) Begin() (*Tx, error) {
|
||||
begin, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{begin}, err
|
||||
}
|
||||
|
||||
func (d *DB) BeginCtx(ctx context.Context) (*Tx, error) {
|
||||
begin, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{begin}, err
|
||||
}
|
5
database/db_test.go
Normal file
5
database/db_test.go
Normal file
@ -0,0 +1,5 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
27
database/init.sql
Normal file
27
database/init.sql
Normal file
@ -0,0 +1,27 @@
|
||||
CREATE TABLE IF NOT EXISTS users
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
username TEXT UNIQUE NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
picture TEXT,
|
||||
website TEXT,
|
||||
email TEXT NOT NULL,
|
||||
email_verified INTEGER DEFAULT 0 NOT NULL,
|
||||
pronouns TEXT DEFAULT "they/them" NOT NULL,
|
||||
birthdate DATE,
|
||||
zoneinfo TEXT DEFAULT "" NOT NULL,
|
||||
locale TEXT DEFAULT "en-US" NOT NULL,
|
||||
updated_at DATETIME,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS client_store
|
||||
(
|
||||
subject TEXT PRIMARY KEY UNIQUE NOT NULL,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
secret TEXT UNIQUE NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
sso INTEGER,
|
||||
active INTEGER DEFAULT 1
|
||||
);
|
177
database/tx.go
Normal file
177
database/tx.go
Normal file
@ -0,0 +1,177 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/1f349/tulip/password"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"github.com/google/uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Tx struct{ tx *sql.Tx }
|
||||
|
||||
func (t *Tx) Commit() error {
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
func (t *Tx) Rollback() {
|
||||
_ = t.tx.Rollback()
|
||||
}
|
||||
|
||||
func (t *Tx) HasUser() error {
|
||||
var exists bool
|
||||
row := t.tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM users)`)
|
||||
err := row.Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) InsertUser(un, pw, email string) error {
|
||||
pwHash, err := password.HashPassword(pw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = t.tx.Exec(`INSERT INTO users (subject, username, password, email) VALUES (?, ?, ?, ?)`, uuid.NewString(), un, pwHash, email)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *Tx) CheckLogin(un, pw string) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT subject, password FROM users WHERE username = ? LIMIT 1`, un)
|
||||
err := row.Scan(&u.Sub, &u.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = password.CheckPasswordHash(u.Password, pw)
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUserDisplayName(sub uuid.UUID) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT name FROM users WHERE subject = ? LIMIT 1`, sub.String())
|
||||
err := row.Scan(&u.Name)
|
||||
u.Sub = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) GetUser(sub uuid.UUID) (*User, error) {
|
||||
var u User
|
||||
row := t.tx.QueryRow(`SELECT name, username, password, picture, website, email, email_verified, pronouns, birthdate, zoneinfo, locale, updated_at, active FROM users WHERE subject = ? LIMIT 1`, sub.String())
|
||||
err := row.Scan(&u.Name, &u.Username, &u.Password, &u.Picture, &u.Website, &u.Email, &u.EmailVerified, &u.Pronouns, &u.Birthdate, &u.ZoneInfo, &u.Locale, &u.UpdatedAt, &u.Active)
|
||||
u.Sub = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error {
|
||||
q, err := t.tx.Query(`SELECT password FROM users WHERE subject = ?`, sub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pwHash string
|
||||
if q.Next() {
|
||||
err = q.Scan(&pwHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("invalid user")
|
||||
}
|
||||
if err := q.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := q.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
err = password.CheckPasswordHash(pwHash, pwOld)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pwNewHash, err := password.HashPassword(pwNew)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
exec, err := t.tx.Exec(`UPDATE users SET password = ?, updated_at = ? WHERE subject = ? AND password = ?`, pwNewHash, time.Now().Format(time.DateTime), sub, pwHash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := exec.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("row wasn't updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error {
|
||||
exec, err := t.tx.Exec(
|
||||
`UPDATE users
|
||||
SET name = ifnull(?, name),
|
||||
picture = ifnull(?, picture),
|
||||
website = ifnull(?, website),
|
||||
pronouns = ifnull(?, pronouns),
|
||||
birthdate = ifnull(?, birthdate),
|
||||
zoneinfo = ifnull(?, zoneinfo),
|
||||
locale = ifnull(?, locale),
|
||||
updated_at = ?
|
||||
WHERE subject = ?`,
|
||||
v.Name,
|
||||
stringify(v.Picture),
|
||||
stringify(v.Website),
|
||||
v.Pronouns.String(),
|
||||
sql.NullTime{Time: v.Birthdate, Valid: !v.Birthdate.IsZero()},
|
||||
v.ZoneInfo.String(),
|
||||
v.Locale.String(),
|
||||
time.Now().Format(time.DateTime),
|
||||
sub,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := exec.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != 1 {
|
||||
return fmt.Errorf("row wasn't updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tx) GetClientInfo(sub string) (oauth2.ClientInfo, error) {
|
||||
var u clientInfoDbOutput
|
||||
row := t.tx.QueryRow(`SELECT secret, domain, sso, active FROM client_store WHERE subject = ? LIMIT 1`, sub)
|
||||
err := row.Scan(&u.secret, &u.domain, &u.sso)
|
||||
u.sub = sub
|
||||
return &u, err
|
||||
}
|
||||
|
||||
type clientInfoDbOutput struct {
|
||||
sub, secret, domain string
|
||||
sso bool
|
||||
}
|
||||
|
||||
func (c *clientInfoDbOutput) GetID() string { return c.sub }
|
||||
func (c *clientInfoDbOutput) GetSecret() string { return c.secret }
|
||||
func (c *clientInfoDbOutput) GetDomain() string { return c.domain }
|
||||
func (c *clientInfoDbOutput) IsPublic() bool { return false }
|
||||
func (c *clientInfoDbOutput) GetUserID() string { return "" }
|
||||
func (c *clientInfoDbOutput) IsSSO() bool { return c.sso }
|
||||
|
||||
func stringify(stringer fmt.Stringer) sql.NullString {
|
||||
if stringer == nil {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return emptyToNull(stringer.String())
|
||||
}
|
||||
|
||||
func emptyToNull(a string) sql.NullString {
|
||||
return sql.NullString{String: a, Valid: a != ""}
|
||||
}
|
55
database/tx_test.go
Normal file
55
database/tx_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"github.com/1f349/tulip/password"
|
||||
"github.com/MrMelon54/pronouns"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/text/language"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTx_ChangeUserPassword(t *testing.T) {
|
||||
u := uuid.New()
|
||||
pw, err := password.HashPassword("test")
|
||||
assert.NoError(t, err)
|
||||
d, err := Open("file::memory:")
|
||||
assert.NoError(t, err)
|
||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email) VALUES (?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost")
|
||||
assert.NoError(t, err)
|
||||
tx, err := d.Begin()
|
||||
assert.NoError(t, err)
|
||||
err = tx.ChangeUserPassword(u, "test", "new")
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tx.Commit())
|
||||
query, err := d.db.Query(`SELECT password FROM users WHERE subject = ? AND username = ?`, u.String(), "test")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, query.Next())
|
||||
var oldPw string
|
||||
assert.NoError(t, query.Scan(&oldPw))
|
||||
assert.NoError(t, password.CheckPasswordHash(oldPw, "new"))
|
||||
assert.NoError(t, query.Err())
|
||||
assert.NoError(t, query.Close())
|
||||
}
|
||||
|
||||
func TestTx_ModifyUser(t *testing.T) {
|
||||
u := uuid.New()
|
||||
pw, err := password.HashPassword("test")
|
||||
assert.NoError(t, err)
|
||||
d, err := Open("file::memory:")
|
||||
assert.NoError(t, err)
|
||||
_, err = d.db.Exec(`INSERT INTO users (subject, name, username, password, email) VALUES (?, ?, ?, ?, ?)`, u.String(), "Test", "test", pw, "test@localhost")
|
||||
assert.NoError(t, err)
|
||||
tx, err := d.Begin()
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, tx.ModifyUser(u, &UserPatch{
|
||||
Name: "example",
|
||||
Picture: nil,
|
||||
Website: nil,
|
||||
Pronouns: pronouns.Pronoun{},
|
||||
Birthdate: time.Time{},
|
||||
ZoneInfo: nil,
|
||||
Locale: &language.Tag{},
|
||||
}))
|
||||
}
|
25
openid/config.go
Normal file
25
openid/config.go
Normal file
@ -0,0 +1,25 @@
|
||||
package openid
|
||||
|
||||
type Config struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported"`
|
||||
ScopesSupported []string `json:"scopes_supported"`
|
||||
ClaimsSupported []string `json:"claims_supported"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported"`
|
||||
}
|
||||
|
||||
func GenConfig(domain string, scopes, claims []string) Config {
|
||||
return Config{
|
||||
Issuer: "https://" + domain,
|
||||
AuthorizationEndpoint: "https://" + domain + "/authorize",
|
||||
TokenEndpoint: "https://" + domain + "/token",
|
||||
UserInfoEndpoint: "https://" + domain + "/userinfo",
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
ScopesSupported: scopes,
|
||||
ClaimsSupported: claims,
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
}
|
||||
}
|
19
openid/config_test.go
Normal file
19
openid/config_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package openid
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenConfig(t *testing.T) {
|
||||
assert.Equal(t, Config{
|
||||
Issuer: "https://example.com",
|
||||
AuthorizationEndpoint: "https://example.com/authorize",
|
||||
TokenEndpoint: "https://example.com/token",
|
||||
UserInfoEndpoint: "https://example.com/userinfo",
|
||||
ResponseTypesSupported: []string{"code"},
|
||||
ScopesSupported: []string{"openid", "email"},
|
||||
ClaimsSupported: []string{"name", "email", "preferred_username"},
|
||||
GrantTypesSupported: []string{"authorization_code", "refresh_token"},
|
||||
}, GenConfig("example.com", []string{"openid", "email"}, []string{"name", "email", "preferred_username"}))
|
||||
}
|
32
pages/authorize.go.html
Normal file
32
pages/authorize.go.html
Normal file
@ -0,0 +1,32 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>1f349 ID</title>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>1f349 ID</h1>
|
||||
</header>
|
||||
<main>
|
||||
<form method="POST" action="/authorize">
|
||||
<div>The application {{.AppName}} wants to access your account ({{.User.Name}}). It requests the following permissions:</div>
|
||||
<div>
|
||||
<ul>
|
||||
{{range .WantsList}}
|
||||
<li>{{.Label}}</li>
|
||||
{{end}}
|
||||
</ul>
|
||||
</div>
|
||||
<div>
|
||||
<input type="hidden" name="response_type" value="{{.ResponseType}}"/>
|
||||
<input type="hidden" name="client_id" value="{{.ClientID}}"/>
|
||||
<input type="hidden" name="state" value="{{.State}}"/>
|
||||
<input type="hidden" name="scopes" value="{{.Scope}}"/>
|
||||
<input type="hidden" name="nonce" value="{{.Nonce}}"/>
|
||||
<button class="oauth-action-authorize" name="oauth_action" value="authorize">Authorize</button>
|
||||
<button class="oauth-action-cancel" name="oauth_action" value="cancel">Cancel</button>
|
||||
</div>
|
||||
</form>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
17
pages/index-guest.go.html
Normal file
17
pages/index-guest.go.html
Normal file
@ -0,0 +1,17 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>1f349 ID</title>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>1f349 ID</h1>
|
||||
</header>
|
||||
<main>
|
||||
<div>Not logged in</div>
|
||||
<div>
|
||||
<button onclick="location.href='/login'">Login</button>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
19
pages/index.go.html
Normal file
19
pages/index.go.html
Normal file
@ -0,0 +1,19 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>1f349 ID</title>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>1f349 ID</h1>
|
||||
</header>
|
||||
<main>
|
||||
<div>Logged in as: {{.User.Name}} ({{.User.ID}})</div>
|
||||
<div>
|
||||
<form method="POST" action="/logout"><input type="hidden" name="nonce" value="{{.Nonce}}">
|
||||
<button type="submit">Log Out</button>
|
||||
</form>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
24
pages/login.go.html
Normal file
24
pages/login.go.html
Normal file
@ -0,0 +1,24 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>1f349 ID</title>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>1f349 ID</h1>
|
||||
</header>
|
||||
<main>
|
||||
<form method="POST" action="">
|
||||
<div>
|
||||
<label for="username">User Name</label>
|
||||
<input type="text" name="username" id="username" required/>
|
||||
</div>
|
||||
<div>
|
||||
<label for="password">Password</label>
|
||||
<input type="password" name="password" id="password" required/>
|
||||
</div>
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
24
pages/pages.go
Normal file
24
pages/pages.go
Normal file
@ -0,0 +1,24 @@
|
||||
package pages
|
||||
|
||||
import (
|
||||
"embed"
|
||||
_ "embed"
|
||||
"html/template"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed *
|
||||
embeddedTemplates embed.FS
|
||||
|
||||
pageTemplate *template.Template
|
||||
)
|
||||
|
||||
func LoadPageTemplates() (err error) {
|
||||
pageTemplate, err = template.New("pages").ParseFS(embeddedTemplates, "*.go.html")
|
||||
return
|
||||
}
|
||||
|
||||
func RenderPageTemplate(wr io.Writer, name string, data any) error {
|
||||
return pageTemplate.ExecuteTemplate(wr, name+".go.html", data)
|
||||
}
|
12
password/password.go
Normal file
12
password/password.go
Normal file
@ -0,0 +1,12 @@
|
||||
package password
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
func CheckPasswordHash(hash, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
}
|
72
server/auth.go
Normal file
72
server/auth.go
Normal file
@ -0,0 +1,72 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-session/session"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type UserHandler func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth)
|
||||
|
||||
type UserAuth struct {
|
||||
ID uuid.UUID
|
||||
Session session.Store
|
||||
}
|
||||
|
||||
func (u UserAuth) IsGuest() bool {
|
||||
return u.ID == uuid.Nil
|
||||
}
|
||||
|
||||
func (h *HttpServer) RequireAuthentication(error string, code int, next UserHandler) httprouter.Handle {
|
||||
return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
if auth.IsGuest() {
|
||||
http.Error(rw, error, code)
|
||||
return
|
||||
}
|
||||
next(rw, req, params, auth)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HttpServer) RequireAuthenticationRedirect(redirect string, code int, next UserHandler) httprouter.Handle {
|
||||
return h.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
if auth.IsGuest() {
|
||||
http.Redirect(rw, req, redirect, code)
|
||||
return
|
||||
}
|
||||
next(rw, req, params, auth)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HttpServer) OptionalAuthentication(next UserHandler) httprouter.Handle {
|
||||
return func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
auth, err := h.internalAuthenticationHandler(rw, req)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
next(rw, req, params, auth)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HttpServer) internalAuthenticationHandler(rw http.ResponseWriter, req *http.Request) (UserAuth, error) {
|
||||
ss, err := session.Start(req.Context(), rw, req)
|
||||
if err != nil {
|
||||
return UserAuth{}, fmt.Errorf("failed to start session")
|
||||
}
|
||||
|
||||
userIdRaw, ok := ss.Get("user")
|
||||
if !ok {
|
||||
return UserAuth{Session: ss}, nil
|
||||
}
|
||||
userId, ok := userIdRaw.(uuid.UUID)
|
||||
if !ok {
|
||||
ss.Delete("user")
|
||||
err := ss.Save()
|
||||
if err != nil {
|
||||
return UserAuth{Session: ss}, fmt.Errorf("failed to reset invalid session data")
|
||||
}
|
||||
}
|
||||
return UserAuth{ID: userId, Session: ss}, nil
|
||||
}
|
11
server/auth_test.go
Normal file
11
server/auth_test.go
Normal file
@ -0,0 +1,11 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUserAuth_IsGuest(t *testing.T) {
|
||||
var u UserAuth
|
||||
assert.True(t, u.IsGuest())
|
||||
}
|
30
server/db.go
Normal file
30
server/db.go
Normal file
@ -0,0 +1,30 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/1f349/tulip/database"
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (h *HttpServer) dbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
|
||||
tx, err := h.db.Begin()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
err = action(tx)
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
log.Println("Database action error:", err)
|
||||
return true
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
http.Error(rw, "Database error", http.StatusInternalServerError)
|
||||
log.Println("Database commit error:", err)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
156
server/oauth.go
Normal file
156
server/oauth.go
Normal file
@ -0,0 +1,156 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/go-session/session"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
func (h *HttpServer) authorizeEndpoint(rw http.ResponseWriter, req *http.Request, _ httprouter.Params) {
|
||||
ss, err := session.Start(req.Context(), rw, req)
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to load session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := h.oauthSrv.UserAuthorizationHandler(rw, req)
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to check user", http.StatusInternalServerError)
|
||||
return
|
||||
} else if userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// function is only called with GET or POST method
|
||||
isPost := req.Method == http.MethodPost
|
||||
|
||||
var form url.Values
|
||||
if isPost {
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to parse form", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
form = req.PostForm
|
||||
} else {
|
||||
form = req.URL.Query()
|
||||
}
|
||||
|
||||
clientID := form.Get("client_id")
|
||||
client, err := h.oauthMgr.GetClient(req.Context(), clientID)
|
||||
if err != nil {
|
||||
http.Error(rw, "Invalid client", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectUri := form.Get("redirect_uri")
|
||||
if redirectUri != client.GetDomain() {
|
||||
http.Error(rw, "Incorrect redirect URI", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if form.Has("cancel") {
|
||||
uCancel, err := url.Parse(client.GetDomain())
|
||||
if err != nil {
|
||||
http.Error(rw, "Invalid redirect URI", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
q := uCancel.Query()
|
||||
q.Set("error", "access_denied")
|
||||
uCancel.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(rw, req, uCancel.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
var isSSO bool
|
||||
if clientIsSSO, ok := client.(interface{ IsSSO() bool }); ok {
|
||||
isSSO = clientIsSSO.IsSSO()
|
||||
}
|
||||
|
||||
switch {
|
||||
case isSSO && isPost:
|
||||
http.Error(rw, "400 Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
case !isSSO && !isPost:
|
||||
f := func(key string) string { return form.Get(key) }
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(rw, `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authorize</title></head>
|
||||
<body>
|
||||
<form method="POST" action="/authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="scope" value="%s">
|
||||
<input type="hidden" name="state" value="%s">
|
||||
<input type="hidden" name="nonce" value="%s">
|
||||
<input type="hidden" name="response_type" value="%s">
|
||||
<input type="hidden" name="response_mode" value="%s">
|
||||
<div>Scope: %s</div>
|
||||
<div><button type="submit">Authorize</button></div>
|
||||
<div><button type="submit" name="cancel" value="">Cancel</button></div>
|
||||
</form>
|
||||
</html>`, clientID, redirectUri, f("scope"), f("state"), f("nonce"), f("response_type"), f("response_mode"), f("scope"))
|
||||
return
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
// continue flow
|
||||
oauthDataRaw, ok := ss.Get("OAuthData")
|
||||
if ok {
|
||||
ss.Delete("OAuthData")
|
||||
if ss.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
oauthData, ok := oauthDataRaw.(url.Values)
|
||||
if !ok {
|
||||
http.Error(rw, "Failed to load session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.URL.RawQuery = oauthData.Encode()
|
||||
}
|
||||
|
||||
if err := h.oauthSrv.HandleAuthorizeRequest(rw, req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HttpServer) oauthUserAuthorization(rw http.ResponseWriter, req *http.Request) (string, error) {
|
||||
err := req.ParseForm()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
auth, err := h.internalAuthenticationHandler(rw, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if auth.IsGuest() {
|
||||
// handle redirecting to oauth
|
||||
var q url.Values
|
||||
switch req.Method {
|
||||
case http.MethodPost:
|
||||
q = req.PostForm
|
||||
case http.MethodGet:
|
||||
q = req.URL.Query()
|
||||
default:
|
||||
http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed)
|
||||
return "", err
|
||||
}
|
||||
auth.Session.Set("OAuthData", q)
|
||||
if auth.Session.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return "", err
|
||||
}
|
||||
http.Redirect(rw, req, "/login?redirect=oauth", http.StatusFound)
|
||||
return "", nil
|
||||
}
|
||||
return auth.ID.String(), nil
|
||||
}
|
291
server/server.go
Normal file
291
server/server.go
Normal file
@ -0,0 +1,291 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
errors2 "errors"
|
||||
"fmt"
|
||||
"github.com/1f349/tulip/database"
|
||||
"github.com/1f349/tulip/openid"
|
||||
"github.com/1f349/tulip/pages"
|
||||
"github.com/go-oauth2/oauth2/v4"
|
||||
"github.com/go-oauth2/oauth2/v4/errors"
|
||||
"github.com/go-oauth2/oauth2/v4/generates"
|
||||
"github.com/go-oauth2/oauth2/v4/manage"
|
||||
"github.com/go-oauth2/oauth2/v4/server"
|
||||
"github.com/go-oauth2/oauth2/v4/store"
|
||||
"github.com/google/uuid"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errMissingRequiredScope = errors.New("missing required scope")
|
||||
|
||||
type HttpServer struct {
|
||||
r *httprouter.Router
|
||||
oauthSrv *server.Server
|
||||
oauthMgr *manage.Manager
|
||||
db *database.DB
|
||||
domain string
|
||||
privKey []byte
|
||||
}
|
||||
|
||||
func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clientStore oauth2.ClientStore) *http.Server {
|
||||
r := httprouter.New()
|
||||
|
||||
openIdConf := openid.GenConfig(domain, []string{"openid", "email"}, []string{"sub", "name", "preferred_username", "profile", "picture", "website", "email", "email_verified", "gender", "birthdate", "zoneinfo", "locale", "updated_at"})
|
||||
openIdBytes, err := json.Marshal(openIdConf)
|
||||
if err != nil {
|
||||
log.Fatalln("Failed to generate OpenID configuration:", err)
|
||||
}
|
||||
|
||||
if err := pages.LoadPageTemplates(); err != nil {
|
||||
log.Fatalln("Failed to load page templates:", err)
|
||||
}
|
||||
|
||||
oauthManager := manage.NewDefaultManager()
|
||||
oauthSrv := server.NewServer(server.NewConfig(), oauthManager)
|
||||
hs := &HttpServer{
|
||||
r: httprouter.New(),
|
||||
oauthSrv: oauthSrv,
|
||||
oauthMgr: oauthManager,
|
||||
db: db,
|
||||
domain: domain,
|
||||
privKey: privKey,
|
||||
}
|
||||
|
||||
oauthManager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||
oauthManager.MustTokenStorage(store.NewMemoryTokenStore())
|
||||
oauthManager.MapAccessGenerate(generates.NewAccessGenerate())
|
||||
oauthManager.MapClientStorage(clientStore)
|
||||
|
||||
oauthSrv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||
log.Printf("Response error: %#v\n", re)
|
||||
})
|
||||
oauthSrv.SetClientInfoHandler(func(req *http.Request) (clientID, clientSecret string, err error) {
|
||||
cId, cSecret, err := server.ClientBasicHandler(req)
|
||||
if cId == "" && cSecret == "" {
|
||||
cId, cSecret, err = server.ClientFormHandler(req)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return cId, cSecret, nil
|
||||
})
|
||||
oauthSrv.SetUserAuthorizationHandler(hs.oauthUserAuthorization)
|
||||
oauthSrv.SetAuthorizeScopeHandler(func(rw http.ResponseWriter, req *http.Request) (scope string, err error) {
|
||||
var form url.Values
|
||||
if req.Method == http.MethodPost {
|
||||
form = req.PostForm
|
||||
} else {
|
||||
form = req.URL.Query()
|
||||
}
|
||||
a := form.Get("scope")
|
||||
if a != "openid" {
|
||||
return "", errMissingRequiredScope
|
||||
}
|
||||
return "openid", nil
|
||||
})
|
||||
|
||||
newUserUuid := uuid.New()
|
||||
fmt.Println("New User Uuid:", newUserUuid.String())
|
||||
|
||||
r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, _ = rw.Write(openIdBytes)
|
||||
})
|
||||
r.GET("/", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
if auth.IsGuest() {
|
||||
_ = pages.RenderPageTemplate(rw, "index-guest", nil)
|
||||
return
|
||||
}
|
||||
|
||||
lNonce := uuid.NewString()
|
||||
auth.Session.Set("action-nonce", lNonce)
|
||||
if auth.Session.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
hs.dbTx(rw, func(tx *database.Tx) error {
|
||||
userWithName, err := tx.GetUserDisplayName(auth.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user display name: %w", err)
|
||||
}
|
||||
_ = pages.RenderPageTemplate(rw, "index", map[string]any{
|
||||
"Auth": auth,
|
||||
"User": userWithName,
|
||||
"Nonce": lNonce,
|
||||
})
|
||||
return nil
|
||||
})
|
||||
}))
|
||||
r.POST("/logout", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
lNonce, ok := auth.Session.Get("action-nonce")
|
||||
if !ok {
|
||||
http.Error(rw, "Missing nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if subtle.ConstantTimeCompare([]byte(lNonce.(string)), []byte(req.PostFormValue("nonce"))) == 1 {
|
||||
auth.Session.Delete("user")
|
||||
if auth.Session.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
http.Error(rw, "Logout failed", http.StatusInternalServerError)
|
||||
}))
|
||||
r.GET("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
if !auth.IsGuest() {
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_ = pages.RenderPageTemplate(rw, "login", nil)
|
||||
}))
|
||||
r.POST("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
un := req.FormValue("username")
|
||||
pw := req.FormValue("password")
|
||||
var userSub uuid.UUID
|
||||
if hs.dbTx(rw, func(tx *database.Tx) error {
|
||||
loginUser, err := tx.CheckLogin(un, pw)
|
||||
if err != nil {
|
||||
if errors2.Is(err, sql.ErrNoRows) || errors2.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
|
||||
http.Redirect(rw, req, "/login?mismatch=1", http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
http.Error(rw, "Internal server error", http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
userSub = loginUser.Sub
|
||||
return nil
|
||||
}) {
|
||||
return
|
||||
}
|
||||
|
||||
// only continues if the above tx succeeds
|
||||
auth.Session.Set("user", userSub)
|
||||
if auth.Session.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.URL.Query().Get("redirect") {
|
||||
case "oauth":
|
||||
oauthDataRaw, ok := auth.Session.Get("OAuthData")
|
||||
if !ok {
|
||||
http.Error(rw, "Failed to load session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
oauthData, ok := oauthDataRaw.(url.Values)
|
||||
if !ok {
|
||||
http.Error(rw, "Failed to load session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
authUrl := url.URL{Path: "/authorize", RawQuery: oauthData.Encode()}
|
||||
http.Redirect(rw, req, authUrl.String(), http.StatusFound)
|
||||
default:
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
}
|
||||
}))
|
||||
r.GET("/authorize", hs.authorizeEndpoint)
|
||||
r.POST("/authorize", hs.authorizeEndpoint)
|
||||
r.POST("/token", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
if err := oauthSrv.HandleTokenRequest(rw, req); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
r.GET("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
begin, err := db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
user, err := begin.GetUser(auth.ID)
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to read user data", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
lNonce := uuid.NewString()
|
||||
auth.Session.Set("action-nonce", lNonce)
|
||||
if auth.Session.Save() != nil {
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
_ = pages.RenderPageTemplate(rw, "edit", map[string]any{
|
||||
"User": user,
|
||||
"Nonce": lNonce,
|
||||
})
|
||||
}))
|
||||
r.POST("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
|
||||
if req.ParseForm() != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// TODO: parse user patch from form
|
||||
req.Form.Get("")
|
||||
var patch database.UserPatch
|
||||
decoder := json.NewDecoder(req.Body)
|
||||
decoder.DisallowUnknownFields()
|
||||
err := decoder.Decode(&patch)
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
begin, err := db.Begin()
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if begin.ModifyUser(auth.ID, &patch) != nil {
|
||||
http.Error(rw, "Failed to modify user info", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
}))
|
||||
r.GET("/userinfo", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
|
||||
token, err := oauthSrv.ValidationBearerToken(req)
|
||||
if err != nil {
|
||||
http.Error(rw, "403 Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
fmt.Printf("Using token for user: %s by app: %s with scope: '%s'\n", token.GetUserID(), token.GetClientID(), token.GetScope())
|
||||
_ = json.NewEncoder(rw).Encode(map[string]any{
|
||||
"sub": token.GetUserID(),
|
||||
"aud": token.GetClientID(),
|
||||
"name": "Melon",
|
||||
"preferred_username": "melon",
|
||||
"profile": "https://" + domain + "/user/melon",
|
||||
"picture": "https://" + domain + "/picture/melon.svg",
|
||||
"website": "https://mrmelon54.com",
|
||||
"email": "melon@mrmelon54.com",
|
||||
"email_verified": true,
|
||||
"gender": "male",
|
||||
"birthdate": time.Now().Format(time.DateOnly),
|
||||
"zoneinfo": "Europe/London",
|
||||
"locale": "en-GB",
|
||||
"updated_at": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
return &http.Server{
|
||||
Addr: listen,
|
||||
Handler: r,
|
||||
ReadTimeout: time.Minute,
|
||||
ReadHeaderTimeout: time.Minute,
|
||||
WriteTimeout: time.Minute,
|
||||
IdleTimeout: time.Minute,
|
||||
MaxHeaderBytes: 2500,
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user