diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index ca0c64fb..78ca4d6d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -66,28 +67,29 @@ type inviteEventsStatements struct { selectMaxInviteIDStmt *sql.Stmt } -func (s *inviteEventsStatements) prepare(db *sql.DB) (err error) { - _, err = db.Exec(inviteEventsSchema) +func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { + s := &inviteEventsStatements{} + _, err := db.Exec(inviteEventsSchema) if err != nil { - return + return nil, err } if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return + return nil, err } if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return + return nil, err } if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return + return nil, err } if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return + return nil, err } - return + return s, nil } -func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, +func (s *inviteEventsStatements) InsertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) @@ -105,7 +107,7 @@ func (s *inviteEventsStatements) insertInviteEvent( return } -func (s *inviteEventsStatements) deleteInviteEvent( +func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) error { _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) @@ -114,7 +116,7 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. -func (s *inviteEventsStatements) selectInviteEventsInRange( +func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) @@ -143,7 +145,7 @@ func (s *inviteEventsStatements) selectInviteEventsInRange( return result, rows.Err() } -func (s *inviteEventsStatements) selectMaxInviteID( +func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 90976168..9883c362 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -32,6 +32,7 @@ import ( _ "github.com/lib/pq" "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -54,10 +55,10 @@ type SyncServerDatasource struct { accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements - invites inviteEventsStatements eduCache *cache.EDUCache topology outputRoomEventsTopologyStatements backwardExtremities tables.BackwardsExtremities + shared *shared.Database } // NewSyncServerDatasource creates a new sync server database @@ -79,7 +80,8 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp if err = d.roomstate.prepare(d.db); err != nil { return nil, err } - if err = d.invites.prepare(d.db); err != nil { + invites, err := NewPostgresInvitesTable(d.db) + if err != nil { return nil, err } if err = d.topology.prepare(d.db); err != nil { @@ -90,6 +92,10 @@ func NewSyncServerDatasource(dbDataSourceName string, dbProperties common.DbProp return nil, err } d.eduCache = cache.New() + d.shared = &shared.Database{ + DB: d.db, + Invites: invites, + } return &d, nil } @@ -340,7 +346,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx( if maxAccountDataID > maxID { maxID = maxAccountDataID } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return 0, err } @@ -365,7 +371,7 @@ func (d *SyncServerDatasource) syncPositionTx( if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return sp, err } @@ -662,17 +668,14 @@ func (d *SyncServerDatasource) UpsertAccountData( func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (types.StreamPosition, error) { - return d.invites.insertInviteEvent(ctx, inviteEvent) +) (sp types.StreamPosition, err error) { + return d.shared.AddInviteEvent(ctx, inviteEvent) } func (d *SyncServerDatasource) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { - // TODO: Record that invite has been retired in a stream so that we can - // notify the user in an incremental sync. - err := d.invites.deleteInviteEvent(ctx, inviteEventID) - return err + return d.shared.RetireInviteEvent(ctx, inviteEventID) } func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { @@ -697,7 +700,7 @@ func (d *SyncServerDatasource) addInvitesToResponse( fromPos, toPos types.StreamPosition, res *types.Response, ) error { - invites, err := d.invites.selectInviteEventsInRange( + invites, err := d.shared.Invites.SelectInviteEventsInRange( ctx, txn, userID, fromPos, toPos, ) if err != nil { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go new file mode 100644 index 00000000..e89976df --- /dev/null +++ b/syncapi/storage/shared/syncserver.go @@ -0,0 +1,37 @@ +package shared + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database is a temporary struct until we have made syncserver.go the same for both pq/sqlite +// For now this contains the shared functions +type Database struct { + DB *sql.DB + Invites tables.Invites +} + +func (d *Database) AddInviteEvent( + ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, +) (sp types.StreamPosition, err error) { + err = common.WithTransaction(d.DB, func(txn *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) + return err + }) + return +} + +func (d *Database) RetireInviteEvent( + ctx context.Context, inviteEventID string, +) error { + // TODO: Record that invite has been retired in a stream so that we can + // notify the user in an incremental sync. + err := d.Invites.DeleteInviteEvent(ctx, inviteEventID) + return err +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 22efeaeb..26b3a316 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/common" + "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -62,30 +63,37 @@ type inviteEventsStatements struct { selectMaxInviteIDStmt *sql.Stmt } -func (s *inviteEventsStatements) prepare(db *sql.DB, streamID *streamIDStatements) (err error) { - s.streamIDStatements = streamID - _, err = db.Exec(inviteEventsSchema) +func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { + s := &inviteEventsStatements{ + streamIDStatements: streamID, + } + _, err := db.Exec(inviteEventsSchema) + if err != nil { + return nil, err + } + if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { + return nil, err + } + if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { + return nil, err + } + if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { + return nil, err + } + if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *inviteEventsStatements) InsertInviteEvent( + ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) if err != nil { return } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return - } - return -} -func (s *inviteEventsStatements) insertInviteEvent( - ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, streamPos types.StreamPosition, -) (err error) { var headeredJSON []byte headeredJSON, err = json.Marshal(inviteEvent) if err != nil { @@ -103,7 +111,7 @@ func (s *inviteEventsStatements) insertInviteEvent( return } -func (s *inviteEventsStatements) deleteInviteEvent( +func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) error { _, err := s.deleteInviteEventStmt.ExecContext(ctx, inviteEventID) @@ -112,7 +120,7 @@ func (s *inviteEventsStatements) deleteInviteEvent( // selectInviteEventsInRange returns a map of room ID to invite event for the // active invites for the target user ID in the supplied range. -func (s *inviteEventsStatements) selectInviteEventsInRange( +func (s *inviteEventsStatements) SelectInviteEventsInRange( ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ) (map[string]gomatrixserverlib.HeaderedEvent, error) { stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) @@ -141,7 +149,7 @@ func (s *inviteEventsStatements) selectInviteEventsInRange( return result, nil } -func (s *inviteEventsStatements) selectMaxInviteID( +func (s *inviteEventsStatements) SelectMaxInviteID( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { var nullableID sql.NullInt64 diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 84c562bd..a2253dcd 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -35,6 +35,7 @@ import ( "github.com/matrix-org/dendrite/common" "github.com/matrix-org/dendrite/eduserver/cache" + "github.com/matrix-org/dendrite/syncapi/storage/shared" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -58,10 +59,10 @@ type SyncServerDatasource struct { accountData accountDataStatements events outputRoomEventsStatements roomstate currentRoomStateStatements - invites inviteEventsStatements eduCache *cache.EDUCache topology outputRoomEventsTopologyStatements backwardExtremities tables.BackwardsExtremities + shared *shared.Database } // NewSyncServerDatasource creates a new sync server database @@ -106,7 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) { if err = d.roomstate.prepare(d.db, &d.streamID); err != nil { return err } - if err = d.invites.prepare(d.db, &d.streamID); err != nil { + invites, err := NewSqliteInvitesTable(d.db, &d.streamID) + if err != nil { return err } if err = d.topology.prepare(d.db); err != nil { @@ -116,6 +118,10 @@ func (d *SyncServerDatasource) prepare() (err error) { if err != nil { return err } + d.shared = &shared.Database{ + DB: d.db, + Invites: invites, + } return nil } @@ -404,7 +410,7 @@ func (d *SyncServerDatasource) syncStreamPositionTx( if maxAccountDataID > maxID { maxID = maxAccountDataID } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return 0, err } @@ -429,7 +435,7 @@ func (d *SyncServerDatasource) syncPositionTx( if maxAccountDataID > maxEventID { maxEventID = maxAccountDataID } - maxInviteID, err := d.invites.selectMaxInviteID(ctx, txn) + maxInviteID, err := d.shared.Invites.SelectMaxInviteID(ctx, txn) if err != nil { return nil, err } @@ -756,15 +762,8 @@ func (d *SyncServerDatasource) UpsertAccountData( // Returns an error if there was a problem communicating with the database. func (d *SyncServerDatasource) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, -) (streamPos types.StreamPosition, err error) { - err = common.WithTransaction(d.db, func(txn *sql.Tx) error { - streamPos, err = d.streamID.nextStreamID(ctx, txn) - if err != nil { - return err - } - return d.invites.insertInviteEvent(ctx, txn, inviteEvent, streamPos) - }) - return +) (sp types.StreamPosition, err error) { + return d.shared.AddInviteEvent(ctx, inviteEvent) } // RetireInviteEvent removes an old invite event from the database. @@ -772,10 +771,7 @@ func (d *SyncServerDatasource) AddInviteEvent( func (d *SyncServerDatasource) RetireInviteEvent( ctx context.Context, inviteEventID string, ) error { - // TODO: Record that invite has been retired in a stream so that we can - // notify the user in an incremental sync. - err := d.invites.deleteInviteEvent(ctx, inviteEventID) - return err + return d.shared.RetireInviteEvent(ctx, inviteEventID) } func (d *SyncServerDatasource) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) { @@ -804,7 +800,7 @@ func (d *SyncServerDatasource) addInvitesToResponse( fromPos, toPos types.StreamPosition, res *types.Response, ) error { - invites, err := d.invites.selectInviteEventsInRange( + invites, err := d.shared.Invites.SelectInviteEventsInRange( ctx, txn, userID, fromPos, toPos, ) if err != nil { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go new file mode 100644 index 00000000..1a994052 --- /dev/null +++ b/syncapi/storage/tables/interface.go @@ -0,0 +1,16 @@ +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +type Invites interface { + InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) + DeleteInviteEvent(ctx context.Context, inviteEventID string) error + SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition) (map[string]gomatrixserverlib.HeaderedEvent, error) + SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) +}