Check that events pass authentication checks. (#4)

* Check that events pass authentication checks.

Record the list of events that the event passes authentication checks
against.
This commit is contained in:
Mark Haines 2017-02-09 16:48:14 +00:00 committed by GitHub
parent 600f56b4b8
commit fc4eb85379
6 changed files with 560 additions and 19 deletions

View File

@ -30,7 +30,13 @@ type InputRoomEvent struct {
Kind int Kind int
// The event JSON for the event to add. // The event JSON for the event to add.
Event []byte Event []byte
// List of state event IDs that authenticate this event.
// These are likely derived from the "auth_events" JSON key of the event.
// But can be different because the "auth_events" key can be incomplete or wrong.
// For example many matrix events forget to reference the m.room.create event even though it is needed for auth.
// (since synapse allows this to happen we have to allow it as well.)
AuthEventIDs []string
// Optional list of state event IDs forming the state before this event. // Optional list of state event IDs forming the state before this event.
// These state events must have already been persisted. // These state events must have already been persisted.
State []string StateEventIDs []string
} }

View File

@ -2,12 +2,26 @@ package input
import ( import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"sort"
) )
// A RoomEventDatabase has the storage APIs needed to store a room event. // A RoomEventDatabase has the storage APIs needed to store a room event.
type RoomEventDatabase interface { type RoomEventDatabase interface {
StoreEvent(event gomatrixserverlib.Event) error // Stores a matrix room event in the database
StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error
// Lookup the state entries for a list of string event IDs
// Returns a sorted list of state entries.
// Returns a error if the there is an error talking to the database
// or if the event IDs aren't in the database.
StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error)
// Lookup the numeric IDs for a list of string event state keys.
// Returns a map from string state key to numeric ID for the state key.
EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error)
// Lookup the Events for a list of numeric event IDs.
// Returns a sorted list of events.
Events(eventNIDs []int64) ([]types.Event, error)
} }
func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error { func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
@ -17,12 +31,16 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
return err return err
} }
if err := db.StoreEvent(event); err != nil { // Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(db, event, input.AuthEventIDs)
if err != nil {
return err return err
} }
// TODO: // Store the event
// * Check that the event passes authentication checks. if err := db.StoreEvent(event, authEventNIDs); err != nil {
return err
}
if input.Kind == api.KindOutlier { if input.Kind == api.KindOutlier {
// For outliers we can stop after we've stored the event itself as it // For outliers we can stop after we've stored the event itself as it
@ -44,3 +62,193 @@ func processRoomEvent(db RoomEventDatabase, input api.InputRoomEvent) error {
// - The changes to the current state of the room. // - The changes to the current state of the room.
panic("Not implemented") panic("Not implemented")
} }
// checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
func checkAuthEvents(db RoomEventDatabase, event gomatrixserverlib.Event, authEventIDs []string) ([]int64, error) {
// Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(authEventIDs)
if err != nil {
return nil, err
}
// TODO: check for duplicate state keys here.
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(db, stateNeeded, authStateEntries)
if err != nil {
return nil, err
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
return nil, err
}
// Return the numeric IDs for the auth events.
result := make([]int64, len(authStateEntries))
for i := range authStateEntries {
result[i] = authStateEntries[i].EventNID
}
return result, nil
}
type authEvents struct {
stateKeyNIDMap map[string]int64
state stateEntryMap
events eventMap
}
// Create implements gomatrixserverlib.AuthEvents
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
}
// PowerLevels implements gomatrixserverlib.AuthEvents
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
}
// JoinRules implements gomatrixserverlib.AuthEvents
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
}
// Memmber implements gomatrixserverlib.AuthEvents
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
}
// ThirdPartyInvite implements gomatrixserverlib.AuthEvents
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
}
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID int64) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, types.EmptyStateKeyNID})
if !ok {
return nil
}
event, ok := ae.events.lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
func (ae *authEvents) lookupEvent(typeNID int64, stateKey string) *gomatrixserverlib.Event {
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
if !ok {
return nil
}
eventNID, ok := ae.state.lookup(types.StateKeyTuple{typeNID, stateKeyNID})
if !ok {
return nil
}
event, ok := ae.events.lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
// loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents(
db RoomEventDatabase,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
// Lookup the numeric IDs for the state keys needed for auth.
var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(neededStateKeys); err != nil {
return
}
// Load the events we need.
result.state = state
var eventNIDs []int64
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
for _, keyTuple := range keyTuplesNeeded {
eventNID, ok := result.state.lookup(keyTuple)
if ok {
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(eventNIDs); err != nil {
return
}
return
}
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(stateKeyNIDMap map[string]int64, stateNeeded gomatrixserverlib.StateNeeded) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomCreateNID, types.EmptyStateKeyNID})
}
if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomPowerLevelsNID, types.EmptyStateKeyNID})
}
if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomJoinRulesNID, types.EmptyStateKeyNID})
}
for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomMemberNID, stateKeyNID})
}
}
for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{types.MRoomThirdPartyInviteNID, stateKeyNID})
}
}
return keyTuples
}
// Map from event type, state key tuple to numeric event ID.
// Implemented using binary search on a sorted array.
type stateEntryMap []types.StateEntry
// lookup an entry in the event map.
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID int64, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size and are controlled by us.
list := []types.StateEntry(m)
i := sort.Search(len(list), func(i int) bool {
return !list[i].StateKeyTuple.LessThan(stateKey)
})
if i < len(list) && list[i].StateKeyTuple == stateKey {
ok = true
eventNID = list[i].EventNID
}
return
}
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
type eventMap []types.Event
// lookup an entry in the event map.
func (m eventMap) lookup(eventNID int64) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size are controlled by us.
list := []types.Event(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].EventNID >= eventNID
})
if i < len(list) && list[i].EventNID == eventNID {
ok = true
event = &list[i]
}
return
}

View File

@ -0,0 +1,112 @@
package input
import (
"github.com/matrix-org/dendrite/roomserver/types"
"testing"
)
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
var list []types.StateEntry
for i := int64(0); i < entries; i++ {
list = append(list, types.StateEntry{types.StateKeyTuple{i, i}, i})
}
for i := 0; i < b.N; i++ {
entryMap := stateEntryMap(list)
for j := int64(0); j < lookups; j++ {
entryMap.lookup(types.StateKeyTuple{j, j})
}
}
}
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
benchmarkStateEntryMapLookup(100, 10, b)
}
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 100, b)
}
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(100, 100, b)
}
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 10000, b)
}
func TestStateEntryMap(t *testing.T) {
entryMap := stateEntryMap([]types.StateEntry{
{types.StateKeyTuple{1, 1}, 1},
{types.StateKeyTuple{1, 3}, 2},
{types.StateKeyTuple{2, 1}, 3},
})
testCases := []struct {
inputTypeNID int64
inputStateKey int64
wantOK bool
wantEventNID int64
}{
// Check that tuples that in the array are in the map.
{1, 1, true, 1},
{1, 3, true, 2},
{2, 1, true, 3},
// Check that tuples that aren't in the array aren't in the map.
{0, 0, false, 0},
{1, 2, false, 0},
{3, 1, false, 0},
}
for _, testCase := range testCases {
keyTuple := types.StateKeyTuple{testCase.inputTypeNID, testCase.inputStateKey}
gotEventNID, gotOK := entryMap.lookup(keyTuple)
if testCase.wantOK != gotOK {
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
}
if testCase.wantEventNID != gotEventNID {
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
}
}
}
func TestEventMap(t *testing.T) {
events := eventMap([]types.Event{
{EventNID: 1},
{EventNID: 2},
{EventNID: 3},
{EventNID: 5},
{EventNID: 8},
})
testCases := []struct {
inputEventNID int64
wantOK bool
wantEvent *types.Event
}{
// Check that the IDs that are in the array are in the map.
{1, true, &events[0]},
{2, true, &events[1]},
{3, true, &events[2]},
{5, true, &events[3]},
{8, true, &events[4]},
// Check that tuples that aren't in the array aren't in the map.
{0, false, nil},
{4, false, nil},
{6, false, nil},
{7, false, nil},
{9, false, nil},
}
for _, testCase := range testCases {
gotEvent, gotOK := events.lookup(testCase.inputEventNID)
if testCase.wantOK != gotOK {
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
}
if testCase.wantEvent != gotEvent {
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
}
}
}

View File

@ -2,20 +2,25 @@ package storage
import ( import (
"database/sql" "database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
) )
type statements struct { type statements struct {
selectPartitionOffsetsStmt *sql.Stmt selectPartitionOffsetsStmt *sql.Stmt
upsertPartitionOffsetStmt *sql.Stmt upsertPartitionOffsetStmt *sql.Stmt
insertEventTypeNIDStmt *sql.Stmt insertEventTypeNIDStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt selectEventTypeNIDStmt *sql.Stmt
insertEventStateKeyNIDStmt *sql.Stmt insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt selectEventStateKeyNIDStmt *sql.Stmt
insertRoomNIDStmt *sql.Stmt bulkSelectEventStateKeyNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt insertRoomNIDStmt *sql.Stmt
insertEventStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt
insertEventJSONStmt *sql.Stmt insertEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
} }
func (s *statements) prepare(db *sql.DB) error { func (s *statements) prepare(db *sql.DB) error {
@ -196,6 +201,9 @@ func (s *statements) prepareEventStateKeys(db *sql.DB) (err error) {
if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil { if s.selectEventStateKeyNIDStmt, err = db.Prepare(selectEventStateKeyNIDSQL); err != nil {
return return
} }
if s.bulkSelectEventStateKeyNIDStmt, err = db.Prepare(bulkSelectEventStateKeyNIDSQL); err != nil {
return
}
return return
} }
@ -230,6 +238,12 @@ const insertEventStateKeyNIDSQL = "" +
const selectEventStateKeyNIDSQL = "" + const selectEventStateKeyNIDSQL = "" +
"SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1" "SELECT event_state_key_nid FROM event_state_keys WHERE event_state_key = $1"
// Bulk lookup from string state key to numeric ID for that state key.
// Takes an array of strings as the query parameter.
const bulkSelectEventStateKeyNIDSQL = "" +
"SELECT event_state_key, event_state_key_nid FROM event_state_keys" +
" WHERE event_state_key = ANY($1)"
func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) { func (s *statements) insertEventStateKeyNID(eventStateKey string) (eventStateKeyNID int64, err error) {
err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID) err = s.insertEventStateKeyNIDStmt.QueryRow(eventStateKey).Scan(&eventStateKeyNID)
return return
@ -240,6 +254,25 @@ func (s *statements) selectEventStateKeyNID(eventStateKey string) (eventStateKey
return return
} }
func (s *statements) bulkSelectEventStateKeyNID(eventStateKeys []string) (map[string]int64, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.Query(pq.StringArray(eventStateKeys))
if err != nil {
return nil, err
}
defer rows.Close()
result := make(map[string]int64, len(eventStateKeys))
for rows.Next() {
var stateKey string
var stateKeyNID int64
if err := rows.Scan(&stateKey, &stateKeyNID); err != nil {
return nil, err
}
result[stateKey] = stateKeyNID
}
return result, nil
}
func (s *statements) prepareRooms(db *sql.DB) (err error) { func (s *statements) prepareRooms(db *sql.DB) (err error) {
_, err = db.Exec(roomsSchema) _, err = db.Exec(roomsSchema)
if err != nil { if err != nil {
@ -307,17 +340,27 @@ CREATE TABLE IF NOT EXISTS events (
event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE, event_id TEXT NOT NULL CONSTRAINT event_id_unique UNIQUE,
-- The sha256 reference hash for the event. -- The sha256 reference hash for the event.
-- Needed for setting reference hashes when sending new events. -- Needed for setting reference hashes when sending new events.
reference_sha256 BYTEA NOT NULL reference_sha256 BYTEA NOT NULL,
-- A list of numeric IDs for events that can authenticate this event.
auth_event_nids BIGINT[] NOT NULL,
); );
` `
const insertEventSQL = "" + const insertEventSQL = "" +
"INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256)" + "INSERT INTO events (room_nid, event_type_nid, event_state_key_nid, event_id, reference_sha256, auth_event_nids)" +
" VALUES ($1, $2, $3, $4, $5)" + " VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT event_id_unique" + " ON CONFLICT ON CONSTRAINT event_id_unique" +
" DO UPDATE SET event_id = $1" + " DO UPDATE SET event_id = $1" +
" RETURNING event_nid" " RETURNING event_nid"
// Bulk lookup of events by string ID.
// Sort by the numeric IDs for event type and state key.
// This means we can use binary search to lookup entries by type and state key.
const bulkSelectStateEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM events" +
" WHERE event_id = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
func (s *statements) prepareEvents(db *sql.DB) (err error) { func (s *statements) prepareEvents(db *sql.DB) (err error) {
_, err = db.Exec(eventsSchema) _, err = db.Exec(eventsSchema)
if err != nil { if err != nil {
@ -326,6 +369,9 @@ func (s *statements) prepareEvents(db *sql.DB) (err error) {
if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil { if s.insertEventStmt, err = db.Prepare(insertEventSQL); err != nil {
return return
} }
if s.bulkSelectStateEventByIDStmt, err = db.Prepare(bulkSelectStateEventByIDSQL); err != nil {
return
}
return return
} }
@ -333,13 +379,48 @@ func (s *statements) insertEvent(
roomNID, eventTypeNID, eventStateKeyNID int64, roomNID, eventTypeNID, eventStateKeyNID int64,
eventID string, eventID string,
referenceSHA256 []byte, referenceSHA256 []byte,
authEventNIDs []int64,
) (eventNID int64, err error) { ) (eventNID int64, err error) {
err = s.insertEventStmt.QueryRow( err = s.insertEventStmt.QueryRow(
roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256, roomNID, eventTypeNID, eventStateKeyNID, eventID, referenceSHA256,
pq.Int64Array(authEventNIDs),
).Scan(&eventNID) ).Scan(&eventNID)
return return
} }
func (s *statements) bulkSelectStateEventByID(eventIDs []string) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.Query(pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
defer rows.Close()
// We know that we will only get as many results as event IDs
// because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventNID,
&result.EventTypeNID,
&result.EventStateKeyNID,
); err != nil {
return nil, err
}
}
if i != len(eventIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
// If this turns out to be impossible and we do need the debug information here, it would be better
// to do it as a separate query rather than slowing down/complicating the common case.
return nil, fmt.Errorf("storage: state event IDs missing from the database (%d != %d)", i, len(eventIDs))
}
return results, err
}
func (s *statements) prepareEventJSON(db *sql.DB) (err error) { func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
_, err = db.Exec(eventJSONSchema) _, err = db.Exec(eventJSONSchema)
if err != nil { if err != nil {
@ -348,6 +429,9 @@ func (s *statements) prepareEventJSON(db *sql.DB) (err error) {
if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil { if s.insertEventJSONStmt, err = db.Prepare(insertEventJSONSQL); err != nil {
return return
} }
if s.bulkSelectEventJSONStmt, err = db.Prepare(bulkSelectEventJSONSQL); err != nil {
return
}
return return
} }
@ -372,7 +456,41 @@ const insertEventJSONSQL = "" +
"INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" + "INSERT INTO event_json (event_nid, event_json) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING" " ON CONFLICT DO NOTHING"
// Bulk event JSON lookup by numeric event ID.
// Sort by the numeric event ID.
// This means that we can use binary search to lookup by numeric event ID.
const bulkSelectEventJSONSQL = "" +
"SELECT event_nid, event_json FROM event_json" +
" WHERE event_nid = ANY($1)" +
" ORDER BY event_nid ASC"
func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error { func (s *statements) insertEventJSON(eventNID int64, eventJSON []byte) error {
_, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON) _, err := s.insertEventJSONStmt.Exec(eventNID, eventJSON)
return err return err
} }
type eventJSONPair struct {
EventNID int64
EventJSON []byte
}
func (s *statements) bulkSelectEventJSON(eventNIDs []int64) ([]eventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.Query(pq.Int64Array(eventNIDs))
if err != nil {
return nil, err
}
defer rows.Close()
// We know that we will only get as many results as event NIDs
// because of the unique constraint on event NIDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]eventJSONPair, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
if err := rows.Scan(&results[i].EventNID, &results[i].EventJSON); err != nil {
return nil, err
}
}
return results[:i], nil
}

View File

@ -38,7 +38,7 @@ func (d *Database) SetPartitionOffset(topic string, partition int32, offset int6
} }
// StoreEvent implements input.EventDatabase // StoreEvent implements input.EventDatabase
func (d *Database) StoreEvent(event gomatrixserverlib.Event) error { func (d *Database) StoreEvent(event gomatrixserverlib.Event, authEventNIDs []int64) error {
var ( var (
roomNID int64 roomNID int64
eventTypeNID int64 eventTypeNID int64
@ -70,6 +70,7 @@ func (d *Database) StoreEvent(event gomatrixserverlib.Event) error {
eventStateKeyNID, eventStateKeyNID,
event.EventID(), event.EventID(),
event.EventReference().EventSHA256, event.EventReference().EventSHA256,
authEventNIDs,
); err != nil { ); err != nil {
return err return err
} }
@ -115,3 +116,32 @@ func (d *Database) assignStateKeyNID(eventStateKey string) (int64, error) {
} }
return eventStateKeyNID, nil return eventStateKeyNID, nil
} }
// StateEntriesForEventIDs implements input.EventDatabase
func (d *Database) StateEntriesForEventIDs(eventIDs []string) ([]types.StateEntry, error) {
return d.statements.bulkSelectStateEventByID(eventIDs)
}
// EventStateKeyNIDs implements input.EventDatabase
func (d *Database) EventStateKeyNIDs(eventStateKeys []string) (map[string]int64, error) {
return d.statements.bulkSelectEventStateKeyNID(eventStateKeys)
}
// Events implements input.EventDatabase
func (d *Database) Events(eventNIDs []int64) ([]types.Event, error) {
eventJSONs, err := d.statements.bulkSelectEventJSON(eventNIDs)
if err != nil {
return nil, err
}
results := make([]types.Event, len(eventJSONs))
for i, eventJSON := range eventJSONs {
result := &results[i]
result.EventNID = eventJSON.EventNID
// TODO: Use NewEventFromTrustedJSON for efficiency
result.Event, err = gomatrixserverlib.NewEventFromUntrustedJSON(eventJSON.EventJSON)
if err != nil {
return nil, err
}
}
return results, nil
}

View File

@ -1,6 +1,10 @@
// Package types provides the types that are used internally within the roomserver. // Package types provides the types that are used internally within the roomserver.
package types package types
import (
"github.com/matrix-org/gomatrixserverlib"
)
// A PartitionOffset is the offset into a partition of the input log. // A PartitionOffset is the offset into a partition of the input log.
type PartitionOffset struct { type PartitionOffset struct {
// The ID of the partition. // The ID of the partition.
@ -8,3 +12,66 @@ type PartitionOffset struct {
// The offset into the partition. // The offset into the partition.
Offset int64 Offset int64
} }
// A StateKeyTuple is a pair of a numeric event type and a numeric state key.
// It is used to lookup state entries.
type StateKeyTuple struct {
// The numeric ID for the event type.
EventTypeNID int64
// The numeric ID for the state key.
EventStateKeyNID int64
}
// LessThan returns true if this state key is less than the other state key.
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
func (a StateKeyTuple) LessThan(b StateKeyTuple) bool {
if a.EventTypeNID != b.EventTypeNID {
return a.EventTypeNID < b.EventTypeNID
}
return a.EventStateKeyNID < b.EventStateKeyNID
}
// A StateEntry is an entry in the room state of a matrix room.
type StateEntry struct {
StateKeyTuple
// The numeric ID for the event.
EventNID int64
}
// LessThan returns true if this state entry is less than the other state entry.
// The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries.
func (a StateEntry) LessThan(b StateEntry) bool {
if a.StateKeyTuple != b.StateKeyTuple {
return a.StateKeyTuple.LessThan(b.StateKeyTuple)
}
return a.EventNID < b.EventNID
}
// An Event is a gomatrixserverlib.Event with the numeric event ID attached.
// It is when performing bulk event lookup in the database.
type Event struct {
EventNID int64
gomatrixserverlib.Event
}
const (
// MRoomCreateNID is the numeric ID for the "m.room.create" event type.
MRoomCreateNID = 1
// MRoomPowerLevelsNID is the numeric ID for the "m.room.power_levels" event type.
MRoomPowerLevelsNID = 2
// MRoomJoinRulesNID is the numeric ID for the "m.room.join_rules" event type.
MRoomJoinRulesNID = 3
// MRoomThirdPartyInviteNID is the numeric ID for the "m.room.third_party_invite" event type.
MRoomThirdPartyInviteNID = 4
// MRoomMemberNID is the numeric ID for the "m.room.member" event type.
MRoomMemberNID = 5
// MRoomRedactionNID is the numeric ID for the "m.room.redaction" event type.
MRoomRedactionNID = 6
// MRoomHistoryVisibilityNID is the numeric ID for the "m.room.history_visibility" event type.
MRoomHistoryVisibilityNID = 7
)
const (
// EmptyStateKeyNID is the numeric ID for the empty state key.
EmptyStateKeyNID = 1
)