mirror of
https://github.com/1f349/dendrite.git
synced 2024-12-23 08:44:11 +00:00
Check that events pass authentication checks. (#4)
* Check that events pass authentication checks. Record the list of events that the event passes authentication checks against.
This commit is contained in:
parent
600f56b4b8
commit
fc4eb85379
@ -30,7 +30,13 @@ type InputRoomEvent struct {
|
||||
Kind int
|
||||
// The event JSON for the event to add.
|
||||
Event []byte
|
||||
// List of state event IDs that authenticate this event.
|
||||
// These are likely derived from the "auth_events" JSON key of the event.
|
||||
// But can be different because the "auth_events" key can be incomplete or wrong.
|
||||
// For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
|
||||
// (since synapse allows this to happen we have to allow it as well.)
|
||||
AuthEventIDs []string
|
||||
// Optional list of state event IDs forming the state before this event.
|
||||
// These state events must have already been persisted.
|
||||
State []string
|
||||
StateEventIDs []string
|
||||
}
|
||||
|
@ -2,12 +2,26 @@ package input
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// A RoomEventDatabase has the storage APIs needed to store a room event.
|
||||
type RoomEventDatabase interface {
|
||||
StoreEvent(event gomatrixserverlib.Event) error
|
||||
// Stores a matrix room event in the database
|
||||
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error
|
||||
// Lookup the state entries for a list of string event IDs
|
||||
// Returns a sorted list of state entries.
|
||||
// Returns a error if the there is an error talking to the database
|
||||
// or if the event IDs aren't in the database.
|
||||
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
|
||||
// Lookup the numeric IDs for a list of string event state keys.
|
||||
// Returns a map from string state key to numeric ID for the state key.
|
||||
EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error)
|
||||
// Lookup the Events for a list of numeric event IDs.
|
||||
// Returns a sorted list of events.
|
||||
Events(eventNIDs []int64) ([]types.Event, error)
|
||||
}
|
||||
|
||||
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||
@ -17,12 +31,16 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := db.StoreEvent(event); err != nil {
|
||||
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
|
||||
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// * Check that the event passes authentication checks.
|
||||
// Store the event
|
||||
if err := db.StoreEvent(event, authEventNIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if input.Kind == api.KindOutlier {
|
||||
// For outliers we can stop after we've stored the event itself as it
|
||||
@ -44,3 +62,193 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
|
||||
// - The changes to the current state of the room.
|
||||
panic("Not implemented")
|
||||
}
|
||||
|
||||
// checkAuthEvents checks that the event passes authentication checks
|
||||
// Returns the numeric IDs for the auth events.
|
||||
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) {
|
||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: check for duplicate state keys here.
|
||||
|
||||
// Work out which of the state events we actually need.
|
||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the event is allowed.
|
||||
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the numeric IDs for the auth events.
|
||||
result := make([]int64, len(authStateEntries))
|
||||
for i := range authStateEntries {
|
||||
result[i] = authStateEntries[i].EventNID
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type authEvents struct {
|
||||
stateKeyNIDMap map[string]int64
|
||||
state stateEntryMap
|
||||
events eventMap
|
||||
}
|
||||
|
||||
// Create implements gomatrixserverlib.AuthEvents
|
||||
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
|
||||
}
|
||||
|
||||
// PowerLevels implements gomatrixserverlib.AuthEvents
|
||||
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
|
||||
}
|
||||
|
||||
// JoinRules implements gomatrixserverlib.AuthEvents
|
||||
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
|
||||
}
|
||||
|
||||
// Memmber implements gomatrixserverlib.AuthEvents
|
||||
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
|
||||
}
|
||||
|
||||
// ThirdPartyInvite implements gomatrixserverlib.AuthEvents
|
||||
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event {
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
event, ok := ae.events.lookup(eventNID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &event.Event
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event {
|
||||
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
event, ok := ae.events.lookup(eventNID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &event.Event
|
||||
}
|
||||
|
||||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||
func loadAuthEvents(
|
||||
db RoomEventDatabase,
|
||||
needed gomatrixserverlib.StateNeeded,
|
||||
state []types.StateEntry,
|
||||
) (result authEvents, err error) {
|
||||
// Lookup the numeric IDs for the state keys needed for auth.
|
||||
var neededStateKeys []string
|
||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Load the events we need.
|
||||
result.state = state
|
||||
var eventNIDs []int64
|
||||
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
|
||||
for _, keyTuple := range keyTuplesNeeded {
|
||||
eventNID, ok := result.state.lookup(keyTuple)
|
||||
if ok {
|
||||
eventNIDs = append(eventNIDs, eventNID)
|
||||
}
|
||||
}
|
||||
if result.events, err = db.Events(eventNIDs); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
|
||||
var keyTuples []types.StateKeyTuple
|
||||
if stateNeeded.Create {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
|
||||
}
|
||||
if stateNeeded.PowerLevels {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
|
||||
}
|
||||
if stateNeeded.JoinRules {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
|
||||
}
|
||||
for _, member := range stateNeeded.Member {
|
||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
|
||||
}
|
||||
}
|
||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
|
||||
}
|
||||
}
|
||||
return keyTuples
|
||||
}
|
||||
|
||||
// Map from event type, state key tuple to numeric event ID.
|
||||
// Implemented using binary search on a sorted array.
|
||||
type stateEntryMap []types.StateEntry
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
// size and are controlled by us.
|
||||
list := []types.StateEntry(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return !list[i].StateKeyTuple.LessThan(stateKey)
|
||||
})
|
||||
if i < len(list) && list[i].StateKeyTuple == stateKey {
|
||||
ok = true
|
||||
eventNID = list[i].EventNID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Map from numeric event ID to event.
|
||||
// Implemented using binary search on a sorted array.
|
||||
type eventMap []types.Event
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
// size are controlled by us.
|
||||
list := []types.Event(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return list[i].EventNID >= eventNID
|
||||
})
|
||||
if i < len(list) && list[i].EventNID == eventNID {
|
||||
ok = true
|
||||
event = &list[i]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -0,0 +1,112 @@
|
||||
package input
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
|
||||
var list []types.StateEntry
|
||||
for i := int64(0); i < entries; i++ {
|
||||
list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i})
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
entryMap := stateEntryMap(list)
|
||||
for j := int64(0); j < lookups; j++ {
|
||||
entryMap.lookup(types.StateKeyTuple{j, j})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(100, 10, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(1000, 100, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(100, 100, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(1000, 10000, b)
|
||||
}
|
||||
|
||||
func TestStateEntryMap(t *testing.T) {
|
||||
entryMap := stateEntryMap([]types.StateEntry{
|
||||
{types.StateKeyTuple{1, 1}, 1},
|
||||
{types.StateKeyTuple{1, 3}, 2},
|
||||
{types.StateKeyTuple{2, 1}, 3},
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputTypeNID int64
|
||||
inputStateKey int64
|
||||
wantOK bool
|
||||
wantEventNID int64
|
||||
}{
|
||||
// Check that tuples that in the array are in the map.
|
||||
{1, 1, true, 1},
|
||||
{1, 3, true, 2},
|
||||
{2, 1, true, 3},
|
||||
// Check that tuples that aren't in the array aren't in the map.
|
||||
{0, 0, false, 0},
|
||||
{1, 2, false, 0},
|
||||
{3, 1, false, 0},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
keyTuple := types.StateKeyTuple{testCase.inputTypeNID, testCase.inputStateKey}
|
||||
gotEventNID, gotOK := entryMap.lookup(keyTuple)
|
||||
if testCase.wantOK != gotOK {
|
||||
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
|
||||
}
|
||||
if testCase.wantEventNID != gotEventNID {
|
||||
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMap(t *testing.T) {
|
||||
events := eventMap([]types.Event{
|
||||
{EventNID: 1},
|
||||
{EventNID: 2},
|
||||
{EventNID: 3},
|
||||
{EventNID: 5},
|
||||
{EventNID: 8},
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputEventNID int64
|
||||
wantOK bool
|
||||
wantEvent *types.Event
|
||||
}{
|
||||
// Check that the IDs that are in the array are in the map.
|
||||
{1, true, &events[0]},
|
||||
{2, true, &events[1]},
|
||||
{3, true, &events[2]},
|
||||
{5, true, &events[3]},
|
||||
{8, true, &events[4]},
|
||||
// Check that tuples that aren't in the array aren't in the map.
|
||||
{0, false, nil},
|
||||
{4, false, nil},
|
||||
{6, false, nil},
|
||||
{7, false, nil},
|
||||
{9, false, nil},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
gotEvent, gotOK := events.lookup(testCase.inputEventNID)
|
||||
if testCase.wantOK != gotOK {
|
||||
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
|
||||
}
|
||||
|
||||
if testCase.wantEvent != gotEvent {
|
||||
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@ -2,20 +2,25 @@ package storage
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
type statements struct {
|
||||
selectPartitionOffsetsStmt *sql.Stmt
|
||||
upsertPartitionOffsetStmt *sql.Stmt
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
selectPartitionOffsetsStmt *sql.Stmt
|
||||
upsertPartitionOffsetStmt *sql.Stmt
|
||||
insertEventTypeNIDStmt *sql.Stmt
|
||||
selectEventTypeNIDStmt *sql.Stmt
|
||||
insertEventStateKeyNIDStmt *sql.Stmt
|
||||
selectEventStateKeyNIDStmt *sql.Stmt
|
||||
bulkSelectEventStateKeyNIDStmt *sql.Stmt
|
||||
insertRoomNIDStmt *sql.Stmt
|
||||
selectRoomNIDStmt *sql.Stmt
|
||||
insertEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
insertEventJSONStmt *sql.Stmt
|
||||
bulkSelectEventJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *statements) prepare(db *sql.DB) error {
|
||||
@ -196,6 +201,9 @@ func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) {
|
||||
if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectEventStateKeyNIDStmt, err = db.Prepare(bulkSelectEventStateKeyNIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -230,6 +238,12 @@ const insertEventStateKeyNIDSQL = "" +
|
||||
const selectEventStateKeyNIDSQL = "" +
|
||||
"SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1"
|
||||
|
||||
// Bulk lookup from string state key to numeric ID for that state key.
|
||||
// Takes an array of strings as the query parameter.
|
||||
const bulkSelectEventStateKeyNIDSQL = "" +
|
||||
"SELECT event_state_key, event_state_key_nid FROM event_state_keys" +
|
||||
" WHERE event_state_key = ANY($1)"
|
||||
|
||||
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
|
||||
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
|
||||
return
|
||||
@ -240,6 +254,25 @@ func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKey
|
||||
return
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result := make(map[string]int64, len(eventStateKeys))
|
||||
for rows.Next() {
|
||||
var stateKey string
|
||||
var stateKeyNID int64
|
||||
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[stateKey] = stateKeyNID
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *statements) prepareRooms(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(roomsSchema)
|
||||
if err != nil {
|
||||
@ -307,17 +340,27 @@ CREATE TABLE IF NOT EXISTS events (
|
||||
event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE,
|
||||
-- The sha256 reference hash for the event.
|
||||
-- Needed for setting reference hashes when sending new events.
|
||||
reference_sha256 BYTEA NOT NULL
|
||||
reference_sha256 BYTEA NOT NULL,
|
||||
-- A list of numeric IDs for events that can authenticate this event.
|
||||
auth_event_nids BIGINT[] NOT NULL,
|
||||
);
|
||||
`
|
||||
|
||||
const insertEventSQL = "" +
|
||||
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)" +
|
||||
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT event_id_unique" +
|
||||
" DO UPDATE SET event_id = $1" +
|
||||
" RETURNING event_nid"
|
||||
|
||||
// Bulk lookup of events by string ID.
|
||||
// Sort by the numeric IDs for event type and state key.
|
||||
// This means we can use binary search to lookup entries by type and state key.
|
||||
const bulkSelectStateEventByIDSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid FROM events" +
|
||||
" WHERE event_id = ANY($1)" +
|
||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||
|
||||
func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(eventsSchema)
|
||||
if err != nil {
|
||||
@ -326,6 +369,9 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) {
|
||||
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -333,13 +379,48 @@ func (s *statements) insertEvent(
|
||||
roomNID, eventTypeNID, eventStateKeyNID int64,
|
||||
eventID string,
|
||||
referenceSHA256 []byte,
|
||||
authEventNIDs []int64,
|
||||
) (eventNID int64, err error) {
|
||||
err = s.insertEventStmt.QueryRow(
|
||||
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256,
|
||||
pq.Int64Array(authEventNIDs),
|
||||
).Scan(&eventNID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
|
||||
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
// We know that we will only get as many results as event IDs
|
||||
// because of the unique constraint on event IDs.
|
||||
// So we can allocate an array of the correct size now.
|
||||
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||
results := make([]types.StateEntry, len(eventIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
if err = rows.Scan(
|
||||
&result.EventNID,
|
||||
&result.EventTypeNID,
|
||||
&result.EventStateKeyNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(eventIDs) {
|
||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||
// to do it as a separate query rather than slowing down/complicating the common case.
|
||||
return nil, fmt.Errorf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs))
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(eventJSONSchema)
|
||||
if err != nil {
|
||||
@ -348,6 +429,9 @@ func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
|
||||
if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectEventJSONStmt, err = db.Prepare(bulkSelectEventJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -372,7 +456,41 @@ const insertEventJSONSQL = "" +
|
||||
"INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
// Bulk event JSON lookup by numeric event ID.
|
||||
// Sort by the numeric event ID.
|
||||
// This means that we can use binary search to lookup by numeric event ID.
|
||||
const bulkSelectEventJSONSQL = "" +
|
||||
"SELECT event_nid, event_json FROM event_json" +
|
||||
" WHERE event_nid = ANY($1)" +
|
||||
" ORDER BY event_nid ASC"
|
||||
|
||||
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error {
|
||||
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON)
|
||||
return err
|
||||
}
|
||||
|
||||
type eventJSONPair struct {
|
||||
EventNID int64
|
||||
EventJSON []byte
|
||||
}
|
||||
|
||||
func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// We know that we will only get as many results as event NIDs
|
||||
// because of the unique constraint on event NIDs.
|
||||
// So we can allocate an array of the correct size now.
|
||||
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
|
||||
results := make([]eventJSONPair, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return results[:i], nil
|
||||
}
|
||||
|
@ -38,7 +38,7 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
|
||||
}
|
||||
|
||||
// StoreEvent implements input.EventDatabase
|
||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event) error {
|
||||
func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error {
|
||||
var (
|
||||
roomNID int64
|
||||
eventTypeNID int64
|
||||
@ -70,6 +70,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event) error {
|
||||
eventStateKeyNID,
|
||||
event.EventID(),
|
||||
event.EventReference().EventSHA256,
|
||||
authEventNIDs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -115,3 +116,32 @@ func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) {
|
||||
}
|
||||
return eventStateKeyNID, nil
|
||||
}
|
||||
|
||||
// StateEntriesForEventIDs implements input.EventDatabase
|
||||
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
|
||||
return d.statements.bulkSelectStateEventByID(eventIDs)
|
||||
}
|
||||
|
||||
// EventStateKeyNIDs implements input.EventDatabase
|
||||
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) {
|
||||
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
|
||||
}
|
||||
|
||||
// Events implements input.EventDatabase
|
||||
func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
|
||||
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results := make([]types.Event, len(eventJSONs))
|
||||
for i, eventJSON := range eventJSONs {
|
||||
result := &results[i]
|
||||
result.EventNID = eventJSON.EventNID
|
||||
// TODO: Use NewEventFromTrustedJSON for efficiency
|
||||
result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
// Package types provides the types that are used internally within the roomserver.
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A PartitionOffset is the offset into a partition of the input log.
|
||||
type PartitionOffset struct {
|
||||
// The ID of the partition.
|
||||
@ -8,3 +12,66 @@ type PartitionOffset struct {
|
||||
// The offset into the partition.
|
||||
Offset int64
|
||||
}
|
||||
|
||||
// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
|
||||
// It is used to lookup state entries.
|
||||
type StateKeyTuple struct {
|
||||
// The numeric ID for the event type.
|
||||
EventTypeNID int64
|
||||
// The numeric ID for the state key.
|
||||
EventStateKeyNID int64
|
||||
}
|
||||
|
||||
// LessThan returns true if this state key is less than the other state key.
|
||||
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
|
||||
func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
|
||||
if a.EventTypeNID != b.EventTypeNID {
|
||||
return a.EventTypeNID < b.EventTypeNID
|
||||
}
|
||||
return a.EventStateKeyNID < b.EventStateKeyNID
|
||||
}
|
||||
|
||||
// A StateEntry is an entry in the room state of a matrix room.
|
||||
type StateEntry struct {
|
||||
StateKeyTuple
|
||||
// The numeric ID for the event.
|
||||
EventNID int64
|
||||
}
|
||||
|
||||
// LessThan returns true if this state entry is less than the other state entry.
|
||||
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
|
||||
func (a StateEntry) LessThan(b StateEntry) bool {
|
||||
if a.StateKeyTuple != b.StateKeyTuple {
|
||||
return a.StateKeyTuple.LessThan(b.StateKeyTuple)
|
||||
}
|
||||
return a.EventNID < b.EventNID
|
||||
}
|
||||
|
||||
// An Event is a gomatrixserverlib.Event with the numeric event ID attached.
|
||||
// It is when performing bulk event lookup in the database.
|
||||
type Event struct {
|
||||
EventNID int64
|
||||
gomatrixserverlib.Event
|
||||
}
|
||||
|
||||
const (
|
||||
// MRoomCreateNID is the numeric ID for the "m.room.create" event type.
|
||||
MRoomCreateNID = 1
|
||||
// MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type.
|
||||
MRoomPowerLevelsNID = 2
|
||||
// MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type.
|
||||
MRoomJoinRulesNID = 3
|
||||
// MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type.
|
||||
MRoomThirdPartyInviteNID = 4
|
||||
// MRoomMemberNID is the numeric ID for the "m.room.member" event type.
|
||||
MRoomMemberNID = 5
|
||||
// MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type.
|
||||
MRoomRedactionNID = 6
|
||||
// MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type.
|
||||
MRoomHistoryVisibilityNID = 7
|
||||
)
|
||||
|
||||
const (
|
||||
// EmptyStateKeyNID is the numeric ID for the empty state key.
|
||||
EmptyStateKeyNID = 1
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user