mirror of
https://github.com/1f349/dendrite.git
synced 2025-01-11 09:56:25 +00:00
Use returned ID from INSERT in create filter (#297)
This commit is contained in:
parent
f6bda82366
commit
c0271c2462
@ -15,8 +15,8 @@
|
|||||||
package accounts
|
package accounts
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
const filterSchema = `
|
const filterSchema = `
|
||||||
@ -41,13 +41,9 @@ const selectFilterSQL = "" +
|
|||||||
const insertFilterSQL = "" +
|
const insertFilterSQL = "" +
|
||||||
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
|
"INSERT INTO account_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
|
||||||
|
|
||||||
const findMaxIDSQL = "" +
|
|
||||||
"SELECT MAX(id) FROM account_filter WHERE localpart = $1"
|
|
||||||
|
|
||||||
type filterStatements struct {
|
type filterStatements struct {
|
||||||
selectFilterStmt *sql.Stmt
|
selectFilterStmt *sql.Stmt
|
||||||
insertFilterStmt *sql.Stmt
|
insertFilterStmt *sql.Stmt
|
||||||
findMaxIDStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
||||||
@ -61,10 +57,6 @@ func (s *filterStatements) prepare(db *sql.DB) (err error) {
|
|||||||
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if s.findMaxIDStmt, err = db.Prepare(findMaxIDSQL); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,14 +69,7 @@ func (s *filterStatements) selectFilter(
|
|||||||
|
|
||||||
func (s *filterStatements) insertFilter(
|
func (s *filterStatements) insertFilter(
|
||||||
ctx context.Context, filter string, localpart string,
|
ctx context.Context, filter string, localpart string,
|
||||||
) (err error) {
|
) (pos string, err error) {
|
||||||
_, err = s.insertFilterStmt.ExecContext(ctx, filter, localpart)
|
err = s.insertFilterStmt.QueryRowContext(ctx, filter, localpart).Scan(&pos)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *filterStatements) findMaxID(
|
|
||||||
ctx context.Context, localpart string,
|
|
||||||
) (id string, err error) {
|
|
||||||
err = s.findMaxIDStmt.QueryRowContext(ctx, localpart).Scan(&id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,7 @@ import (
|
|||||||
|
|
||||||
// Database represents an account database
|
// Database represents an account database
|
||||||
type Database struct {
|
type Database struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
common.PartitionOffsetStatements
|
common.PartitionOffsetStatements
|
||||||
accounts accountsStatements
|
accounts accountsStatements
|
||||||
profiles profilesStatements
|
profiles profilesStatements
|
||||||
@ -333,11 +333,7 @@ func (d *Database) GetFilter(
|
|||||||
func (d *Database) PutFilter(
|
func (d *Database) PutFilter(
|
||||||
ctx context.Context, localpart, filter string,
|
ctx context.Context, localpart, filter string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
err := d.filter.insertFilter(ctx, filter, localpart)
|
return d.filter.insertFilter(ctx, filter, localpart)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return d.filter.findMaxID(ctx, localpart)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckAccountAvailability checks if the username/localpart is already present in the database.
|
// CheckAccountAvailability checks if the username/localpart is already present in the database.
|
||||||
|
Loading…
Reference in New Issue
Block a user