diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index 017682e0..4950e623 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -32,8 +32,8 @@ import ( const MRoomServerACL = "m.room.server_acl" type ServerACLDatabase interface { - // GetKnownRooms returns a list of all rooms we know about. - GetKnownRooms(ctx context.Context) ([]string, error) + // RoomsWithACLs returns all room IDs for rooms with ACLs + RoomsWithACLs(ctx context.Context) ([]string, error) // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. @@ -57,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { } // Look up all of the rooms that the current state server knows about. - rooms, err := db.GetKnownRooms(ctx) + rooms, err := db.RoomsWithACLs(ctx) if err != nil { logrus.WithError(err).Fatalf("Failed to get known rooms") } diff --git a/roomserver/acls/acls_test.go b/roomserver/acls/acls_test.go index efe1d209..09920308 100644 --- a/roomserver/acls/acls_test.go +++ b/roomserver/acls/acls_test.go @@ -116,7 +116,7 @@ var ( type dummyACLDB struct{} -func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) { +func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) { return []string{"1", "2"}, nil } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ef5bc3d1..a4300d0c 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -86,6 +86,9 @@ type RoomserverInternalAPI interface { req *QueryAuthChainRequest, res *QueryAuthChainResponse, ) error + + // RoomsWithACLs returns all room IDs for rooms with ACLs + RoomsWithACLs(ctx context.Context) ([]string, error) } type UserRoomPrivateKeyCreator interface { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 74b01028..c7f8d100 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -1099,3 +1099,8 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, return nil, nil } + +// RoomsWithACLs returns all room IDs for rooms with ACLs +func (r *Queryer) RoomsWithACLs(ctx context.Context) ([]string, error) { + return r.DB.RoomsWithACLs(ctx) +} diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 88e33571..85312efd 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -1284,3 +1284,38 @@ func TestRoomConsumerRecreation(t *testing.T) { wantAckWait := input.MaximumMissingProcessingTime + (time.Second * 10) assert.Equal(t, wantAckWait, info.Config.AckWait) } + +func TestRoomsWithACLs(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + noACLRoom := test.NewRoom(t, alice) + aclRoom := test.NewRoom(t, alice) + + aclRoom.CreateAndInsert(t, alice, "m.room.server_acl", map[string]any{ + "deny": []string{"evilhost.test"}, + "allow": []string{"*"}, + }, test.WithStateKey("")) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType) + defer closeDB() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + // start JetStream listeners + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + + for _, room := range []*test.Room{noACLRoom, aclRoom} { + // Create the rooms + err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false) + assert.NoError(t, err) + } + + // Validate that we only have one ACLd room. + roomsWithACLs, err := rsAPI.RoomsWithACLs(ctx) + assert.NoError(t, err) + assert.Equal(t, []string{aclRoom.ID}, roomsWithACLs) + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 0638252b..a1a722c5 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -170,8 +170,6 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) - // GetKnownRooms returns a list of all rooms we know about. - GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error @@ -193,6 +191,9 @@ type Database interface { MaybeRedactEvent( ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI, ) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error) + + // RoomsWithACLs returns all room IDs for rooms with ACLs + RoomsWithACLs(ctx context.Context) ([]string, error) } type UserRoomKeys interface { diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 1c9cd159..180a03cd 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -68,6 +68,10 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( -- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility) CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5); + +-- The following indexes are used by bulkSelectStateEventByNIDSQL +CREATE INDEX IF NOT EXISTS roomserver_event_event_type_nid_idx ON roomserver_events (event_type_nid); +CREATE INDEX IF NOT EXISTS roomserver_event_state_key_nid_idx ON roomserver_events (event_state_key_nid); ` const insertEventSQL = "" + @@ -147,6 +151,8 @@ const selectRoomNIDsForEventNIDsSQL = "" + const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" +const selectRoomsWithEventTypeNIDSQL = `SELECT DISTINCT room_nid FROM roomserver_events WHERE event_type_nid = $1` + type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt @@ -166,6 +172,7 @@ type eventStatements struct { selectMaxEventDepthStmt *sql.Stmt selectRoomNIDsForEventNIDsStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt + selectRoomsWithEventTypeNIDStmt *sql.Stmt } func CreateEventsTable(db *sql.DB) error { @@ -206,6 +213,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL}, {&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL}, {&s.selectEventRejectedStmt, selectEventRejectedSQL}, + {&s.selectRoomsWithEventTypeNIDStmt, selectRoomsWithEventTypeNIDSQL}, }.Prepare(db) } @@ -582,3 +590,25 @@ func (s *eventStatements) SelectEventRejected( err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected) return } + +func (s *eventStatements) SelectRoomsWithEventTypeNID( + ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID, +) ([]types.RoomNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, eventTypeNID) + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed") + if err != nil { + return nil, err + } + + var roomNIDs []types.RoomNID + var roomNID types.RoomNID + for rows.Next() { + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + + return roomNIDs, rows.Err() +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index bc3820b2..4de6dee4 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -76,9 +76,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" -const selectRoomIDsSQL = "" + - "SELECT room_id FROM roomserver_rooms WHERE array_length(latest_event_nids, 1) > 0" - const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid = ANY($1)" @@ -94,7 +91,6 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt selectRoomVersionsForRoomNIDsStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt - selectRoomIDsStmt *sql.Stmt bulkSelectRoomIDsStmt *sql.Stmt bulkSelectRoomNIDsStmt *sql.Stmt } @@ -116,29 +112,11 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, {&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, - {&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, {&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL}, }.Prepare(db) } -func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) - rows, err := stmt.QueryContext(ctx) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") - var roomIDs []string - var roomID string - for rows.Next() { - if err = rows.Scan(&roomID); err != nil { - return nil, err - } - roomIDs = append(roomIDs, roomID) - } - return roomIDs, rows.Err() -} func (s *roomStatements) InsertRoomNID( ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 682cead6..cde2e656 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1625,9 +1625,24 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } -// GetKnownRooms returns a list of all rooms we know about. -func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { - return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) +func (d *Database) RoomsWithACLs(ctx context.Context) ([]string, error) { + + eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, "m.room.server_acl") + if err != nil { + return nil, err + } + + roomNIDs, err := d.EventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID) + if err != nil { + return nil, err + } + + roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) + if err != nil { + return nil, err + } + + return roomIDs, nil } // ForgetRoom sets a users room to forgotten diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 2c269bce..26401e45 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -44,6 +44,14 @@ const eventsSchema = ` auth_event_nids TEXT NOT NULL DEFAULT '[]', is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); + +-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility) +CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5); + +-- The following indexes are used by bulkSelectStateEventByNIDSQL +CREATE INDEX IF NOT EXISTS roomserver_event_event_type_nid_idx ON roomserver_events (event_type_nid); +CREATE INDEX IF NOT EXISTS roomserver_event_state_key_nid_idx ON roomserver_events (event_state_key_nid); + ` const insertEventSQL = ` @@ -120,6 +128,8 @@ const selectRoomNIDsForEventNIDsSQL = "" + const selectEventRejectedSQL = "" + "SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2" +const selectRoomsWithEventTypeNIDSQL = `SELECT DISTINCT room_nid FROM roomserver_events WHERE event_type_nid = $1` + type eventStatements struct { db *sql.DB insertEventStmt *sql.Stmt @@ -135,6 +145,7 @@ type eventStatements struct { bulkSelectStateAtEventAndReferenceStmt *sql.Stmt bulkSelectEventIDStmt *sql.Stmt selectEventRejectedStmt *sql.Stmt + selectRoomsWithEventTypeNIDStmt *sql.Stmt //bulkSelectEventNIDStmt *sql.Stmt //bulkSelectUnsentEventNIDStmt *sql.Stmt //selectRoomNIDsForEventNIDsStmt *sql.Stmt @@ -192,6 +203,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) { //{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL}, //{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL}, {&s.selectEventRejectedStmt, selectEventRejectedSQL}, + {&s.selectRoomsWithEventTypeNIDStmt, selectRoomsWithEventTypeNIDSQL}, }.Prepare(db) } @@ -682,3 +694,25 @@ func (s *eventStatements) SelectEventRejected( err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected) return } + +func (s *eventStatements) SelectRoomsWithEventTypeNID( + ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID, +) ([]types.RoomNID, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomsWithEventTypeNIDStmt) + rows, err := stmt.QueryContext(ctx, eventTypeNID) + defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed") + if err != nil { + return nil, err + } + + var roomNIDs []types.RoomNID + var roomNID types.RoomNID + for rows.Next() { + if err := rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + + return roomNIDs, rows.Err() +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 22700a71..5034b242 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -65,9 +65,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" + const selectRoomInfoSQL = "" + "SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1" -const selectRoomIDsSQL = "" + - "SELECT room_id FROM roomserver_rooms WHERE latest_event_nids != '[]'" - const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" @@ -87,7 +84,6 @@ type roomStatements struct { updateLatestEventNIDsStmt *sql.Stmt //selectRoomVersionForRoomNIDStmt *sql.Stmt selectRoomInfoStmt *sql.Stmt - selectRoomIDsStmt *sql.Stmt } func CreateRoomsTable(db *sql.DB) error { @@ -108,29 +104,10 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, - {&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } -func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) { - stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt) - rows, err := stmt.QueryContext(ctx) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed") - var roomIDs []string - var roomID string - for rows.Next() { - if err = rows.Scan(&roomID); err != nil { - return nil, err - } - roomIDs = append(roomIDs, roomID) - } - return roomIDs, rows.Err() -} - func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) { var info types.RoomInfo var latestNIDsJSON string diff --git a/roomserver/storage/tables/events_table_test.go b/roomserver/storage/tables/events_table_test.go index 5ed80564..52aeacc2 100644 --- a/roomserver/storage/tables/events_table_test.go +++ b/roomserver/storage/tables/events_table_test.go @@ -2,6 +2,7 @@ package tables_test import ( "context" + "fmt" "testing" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -147,3 +148,38 @@ func Test_EventsTable(t *testing.T) { assert.Equal(t, int64(len(room.Events())+1), maxDepth) }) } + +func TestRoomsWithACL(t *testing.T) { + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + eventStateKeys, closeEventStateKeys := mustCreateEventTypesTable(t, dbType) + defer closeEventStateKeys() + + eventsTable, closeEventsTable := mustCreateEventsTable(t, dbType) + defer closeEventsTable() + + ctx := context.Background() + + // insert the m.room.server_acl event type + eventTypeNID, err := eventStateKeys.InsertEventTypeNID(ctx, nil, "m.room.server_acl") + assert.Nil(t, err) + + // Create ACL'd rooms + var wantRoomNIDs []types.RoomNID + for i := 0; i < 10; i++ { + _, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false) + assert.Nil(t, err) + wantRoomNIDs = append(wantRoomNIDs, types.RoomNID(i)) + } + + // Create non-ACL'd rooms (eventTypeNID+1) + for i := 10; i < 20; i++ { + _, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID+1, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false) + assert.Nil(t, err) + } + + gotRoomNIDs, err := eventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID) + assert.Nil(t, err) + assert.Equal(t, wantRoomNIDs, gotRoomNIDs) + }) +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index b3cb3188..ff810a2b 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -69,6 +69,8 @@ type Events interface { SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error) SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error) SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error) + + SelectRoomsWithEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID) ([]types.RoomNID, error) } type Rooms interface { @@ -80,7 +82,6 @@ type Rooms interface { UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) - SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) } diff --git a/roomserver/storage/tables/rooms_table_test.go b/roomserver/storage/tables/rooms_table_test.go index eddd012c..e97e3e33 100644 --- a/roomserver/storage/tables/rooms_table_test.go +++ b/roomserver/storage/tables/rooms_table_test.go @@ -74,11 +74,6 @@ func TestRoomsTable(t *testing.T) { assert.NoError(t, err) assert.Nil(t, roomInfo) - // There are no rooms with latestEventNIDs yet - roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil) - assert.NoError(t, err) - assert.Equal(t, 0, len(roomIDs)) - roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337}) assert.NoError(t, err) assert.Equal(t, roomVersions[wantRoomNID], room.Version) @@ -86,7 +81,7 @@ func TestRoomsTable(t *testing.T) { _, ok := roomVersions[1337] assert.False(t, ok) - roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337}) + roomIDs, err := tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337}) assert.NoError(t, err) assert.Equal(t, []string{room.ID}, roomIDs)