diff --git a/cmd/tulip/serve.go b/cmd/tulip/serve.go index 84045a7..213ea45 100644 --- a/cmd/tulip/serve.go +++ b/cmd/tulip/serve.go @@ -91,7 +91,8 @@ func normalLoad(startUp startUpConfig, wd string) { exit_reload.ExitReload("Tulip", func() {}, func() { // stop http server - srv.Close() + _ = srv.Close() + _ = db.Close() }) } @@ -112,6 +113,7 @@ func checkDbHasUser(db *database.DB) error { if err != nil { return fmt.Errorf("failed to start transaction: %w", err) } + defer tx.Rollback() if err := tx.HasUser(); err != nil { if errors.Is(err, sql.ErrNoRows) { err := tx.InsertUser("Admin", "admin", "admin", "admin@localhost") diff --git a/database/db-types.go b/database/db-types.go index 4ed05b3..e80e73a 100644 --- a/database/db-types.go +++ b/database/db-types.go @@ -2,6 +2,7 @@ package database import ( "database/sql" + "fmt" "github.com/MrMelon54/pronouns" "github.com/google/uuid" "golang.org/x/text/language" @@ -36,7 +37,8 @@ type UserPatch struct { Locale language.Tag } -func (u *UserPatch) ParseFromForm(v url.Values) (err error) { +func (u *UserPatch) ParseFromForm(v url.Values) (safeErrs []error) { + var err error u.Name = v.Get("name") u.Picture = v.Get("picture") u.Website = v.Get("website") @@ -45,16 +47,16 @@ func (u *UserPatch) ParseFromForm(v url.Values) (err error) { } else { u.Pronouns, err = pronouns.FindPronoun(v.Get("pronouns")) if err != nil { - return err + safeErrs = append(safeErrs, fmt.Errorf("invalid pronoun selected")) } } - if v.Has("reset_birthdate") { + if v.Has("reset_birthdate") || v.Get("birthdate") == "" { u.Birthdate = sql.NullTime{} } else { u.Birthdate = sql.NullTime{Valid: true} u.Birthdate.Time, err = time.Parse(time.DateOnly, v.Get("birthdate")) if err != nil { - return err + safeErrs = append(safeErrs, fmt.Errorf("invalid time selected")) } } if v.Has("reset_zoneinfo") { @@ -62,7 +64,7 @@ func (u *UserPatch) ParseFromForm(v url.Values) (err error) { } else { u.ZoneInfo, err = time.LoadLocation(v.Get("zoneinfo")) if err != nil { - return err + safeErrs = append(safeErrs, fmt.Errorf("invalid timezone selected")) } } if v.Has("reset_locale") { @@ -70,8 +72,8 @@ func (u *UserPatch) ParseFromForm(v url.Values) (err error) { } else { u.Locale, err = language.Parse(v.Get("locale")) if err != nil { - return err + safeErrs = append(safeErrs, fmt.Errorf("invalid language selected")) } } - return nil + return } diff --git a/database/db.go b/database/db.go index daf2c62..0e38862 100644 --- a/database/db.go +++ b/database/db.go @@ -35,3 +35,7 @@ func (d *DB) BeginCtx(ctx context.Context) (*Tx, error) { } return &Tx{begin}, err } + +func (d *DB) Close() error { + return d.db.Close() +} diff --git a/pages/edit.go.html b/pages/edit.go.html index cc01ec4..31bace3 100644 --- a/pages/edit.go.html +++ b/pages/edit.go.html @@ -22,16 +22,16 @@
- +
@@ -58,7 +58,7 @@ {{end}} - + diff --git a/server/server.go b/server/server.go index c101b8e..19d01b8 100644 --- a/server/server.go +++ b/server/server.go @@ -234,6 +234,7 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien if err := pages.RenderPageTemplate(rw, "edit", map[string]any{ "User": user, "Nonce": lNonce, + "FieldPronoun": user.Pronouns.String(), "ListZoneInfo": lists.ListZoneInfo(), "ListLocale": lists.ListLocale(), }); err != nil { @@ -243,13 +244,22 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien r.POST("/edit", hs.RequireAuthentication("403 Forbidden", http.StatusForbidden, func(rw http.ResponseWriter, req *http.Request, params httprouter.Params, auth UserAuth) { if req.ParseForm() != nil { rw.WriteHeader(http.StatusBadRequest) + _, _ = rw.Write([]byte("400 Bad Request\n")) return } var patch database.UserPatch - err := patch.ParseFromForm(req.Form) - if err != nil { + errs := patch.ParseFromForm(req.Form) + if len(errs) > 0 { rw.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintln(rw, "\n\n") + _, _ = fmt.Fprintln(rw, "

400 Bad Request: Failed to parse form data, press the back button in your browser, check your inputs and try again.

") + _, _ = fmt.Fprintln(rw, "") + _, _ = fmt.Fprintln(rw, "\n") return } if hs.DbTx(rw, func(tx *database.Tx) error { @@ -260,7 +270,7 @@ func NewHttpServer(listen, domain string, db *database.DB, privKey []byte, clien }) { return } - http.Redirect(rw, req, "/", http.StatusFound) + http.Redirect(rw, req, "/edit", http.StatusFound) })) r.GET("/userinfo", func(rw http.ResponseWriter, req *http.Request, params httprouter.Params) { token, err := oauthSrv.ValidationBearerToken(req)