dendrite/setup/mscs/msc2836/storage.go
Till 699f5ca8c1
More rows.Close() and rows.Err() (#3262)
Looks like we missed some `rows.Close()`

Even though `rows.Err()` is mostly not necessary, we should be more
consistent in the DB layer.

[skip ci]
2023-11-09 08:42:33 +01:00

367 lines
12 KiB
Go

package msc2836
import (
"bytes"
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
type eventInfo struct {
EventID string
OriginServerTS spec.Timestamp
RoomID string
}
type Database interface {
// StoreRelation stores the parent->child and child->parent relationship for later querying.
// Also stores the event metadata e.g timestamp
StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error
// ChildrenForParent returns the events who have the given `eventID` as an m.relationship with the
// provided `relType`. The returned slice is sorted by origin_server_ts according to whether
// `recentFirst` is true or false.
ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error)
// ParentForChild returns the parent event for the given child `eventID`. The eventInfo should be nil if
// there is no parent for this child event, with no error. The parent eventInfo can be missing the
// timestamp if the event is not known to the server.
ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error)
// UpdateChildMetadata persists the children_count and children_hash from this event if and only if
// the count is greater than what was previously there. If the count is updated, the event will be
// updated to be unexplored.
UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error
// ChildMetadata returns the children_count and children_hash for the event ID in question.
// Also returns the `explored` flag, which is set to true when MarkChildrenExplored is called and is set
// back to `false` when a larger count is inserted via UpdateChildMetadata.
// Returns nil error if the event ID does not exist.
ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error)
// MarkChildrenExplored sets the 'explored' flag on this event to `true`.
MarkChildrenExplored(ctx context.Context, eventID string) error
}
type DB struct {
db *sql.DB
writer sqlutil.Writer
insertEdgeStmt *sql.Stmt
insertNodeStmt *sql.Stmt
selectChildrenForParentOldestFirstStmt *sql.Stmt
selectChildrenForParentRecentFirstStmt *sql.Stmt
selectParentForChildStmt *sql.Stmt
updateChildMetadataStmt *sql.Stmt
selectChildMetadataStmt *sql.Stmt
updateChildMetadataExploredStmt *sql.Stmt
}
// NewDatabase loads the database for msc2836
func NewDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
if dbOpts.ConnectionString.IsPostgres() {
return newPostgresDatabase(conMan, dbOpts)
}
return newSQLiteDatabase(conMan, dbOpts)
}
func newPostgresDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{}
var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2836_edges (
parent_event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
rel_type TEXT NOT NULL,
parent_room_id TEXT NOT NULL,
parent_servers TEXT NOT NULL,
CONSTRAINT msc2836_edges_uniq UNIQUE (parent_event_id, child_event_id, rel_type)
);
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL,
unsigned_children_count BIGINT NOT NULL,
unsigned_children_hash TEXT NOT NULL,
explored SMALLINT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
VALUES($1, $2, $3, $4, $5, $6)
ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
selectChildrenQuery := `
SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
WHERE parent_event_id = $1 AND rel_type = $2
ORDER BY origin_server_ts
`
if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
return nil, err
}
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err
}
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges
WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
if d.updateChildMetadataStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
`); err != nil {
return nil, err
}
if d.selectChildMetadataStmt, err = d.db.Prepare(`
SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
`); err != nil {
return nil, err
}
if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
`); err != nil {
return nil, err
}
return &d, err
}
func newSQLiteDatabase(conMan *sqlutil.Connections, dbOpts *config.DatabaseOptions) (Database, error) {
d := DB{}
var err error
if d.db, d.writer, err = conMan.Connection(dbOpts); err != nil {
return nil, err
}
_, err = d.db.Exec(`
CREATE TABLE IF NOT EXISTS msc2836_edges (
parent_event_id TEXT NOT NULL,
child_event_id TEXT NOT NULL,
rel_type TEXT NOT NULL,
parent_room_id TEXT NOT NULL,
parent_servers TEXT NOT NULL,
UNIQUE (parent_event_id, child_event_id, rel_type)
);
CREATE TABLE IF NOT EXISTS msc2836_nodes (
event_id TEXT PRIMARY KEY NOT NULL,
origin_server_ts BIGINT NOT NULL,
room_id TEXT NOT NULL,
unsigned_children_count BIGINT NOT NULL,
unsigned_children_hash TEXT NOT NULL,
explored SMALLINT NOT NULL
);
`)
if err != nil {
return nil, err
}
if d.insertEdgeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_edges(parent_event_id, child_event_id, rel_type, parent_room_id, parent_servers)
VALUES($1, $2, $3, $4, $5)
ON CONFLICT (parent_event_id, child_event_id, rel_type) DO NOTHING
`); err != nil {
return nil, err
}
if d.insertNodeStmt, err = d.db.Prepare(`
INSERT INTO msc2836_nodes(event_id, origin_server_ts, room_id, unsigned_children_count, unsigned_children_hash, explored)
VALUES($1, $2, $3, $4, $5, $6)
ON CONFLICT DO NOTHING
`); err != nil {
return nil, err
}
selectChildrenQuery := `
SELECT child_event_id, origin_server_ts, room_id FROM msc2836_edges
LEFT JOIN msc2836_nodes ON msc2836_edges.child_event_id = msc2836_nodes.event_id
WHERE parent_event_id = $1 AND rel_type = $2
ORDER BY origin_server_ts
`
if d.selectChildrenForParentOldestFirstStmt, err = d.db.Prepare(selectChildrenQuery + "ASC"); err != nil {
return nil, err
}
if d.selectChildrenForParentRecentFirstStmt, err = d.db.Prepare(selectChildrenQuery + "DESC"); err != nil {
return nil, err
}
if d.selectParentForChildStmt, err = d.db.Prepare(`
SELECT parent_event_id, parent_room_id FROM msc2836_edges
WHERE child_event_id = $1 AND rel_type = $2
`); err != nil {
return nil, err
}
if d.updateChildMetadataStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET unsigned_children_count=$1, unsigned_children_hash=$2, explored=$3 WHERE event_id=$4
`); err != nil {
return nil, err
}
if d.selectChildMetadataStmt, err = d.db.Prepare(`
SELECT unsigned_children_count, unsigned_children_hash, explored FROM msc2836_nodes WHERE event_id=$1
`); err != nil {
return nil, err
}
if d.updateChildMetadataExploredStmt, err = d.db.Prepare(`
UPDATE msc2836_nodes SET explored=$1 WHERE event_id=$2
`); err != nil {
return nil, err
}
return &d, nil
}
func (p *DB) StoreRelation(ctx context.Context, ev *types.HeaderedEvent) error {
parent, child, relType := parentChildEventIDs(ev)
if parent == "" || child == "" {
return nil
}
relationRoomID, relationServers := roomIDAndServers(ev)
relationServersJSON, err := json.Marshal(relationServers)
if err != nil {
return err
}
count, hash := extractChildMetadata(ev)
return p.writer.Do(p.db, nil, func(txn *sql.Tx) error {
_, err := txn.Stmt(p.insertEdgeStmt).ExecContext(ctx, parent, child, relType, relationRoomID, string(relationServersJSON))
if err != nil {
return err
}
util.GetLogger(ctx).Infof("StoreRelation child=%s parent=%s rel_type=%s", child, parent, relType)
_, err = txn.Stmt(p.insertNodeStmt).ExecContext(ctx, ev.EventID(), ev.OriginServerTS(), ev.RoomID().String(), count, base64.RawStdEncoding.EncodeToString(hash), 0)
return err
})
}
func (p *DB) UpdateChildMetadata(ctx context.Context, ev *types.HeaderedEvent) error {
eventCount, eventHash := extractChildMetadata(ev)
if eventCount == 0 {
return nil // nothing to update with
}
// extract current children count/hash, if they are less than the current event then update the columns and set to unexplored
count, hash, _, err := p.ChildMetadata(ctx, ev.EventID())
if err != nil {
return err
}
if eventCount > count || (eventCount == count && !bytes.Equal(hash, eventHash)) {
_, err = p.updateChildMetadataStmt.ExecContext(ctx, eventCount, base64.RawStdEncoding.EncodeToString(eventHash), 0, ev.EventID())
return err
}
return nil
}
func (p *DB) ChildMetadata(ctx context.Context, eventID string) (count int, hash []byte, explored bool, err error) {
var b64hash string
var exploredInt int
if err = p.selectChildMetadataStmt.QueryRowContext(ctx, eventID).Scan(&count, &b64hash, &exploredInt); err != nil {
if err == sql.ErrNoRows {
err = nil
}
return
}
hash, err = base64.RawStdEncoding.DecodeString(b64hash)
explored = exploredInt > 0
return
}
func (p *DB) MarkChildrenExplored(ctx context.Context, eventID string) error {
_, err := p.updateChildMetadataExploredStmt.ExecContext(ctx, 1, eventID)
return err
}
func (p *DB) ChildrenForParent(ctx context.Context, eventID, relType string, recentFirst bool) ([]eventInfo, error) {
var rows *sql.Rows
var err error
if recentFirst {
rows, err = p.selectChildrenForParentRecentFirstStmt.QueryContext(ctx, eventID, relType)
} else {
rows, err = p.selectChildrenForParentOldestFirstStmt.QueryContext(ctx, eventID, relType)
}
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
var children []eventInfo
for rows.Next() {
var evInfo eventInfo
if err := rows.Scan(&evInfo.EventID, &evInfo.OriginServerTS, &evInfo.RoomID); err != nil {
return nil, err
}
children = append(children, evInfo)
}
return children, rows.Err()
}
func (p *DB) ParentForChild(ctx context.Context, eventID, relType string) (*eventInfo, error) {
var ei eventInfo
err := p.selectParentForChildStmt.QueryRowContext(ctx, eventID, relType).Scan(&ei.EventID, &ei.RoomID)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &ei, nil
}
func parentChildEventIDs(ev *types.HeaderedEvent) (parent, child, relType string) {
if ev == nil {
return
}
body := struct {
Relationship struct {
RelType string `json:"rel_type"`
EventID string `json:"event_id"`
} `json:"m.relationship"`
}{}
if err := json.Unmarshal(ev.Content(), &body); err != nil {
return
}
if body.Relationship.EventID == "" || body.Relationship.RelType == "" {
return
}
return body.Relationship.EventID, ev.EventID(), body.Relationship.RelType
}
func roomIDAndServers(ev *types.HeaderedEvent) (roomID string, servers []string) {
servers = []string{}
if ev == nil {
return
}
body := struct {
RoomID string `json:"relationship_room_id"`
Servers []string `json:"relationship_servers"`
}{}
if err := json.Unmarshal(ev.Unsigned(), &body); err != nil {
return
}
return body.RoomID, body.Servers
}
func extractChildMetadata(ev *types.HeaderedEvent) (count int, hash []byte) {
unsigned := struct {
Counts map[string]int `json:"children"`
Hash spec.Base64Bytes `json:"children_hash"`
}{}
if err := json.Unmarshal(ev.Unsigned(), &unsigned); err != nil {
// expected if there is no unsigned field at all
return
}
for _, c := range unsigned.Counts {
count += c
}
hash = unsigned.Hash
return
}