Improve underlying types and start writing edit page

This commit is contained in:
Melon 2023-09-07 11:45:16 +01:00
parent 703f3d17cd
commit ece74ea36a
Signed by: melon
GPG Key ID: 6C9D970C50D26A25
11 changed files with 181 additions and 203 deletions

View File

@ -114,7 +114,7 @@ func checkDbHasUser(db *database.DB) error {
} }
if err := tx.HasUser(); err != nil { if err := tx.HasUser(); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err := tx.InsertUser("admin", "admin", "admin@localhost") err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost")
if err != nil { if err != nil {
return fmt.Errorf("failed to add user: %w", err) return fmt.Errorf("failed to add user: %w", err)
} }

View File

@ -25,7 +25,7 @@ func marshalValueOrNull(null bool, data any) ([]byte, error) {
type NullStringScanner struct{ sql.NullString } type NullStringScanner struct{ sql.NullString }
func (s *NullStringScanner) Null() bool { return !s.Valid } func (s *NullStringScanner) Null() bool { return !s.Valid }
func (s *NullStringScanner) Scan(src any) error { return s.Scan(src) } func (s *NullStringScanner) Scan(src any) error { return s.NullString.Scan(src) }
func (s NullStringScanner) MarshalJSON() ([]byte, error) { func (s NullStringScanner) MarshalJSON() ([]byte, error) {
return marshalValueOrNull(s.Null(), s.String) return marshalValueOrNull(s.Null(), s.String)
} }
@ -75,6 +75,14 @@ func (l *LocationScanner) Scan(src any) error {
return nil return nil
} }
func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) } func (l LocationScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Location.String()) }
func (l *LocationScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
}
type LocaleScanner struct{ language.Tag } type LocaleScanner struct{ language.Tag }
@ -90,8 +98,14 @@ func (l *LocaleScanner) Scan(src any) error {
l.Tag = lang l.Tag = lang
return nil return nil
} }
func (l LocaleScanner) MarshalJSON() ([]byte, error) { func (l LocaleScanner) MarshalJSON() ([]byte, error) { return json.Marshal(l.Tag.String()) }
return json.Marshal(l.Tag.String()) func (l *LocaleScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return l.Scan(a)
} }
type PronounScanner struct{ pronouns.Pronoun } type PronounScanner struct{ pronouns.Pronoun }
@ -109,3 +123,11 @@ func (p *PronounScanner) Scan(src any) error {
return nil return nil
} }
func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) } func (p PronounScanner) MarshalJSON() ([]byte, error) { return json.Marshal(p.Pronoun.String()) }
func (p *PronounScanner) UnmarshalJSON(bytes []byte) error {
var a string
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
return p.Scan(a)
}

View File

@ -1,7 +1,7 @@
package database package database
import ( import (
"encoding/json" "database/sql"
"github.com/MrMelon54/pronouns" "github.com/MrMelon54/pronouns"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/text/language" "golang.org/x/text/language"
@ -27,79 +27,51 @@ type User struct {
} }
type UserPatch struct { type UserPatch struct {
Name NullStringScanner `json:"name"` Name string
Picture NullStringScanner `json:"picture"` Picture string
Website NullStringScanner `json:"website"` Website string
Pronouns PronounScanner `json:"pronouns"` Pronouns pronouns.Pronoun
Birthdate NullDateScanner `json:"birthdate"` Birthdate sql.NullTime
ZoneInfo *time.Location `json:"zoneinfo"` ZoneInfo *time.Location
Locale *language.Tag `json:"locale"` Locale language.Tag
} }
func (u *UserPatch) UnmarshalJSON(bytes []byte) error { func (u *UserPatch) ParseFromForm(v url.Values) (err error) {
var m struct { u.Name = v.Get("name")
Name string `json:"name"` u.Picture = v.Get("picture")
Picture string `json:"picture"` u.Website = v.Get("website")
Website string `json:"website"` if v.Has("reset_pronouns") {
Pronouns string `json:"pronouns"` u.Pronouns = pronouns.TheyThem
Birthdate string `json:"birthdate"` } else {
ZoneInfo string `json:"zoneinfo"` u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns"))
Locale string `json:"locale"`
}
err := json.Unmarshal(bytes, &m)
if err != nil {
return err
}
u.Name = m.Name
// only parse the picture address if included
if m.Picture != "" {
u.Picture, err = url.Parse(m.Picture)
if err != nil { if err != nil {
return err return err
} }
} }
if v.Has("reset_birthdate") {
// only parse the website address if included u.Birthdate = sql.NullTime{}
if m.Website != "" { } else {
u.Website, err = url.Parse(m.Website) u.Birthdate = sql.NullTime{Valid: true}
u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate"))
if err != nil { if err != nil {
return err return err
} }
} }
if v.Has("reset_zoneinfo") {
// only parse the pronouns if included u.ZoneInfo = time.UTC
if m.Pronouns != "" { } else {
u.Pronouns, err = pronouns.FindPronoun(m.Pronouns) u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo"))
if err != nil { if err != nil {
return err return err
} }
} }
if v.Has("reset_locale") {
// only parse the birthdate if included u.Locale = language.AmericanEnglish
if m.Birthdate != "" { } else {
u.Birthdate, err = time.Parse(time.DateOnly, m.Birthdate) u.Locale, err = language.Parse(v.Get("locale"))
if err != nil { if err != nil {
return err return err
} }
} }
// only parse the zoneinfo if included
if m.ZoneInfo != "" {
u.ZoneInfo, err = time.LoadLocation(m.ZoneInfo)
if err != nil {
return err
}
}
if m.Locale != "" {
locale, err := language.Parse(m.Locale)
if err != nil {
return err
}
u.Locale = &locale
}
return nil return nil
} }
var _ json.Unmarshaler = &UserPatch{}

View File

@ -1,77 +0,0 @@
package database
import (
"encoding/json"
"github.com/MrMelon54/pronouns"
"github.com/stretchr/testify/assert"
"maps"
"testing"
"time"
)
func TestUserPatch_UnmarshalJSON(t *testing.T) {
const a = `{
"name": "Test",
"picture": "https://example.com/logo.png",
"website": "https://example.com",
"gender": "robot",
"pronouns": "they/them",
"birthdate": "3070-01-01",
"zoneinfo": "Europe/London",
"locale": "en-GB"
}`
var p UserPatch
assert.NoError(t, json.Unmarshal([]byte(a), &p))
assert.Equal(t, "Test", p.Name)
assert.Equal(t, "https://example.com/logo.png", p.Picture.String())
assert.Equal(t, "https://example.com", p.Website.String())
assert.Equal(t, pronouns.TheyThem, p.Pronouns)
assert.Equal(t, time.Date(3070, time.January, 1, 0, 0, 0, 0, time.UTC), p.Birthdate)
location, err := time.LoadLocation("Europe/London")
assert.NoError(t, err)
assert.Equal(t, location, p.ZoneInfo)
assert.Equal(t, "en-GB", p.Locale.String())
}
func TestUserPatch_UnmarshalJSON2(t *testing.T) {
var userModifyChecks = map[string]struct{ valid, invalid []string }{
"picture": {valid: []string{"https://example.com/icon.png"}, invalid: []string{"%/icon.png"}},
"website": {valid: []string{"https://example.com"}, invalid: []string{"%/example.com"}},
"pronouns": {valid: []string{"he/him", "she/her"}, invalid: []string{"a/a"}},
"birthdate": {valid: []string{"2023-08-07", "2023-01-01"}, invalid: []string{"2023-00-00", "hello"}},
"zoneinfo": {
valid: []string{"Europe/London", "Europe/Berlin", "America/Los_Angeles", "America/Edmonton", "America/Montreal"},
invalid: []string{"Europe/York", "Canada/Edmonton", "hello"},
},
"locale": {valid: []string{"en-GB", "en-US", "zh-CN"}, invalid: []string{"en-YY"}},
}
m := map[string]string{
"name": "Test",
"picture": "https://example.com/logo.png",
"website": "https://example.com",
"gender": "robot",
"pronouns": "they/them",
"birthdate": "3070-01-01",
"zoneinfo": "Europe/London",
"locale": "en-GB",
}
for k, v := range userModifyChecks {
t.Run(k, func(t *testing.T) {
m2 := maps.Clone(m)
for _, i := range v.valid {
m2[k] = i
marshal, err := json.Marshal(m2)
assert.NoError(t, err)
var m3 UserPatch
assert.NoError(t, json.Unmarshal(marshal, &m3))
}
for _, i := range v.invalid {
m2[k] = i
marshal, err := json.Marshal(m2)
assert.NoError(t, err)
var m3 UserPatch
assert.Error(t, json.Unmarshal(marshal, &m3))
}
})
}
}

View File

@ -4,13 +4,13 @@ CREATE TABLE IF NOT EXISTS users
name TEXT NOT NULL, name TEXT NOT NULL,
username TEXT UNIQUE NOT NULL, username TEXT UNIQUE NOT NULL,
password TEXT NOT NULL, password TEXT NOT NULL,
picture TEXT, picture TEXT DEFAULT "" NOT NULL,
website TEXT, website TEXT DEFAULT "" NOT NULL,
email TEXT NOT NULL, email TEXT NOT NULL,
email_verified INTEGER DEFAULT 0 NOT NULL, email_verified INTEGER DEFAULT 0 NOT NULL,
pronouns TEXT DEFAULT "they/them" NOT NULL, pronouns TEXT DEFAULT "they/them" NOT NULL,
birthdate DATE, birthdate DATE,
zoneinfo TEXT DEFAULT "" NOT NULL, zoneinfo TEXT DEFAULT "UTC" NOT NULL,
locale TEXT DEFAULT "en-US" NOT NULL, locale TEXT DEFAULT "en-US" NOT NULL,
updated_at DATETIME, updated_at DATETIME,
active INTEGER DEFAULT 1 active INTEGER DEFAULT 1

View File

@ -32,12 +32,12 @@ func (t *Tx) HasUser() error {
return nil return nil
} }
func (t *Tx) InsertUser(un, pw, email string) error { func (t *Tx) InsertUser(name, un, pw, email string) error {
pwHash, err := password.HashPassword(pw) pwHash, err := password.HashPassword(pw)
if err != nil { if err != nil {
return err return err
} }
_, err = t.tx.Exec(`INSERT INTO users (subject, username, password, email) VALUES (?, ?, ?, ?)`, uuid.NewString(), un, pwHash, email) _, err = t.tx.Exec(`INSERT INTO users (subject, name, username, password, email) VALUES (?, ?, ?, ?, ?)`, uuid.NewString(), name, un, pwHash, email)
return err return err
} }
@ -113,20 +113,20 @@ func (t *Tx) ChangeUserPassword(sub uuid.UUID, pwOld, pwNew string) error {
func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error { func (t *Tx) ModifyUser(sub uuid.UUID, v *UserPatch) error {
exec, err := t.tx.Exec( exec, err := t.tx.Exec(
`UPDATE users `UPDATE users
SET name = ifnull(?, name), SET name = ?,
picture = ifnull(?, picture), picture = ?,
website = ifnull(?, website), website = ?,
pronouns = ifnull(?, pronouns), pronouns = ?,
birthdate = ifnull(?, birthdate), birthdate = ?,
zoneinfo = ifnull(?, zoneinfo), zoneinfo = ?,
locale = ifnull(?, locale), locale = ?,
updated_at = ? updated_at = ?
WHERE subject = ?`, WHERE subject = ?`,
v.Name, v.Name,
stringify(v.Picture), v.Picture,
stringify(v.Website), v.Website,
v.Pronouns.String(), v.Pronouns.String(),
sql.NullTime{Time: v.Birthdate, Valid: !v.Birthdate.IsZero()}, v.Birthdate,
v.ZoneInfo.String(), v.ZoneInfo.String(),
v.Locale.String(), v.Locale.String(),
time.Now().Format(time.DateTime), time.Now().Format(time.DateTime),
@ -164,14 +164,3 @@ func (c *clientInfoDbOutput) GetDomain() string { return c.domain }
func (c *clientInfoDbOutput) IsPublic() bool { return false } func (c *clientInfoDbOutput) IsPublic() bool { return false }
func (c *clientInfoDbOutput) GetUserID() string { return "" } func (c *clientInfoDbOutput) GetUserID() string { return "" }
func (c *clientInfoDbOutput) IsSSO() bool { return c.sso } func (c *clientInfoDbOutput) IsSSO() bool { return c.sso }
func stringify(stringer fmt.Stringer) sql.NullString {
if stringer == nil {
return sql.NullString{}
}
return emptyToNull(stringer.String())
}
func emptyToNull(a string) sql.NullString {
return sql.NullString{String: a, Valid: a != ""}
}

View File

@ -45,11 +45,8 @@ func TestTx_ModifyUser(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, tx.ModifyUser(u, &UserPatch{ assert.NoError(t, tx.ModifyUser(u, &UserPatch{
Name: "example", Name: "example",
Picture: nil, Pronouns: pronouns.TheyThem,
Website: nil, ZoneInfo: time.UTC,
Pronouns: pronouns.Pronoun{}, Locale: language.AmericanEnglish,
Birthdate: time.Time{},
ZoneInfo: nil,
Locale: &language.Tag{},
})) }))
} }

66
pages/edit.go.html Normal file
View File

@ -0,0 +1,66 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>1f349 ID</title>
</head>
<body>
<header>
<h1>1f349 ID</h1>
</header>
<main>
<div>Logged in as: {{.User.Name}} ({{.User.Sub}})</div>
<div>
<form method="POST" action="/edit">
<input type="hidden" name="nonce" value="{{.Nonce}}">
<div>
<label for="field_name">Name</label>
<input type="text" name="name" id="field_name" value="{{.User.Name}}">
</div>
<div>
<label for="field_picture">Picture</label>
<input type="text" name="picture" id="field_picture" value="{{.User.Picture}}">
</div>
<div>
<label for="field_website">Website</label>
<input type="text" name="website" id="field_website" value="{{.User.Picture}}">
</div>
<div>
<label for="field_pronouns">Pronouns</label>
<select name="pronouns" id="field_pronouns">
<option value="they/them" selected>They/Them</option>
<option value="he/him">He/Him</option>
<option value="she/her">She/Her</option>
<option value="it/its">It/Its</option>
<option value="one/one's">One/One's</option>
</select>
<label>Reset? <input type="checkbox" name="reset_pronouns"></label>
</div>
<div>
<label for="field_birthdate">Birthdate</label>
<input type="text" name="birthdate" id="field_birthdate" value="{{.User.Birthdate}}">
<label>Reset? <input type="checkbox" name="reset_birthdate"></label>
</div>
<div>
<label for="field_zoneinfo">Time Zone</label>
<input type="text" name="zoneinfo" id="field_zoneinfo" value="{{.User.ZoneInfo}}" list="list_zoneinfo">
<datalist id="list_zoneinfo">
<!-- Fill in -->
<option value="Europe/London"></option>
</datalist>
<label>Reset? <input type="checkbox" name="reset_zoneinfo"></label>
</div>
<div>
<label for="field_locale">Language</label>
<input type="text" name="locale" id="field_locale" value="{{.User.Locale}}" list="list_locale">
<datalist id="list_locale">
<!-- Fill in -->
<option value="en-US"></option>
</datalist>
<label>Reset? <input type="checkbox" name="reset_zoneinfo"></label>
</div>
<button type="submit">Edit</button>
</form>
</div>
</main>
</body>
</html>

View File

@ -8,9 +8,13 @@
<h1>1f349 ID</h1> <h1>1f349 ID</h1>
</header> </header>
<main> <main>
<div>Logged in as: {{.User.Name}} ({{.User.ID}})</div> <div>Logged in as: {{.User.Name}} ({{.User.Sub}})</div>
<div> <div>
<form method="POST" action="/logout"><input type="hidden" name="nonce" value="{{.Nonce}}"> <button onclick="location.href='/edit'">Edit Profile</button>
</div>
<div>
<form method="POST" action="/logout">
<input type="hidden" name="nonce" value="{{.Nonce}}">
<button type="submit">Log Out</button> <button type="submit">Log Out</button>
</form> </form>
</div> </div>

View File

@ -6,7 +6,10 @@ import (
"net/http" "net/http"
) )
func (h *HttpServer) dbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool { // DbTx wraps a database transaction with http error messages and a simple action
// function. If the action function returns an error the transaction will be
// rolled back. If there is no error then the transaction is committed.
func (h *HttpServer) DbTx(rw http.ResponseWriter, action func(tx *database.Tx) error) bool {
tx, err := h.db.Begin() tx, err := h.db.Begin()
if err != nil { if err != nil {
http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError) http.Error(rw, "Failed to begin database transaction", http.StatusInternalServerError)

View File

@ -93,9 +93,6 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien
return "openid", nil return "openid", nil
}) })
newUserUuid := uuid.New()
fmt.Println("New User Uuid:", newUserUuid.String())
r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { r.GET("/.well-known/openid-configuration", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) {
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_, _ = rw.Write(openIdBytes) _, _ = rw.Write(openIdBytes)
@ -115,16 +112,18 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien
return return
} }
hs.dbTx(rw, func(tx *database.Tx) error { hs.DbTx(rw, func(tx *database.Tx) error {
userWithName, err := tx.GetUserDisplayName(auth.ID) userWithName, err := tx.GetUserDisplayName(auth.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user display name: %w", err) return fmt.Errorf("failed to get user display name: %w", err)
} }
_ = pages.RenderPageTemplate(rw, "index", map[string]any{ if err := pages.RenderPageTemplate(rw, "index", map[string]any{
"Auth": auth, "Auth": auth,
"User": userWithName, "User": userWithName,
"Nonce": lNonce, "Nonce": lNonce,
}) }); err != nil {
log.Printf("Failed to render page: edit: %s\n", err)
}
return nil return nil
}) })
})) }))
@ -152,13 +151,15 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien
} }
rw.Header().Set("Content-Type", "text/html") rw.Header().Set("Content-Type", "text/html")
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
_ = pages.RenderPageTemplate(rw, "login", nil) if err := pages.RenderPageTemplate(rw, "login", nil); err != nil {
log.Printf("Failed to render page: edit: %s\n", err)
}
})) }))
r.POST("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { r.POST("/login", hs.OptionalAuthentication(func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
un := req.FormValue("username") un := req.FormValue("username")
pw := req.FormValue("password") pw := req.FormValue("password")
var userSub uuid.UUID var userSub uuid.UUID
if hs.dbTx(rw, func(tx *database.Tx) error { if hs.DbTx(rw, func(tx *database.Tx) error {
loginUser, err := tx.CheckLogin(un, pw) loginUser, err := tx.CheckLogin(un, pw)
if err != nil { if err != nil {
if errors2.Is(err, sql.ErrNoRows) || errors2.Is(err, bcrypt.ErrMismatchedHashAndPassword) { if errors2.Is(err, sql.ErrNoRows) || errors2.Is(err, bcrypt.ErrMismatchedHashAndPassword) {
@ -207,13 +208,16 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien
} }
}) })
r.GET("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { r.GET("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
begin, err := db.Begin() var user *database.User
if hs.DbTx(rw, func(tx *database.Tx) error {
var err error
user, err = tx.GetUser(auth.ID)
if err != nil { if err != nil {
return return fmt.Errorf("failed to read user data: %w", err)
} }
user, err := begin.GetUser(auth.ID) return nil
if err != nil { }) {
http.Error(rw, "Failed to read user data", http.StatusInternalServerError)
return return
} }
@ -223,33 +227,31 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien
http.Error(rw, "Failed to save session", http.StatusInternalServerError) http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return return
} }
_ = pages.RenderPageTemplate(rw, "edit", map[string]any{ if err := pages.RenderPageTemplate(rw, "edit", map[string]any{
"User": user, "User": user,
"Nonce": lNonce, "Nonce": lNonce,
}) }); err != nil {
log.Printf("Failed to render page: edit: %s\n", err)
}
})) }))
r.POST("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { r.POST("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) {
if req.ParseForm() != nil { if req.ParseForm() != nil {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)
return return
} }
// TODO: parse user patch from form
req.Form.Get("")
var patch database.UserPatch var patch database.UserPatch
decoder := json.NewDecoder(req.Body) err := patch.ParseFromForm(req.Form)
decoder.DisallowUnknownFields()
err := decoder.Decode(&patch)
if err != nil { if err != nil {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)
return return
} }
begin, err := db.Begin() if hs.DbTx(rw, func(tx *database.Tx) error {
if err != nil { if err := tx.ModifyUser(auth.ID, &patch); err != nil {
rw.WriteHeader(http.StatusBadRequest) return fmt.Errorf("failed to modify user info: %w", err)
return
} }
if begin.ModifyUser(auth.ID, &patch) != nil { return nil
http.Error(rw, "Failed to modify user info", http.StatusInternalServerError) }) {
return return
} }
http.Redirect(rw, req, "/", http.StatusFound) http.Redirect(rw, req, "/", http.StatusFound)