mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-26 05:31:32 +00:00
Add contexts to the roomserver storage layer (#229)
* Add contexts to the roomserver storage layer * Fix rooms_table
This commit is contained in:
parent
3133bef797
commit
bfcce5bd21
@ -32,16 +32,16 @@ import (
|
||||
type RoomserverAliasAPIDatabase interface {
|
||||
// Save a given room alias with the room ID it refers to.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
SetRoomAlias(alias string, roomID string) error
|
||||
SetRoomAlias(ctx context.Context, alias string, roomID string) error
|
||||
// Look up the room ID a given alias refers to.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetRoomIDFromAlias(alias string) (string, error)
|
||||
GetRoomIDFromAlias(ctx context.Context, alias string) (string, error)
|
||||
// Look up all aliases referring to a given room ID.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetAliasesFromRoomID(roomID string) ([]string, error)
|
||||
GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
|
||||
// Remove a given room alias.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
RemoveRoomAlias(alias string) error
|
||||
RemoveRoomAlias(ctx context.Context, alias string) error
|
||||
}
|
||||
|
||||
// RoomserverAliasAPI is an implementation of api.RoomserverAliasAPI
|
||||
@ -59,7 +59,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
|
||||
response *api.SetRoomAliasResponse,
|
||||
) error {
|
||||
// Check if the alias isn't already referring to a room
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -71,7 +71,7 @@ func (r *RoomserverAliasAPI) SetRoomAlias(
|
||||
response.AliasExists = false
|
||||
|
||||
// Save the new alias
|
||||
if err := r.DB.SetRoomAlias(request.Alias, request.RoomID); err != nil {
|
||||
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ func (r *RoomserverAliasAPI) GetAliasRoomID(
|
||||
response *api.GetAliasRoomIDResponse,
|
||||
) error {
|
||||
// Look up the room ID in the database
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -109,18 +109,21 @@ func (r *RoomserverAliasAPI) RemoveRoomAlias(
|
||||
response *api.RemoveRoomAliasResponse,
|
||||
) error {
|
||||
// Look up the room ID in the database
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(request.Alias)
|
||||
roomID, err := r.DB.GetRoomIDFromAlias(ctx, request.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove the dalias from the database
|
||||
if err := r.DB.RemoveRoomAlias(request.Alias); err != nil {
|
||||
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send an updated m.room.aliases event
|
||||
if err := r.sendUpdatedAliasesEvent(ctx, request.UserID, roomID); err != nil {
|
||||
// At this point we've already committed the alias to the database so we
|
||||
// shouldn't cancel this request.
|
||||
// TODO: Ensure that we send unsent events when if server restarts.
|
||||
if err := r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -147,7 +150,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
|
||||
|
||||
// Retrieve the updated list of aliases, marhal it and set it as the
|
||||
// event's content
|
||||
aliases, err := r.DB.GetAliasesFromRoomID(roomID)
|
||||
aliases, err := r.DB.GetAliasesFromRoomID(ctx, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -15,16 +15,23 @@
|
||||
package input
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// checkAuthEvents checks that the event passes authentication checks
|
||||
// Returns the numeric IDs for the auth events.
|
||||
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]types.EventNID, error) {
|
||||
func checkAuthEvents(
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
event gomatrixserverlib.Event,
|
||||
authEventIDs []string,
|
||||
) ([]types.EventNID, error) {
|
||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
|
||||
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -34,7 +41,7 @@ func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEv
|
||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
|
||||
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -84,7 +91,10 @@ func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Even
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||
EventTypeNID: typeNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@ -100,7 +110,10 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||
EventTypeNID: typeNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@ -113,6 +126,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
|
||||
|
||||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||
func loadAuthEvents(
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
needed gomatrixserverlib.StateNeeded,
|
||||
state []types.StateEntry,
|
||||
@ -121,7 +135,7 @@ func loadAuthEvents(
|
||||
var neededStateKeys []string
|
||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
|
||||
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -135,34 +149,52 @@ func loadAuthEvents(
|
||||
eventNIDs = append(eventNIDs, eventNID)
|
||||
}
|
||||
}
|
||||
if result.events, err = db.Events(eventNIDs); err != nil {
|
||||
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||
func stateKeyTuplesNeeded(
|
||||
stateKeyNIDMap map[string]types.EventStateKeyNID,
|
||||
stateNeeded gomatrixserverlib.StateNeeded,
|
||||
) []types.StateKeyTuple {
|
||||
var keyTuples []types.StateKeyTuple
|
||||
if stateNeeded.Create {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomCreateNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.PowerLevels {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomPowerLevelsNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.JoinRules {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomJoinRulesNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
for _, member := range stateNeeded.Member {
|
||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomMemberNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomThirdPartyInviteNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return keyTuples
|
||||
|
@ -15,6 +15,7 @@
|
||||
package input
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -28,22 +29,38 @@ import (
|
||||
type RoomEventDatabase interface {
|
||||
state.RoomStateDatabase
|
||||
// Stores a matrix room event in the database
|
||||
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error)
|
||||
StoreEvent(
|
||||
ctx context.Context,
|
||||
event gomatrixserverlib.Event,
|
||||
authEventNIDs []types.EventNID,
|
||||
) (types.RoomNID, types.StateAtEvent, error)
|
||||
// Look up the state entries for a list of string event IDs
|
||||
// Returns an error if the there is an error talking to the database
|
||||
// Returns a types.MissingEventError if the event IDs aren't in the database.
|
||||
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
||||
StateEntriesForEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error)
|
||||
// Set the state at an event.
|
||||
SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error
|
||||
SetState(
|
||||
ctx context.Context,
|
||||
eventNID types.EventNID,
|
||||
stateNID types.StateSnapshotNID,
|
||||
) error
|
||||
// Look up the latest events in a room in preparation for an update.
|
||||
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
|
||||
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
|
||||
// If this returns an error then no further action is required.
|
||||
GetLatestEventsForUpdate(roomNID types.RoomNID) (updater types.RoomRecentEventsUpdater, err error)
|
||||
GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) (updater types.RoomRecentEventsUpdater, err error)
|
||||
// Look up the string event IDs for a list of numeric event IDs
|
||||
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
EventIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]string, error)
|
||||
// Build a membership updater for the target user in a room.
|
||||
MembershipUpdater(roomID, targerUserID string) (types.MembershipUpdater, error)
|
||||
MembershipUpdater(
|
||||
ctx context.Context, roomID, targerUserID string,
|
||||
) (types.MembershipUpdater, error)
|
||||
}
|
||||
|
||||
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
|
||||
@ -52,18 +69,23 @@ type OutputRoomEventWriter interface {
|
||||
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
|
||||
}
|
||||
|
||||
func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputRoomEvent) error {
|
||||
func processRoomEvent(
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
ow OutputRoomEventWriter,
|
||||
input api.InputRoomEvent,
|
||||
) error {
|
||||
// Parse and validate the event JSON
|
||||
event := input.Event
|
||||
|
||||
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
|
||||
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
|
||||
authEventNIDs, err := checkAuthEvents(ctx, db, event, input.AuthEventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Store the event
|
||||
roomNID, stateAtEvent, err := db.StoreEvent(event, authEventNIDs)
|
||||
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, authEventNIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -82,20 +104,20 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
|
||||
// We've been told what the state at the event is so we don't need to calculate it.
|
||||
// Check that those state events are in the database and store the state.
|
||||
var entries []types.StateEntry
|
||||
if entries, err = db.StateEntriesForEventIDs(input.StateEventIDs); err != nil {
|
||||
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(roomNID, nil, entries); err != nil {
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(db, event, roomNID); err != nil {
|
||||
if stateAtEvent.BeforeStateSnapshotNID, err = state.CalculateAndStoreStateBeforeEvent(ctx, db, event, roomNID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
db.SetState(stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||
db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
|
||||
}
|
||||
|
||||
if input.Kind == api.KindBackfill {
|
||||
@ -104,14 +126,19 @@ func processRoomEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.
|
||||
}
|
||||
|
||||
// Update the extremities of the event graph for the room
|
||||
if err := updateLatestEvents(db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
|
||||
if err := updateLatestEvents(ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input api.InputInviteEvent) (err error) {
|
||||
func processInviteEvent(
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
ow OutputRoomEventWriter,
|
||||
input api.InputInviteEvent,
|
||||
) (err error) {
|
||||
if input.Event.StateKey() == nil {
|
||||
return fmt.Errorf("invite must be a state event")
|
||||
}
|
||||
@ -119,7 +146,7 @@ func processInviteEvent(db RoomEventDatabase, ow OutputRoomEventWriter, input ap
|
||||
roomID := input.Event.RoomID()
|
||||
targetUserID := *input.Event.StateKey()
|
||||
|
||||
updater, err := db.MembershipUpdater(roomID, targetUserID)
|
||||
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -59,12 +59,12 @@ func (r *RoomserverInputAPI) InputRoomEvents(
|
||||
response *api.InputRoomEventsResponse,
|
||||
) error {
|
||||
for i := range request.InputRoomEvents {
|
||||
if err := processRoomEvent(r.DB, r, request.InputRoomEvents[i]); err != nil {
|
||||
if err := processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := range request.InputInviteEvents {
|
||||
if err := processInviteEvent(r.DB, r, request.InputInviteEvents[i]); err != nil {
|
||||
if err := processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ package input
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
@ -42,6 +43,7 @@ import (
|
||||
// 7 <----- latest
|
||||
//
|
||||
func updateLatestEvents(
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
ow OutputRoomEventWriter,
|
||||
roomNID types.RoomNID,
|
||||
@ -49,7 +51,7 @@ func updateLatestEvents(
|
||||
event gomatrixserverlib.Event,
|
||||
sendAsServer string,
|
||||
) (err error) {
|
||||
updater, err := db.GetLatestEventsForUpdate(roomNID)
|
||||
updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -57,7 +59,7 @@ func updateLatestEvents(
|
||||
defer common.EndTransaction(updater, &succeeded)
|
||||
|
||||
u := latestEventsUpdater{
|
||||
db: db, updater: updater, ow: ow, roomNID: roomNID,
|
||||
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
|
||||
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
|
||||
}
|
||||
if err = u.doUpdateLatestEvents(); err != nil {
|
||||
@ -73,6 +75,7 @@ func updateLatestEvents(
|
||||
// The state could be passed using function arguments, but it becomes impractical
|
||||
// when there are so many variables to pass around.
|
||||
type latestEventsUpdater struct {
|
||||
ctx context.Context
|
||||
db RoomEventDatabase
|
||||
updater types.RoomRecentEventsUpdater
|
||||
ow OutputRoomEventWriter
|
||||
@ -133,7 +136,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
|
||||
return err
|
||||
}
|
||||
|
||||
updates, err := updateMemberships(u.db, u.updater, u.removed, u.added)
|
||||
updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -174,18 +177,22 @@ func (u *latestEventsUpdater) latestState() error {
|
||||
for i := range u.latest {
|
||||
latestStateAtEvents[i] = u.latest[i].StateAtEvent
|
||||
}
|
||||
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(u.db, u.roomNID, latestStateAtEvents)
|
||||
u.newStateNID, err = state.CalculateAndStoreStateAfterEvents(
|
||||
u.ctx, u.db, u.roomNID, latestStateAtEvents,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(u.db, u.oldStateNID, u.newStateNID)
|
||||
u.removed, u.added, err = state.DifferenceBetweeenStateSnapshots(
|
||||
u.ctx, u.db, u.oldStateNID, u.newStateNID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = state.DifferenceBetweeenStateSnapshots(
|
||||
u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
||||
u.ctx, u.db, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -193,7 +200,12 @@ func (u *latestEventsUpdater) latestState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func calculateLatest(oldLatest []types.StateAtEventAndReference, alreadyReferenced bool, prevEvents []gomatrixserverlib.EventReference, newEvent types.StateAtEventAndReference) []types.StateAtEventAndReference {
|
||||
func calculateLatest(
|
||||
oldLatest []types.StateAtEventAndReference,
|
||||
alreadyReferenced bool,
|
||||
prevEvents []gomatrixserverlib.EventReference,
|
||||
newEvent types.StateAtEventAndReference,
|
||||
) []types.StateAtEventAndReference {
|
||||
var alreadyInLatest bool
|
||||
var newLatest []types.StateAtEventAndReference
|
||||
for _, l := range oldLatest {
|
||||
@ -253,7 +265,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
|
||||
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
|
||||
}
|
||||
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
|
||||
eventIDMap, err := u.db.EventIDs(stateEventNIDs)
|
||||
eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package input
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
@ -27,7 +28,10 @@ import (
|
||||
// Returns a list of output events to write to the kafka log to inform the
|
||||
// consumers about the invites added or retired by the change in current state.
|
||||
func updateMemberships(
|
||||
db RoomEventDatabase, updater types.RoomRecentEventsUpdater, removed, added []types.StateEntry,
|
||||
ctx context.Context,
|
||||
db RoomEventDatabase,
|
||||
updater types.RoomRecentEventsUpdater,
|
||||
removed, added []types.StateEntry,
|
||||
) ([]api.OutputEvent, error) {
|
||||
changes := membershipChanges(removed, added)
|
||||
var eventNIDs []types.EventNID
|
||||
@ -43,7 +47,7 @@ func updateMemberships(
|
||||
// Load the event JSON so we can look up the "membership" key.
|
||||
// TODO: Maybe add a membership key to the events table so we can load that
|
||||
// key without having to load the entire event JSON?
|
||||
events, err := db.Events(eventNIDs)
|
||||
events, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -33,35 +33,47 @@ type RoomserverQueryAPIDatabase interface {
|
||||
// Look up the numeric ID for the room.
|
||||
// Returns 0 if the room doesn't exists.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
RoomNID(roomID string) (types.RoomNID, error)
|
||||
RoomNID(ctx context.Context, roomID string) (types.RoomNID, error)
|
||||
// Look up event references for the latest events in the room and the current state snapshot.
|
||||
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
||||
LatestEventIDs(
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
|
||||
// Look up the numeric IDs for a list of events.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
EventNIDs(eventIDs []string) (map[string]types.EventNID, error)
|
||||
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
|
||||
// Lookup the event IDs for a batch of event numeric IDs.
|
||||
// Returns an error if the retrieval went wrong.
|
||||
EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
|
||||
// Lookup the membership of a given user in a given room.
|
||||
// Returns the numeric ID of the latest membership event sent from this user
|
||||
// in this room, along a boolean set to true if the user is still in this room,
|
||||
// false if not.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error)
|
||||
GetMembership(
|
||||
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||
) (membershipEventNID types.EventNID, stillInRoom bool, err error)
|
||||
// Lookup the membership event numeric IDs for all user that are or have
|
||||
// been members of a given room. Only lookup events of "join" membership if
|
||||
// joinOnly is set to true.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error)
|
||||
GetMembershipEventNIDsForRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
|
||||
) ([]types.EventNID, error)
|
||||
// Look up the active invites targeting a user in a room and return the
|
||||
// numeric state key IDs for the user IDs who sent them.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
GetInvitesForUser(roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (senderUserNIDs []types.EventStateKeyNID, err error)
|
||||
GetInvitesForUser(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
) (senderUserNIDs []types.EventStateKeyNID, err error)
|
||||
// Look up the string event state keys for a list of numeric event state keys
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
EventStateKeys([]types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
|
||||
EventStateKeys(
|
||||
context.Context, []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error)
|
||||
}
|
||||
|
||||
// RoomserverQueryAPI is an implementation of api.RoomserverQueryAPI
|
||||
@ -76,7 +88,7 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||
response *api.QueryLatestEventsAndStateResponse,
|
||||
) error {
|
||||
response.QueryLatestEventsAndStateRequest = *request
|
||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -85,18 +97,21 @@ func (r *RoomserverQueryAPI) QueryLatestEventsAndState(
|
||||
}
|
||||
response.RoomExists = true
|
||||
var currentStateSnapshotNID types.StateSnapshotNID
|
||||
response.LatestEvents, currentStateSnapshotNID, response.Depth, err = r.DB.LatestEventIDs(roomNID)
|
||||
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
|
||||
r.DB.LatestEventIDs(ctx, roomNID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Look up the currrent state for the requested tuples.
|
||||
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(r.DB, currentStateSnapshotNID, request.StateToFetch)
|
||||
stateEntries, err := state.LoadStateAtSnapshotForStringTuples(
|
||||
ctx, r.DB, currentStateSnapshotNID, request.StateToFetch,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateEvents, err := r.loadStateEvents(stateEntries)
|
||||
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -112,7 +127,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||
response *api.QueryStateAfterEventsResponse,
|
||||
) error {
|
||||
response.QueryStateAfterEventsRequest = *request
|
||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -121,7 +136,7 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||
}
|
||||
response.RoomExists = true
|
||||
|
||||
prevStates, err := r.DB.StateAtEventIDs(request.PrevEventIDs)
|
||||
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case types.MissingEventError:
|
||||
@ -133,12 +148,14 @@ func (r *RoomserverQueryAPI) QueryStateAfterEvents(
|
||||
response.PrevEventsExist = true
|
||||
|
||||
// Look up the currrent state for the requested tuples.
|
||||
stateEntries, err := state.LoadStateAfterEventsForStringTuples(r.DB, prevStates, request.StateToFetch)
|
||||
stateEntries, err := state.LoadStateAfterEventsForStringTuples(
|
||||
ctx, r.DB, prevStates, request.StateToFetch,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stateEvents, err := r.loadStateEvents(stateEntries)
|
||||
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -155,7 +172,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||
) error {
|
||||
response.QueryEventsByIDRequest = *request
|
||||
|
||||
eventNIDMap, err := r.DB.EventNIDs(request.EventIDs)
|
||||
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -165,7 +182,7 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||
eventNIDs = append(eventNIDs, nid)
|
||||
}
|
||||
|
||||
events, err := r.loadEvents(eventNIDs)
|
||||
events, err := r.loadEvents(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -174,16 +191,20 @@ func (r *RoomserverQueryAPI) QueryEventsByID(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RoomserverQueryAPI) loadStateEvents(stateEntries []types.StateEntry) ([]gomatrixserverlib.Event, error) {
|
||||
func (r *RoomserverQueryAPI) loadStateEvents(
|
||||
ctx context.Context, stateEntries []types.StateEntry,
|
||||
) ([]gomatrixserverlib.Event, error) {
|
||||
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||
for i := range stateEntries {
|
||||
eventNIDs[i] = stateEntries[i].EventNID
|
||||
}
|
||||
return r.loadEvents(eventNIDs)
|
||||
return r.loadEvents(ctx, eventNIDs)
|
||||
}
|
||||
|
||||
func (r *RoomserverQueryAPI) loadEvents(eventNIDs []types.EventNID) ([]gomatrixserverlib.Event, error) {
|
||||
stateEvents, err := r.DB.Events(eventNIDs)
|
||||
func (r *RoomserverQueryAPI) loadEvents(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]gomatrixserverlib.Event, error) {
|
||||
stateEvents, err := r.DB.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -201,12 +222,12 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||
request *api.QueryMembershipsForRoomRequest,
|
||||
response *api.QueryMembershipsForRoomResponse,
|
||||
) error {
|
||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
membershipEventNID, stillInRoom, err := r.DB.GetMembership(roomNID, request.Sender)
|
||||
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@ -223,14 +244,14 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||
var events []types.Event
|
||||
if stillInRoom {
|
||||
var eventNIDs []types.EventNID
|
||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(roomNID, request.JoinedOnly)
|
||||
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events, err = r.DB.Events(eventNIDs)
|
||||
events, err = r.DB.Events(ctx, eventNIDs)
|
||||
} else {
|
||||
events, err = r.getMembershipsBeforeEventNID(membershipEventNID, request.JoinedOnly)
|
||||
events, err = r.getMembershipsBeforeEventNID(ctx, membershipEventNID, request.JoinedOnly)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -249,22 +270,24 @@ func (r *RoomserverQueryAPI) QueryMembershipsForRoom(
|
||||
// of the event's room as it was when this event was fired, then filters the state events to
|
||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||
// Returns an error if there was an issue fetching the events.
|
||||
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNID, joinedOnly bool) ([]types.Event, error) {
|
||||
func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(
|
||||
ctx context.Context, eventNID types.EventNID, joinedOnly bool,
|
||||
) ([]types.Event, error) {
|
||||
events := []types.Event{}
|
||||
// Lookup the event NID
|
||||
eIDs, err := r.DB.EventIDs([]types.EventNID{eventNID})
|
||||
eIDs, err := r.DB.EventIDs(ctx, []types.EventNID{eventNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventIDs := []string{eIDs[eventNID]}
|
||||
|
||||
prevState, err := r.DB.StateAtEventIDs(eventIDs)
|
||||
prevState, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the state as it was when this event was fired
|
||||
stateEntries, err := state.LoadCombinedStateAfterEvents(r.DB, prevState)
|
||||
stateEntries, err := state.LoadCombinedStateAfterEvents(ctx, r.DB, prevState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -278,7 +301,7 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
|
||||
}
|
||||
|
||||
// Get all of the events in this state
|
||||
stateEvents, err := r.DB.Events(eventNIDs)
|
||||
stateEvents, err := r.DB.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -304,27 +327,27 @@ func (r *RoomserverQueryAPI) getMembershipsBeforeEventNID(eventNID types.EventNI
|
||||
|
||||
// QueryInvitesForUser implements api.RoomserverQueryAPI
|
||||
func (r *RoomserverQueryAPI) QueryInvitesForUser(
|
||||
_ context.Context,
|
||||
ctx context.Context,
|
||||
request *api.QueryInvitesForUserRequest,
|
||||
response *api.QueryInvitesForUserResponse,
|
||||
) error {
|
||||
roomNID, err := r.DB.RoomNID(request.RoomID)
|
||||
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
targetUserNIDs, err := r.DB.EventStateKeyNIDs([]string{request.TargetUserID})
|
||||
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
targetUserNID := targetUserNIDs[request.TargetUserID]
|
||||
|
||||
senderUserNIDs, err := r.DB.GetInvitesForUser(roomNID, targetUserNID)
|
||||
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
senderUserIDs, err := r.DB.EventStateKeys(senderUserNIDs)
|
||||
senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -342,14 +365,14 @@ func (r *RoomserverQueryAPI) QueryServerAllowedToSeeEvent(
|
||||
request *api.QueryServerAllowedToSeeEventRequest,
|
||||
response *api.QueryServerAllowedToSeeEventResponse,
|
||||
) error {
|
||||
stateEntries, err := state.LoadStateAtEvent(r.DB, request.EventID)
|
||||
stateEntries, err := state.LoadStateAtEvent(ctx, r.DB, request.EventID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: We probably want to make it so that we don't have to pull
|
||||
// out all the state if possible.
|
||||
stateAtEvent, err := r.loadStateEvents(stateEntries)
|
||||
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
@ -30,49 +31,58 @@ import (
|
||||
// A RoomStateDatabase has the storage APIs needed to load state from the database
|
||||
type RoomStateDatabase interface {
|
||||
// Store the room state at an event in the database
|
||||
AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||
AddState(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (types.StateSnapshotNID, error)
|
||||
// Look up the state of a room at each event for a list of string event IDs.
|
||||
// Returns an error if there is an error talking to the database
|
||||
// Returns a types.MissingEventError if the room state for the event IDs aren't in the database
|
||||
StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error)
|
||||
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
|
||||
// Look up the numeric IDs for a list of string event types.
|
||||
// Returns a map from string event type to numeric ID for the event type.
|
||||
EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
|
||||
// Look up the numeric IDs for a list of string event state keys.
|
||||
// Returns a map from string state key to numeric ID for the state key.
|
||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||
// Look up the numeric state data IDs for each numeric state snapshot ID
|
||||
// The returned slice is sorted by numeric state snapshot ID.
|
||||
StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||
// Look up the state data for each numeric state data ID
|
||||
// The returned slice is sorted by numeric state data ID.
|
||||
StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||
// Look up the state data for the state key tuples for each numeric state block ID
|
||||
// This is used to fetch a subset of the room state at a snapshot.
|
||||
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
|
||||
// The returned slice is sorted by numeric state block ID.
|
||||
StateEntriesForTuples(stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) (
|
||||
[]types.StateEntryList, error,
|
||||
)
|
||||
StateEntriesForTuples(
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error)
|
||||
// Look up the Events for a list of numeric event IDs.
|
||||
// Returns a sorted list of events.
|
||||
Events(eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||
// Look up snapshot NID for an event ID string
|
||||
SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error)
|
||||
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
|
||||
}
|
||||
|
||||
// LoadStateAtSnapshot loads the full state of a room at a particular snapshot.
|
||||
// This is typically the state before an event or the current state of a room.
|
||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||
func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID) ([]types.StateEntry, error) {
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
||||
func LoadStateAtSnapshot(
|
||||
ctx context.Context, db RoomStateDatabase, stateNID types.StateSnapshotNID,
|
||||
) ([]types.StateEntry, error) {
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||
stateBlockNIDList := stateBlockNIDLists[0]
|
||||
|
||||
stateEntryLists, err := db.StateEntries(stateBlockNIDList.StateBlockNIDs)
|
||||
stateEntryLists, err := db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -100,13 +110,15 @@ func LoadStateAtSnapshot(db RoomStateDatabase, stateNID types.StateSnapshotNID)
|
||||
}
|
||||
|
||||
// LoadStateAtEvent loads the full state of a room at a particular event.
|
||||
func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry, error) {
|
||||
snapshotNID, err := db.SnapshotNIDFromEventID(eventID)
|
||||
func LoadStateAtEvent(
|
||||
ctx context.Context, db RoomStateDatabase, eventID string,
|
||||
) ([]types.StateEntry, error) {
|
||||
snapshotNID, err := db.SnapshotNIDFromEventID(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stateEntries, err := LoadStateAtSnapshot(db, snapshotNID)
|
||||
stateEntries, err := LoadStateAtSnapshot(ctx, db, snapshotNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -116,7 +128,9 @@ func LoadStateAtEvent(db RoomStateDatabase, eventID string) ([]types.StateEntry,
|
||||
|
||||
// LoadCombinedStateAfterEvents loads a snapshot of the state after each of the events
|
||||
// and combines those snapshots together into a single list.
|
||||
func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.StateAtEvent) ([]types.StateEntry, error) {
|
||||
func LoadCombinedStateAfterEvents(
|
||||
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
|
||||
) ([]types.StateEntry, error) {
|
||||
stateNIDs := make([]types.StateSnapshotNID, len(prevStates))
|
||||
for i, state := range prevStates {
|
||||
stateNIDs[i] = state.BeforeStateSnapshotNID
|
||||
@ -125,7 +139,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||
// Deduplicate the IDs before passing them to the database.
|
||||
// There could be duplicates because the events could be state events where
|
||||
// the snapshot of the room state before them was the same.
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs(uniqueStateSnapshotNIDs(stateNIDs))
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, uniqueStateSnapshotNIDs(stateNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -138,7 +152,7 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||
// Deduplicate the IDs before passing them to the database.
|
||||
// There could be duplicates because a block of state entries could be reused by
|
||||
// multiple snapshots.
|
||||
stateEntryLists, err := db.StateEntries(uniqueStateBlockNIDs(stateBlockNIDs))
|
||||
stateEntryLists, err := db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -186,9 +200,9 @@ func LoadCombinedStateAfterEvents(db RoomStateDatabase, prevStates []types.State
|
||||
}
|
||||
|
||||
// DifferenceBetweeenStateSnapshots works out which state entries have been added and removed between two snapshots.
|
||||
func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID) (
|
||||
removed, added []types.StateEntry, err error,
|
||||
) {
|
||||
func DifferenceBetweeenStateSnapshots(
|
||||
ctx context.Context, db RoomStateDatabase, oldStateNID, newStateNID types.StateSnapshotNID,
|
||||
) (removed, added []types.StateEntry, err error) {
|
||||
if oldStateNID == newStateNID {
|
||||
// If the snapshot NIDs are the same then nothing has changed
|
||||
return nil, nil, nil
|
||||
@ -197,13 +211,13 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
|
||||
var oldEntries []types.StateEntry
|
||||
var newEntries []types.StateEntry
|
||||
if oldStateNID != 0 {
|
||||
oldEntries, err = LoadStateAtSnapshot(db, oldStateNID)
|
||||
oldEntries, err = LoadStateAtSnapshot(ctx, db, oldStateNID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
if newStateNID != 0 {
|
||||
newEntries, err = LoadStateAtSnapshot(db, newStateNID)
|
||||
newEntries, err = LoadStateAtSnapshot(ctx, db, newStateNID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@ -246,19 +260,26 @@ func DifferenceBetweeenStateSnapshots(db RoomStateDatabase, oldStateNID, newStat
|
||||
// This is typically the state before an event or the current state of a room.
|
||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||
func LoadStateAtSnapshotForStringTuples(
|
||||
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
stateNID types.StateSnapshotNID,
|
||||
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
|
||||
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return loadStateAtSnapshotForNumericTuples(db, stateNID, numericTuples)
|
||||
return loadStateAtSnapshotForNumericTuples(ctx, db, stateNID, numericTuples)
|
||||
}
|
||||
|
||||
// stringTuplesToNumericTuples converts the string state key tuples into numeric IDs
|
||||
// If there isn't a numeric ID for either the event type or the event state key then the tuple is discarded.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixserverlib.StateKeyTuple) ([]types.StateKeyTuple, error) {
|
||||
func stringTuplesToNumericTuples(
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
stringTuples []gomatrixserverlib.StateKeyTuple,
|
||||
) ([]types.StateKeyTuple, error) {
|
||||
eventTypes := make([]string, len(stringTuples))
|
||||
stateKeys := make([]string, len(stringTuples))
|
||||
for i := range stringTuples {
|
||||
@ -266,12 +287,12 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
|
||||
stateKeys[i] = stringTuples[i].StateKey
|
||||
}
|
||||
eventTypes = util.UniqueStrings(eventTypes)
|
||||
eventTypeMap, err := db.EventTypeNIDs(eventTypes)
|
||||
eventTypeMap, err := db.EventTypeNIDs(ctx, eventTypes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateKeys = util.UniqueStrings(stateKeys)
|
||||
stateKeyMap, err := db.EventStateKeyNIDs(stateKeys)
|
||||
stateKeyMap, err := db.EventStateKeyNIDs(ctx, stateKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -297,16 +318,21 @@ func stringTuplesToNumericTuples(db RoomStateDatabase, stringTuples []gomatrixse
|
||||
// This is typically the state before an event or the current state of a room.
|
||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||
func loadStateAtSnapshotForNumericTuples(
|
||||
db RoomStateDatabase, stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
stateNID types.StateSnapshotNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{stateNID})
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||
stateBlockNIDList := stateBlockNIDLists[0]
|
||||
|
||||
stateEntryLists, err := db.StateEntriesForTuples(stateBlockNIDList.StateBlockNIDs, stateKeyTuples)
|
||||
stateEntryLists, err := db.StateEntriesForTuples(
|
||||
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -341,23 +367,29 @@ func loadStateAtSnapshotForNumericTuples(
|
||||
// This is typically the state before an event.
|
||||
// Returns a sorted list of state entries or an error if there was a problem talking to the database.
|
||||
func LoadStateAfterEventsForStringTuples(
|
||||
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
prevStates []types.StateAtEvent,
|
||||
stateKeyTuples []gomatrixserverlib.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
numericTuples, err := stringTuplesToNumericTuples(db, stateKeyTuples)
|
||||
numericTuples, err := stringTuplesToNumericTuples(ctx, db, stateKeyTuples)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return loadStateAfterEventsForNumericTuples(db, prevStates, numericTuples)
|
||||
return loadStateAfterEventsForNumericTuples(ctx, db, prevStates, numericTuples)
|
||||
}
|
||||
|
||||
func loadStateAfterEventsForNumericTuples(
|
||||
db RoomStateDatabase, prevStates []types.StateAtEvent, stateKeyTuples []types.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
prevStates []types.StateAtEvent,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
if len(prevStates) == 1 {
|
||||
// Fast path for a single event.
|
||||
prevState := prevStates[0]
|
||||
result, err := loadStateAtSnapshotForNumericTuples(
|
||||
db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
||||
ctx, db, prevState.BeforeStateSnapshotNID, stateKeyTuples,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -390,7 +422,7 @@ func loadStateAfterEventsForNumericTuples(
|
||||
|
||||
// TODO: Add metrics for this as it could take a long time for big rooms
|
||||
// with large conflicts.
|
||||
fullState, _, _, err := calculateStateAfterManyEvents(db, prevStates)
|
||||
fullState, _, _, err := calculateStateAfterManyEvents(ctx, db, prevStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -403,7 +435,10 @@ func loadStateAfterEventsForNumericTuples(
|
||||
for _, tuple := range stateKeyTuples {
|
||||
eventNID, ok := stateEntryMap(fullState).lookup(tuple)
|
||||
if ok {
|
||||
result = append(result, types.StateEntry{tuple, eventNID})
|
||||
result = append(result, types.StateEntry{
|
||||
StateKeyTuple: tuple,
|
||||
EventNID: eventNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
sort.Sort(stateEntrySorter(result))
|
||||
@ -509,7 +544,10 @@ func init() {
|
||||
// Stores the snapshot of the state in the database.
|
||||
// Returns a numeric ID for the snapshot of the state before the event.
|
||||
func CalculateAndStoreStateBeforeEvent(
|
||||
db RoomStateDatabase, event gomatrixserverlib.Event, roomNID types.RoomNID,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
event gomatrixserverlib.Event,
|
||||
roomNID types.RoomNID,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
// Load the state at the prev events.
|
||||
prevEventRefs := event.PrevEvents()
|
||||
@ -518,25 +556,30 @@ func CalculateAndStoreStateBeforeEvent(
|
||||
prevEventIDs[i] = prevEventRefs[i].EventID
|
||||
}
|
||||
|
||||
prevStates, err := db.StateAtEventIDs(prevEventIDs)
|
||||
prevStates, err := db.StateAtEventIDs(ctx, prevEventIDs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// The state before this event will be the state after the events that came before it.
|
||||
return CalculateAndStoreStateAfterEvents(db, roomNID, prevStates)
|
||||
return CalculateAndStoreStateAfterEvents(ctx, db, roomNID, prevStates)
|
||||
}
|
||||
|
||||
// CalculateAndStoreStateAfterEvents finds the room state after the given events.
|
||||
// Stores the resulting state in the database and returns a numeric ID for that snapshot.
|
||||
func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent) (types.StateSnapshotNID, error) {
|
||||
func CalculateAndStoreStateAfterEvents(
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
roomNID types.RoomNID,
|
||||
prevStates []types.StateAtEvent,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
metrics := calculateStateMetrics{startTime: time.Now(), prevEventLength: len(prevStates)}
|
||||
|
||||
if len(prevStates) == 0 {
|
||||
// 2) There weren't any prev_events for this event so the state is
|
||||
// empty.
|
||||
metrics.algorithm = "empty_state"
|
||||
return metrics.stop(db.AddState(roomNID, nil, nil))
|
||||
return metrics.stop(db.AddState(ctx, roomNID, nil, nil))
|
||||
}
|
||||
|
||||
if len(prevStates) == 1 {
|
||||
@ -551,7 +594,9 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
|
||||
}
|
||||
// The previous event was a state event so we need to store a copy
|
||||
// of the previous state updated with that event.
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs([]types.StateSnapshotNID{prevState.BeforeStateSnapshotNID})
|
||||
stateBlockNIDLists, err := db.StateBlockNIDs(
|
||||
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
|
||||
)
|
||||
if err != nil {
|
||||
metrics.algorithm = "_load_state_blocks"
|
||||
return metrics.stop(0, err)
|
||||
@ -562,14 +607,14 @@ func CalculateAndStoreStateAfterEvents(db RoomStateDatabase, roomNID types.RoomN
|
||||
// add the state event as a block of size one to the end of the blocks.
|
||||
metrics.algorithm = "single_delta"
|
||||
return metrics.stop(db.AddState(
|
||||
roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
|
||||
))
|
||||
}
|
||||
// If there are too many deltas then we need to calculate the full state
|
||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||
}
|
||||
|
||||
return calculateAndStoreStateAfterManyEvents(db, roomNID, prevStates, metrics)
|
||||
return calculateAndStoreStateAfterManyEvents(ctx, db, roomNID, prevStates, metrics)
|
||||
}
|
||||
|
||||
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
|
||||
@ -583,10 +628,15 @@ const maxStateBlockNIDs = 64
|
||||
// This handles the slow path of calculateAndStoreStateAfterEvents for when there is more than one event.
|
||||
// Stores the resulting state and returns a numeric ID for the snapshot.
|
||||
func calculateAndStoreStateAfterManyEvents(
|
||||
db RoomStateDatabase, roomNID types.RoomNID, prevStates []types.StateAtEvent, metrics calculateStateMetrics,
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
roomNID types.RoomNID,
|
||||
prevStates []types.StateAtEvent,
|
||||
metrics calculateStateMetrics,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
|
||||
state, algorithm, conflictLength, err := calculateStateAfterManyEvents(db, prevStates)
|
||||
state, algorithm, conflictLength, err :=
|
||||
calculateStateAfterManyEvents(ctx, db, prevStates)
|
||||
metrics.algorithm = algorithm
|
||||
if err != nil {
|
||||
return metrics.stop(0, err)
|
||||
@ -596,16 +646,16 @@ func calculateAndStoreStateAfterManyEvents(
|
||||
// previous state.
|
||||
metrics.conflictLength = conflictLength
|
||||
metrics.fullStateLength = len(state)
|
||||
return metrics.stop(db.AddState(roomNID, nil, state))
|
||||
return metrics.stop(db.AddState(ctx, roomNID, nil, state))
|
||||
}
|
||||
|
||||
func calculateStateAfterManyEvents(
|
||||
db RoomStateDatabase, prevStates []types.StateAtEvent,
|
||||
ctx context.Context, db RoomStateDatabase, prevStates []types.StateAtEvent,
|
||||
) (state []types.StateEntry, algorithm string, conflictLength int, err error) {
|
||||
var combined []types.StateEntry
|
||||
// Conflict resolution.
|
||||
// First stage: load the state after each of the prev events.
|
||||
combined, err = LoadCombinedStateAfterEvents(db, prevStates)
|
||||
combined, err = LoadCombinedStateAfterEvents(ctx, db, prevStates)
|
||||
if err != nil {
|
||||
algorithm = "_load_combined_state"
|
||||
return
|
||||
@ -635,7 +685,7 @@ func calculateStateAfterManyEvents(
|
||||
}
|
||||
|
||||
var resolved []types.StateEntry
|
||||
resolved, err = resolveConflicts(db, notConflicted, conflicts)
|
||||
resolved, err = resolveConflicts(ctx, db, notConflicted, conflicts)
|
||||
if err != nil {
|
||||
algorithm = "_resolve_conflicts"
|
||||
return
|
||||
@ -657,10 +707,14 @@ func calculateStateAfterManyEvents(
|
||||
// Returns a list that combines the entries without conflicts with the result of state resolution for the entries with conflicts.
|
||||
// The returned list is sorted by state key tuple.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.StateEntry) ([]types.StateEntry, error) {
|
||||
func resolveConflicts(
|
||||
ctx context.Context,
|
||||
db RoomStateDatabase,
|
||||
notConflicted, conflicted []types.StateEntry,
|
||||
) ([]types.StateEntry, error) {
|
||||
|
||||
// Load the conflicted events
|
||||
conflictedEvents, eventIDMap, err := loadStateEvents(db, conflicted)
|
||||
conflictedEvents, eventIDMap, err := loadStateEvents(ctx, db, conflicted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -672,7 +726,7 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||
var neededStateKeys []string
|
||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||
stateKeyNIDMap, err := db.EventStateKeyNIDs(neededStateKeys)
|
||||
stateKeyNIDMap, err := db.EventStateKeyNIDs(ctx, neededStateKeys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -682,10 +736,13 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||
var authEntries []types.StateEntry
|
||||
for _, tuple := range tuplesNeeded {
|
||||
if eventNID, ok := stateEntryMap(notConflicted).lookup(tuple); ok {
|
||||
authEntries = append(authEntries, types.StateEntry{tuple, eventNID})
|
||||
authEntries = append(authEntries, types.StateEntry{
|
||||
StateKeyTuple: tuple,
|
||||
EventNID: eventNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
authEvents, _, err := loadStateEvents(db, authEntries)
|
||||
authEvents, _, err := loadStateEvents(ctx, db, authEntries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -711,24 +768,39 @@ func resolveConflicts(db RoomStateDatabase, notConflicted, conflicted []types.St
|
||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||
var keyTuples []types.StateKeyTuple
|
||||
if stateNeeded.Create {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomCreateNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.PowerLevels {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomPowerLevelsNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.JoinRules {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomJoinRulesNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
for _, member := range stateNeeded.Member {
|
||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomMemberNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomThirdPartyInviteNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return keyTuples
|
||||
@ -738,12 +810,14 @@ func stateKeyTuplesNeeded(stateKeyNIDMap map[string]types.EventStateKeyNID, stat
|
||||
// Returns a list of state events in no particular order and a map from string event ID back to state entry.
|
||||
// The map can be used to recover which numeric state entry a given event is for.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func loadStateEvents(db RoomStateDatabase, entries []types.StateEntry) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
||||
func loadStateEvents(
|
||||
ctx context.Context, db RoomStateDatabase, entries []types.StateEntry,
|
||||
) ([]gomatrixserverlib.Event, map[string]types.StateEntry, error) {
|
||||
eventNIDs := make([]types.EventNID, len(entries))
|
||||
for i := range entries {
|
||||
eventNIDs[i] = entries[i].EventNID
|
||||
}
|
||||
events, err := db.Events(eventNIDs)
|
||||
events, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
@ -65,8 +66,10 @@ func (s *eventJSONStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) insertEventJSON(eventNID types.EventNID, eventJSON []byte) error {
|
||||
_, err := s.insertEventJSONStmt.Exec(int64(eventNID), eventJSON)
|
||||
func (s *eventJSONStatements) insertEventJSON(
|
||||
ctx context.Context, eventNID types.EventNID, eventJSON []byte,
|
||||
) error {
|
||||
_, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -75,8 +78,10 @@ type eventJSONPair struct {
|
||||
EventJSON []byte
|
||||
}
|
||||
|
||||
func (s *eventJSONStatements) bulkSelectEventJSON(eventNIDs []types.EventNID) ([]eventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventJSONStatements) bulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]eventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -91,20 +92,30 @@ func (s *eventStateKeyStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) insertEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
func (s *eventStateKeyStatements) insertEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
err := common.TxStmt(txn, s.insertEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
stmt := common.TxStmt(txn, s.insertEventStateKeyNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) selectEventStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
func (s *eventStateKeyStatements) selectEventStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
var eventStateKeyNID int64
|
||||
err := common.TxStmt(txn, s.selectEventStateKeyNIDStmt).QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
stmt := common.TxStmt(txn, s.selectEventStateKeyNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, eventStateKey).Scan(&eventStateKeyNID)
|
||||
return types.EventStateKeyNID(eventStateKeyNID), err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
|
||||
func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
||||
ctx, pq.StringArray(eventStateKeys),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -122,18 +133,23 @@ func (s *eventStateKeyStatements) bulkSelectEventStateKeyNID(eventStateKeys []st
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) selectEventStateKey(txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID) (string, error) {
|
||||
func (s *eventStateKeyStatements) selectEventStateKey(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNID types.EventStateKeyNID,
|
||||
) (string, error) {
|
||||
var eventStateKey string
|
||||
err := common.TxStmt(txn, s.selectEventStateKeyStmt).QueryRow(eventStateKeyNID).Scan(&eventStateKey)
|
||||
stmt := common.TxStmt(txn, s.selectEventStateKeyStmt)
|
||||
err := stmt.QueryRowContext(ctx, eventStateKeyNID).Scan(&eventStateKey)
|
||||
return eventStateKey, err
|
||||
}
|
||||
|
||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
|
||||
func (s *eventStateKeyStatements) bulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
var nIDs pq.Int64Array
|
||||
for i := range eventStateKeyNIDs {
|
||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectEventStateKeyStmt.Query(nIDs)
|
||||
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -107,20 +108,26 @@ func (s *eventTypeStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) insertEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
func (s *eventTypeStatements) insertEventTypeNID(
|
||||
ctx context.Context, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
err := s.insertEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
err := s.insertEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) selectEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
func (s *eventTypeStatements) selectEventTypeNID(
|
||||
ctx context.Context, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
var eventTypeNID int64
|
||||
err := s.selectEventTypeNIDStmt.QueryRow(eventType).Scan(&eventTypeNID)
|
||||
err := s.selectEventTypeNIDStmt.QueryRowContext(ctx, eventType).Scan(&eventTypeNID)
|
||||
return types.EventTypeNID(eventTypeNID), err
|
||||
}
|
||||
|
||||
func (s *eventTypeStatements) bulkSelectEventTypeNID(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
||||
rows, err := s.bulkSelectEventTypeNIDStmt.Query(pq.StringArray(eventTypes))
|
||||
func (s *eventTypeStatements) bulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
@ -154,7 +155,10 @@ func (s *eventStatements) prepare(db *sql.DB) (err error) {
|
||||
}
|
||||
|
||||
func (s *eventStatements) insertEvent(
|
||||
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
eventTypeNID types.EventTypeNID,
|
||||
eventStateKeyNID types.EventStateKeyNID,
|
||||
eventID string,
|
||||
referenceSHA256 []byte,
|
||||
authEventNIDs []types.EventNID,
|
||||
@ -162,24 +166,28 @@ func (s *eventStatements) insertEvent(
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNID int64
|
||||
var stateNID int64
|
||||
err := s.insertEventStmt.QueryRow(
|
||||
int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID), eventID, referenceSHA256,
|
||||
eventNIDsAsArray(authEventNIDs), depth,
|
||||
err := s.insertEventStmt.QueryRowContext(
|
||||
ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
|
||||
eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
|
||||
).Scan(&eventNID, &stateNID)
|
||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||
}
|
||||
|
||||
func (s *eventStatements) selectEvent(eventID string) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
func (s *eventStatements) selectEvent(
|
||||
ctx context.Context, eventID string,
|
||||
) (types.EventNID, types.StateSnapshotNID, error) {
|
||||
var eventNID int64
|
||||
var stateNID int64
|
||||
err := s.selectEventStmt.QueryRow(eventID).Scan(&eventNID, &stateNID)
|
||||
err := s.selectEventStmt.QueryRowContext(ctx, eventID).Scan(&eventNID, &stateNID)
|
||||
return types.EventNID(eventNID), types.StateSnapshotNID(stateNID), err
|
||||
}
|
||||
|
||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
|
||||
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
|
||||
func (s *eventStatements) bulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -216,8 +224,10 @@ func (s *eventStatements) bulkSelectStateEventByID(eventIDs []string) ([]types.S
|
||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types.StateAtEvent, error) {
|
||||
rows, err := s.bulkSelectStateAtEventByIDStmt.Query(pq.StringArray(eventIDs))
|
||||
func (s *eventStatements) bulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -248,28 +258,40 @@ func (s *eventStatements) bulkSelectStateAtEventByID(eventIDs []string) ([]types
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *eventStatements) updateEventState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
|
||||
_, err := s.updateEventStateStmt.Exec(int64(eventNID), int64(stateNID))
|
||||
func (s *eventStatements) updateEventState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
_, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) selectEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) {
|
||||
err = common.TxStmt(txn, s.selectEventSentToOutputStmt).QueryRow(int64(eventNID)).Scan(&sentToOutput)
|
||||
func (s *eventStatements) selectEventSentToOutput(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||
) (sentToOutput bool, err error) {
|
||||
stmt := common.TxStmt(txn, s.selectEventSentToOutputStmt)
|
||||
stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&sentToOutput)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *eventStatements) updateEventSentToOutput(txn *sql.Tx, eventNID types.EventNID) error {
|
||||
_, err := common.TxStmt(txn, s.updateEventSentToOutputStmt).Exec(int64(eventNID))
|
||||
func (s *eventStatements) updateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
|
||||
stmt := common.TxStmt(txn, s.updateEventSentToOutputStmt)
|
||||
_, err := stmt.ExecContext(ctx, int64(eventNID))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *eventStatements) selectEventID(txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) {
|
||||
err = common.TxStmt(txn, s.selectEventIDStmt).QueryRow(int64(eventNID)).Scan(&eventID)
|
||||
func (s *eventStatements) selectEventID(
|
||||
ctx context.Context, txn *sql.Tx, eventNID types.EventNID,
|
||||
) (eventID string, err error) {
|
||||
stmt := common.TxStmt(txn, s.selectEventIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, int64(eventNID)).Scan(&eventID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error) {
|
||||
rows, err := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt).Query(eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventStatements) bulkSelectStateAtEventAndReference(
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]types.StateAtEventAndReference, error) {
|
||||
stmt := common.TxStmt(txn, s.bulkSelectStateAtEventAndReferenceStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -304,8 +326,10 @@ func (s *eventStatements) bulkSelectStateAtEventAndReference(txn *sql.Tx, eventN
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error) {
|
||||
rows, err := s.bulkSelectEventReferenceStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventStatements) bulkSelectEventReference(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]gomatrixserverlib.EventReference, error) {
|
||||
rows, err := s.bulkSelectEventReferenceStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -325,8 +349,8 @@ func (s *eventStatements) bulkSelectEventReference(eventNIDs []types.EventNID) (
|
||||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
rows, err := s.bulkSelectEventIDStmt.Query(eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventStatements) bulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -349,8 +373,8 @@ func (s *eventStatements) bulkSelectEventID(eventNIDs []types.EventNID) (map[typ
|
||||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectEventNIDStmt.Query(pq.StringArray(eventIDs))
|
||||
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -367,9 +391,10 @@ func (s *eventStatements) bulkSelectEventNID(eventIDs []string) (map[string]type
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *eventStatements) selectMaxEventDepth(eventNIDs []types.EventNID) (int64, error) {
|
||||
func (s *eventStatements) selectMaxEventDepth(ctx context.Context, eventNIDs []types.EventNID) (int64, error) {
|
||||
var result int64
|
||||
err := s.selectMaxEventDepthStmt.QueryRow(eventNIDsAsArray(eventNIDs)).Scan(&result)
|
||||
stmt := s.selectMaxEventDepthStmt
|
||||
err := stmt.QueryRowContext(ctx, eventNIDsAsArray(eventNIDs)).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -91,12 +92,13 @@ func (s *inviteStatements) prepare(db *sql.DB) (err error) {
|
||||
}
|
||||
|
||||
func (s *inviteStatements) insertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
result, err := common.TxStmt(txn, s.insertInviteEventStmt).Exec(
|
||||
inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
result, err := common.TxStmt(txn, s.insertInviteEventStmt).ExecContext(
|
||||
ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -109,9 +111,11 @@ func (s *inviteStatements) insertInviteEvent(
|
||||
}
|
||||
|
||||
func (s *inviteStatements) updateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) ([]string, error) {
|
||||
rows, err := common.TxStmt(txn, s.updateInviteRetiredStmt).Query(roomNID, targetUserNID)
|
||||
stmt := common.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -129,10 +133,11 @@ func (s *inviteStatements) updateInviteRetired(
|
||||
|
||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) selectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.Query(
|
||||
targetUserNID, roomNID,
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -114,34 +115,38 @@ func (s *membershipStatements) prepare(db *sql.DB) (err error) {
|
||||
}
|
||||
|
||||
func (s *membershipStatements) insertMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) error {
|
||||
_, err := common.TxStmt(txn, s.insertMembershipStmt).Exec(roomNID, targetUserNID)
|
||||
stmt := common.TxStmt(txn, s.insertMembershipStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomNID, targetUserNID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) selectMembershipForUpdate(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (membership membershipState, err error) {
|
||||
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRow(
|
||||
roomNID, targetUserNID,
|
||||
err = common.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) selectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership membershipState, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRow(
|
||||
roomNID, targetUserNID,
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) selectMembershipsFromRoom(
|
||||
roomNID types.RoomNID,
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
rows, err := s.selectMembershipsFromRoomStmt.Query(roomNID)
|
||||
rows, err := s.selectMembershipsFromRoomStmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -156,9 +161,11 @@ func (s *membershipStatements) selectMembershipsFromRoom(
|
||||
return
|
||||
}
|
||||
func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID, membership membershipState,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
rows, err := s.selectMembershipsFromRoomAndMembershipStmt.Query(roomNID, membership)
|
||||
stmt := s.selectMembershipsFromRoomAndMembershipStmt
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -174,12 +181,13 @@ func (s *membershipStatements) selectMembershipsFromRoomAndMembership(
|
||||
}
|
||||
|
||||
func (s *membershipStatements) updateMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
senderUserNID types.EventStateKeyNID, membership membershipState,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
_, err := common.TxStmt(txn, s.updateMembershipStmt).Exec(
|
||||
roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
||||
_, err := common.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||
ctx, roomNID, targetUserNID, senderUserNID, membership, eventNID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
@ -73,14 +74,26 @@ func (s *previousEventStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *previousEventStatements) insertPreviousEvent(txn *sql.Tx, previousEventID string, previousEventReferenceSHA256 []byte, eventNID types.EventNID) error {
|
||||
_, err := common.TxStmt(txn, s.insertPreviousEventStmt).Exec(previousEventID, previousEventReferenceSHA256, int64(eventNID))
|
||||
func (s *previousEventStatements) insertPreviousEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
previousEventID string,
|
||||
previousEventReferenceSHA256 []byte,
|
||||
eventNID types.EventNID,
|
||||
) error {
|
||||
stmt := common.TxStmt(txn, s.insertPreviousEventStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if the event reference exists
|
||||
// Returns sql.ErrNoRows if the event reference doesn't exist.
|
||||
func (s *previousEventStatements) selectPreviousEventExists(txn *sql.Tx, eventID string, eventReferenceSHA256 []byte) error {
|
||||
func (s *previousEventStatements) selectPreviousEventExists(
|
||||
ctx context.Context, txn *sql.Tx, eventID string, eventReferenceSHA256 []byte,
|
||||
) error {
|
||||
var ok int64
|
||||
return common.TxStmt(txn, s.selectPreviousEventExistsStmt).QueryRow(eventID, eventReferenceSHA256).Scan(&ok)
|
||||
stmt := common.TxStmt(txn, s.selectPreviousEventExistsStmt)
|
||||
return stmt.QueryRowContext(ctx, eventID, eventReferenceSHA256).Scan(&ok)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
@ -62,22 +63,28 @@ func (s *roomAliasesStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) insertRoomAlias(alias string, roomID string) (err error) {
|
||||
_, err = s.insertRoomAliasStmt.Exec(alias, roomID)
|
||||
func (s *roomAliasesStatements) insertRoomAlias(
|
||||
ctx context.Context, alias string, roomID string,
|
||||
) (err error) {
|
||||
_, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) selectRoomIDFromAlias(alias string) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRow(alias).Scan(&roomID)
|
||||
func (s *roomAliasesStatements) selectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases []string, err error) {
|
||||
func (s *roomAliasesStatements) selectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
) (aliases []string, err error) {
|
||||
aliases = []string{}
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.Query(roomID)
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -94,7 +101,9 @@ func (s *roomAliasesStatements) selectAliasesFromRoomID(roomID string) (aliases
|
||||
return
|
||||
}
|
||||
|
||||
func (s *roomAliasesStatements) deleteRoomAlias(alias string) (err error) {
|
||||
_, err = s.deleteRoomAliasStmt.Exec(alias)
|
||||
func (s *roomAliasesStatements) deleteRoomAlias(
|
||||
ctx context.Context, alias string,
|
||||
) (err error) {
|
||||
_, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias)
|
||||
return
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -81,22 +82,31 @@ func (s *roomStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) insertRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
||||
func (s *roomStatements) insertRoomNID(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
err := common.TxStmt(txn, s.insertRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
||||
stmt := common.TxStmt(txn, s.insertRoomNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) selectRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
||||
func (s *roomStatements) selectRoomNID(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (types.RoomNID, error) {
|
||||
var roomNID int64
|
||||
err := common.TxStmt(txn, s.selectRoomNIDStmt).QueryRow(roomID).Scan(&roomNID)
|
||||
stmt := common.TxStmt(txn, s.selectRoomNIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID)
|
||||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
func (s *roomStatements) selectLatestEventNIDs(
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
var nids pq.Int64Array
|
||||
var stateSnapshotNID int64
|
||||
err := s.selectLatestEventNIDsStmt.QueryRow(int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||
stmt := s.selectLatestEventNIDsStmt
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -107,13 +117,14 @@ func (s *roomStatements) selectLatestEventNIDs(roomNID types.RoomNID) ([]types.E
|
||||
return eventNIDs, types.StateSnapshotNID(stateSnapshotNID), nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID types.RoomNID) (
|
||||
[]types.EventNID, types.EventNID, types.StateSnapshotNID, error,
|
||||
) {
|
||||
func (s *roomStatements) selectLatestEventsNIDsForUpdate(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID,
|
||||
) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) {
|
||||
var nids pq.Int64Array
|
||||
var lastEventSentNID int64
|
||||
var stateSnapshotNID int64
|
||||
err := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt).QueryRow(int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
||||
stmt := common.TxStmt(txn, s.selectLatestEventNIDsForUpdateStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &lastEventSentNID, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
@ -125,11 +136,20 @@ func (s *roomStatements) selectLatestEventsNIDsForUpdate(txn *sql.Tx, roomNID ty
|
||||
}
|
||||
|
||||
func (s *roomStatements) updateLatestEventNIDs(
|
||||
txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID,
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
eventNIDs []types.EventNID,
|
||||
lastEventSentNID types.EventNID,
|
||||
stateSnapshotNID types.StateSnapshotNID,
|
||||
) error {
|
||||
_, err := common.TxStmt(txn, s.updateLatestEventNIDsStmt).Exec(
|
||||
roomNID, eventNIDsAsArray(eventNIDs), int64(lastEventSentNID), int64(stateSnapshotNID),
|
||||
stmt := common.TxStmt(txn, s.updateLatestEventNIDsStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
roomNID,
|
||||
eventNIDsAsArray(eventNIDs),
|
||||
int64(lastEventSentNID),
|
||||
int64(stateSnapshotNID),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
@ -97,9 +98,14 @@ func (s *stateBlockStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBlockNID, entries []types.StateEntry) error {
|
||||
func (s *stateBlockStatements) bulkInsertStateData(
|
||||
ctx context.Context,
|
||||
stateBlockNID types.StateBlockNID,
|
||||
entries []types.StateEntry,
|
||||
) error {
|
||||
for _, entry := range entries {
|
||||
_, err := s.insertStateDataStmt.Exec(
|
||||
_, err := s.insertStateDataStmt.ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
@ -112,18 +118,22 @@ func (s *stateBlockStatements) bulkInsertStateData(stateBlockNID types.StateBloc
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) selectNextStateBlockNID() (types.StateBlockNID, error) {
|
||||
func (s *stateBlockStatements) selectNextStateBlockNID(
|
||||
ctx context.Context,
|
||||
) (types.StateBlockNID, error) {
|
||||
var stateBlockNID int64
|
||||
err := s.selectNextStateBlockNIDStmt.QueryRow().Scan(&stateBlockNID)
|
||||
err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
return types.StateBlockNID(stateBlockNID), err
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
||||
func (s *stateBlockStatements) bulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
nids := make([]int64, len(stateBlockNIDs))
|
||||
for i := range stateBlockNIDs {
|
||||
nids[i] = int64(stateBlockNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateBlockEntriesStmt.Query(pq.Int64Array(nids))
|
||||
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -165,15 +175,20 @@ func (s *stateBlockStatements) bulkSelectStateBlockEntries(stateBlockNIDs []type
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) bulkSelectFilteredStateBlockEntries(
|
||||
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
||||
sort.Sort(tuples)
|
||||
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.Query(
|
||||
stateBlockNIDsAsArray(stateBlockNIDs), eventTypeNIDArray, eventStateKeyNIDArray,
|
||||
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
|
||||
ctx,
|
||||
stateBlockNIDsAsArray(stateBlockNIDs),
|
||||
eventTypeNIDArray,
|
||||
eventStateKeyNIDArray,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -15,29 +15,30 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := stateKeyTupleSorter{
|
||||
{1, 2},
|
||||
{1, 4},
|
||||
{2, 2},
|
||||
{1, 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []types.StateKeyTuple{
|
||||
{1, 1},
|
||||
{1, 2},
|
||||
{1, 4},
|
||||
{2, 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []types.StateKeyTuple{
|
||||
{0, 0},
|
||||
{1, 3},
|
||||
{2, 1},
|
||||
{3, 1},
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
@ -74,21 +75,25 @@ func (s *stateSnapshotStatements) prepare(db *sql.DB) (err error) {
|
||||
}.prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) insertState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) {
|
||||
func (s *stateSnapshotStatements) insertState(
|
||||
ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
nids := make([]int64, len(stateBlockNIDs))
|
||||
for i := range stateBlockNIDs {
|
||||
nids[i] = int64(stateBlockNIDs[i])
|
||||
}
|
||||
err = s.insertStateStmt.QueryRow(int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||
err = s.insertStateStmt.QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
||||
func (s *stateSnapshotStatements) bulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]int64, len(stateNIDs))
|
||||
for i := range stateNIDs {
|
||||
nids[i] = int64(stateNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateBlockNIDsStmt.Query(pq.Int64Array(nids))
|
||||
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
// Import the postgres database driver.
|
||||
@ -43,7 +44,9 @@ func Open(dataSourceName string) (*Database, error) {
|
||||
}
|
||||
|
||||
// StoreEvent implements input.EventDatabase
|
||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []types.EventNID) (types.RoomNID, types.StateAtEvent, error) {
|
||||
func (d *Database) StoreEvent(
|
||||
ctx context.Context, event gomatrixserverlib.Event, authEventNIDs []types.EventNID,
|
||||
) (types.RoomNID, types.StateAtEvent, error) {
|
||||
var (
|
||||
roomNID types.RoomNID
|
||||
eventTypeNID types.EventTypeNID
|
||||
@ -53,11 +56,11 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||
err error
|
||||
)
|
||||
|
||||
if roomNID, err = d.assignRoomNID(nil, event.RoomID()); err != nil {
|
||||
if roomNID, err = d.assignRoomNID(ctx, nil, event.RoomID()); err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
|
||||
if eventTypeNID, err = d.assignEventTypeNID(event.Type()); err != nil {
|
||||
if eventTypeNID, err = d.assignEventTypeNID(ctx, event.Type()); err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
|
||||
@ -65,12 +68,13 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||
// Assigned a numeric ID for the state_key if there is one present.
|
||||
// Otherwise set the numeric ID for the state_key to 0.
|
||||
if eventStateKey != nil {
|
||||
if eventStateKeyNID, err = d.assignStateKeyNID(nil, *eventStateKey); err != nil {
|
||||
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, nil, *eventStateKey); err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if eventNID, stateNID, err = d.statements.insertEvent(
|
||||
ctx,
|
||||
roomNID,
|
||||
eventTypeNID,
|
||||
eventStateKeyNID,
|
||||
@ -81,14 +85,14 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||
); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// We've already inserted the event so select the numeric event ID
|
||||
eventNID, stateNID, err = d.statements.selectEvent(event.EventID())
|
||||
eventNID, stateNID, err = d.statements.selectEvent(ctx, event.EventID())
|
||||
}
|
||||
if err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = d.statements.insertEventJSON(eventNID, event.JSON()); err != nil {
|
||||
if err = d.statements.insertEventJSON(ctx, eventNID, event.JSON()); err != nil {
|
||||
return 0, types.StateAtEvent{}, err
|
||||
}
|
||||
|
||||
@ -104,76 +108,94 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []typ
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) assignRoomNID(txn *sql.Tx, roomID string) (types.RoomNID, error) {
|
||||
func (d *Database) assignRoomNID(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (types.RoomNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
roomNID, err := d.statements.selectRoomNID(txn, roomID)
|
||||
roomNID, err := d.statements.selectRoomNID(ctx, txn, roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
// We don't have a numeric ID so insert one into the database.
|
||||
roomNID, err = d.statements.insertRoomNID(txn, roomID)
|
||||
roomNID, err = d.statements.insertRoomNID(ctx, txn, roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
// We raced with another insert so run the select again.
|
||||
roomNID, err = d.statements.selectRoomNID(txn, roomID)
|
||||
roomNID, err = d.statements.selectRoomNID(ctx, txn, roomID)
|
||||
}
|
||||
}
|
||||
return roomNID, err
|
||||
}
|
||||
|
||||
func (d *Database) assignEventTypeNID(eventType string) (types.EventTypeNID, error) {
|
||||
func (d *Database) assignEventTypeNID(
|
||||
ctx context.Context, eventType string,
|
||||
) (types.EventTypeNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventTypeNID, err := d.statements.selectEventTypeNID(eventType)
|
||||
eventTypeNID, err := d.statements.selectEventTypeNID(ctx, eventType)
|
||||
if err == sql.ErrNoRows {
|
||||
// We don't have a numeric ID so insert one into the database.
|
||||
eventTypeNID, err = d.statements.insertEventTypeNID(eventType)
|
||||
eventTypeNID, err = d.statements.insertEventTypeNID(ctx, eventType)
|
||||
if err == sql.ErrNoRows {
|
||||
// We raced with another insert so run the select again.
|
||||
eventTypeNID, err = d.statements.selectEventTypeNID(eventType)
|
||||
eventTypeNID, err = d.statements.selectEventTypeNID(ctx, eventType)
|
||||
}
|
||||
}
|
||||
return eventTypeNID, err
|
||||
}
|
||||
|
||||
func (d *Database) assignStateKeyNID(txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error) {
|
||||
func (d *Database) assignStateKeyNID(
|
||||
ctx context.Context, txn *sql.Tx, eventStateKey string,
|
||||
) (types.EventStateKeyNID, error) {
|
||||
// Check if we already have a numeric ID in the database.
|
||||
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(txn, eventStateKey)
|
||||
eventStateKeyNID, err := d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||
if err == sql.ErrNoRows {
|
||||
// We don't have a numeric ID so insert one into the database.
|
||||
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(txn, eventStateKey)
|
||||
eventStateKeyNID, err = d.statements.insertEventStateKeyNID(ctx, txn, eventStateKey)
|
||||
if err == sql.ErrNoRows {
|
||||
// We raced with another insert so run the select again.
|
||||
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(txn, eventStateKey)
|
||||
eventStateKeyNID, err = d.statements.selectEventStateKeyNID(ctx, txn, eventStateKey)
|
||||
}
|
||||
}
|
||||
return eventStateKeyNID, err
|
||||
}
|
||||
|
||||
// StateEntriesForEventIDs implements input.EventDatabase
|
||||
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
|
||||
return d.statements.bulkSelectStateEventByID(eventIDs)
|
||||
func (d *Database) StateEntriesForEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
return d.statements.bulkSelectStateEventByID(ctx, eventIDs)
|
||||
}
|
||||
|
||||
// EventTypeNIDs implements state.RoomStateDatabase
|
||||
func (d *Database) EventTypeNIDs(eventTypes []string) (map[string]types.EventTypeNID, error) {
|
||||
return d.statements.bulkSelectEventTypeNID(eventTypes)
|
||||
func (d *Database) EventTypeNIDs(
|
||||
ctx context.Context, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
return d.statements.bulkSelectEventTypeNID(ctx, eventTypes)
|
||||
}
|
||||
|
||||
// EventStateKeyNIDs implements state.RoomStateDatabase
|
||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]types.EventStateKeyNID, error) {
|
||||
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
||||
func (d *Database) EventStateKeyNIDs(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
return d.statements.bulkSelectEventStateKeyNID(ctx, eventStateKeys)
|
||||
}
|
||||
|
||||
// EventStateKeys implements query.RoomserverQueryAPIDatabase
|
||||
func (d *Database) EventStateKeys(eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error) {
|
||||
return d.statements.bulkSelectEventStateKey(eventStateKeyNIDs)
|
||||
func (d *Database) EventStateKeys(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
return d.statements.bulkSelectEventStateKey(ctx, eventStateKeyNIDs)
|
||||
}
|
||||
|
||||
// EventNIDs implements query.RoomserverQueryAPIDatabase
|
||||
func (d *Database) EventNIDs(eventIDs []string) (map[string]types.EventNID, error) {
|
||||
return d.statements.bulkSelectEventNID(eventIDs)
|
||||
func (d *Database) EventNIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) (map[string]types.EventNID, error) {
|
||||
return d.statements.bulkSelectEventNID(ctx, eventIDs)
|
||||
}
|
||||
|
||||
// Events implements input.EventDatabase
|
||||
func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
|
||||
func (d *Database) Events(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]types.Event, error) {
|
||||
eventJSONs, err := d.statements.bulkSelectEventJSON(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -191,78 +213,98 @@ func (d *Database) Events(eventNIDs []types.EventNID) ([]types.Event, error) {
|
||||
}
|
||||
|
||||
// AddState implements input.EventDatabase
|
||||
func (d *Database) AddState(roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) {
|
||||
func (d *Database) AddState(
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
state []types.StateEntry,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
if len(state) > 0 {
|
||||
stateBlockNID, err := d.statements.selectNextStateBlockNID()
|
||||
stateBlockNID, err := d.statements.selectNextStateBlockNID(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err = d.statements.bulkInsertStateData(stateBlockNID, state); err != nil {
|
||||
if err = d.statements.bulkInsertStateData(ctx, stateBlockNID, state); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
||||
}
|
||||
|
||||
return d.statements.insertState(roomNID, stateBlockNIDs)
|
||||
return d.statements.insertState(ctx, roomNID, stateBlockNIDs)
|
||||
}
|
||||
|
||||
// SetState implements input.EventDatabase
|
||||
func (d *Database) SetState(eventNID types.EventNID, stateNID types.StateSnapshotNID) error {
|
||||
return d.statements.updateEventState(eventNID, stateNID)
|
||||
func (d *Database) SetState(
|
||||
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
|
||||
) error {
|
||||
return d.statements.updateEventState(ctx, eventNID, stateNID)
|
||||
}
|
||||
|
||||
// StateAtEventIDs implements input.EventDatabase
|
||||
func (d *Database) StateAtEventIDs(eventIDs []string) ([]types.StateAtEvent, error) {
|
||||
return d.statements.bulkSelectStateAtEventByID(eventIDs)
|
||||
func (d *Database) StateAtEventIDs(
|
||||
ctx context.Context, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
return d.statements.bulkSelectStateAtEventByID(ctx, eventIDs)
|
||||
}
|
||||
|
||||
// StateBlockNIDs implements state.RoomStateDatabase
|
||||
func (d *Database) StateBlockNIDs(stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) {
|
||||
return d.statements.bulkSelectStateBlockNIDs(stateNIDs)
|
||||
func (d *Database) StateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
return d.statements.bulkSelectStateBlockNIDs(ctx, stateNIDs)
|
||||
}
|
||||
|
||||
// StateEntries implements state.RoomStateDatabase
|
||||
func (d *Database) StateEntries(stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) {
|
||||
return d.statements.bulkSelectStateBlockEntries(stateBlockNIDs)
|
||||
func (d *Database) StateEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return d.statements.bulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
||||
}
|
||||
|
||||
// SnapshotNIDFromEventID implements state.RoomStateDatabase
|
||||
func (d *Database) SnapshotNIDFromEventID(eventID string) (types.StateSnapshotNID, error) {
|
||||
_, stateNID, err := d.statements.selectEvent(eventID)
|
||||
func (d *Database) SnapshotNIDFromEventID(
|
||||
ctx context.Context, eventID string,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
_, stateNID, err := d.statements.selectEvent(ctx, eventID)
|
||||
return stateNID, err
|
||||
}
|
||||
|
||||
// EventIDs implements input.RoomEventDatabase
|
||||
func (d *Database) EventIDs(eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
return d.statements.bulkSelectEventID(eventNIDs)
|
||||
func (d *Database) EventIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]string, error) {
|
||||
return d.statements.bulkSelectEventID(ctx, eventNIDs)
|
||||
}
|
||||
|
||||
// GetLatestEventsForUpdate implements input.EventDatabase
|
||||
func (d *Database) GetLatestEventsForUpdate(roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) {
|
||||
func (d *Database) GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) (types.RoomRecentEventsUpdater, error) {
|
||||
txn, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.statements.selectLatestEventsNIDsForUpdate(txn, roomNID)
|
||||
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
|
||||
d.statements.selectLatestEventsNIDsForUpdate(ctx, txn, roomNID)
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(txn, eventNIDs)
|
||||
stateAndRefs, err := d.statements.bulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
var lastEventIDSent string
|
||||
if lastEventNIDSent != 0 {
|
||||
lastEventIDSent, err = d.statements.selectEventID(txn, lastEventNIDSent)
|
||||
lastEventIDSent, err = d.statements.selectEventID(ctx, txn, lastEventNIDSent)
|
||||
if err != nil {
|
||||
txn.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &roomRecentEventsUpdater{
|
||||
transaction{txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||
transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -293,7 +335,7 @@ func (u *roomRecentEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotN
|
||||
// StorePreviousEvents implements types.RoomRecentEventsUpdater
|
||||
func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
|
||||
for _, ref := range previousEventReferences {
|
||||
if err := u.d.statements.insertPreviousEvent(u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
if err := u.d.statements.insertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -302,7 +344,7 @@ func (u *roomRecentEventsUpdater) StorePreviousEvents(eventNID types.EventNID, p
|
||||
|
||||
// IsReferenced implements types.RoomRecentEventsUpdater
|
||||
func (u *roomRecentEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
|
||||
err := u.d.statements.selectPreviousEventExists(u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||
err := u.d.statements.selectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
@ -321,26 +363,26 @@ func (u *roomRecentEventsUpdater) SetLatestEvents(
|
||||
for i := range latest {
|
||||
eventNIDs[i] = latest[i].EventNID
|
||||
}
|
||||
return u.d.statements.updateLatestEventNIDs(u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||
return u.d.statements.updateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
|
||||
}
|
||||
|
||||
// HasEventBeenSent implements types.RoomRecentEventsUpdater
|
||||
func (u *roomRecentEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
|
||||
return u.d.statements.selectEventSentToOutput(u.txn, eventNID)
|
||||
return u.d.statements.selectEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||
}
|
||||
|
||||
// MarkEventAsSent implements types.RoomRecentEventsUpdater
|
||||
func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
||||
return u.d.statements.updateEventSentToOutput(u.txn, eventNID)
|
||||
return u.d.statements.updateEventSentToOutput(u.ctx, u.txn, eventNID)
|
||||
}
|
||||
|
||||
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (types.MembershipUpdater, error) {
|
||||
return u.d.membershipUpdaterTxn(u.txn, u.roomNID, targetUserNID)
|
||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomNID, targetUserNID)
|
||||
}
|
||||
|
||||
// RoomNID implements query.RoomserverQueryAPIDB
|
||||
func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||
roomNID, err := d.statements.selectRoomNID(nil, roomID)
|
||||
func (d *Database) RoomNID(ctx context.Context, roomID string) (types.RoomNID, error) {
|
||||
roomNID, err := d.statements.selectRoomNID(ctx, nil, roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
@ -348,16 +390,18 @@ func (d *Database) RoomNID(roomID string) (types.RoomNID, error) {
|
||||
}
|
||||
|
||||
// LatestEventIDs implements query.RoomserverQueryAPIDatabase
|
||||
func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
|
||||
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(roomNID)
|
||||
func (d *Database) LatestEventIDs(
|
||||
ctx context.Context, roomNID types.RoomNID,
|
||||
) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error) {
|
||||
eventNIDs, currentStateSnapshotNID, err := d.statements.selectLatestEventNIDs(ctx, roomNID)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
references, err := d.statements.bulkSelectEventReference(eventNIDs)
|
||||
references, err := d.statements.bulkSelectEventReference(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
depth, err := d.statements.selectMaxEventDepth(eventNIDs)
|
||||
depth, err := d.statements.selectMaxEventDepth(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
@ -366,40 +410,48 @@ func (d *Database) LatestEventIDs(roomNID types.RoomNID) ([]gomatrixserverlib.Ev
|
||||
|
||||
// GetInvitesForUser implements query.RoomserverQueryAPIDatabase
|
||||
func (d *Database) GetInvitesForUser(
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context,
|
||||
roomNID types.RoomNID,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
) (senderUserIDs []types.EventStateKeyNID, err error) {
|
||||
return d.statements.selectInviteActiveForUserInRoom(targetUserNID, roomNID)
|
||||
return d.statements.selectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
|
||||
}
|
||||
|
||||
// SetRoomAlias implements alias.RoomserverAliasAPIDB
|
||||
func (d *Database) SetRoomAlias(alias string, roomID string) error {
|
||||
return d.statements.insertRoomAlias(alias, roomID)
|
||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string) error {
|
||||
return d.statements.insertRoomAlias(ctx, alias, roomID)
|
||||
}
|
||||
|
||||
// GetRoomIDFromAlias implements alias.RoomserverAliasAPIDB
|
||||
func (d *Database) GetRoomIDFromAlias(alias string) (string, error) {
|
||||
return d.statements.selectRoomIDFromAlias(alias)
|
||||
func (d *Database) GetRoomIDFromAlias(ctx context.Context, alias string) (string, error) {
|
||||
return d.statements.selectRoomIDFromAlias(ctx, alias)
|
||||
}
|
||||
|
||||
// GetAliasesFromRoomID implements alias.RoomserverAliasAPIDB
|
||||
func (d *Database) GetAliasesFromRoomID(roomID string) ([]string, error) {
|
||||
return d.statements.selectAliasesFromRoomID(roomID)
|
||||
func (d *Database) GetAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) {
|
||||
return d.statements.selectAliasesFromRoomID(ctx, roomID)
|
||||
}
|
||||
|
||||
// RemoveRoomAlias implements alias.RoomserverAliasAPIDB
|
||||
func (d *Database) RemoveRoomAlias(alias string) error {
|
||||
return d.statements.deleteRoomAlias(alias)
|
||||
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
|
||||
return d.statements.deleteRoomAlias(ctx, alias)
|
||||
}
|
||||
|
||||
// StateEntriesForTuples implements state.RoomStateDatabase
|
||||
func (d *Database) StateEntriesForTuples(
|
||||
stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple,
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
return d.statements.bulkSelectFilteredStateBlockEntries(stateBlockNIDs, stateKeyTuples)
|
||||
return d.statements.bulkSelectFilteredStateBlockEntries(
|
||||
ctx, stateBlockNIDs, stateKeyTuples,
|
||||
)
|
||||
}
|
||||
|
||||
// MembershipUpdater implements input.RoomEventDatabase
|
||||
func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.MembershipUpdater, error) {
|
||||
func (d *Database) MembershipUpdater(
|
||||
ctx context.Context, roomID, targetUserID string,
|
||||
) (types.MembershipUpdater, error) {
|
||||
txn, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -411,17 +463,17 @@ func (d *Database) MembershipUpdater(roomID, targetUserID string) (types.Members
|
||||
}
|
||||
}()
|
||||
|
||||
roomNID, err := d.assignRoomNID(txn, roomID)
|
||||
roomNID, err := d.assignRoomNID(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUserNID, err := d.assignStateKeyNID(txn, targetUserID)
|
||||
targetUserNID, err := d.assignStateKeyNID(ctx, txn, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updater, err := d.membershipUpdaterTxn(txn, roomNID, targetUserNID)
|
||||
updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -439,20 +491,23 @@ type membershipUpdater struct {
|
||||
}
|
||||
|
||||
func (d *Database) membershipUpdaterTxn(
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
targetUserNID types.EventStateKeyNID,
|
||||
) (types.MembershipUpdater, error) {
|
||||
|
||||
if err := d.statements.insertMembership(txn, roomNID, targetUserNID); err != nil {
|
||||
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
membership, err := d.statements.selectMembershipForUpdate(txn, roomNID, targetUserNID)
|
||||
membership, err := d.statements.selectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &membershipUpdater{
|
||||
transaction{txn}, d, roomNID, targetUserNID, membership,
|
||||
transaction{ctx, txn}, d, roomNID, targetUserNID, membership,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -473,19 +528,19 @@ func (u *membershipUpdater) IsLeave() bool {
|
||||
|
||||
// SetToInvite implements types.MembershipUpdater
|
||||
func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, event.Sender())
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
inserted, err := u.d.statements.insertInviteEvent(
|
||||
u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if u.membership != membershipStateInvite {
|
||||
if err = u.d.statements.updateMembership(
|
||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateInvite, 0,
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -497,7 +552,7 @@ func (u *membershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, er
|
||||
func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
|
||||
var inviteEventIDs []string
|
||||
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -505,7 +560,7 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
// If this is a join event update, there is no invite to update
|
||||
if !isUpdate {
|
||||
inviteEventIDs, err = u.d.statements.updateInviteRetired(
|
||||
u.txn, u.roomNID, u.targetUserNID,
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -513,14 +568,15 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
}
|
||||
|
||||
// Look up the NID of the new join event
|
||||
nIDs, err := u.d.EventNIDs([]string{eventID})
|
||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.membership != membershipStateJoin || isUpdate {
|
||||
if err = u.d.statements.updateMembership(
|
||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateJoin, nIDs[eventID],
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||
membershipStateJoin, nIDs[eventID],
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -531,26 +587,27 @@ func (u *membershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
|
||||
|
||||
// SetToLeave implements types.MembershipUpdater
|
||||
func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.txn, senderUserID)
|
||||
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inviteEventIDs, err := u.d.statements.updateInviteRetired(
|
||||
u.txn, u.roomNID, u.targetUserNID,
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Look up the NID of the new leave event
|
||||
nIDs, err := u.d.EventNIDs([]string{eventID})
|
||||
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u.membership != membershipStateLeaveOrBan {
|
||||
if err = u.d.statements.updateMembership(
|
||||
u.txn, u.roomNID, u.targetUserNID, senderUserNID, membershipStateLeaveOrBan, nIDs[eventID],
|
||||
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
|
||||
membershipStateLeaveOrBan, nIDs[eventID],
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -559,19 +616,18 @@ func (u *membershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
|
||||
}
|
||||
|
||||
// GetMembership implements query.RoomserverQueryAPIDB
|
||||
func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID string) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
||||
txn, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer txn.Commit()
|
||||
|
||||
requestSenderUserNID, err := d.assignStateKeyNID(txn, requestSenderUserID)
|
||||
func (d *Database) GetMembership(
|
||||
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
|
||||
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
|
||||
requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
senderMembershipEventNID, senderMembership, err := d.statements.selectMembershipFromRoomAndTarget(roomNID, requestSenderUserNID)
|
||||
senderMembershipEventNID, senderMembership, err :=
|
||||
d.statements.selectMembershipFromRoomAndTarget(
|
||||
ctx, roomNID, requestSenderUserNID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
// The user has never been a member of that room
|
||||
return 0, false, nil
|
||||
@ -583,15 +639,20 @@ func (d *Database) GetMembership(roomNID types.RoomNID, requestSenderUserID stri
|
||||
}
|
||||
|
||||
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
|
||||
func (d *Database) GetMembershipEventNIDsForRoom(roomNID types.RoomNID, joinOnly bool) ([]types.EventNID, error) {
|
||||
func (d *Database) GetMembershipEventNIDsForRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
|
||||
) ([]types.EventNID, error) {
|
||||
if joinOnly {
|
||||
return d.statements.selectMembershipsFromRoomAndMembership(roomNID, membershipStateJoin)
|
||||
return d.statements.selectMembershipsFromRoomAndMembership(
|
||||
ctx, roomNID, membershipStateJoin,
|
||||
)
|
||||
}
|
||||
|
||||
return d.statements.selectMembershipsFromRoom(roomNID)
|
||||
return d.statements.selectMembershipsFromRoom(ctx, roomNID)
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
ctx context.Context
|
||||
txn *sql.Tx
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user