mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-08 18:16:59 +00:00
Query rooms with ACLs instead of all rooms (#3338)
This now should actually speed up startup times. This is because _many_ rooms (like DMs) don't have room ACLs, this means that we had around 95% pointless DB queries. (as queried on d.m.org)
This commit is contained in:
parent
09f15a3d3f
commit
928c8c8c4a
@ -32,8 +32,8 @@ import (
|
|||||||
const MRoomServerACL = "m.room.server_acl"
|
const MRoomServerACL = "m.room.server_acl"
|
||||||
|
|
||||||
type ServerACLDatabase interface {
|
type ServerACLDatabase interface {
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// RoomsWithACLs returns all room IDs for rooms with ACLs
|
||||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
RoomsWithACLs(ctx context.Context) ([]string, error)
|
||||||
|
|
||||||
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
|
||||||
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
|
||||||
@ -57,7 +57,7 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Look up all of the rooms that the current state server knows about.
|
// Look up all of the rooms that the current state server knows about.
|
||||||
rooms, err := db.GetKnownRooms(ctx)
|
rooms, err := db.RoomsWithACLs(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Fatalf("Failed to get known rooms")
|
logrus.WithError(err).Fatalf("Failed to get known rooms")
|
||||||
}
|
}
|
||||||
|
@ -116,7 +116,7 @@ var (
|
|||||||
|
|
||||||
type dummyACLDB struct{}
|
type dummyACLDB struct{}
|
||||||
|
|
||||||
func (d dummyACLDB) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d dummyACLDB) RoomsWithACLs(ctx context.Context) ([]string, error) {
|
||||||
return []string{"1", "2"}, nil
|
return []string{"1", "2"}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,6 +86,9 @@ type RoomserverInternalAPI interface {
|
|||||||
req *QueryAuthChainRequest,
|
req *QueryAuthChainRequest,
|
||||||
res *QueryAuthChainResponse,
|
res *QueryAuthChainResponse,
|
||||||
) error
|
) error
|
||||||
|
|
||||||
|
// RoomsWithACLs returns all room IDs for rooms with ACLs
|
||||||
|
RoomsWithACLs(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRoomPrivateKeyCreator interface {
|
type UserRoomPrivateKeyCreator interface {
|
||||||
|
@ -1099,3 +1099,8 @@ func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID,
|
|||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RoomsWithACLs returns all room IDs for rooms with ACLs
|
||||||
|
func (r *Queryer) RoomsWithACLs(ctx context.Context) ([]string, error) {
|
||||||
|
return r.DB.RoomsWithACLs(ctx)
|
||||||
|
}
|
||||||
|
@ -1284,3 +1284,38 @@ func TestRoomConsumerRecreation(t *testing.T) {
|
|||||||
wantAckWait := input.MaximumMissingProcessingTime + (time.Second * 10)
|
wantAckWait := input.MaximumMissingProcessingTime + (time.Second * 10)
|
||||||
assert.Equal(t, wantAckWait, info.Config.AckWait)
|
assert.Equal(t, wantAckWait, info.Config.AckWait)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoomsWithACLs(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
noACLRoom := test.NewRoom(t, alice)
|
||||||
|
aclRoom := test.NewRoom(t, alice)
|
||||||
|
|
||||||
|
aclRoom.CreateAndInsert(t, alice, "m.room.server_acl", map[string]any{
|
||||||
|
"deny": []string{"evilhost.test"},
|
||||||
|
"allow": []string{"*"},
|
||||||
|
}, test.WithStateKey(""))
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
cfg, processCtx, closeDB := testrig.CreateConfig(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
|
||||||
|
natsInstance := &jetstream.NATSInstance{}
|
||||||
|
caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
|
||||||
|
// start JetStream listeners
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
|
||||||
|
for _, room := range []*test.Room{noACLRoom, aclRoom} {
|
||||||
|
// Create the rooms
|
||||||
|
err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we only have one ACLd room.
|
||||||
|
roomsWithACLs, err := rsAPI.RoomsWithACLs(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{aclRoom.ID}, roomsWithACLs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -170,8 +170,6 @@ type Database interface {
|
|||||||
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
||||||
// GetKnownUsers searches all users that userID knows about.
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
|
||||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
|
||||||
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
||||||
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
||||||
|
|
||||||
@ -193,6 +191,9 @@ type Database interface {
|
|||||||
MaybeRedactEvent(
|
MaybeRedactEvent(
|
||||||
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
|
ctx context.Context, roomInfo *types.RoomInfo, eventNID types.EventNID, event gomatrixserverlib.PDU, plResolver state.PowerLevelResolver, querier api.QuerySenderIDAPI,
|
||||||
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
) (gomatrixserverlib.PDU, gomatrixserverlib.PDU, error)
|
||||||
|
|
||||||
|
// RoomsWithACLs returns all room IDs for rooms with ACLs
|
||||||
|
RoomsWithACLs(ctx context.Context) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UserRoomKeys interface {
|
type UserRoomKeys interface {
|
||||||
|
@ -68,6 +68,10 @@ CREATE TABLE IF NOT EXISTS roomserver_events (
|
|||||||
|
|
||||||
-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility)
|
-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility)
|
||||||
CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5);
|
CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5);
|
||||||
|
|
||||||
|
-- The following indexes are used by bulkSelectStateEventByNIDSQL
|
||||||
|
CREATE INDEX IF NOT EXISTS roomserver_event_event_type_nid_idx ON roomserver_events (event_type_nid);
|
||||||
|
CREATE INDEX IF NOT EXISTS roomserver_event_state_key_nid_idx ON roomserver_events (event_state_key_nid);
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertEventSQL = "" +
|
const insertEventSQL = "" +
|
||||||
@ -147,6 +151,8 @@ const selectRoomNIDsForEventNIDsSQL = "" +
|
|||||||
const selectEventRejectedSQL = "" +
|
const selectEventRejectedSQL = "" +
|
||||||
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
||||||
|
|
||||||
|
const selectRoomsWithEventTypeNIDSQL = `SELECT DISTINCT room_nid FROM roomserver_events WHERE event_type_nid = $1`
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
@ -166,6 +172,7 @@ type eventStatements struct {
|
|||||||
selectMaxEventDepthStmt *sql.Stmt
|
selectMaxEventDepthStmt *sql.Stmt
|
||||||
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
selectEventRejectedStmt *sql.Stmt
|
selectEventRejectedStmt *sql.Stmt
|
||||||
|
selectRoomsWithEventTypeNIDStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateEventsTable(db *sql.DB) error {
|
func CreateEventsTable(db *sql.DB) error {
|
||||||
@ -206,6 +213,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
|||||||
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
|
{&s.selectMaxEventDepthStmt, selectMaxEventDepthSQL},
|
||||||
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
|
{&s.selectRoomNIDsForEventNIDsStmt, selectRoomNIDsForEventNIDsSQL},
|
||||||
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
|
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
|
||||||
|
{&s.selectRoomsWithEventTypeNIDStmt, selectRoomsWithEventTypeNIDSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,3 +590,25 @@ func (s *eventStatements) SelectEventRejected(
|
|||||||
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
|
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) SelectRoomsWithEventTypeNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithEventTypeNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventTypeNID)
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return roomNIDs, rows.Err()
|
||||||
|
}
|
||||||
|
@ -76,9 +76,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" +
|
|||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
const selectRoomIDsSQL = "" +
|
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE array_length(latest_event_nids, 1) > 0"
|
|
||||||
|
|
||||||
const bulkSelectRoomIDsSQL = "" +
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid = ANY($1)"
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid = ANY($1)"
|
||||||
|
|
||||||
@ -94,7 +91,6 @@ type roomStatements struct {
|
|||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
selectRoomVersionsForRoomNIDsStmt *sql.Stmt
|
selectRoomVersionsForRoomNIDsStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
selectRoomIDsStmt *sql.Stmt
|
|
||||||
bulkSelectRoomIDsStmt *sql.Stmt
|
bulkSelectRoomIDsStmt *sql.Stmt
|
||||||
bulkSelectRoomNIDsStmt *sql.Stmt
|
bulkSelectRoomNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
@ -116,29 +112,11 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
{&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
|
{&s.selectRoomVersionsForRoomNIDsStmt, selectRoomVersionsForRoomNIDsSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
|
||||||
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
||||||
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
|
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
|
||||||
var roomIDs []string
|
|
||||||
var roomID string
|
|
||||||
for rows.Next() {
|
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
roomIDs = append(roomIDs, roomID)
|
|
||||||
}
|
|
||||||
return roomIDs, rows.Err()
|
|
||||||
}
|
|
||||||
func (s *roomStatements) InsertRoomNID(
|
func (s *roomStatements) InsertRoomNID(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
roomID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||||
|
@ -1625,9 +1625,24 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
|||||||
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
func (d *Database) RoomsWithACLs(ctx context.Context) ([]string, error) {
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
|
||||||
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
eventTypeNID, err := d.GetOrCreateEventTypeNID(ctx, "m.room.server_acl")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
roomNIDs, err := d.EventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForgetRoom sets a users room to forgotten
|
// ForgetRoom sets a users room to forgotten
|
||||||
|
@ -44,6 +44,14 @@ const eventsSchema = `
|
|||||||
auth_event_nids TEXT NOT NULL DEFAULT '[]',
|
auth_event_nids TEXT NOT NULL DEFAULT '[]',
|
||||||
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
|
is_rejected BOOLEAN NOT NULL DEFAULT FALSE
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility)
|
||||||
|
CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5);
|
||||||
|
|
||||||
|
-- The following indexes are used by bulkSelectStateEventByNIDSQL
|
||||||
|
CREATE INDEX IF NOT EXISTS roomserver_event_event_type_nid_idx ON roomserver_events (event_type_nid);
|
||||||
|
CREATE INDEX IF NOT EXISTS roomserver_event_state_key_nid_idx ON roomserver_events (event_state_key_nid);
|
||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertEventSQL = `
|
const insertEventSQL = `
|
||||||
@ -120,6 +128,8 @@ const selectRoomNIDsForEventNIDsSQL = "" +
|
|||||||
const selectEventRejectedSQL = "" +
|
const selectEventRejectedSQL = "" +
|
||||||
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
"SELECT is_rejected FROM roomserver_events WHERE room_nid = $1 AND event_id = $2"
|
||||||
|
|
||||||
|
const selectRoomsWithEventTypeNIDSQL = `SELECT DISTINCT room_nid FROM roomserver_events WHERE event_type_nid = $1`
|
||||||
|
|
||||||
type eventStatements struct {
|
type eventStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
@ -135,6 +145,7 @@ type eventStatements struct {
|
|||||||
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
bulkSelectStateAtEventAndReferenceStmt *sql.Stmt
|
||||||
bulkSelectEventIDStmt *sql.Stmt
|
bulkSelectEventIDStmt *sql.Stmt
|
||||||
selectEventRejectedStmt *sql.Stmt
|
selectEventRejectedStmt *sql.Stmt
|
||||||
|
selectRoomsWithEventTypeNIDStmt *sql.Stmt
|
||||||
//bulkSelectEventNIDStmt *sql.Stmt
|
//bulkSelectEventNIDStmt *sql.Stmt
|
||||||
//bulkSelectUnsentEventNIDStmt *sql.Stmt
|
//bulkSelectUnsentEventNIDStmt *sql.Stmt
|
||||||
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
//selectRoomNIDsForEventNIDsStmt *sql.Stmt
|
||||||
@ -192,6 +203,7 @@ func PrepareEventsTable(db *sql.DB) (tables.Events, error) {
|
|||||||
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
|
//{&s.bulkSelectUnsentEventNIDStmt, bulkSelectUnsentEventNIDSQL},
|
||||||
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
|
//{&s.selectRoomNIDForEventNIDStmt, selectRoomNIDForEventNIDSQL},
|
||||||
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
|
{&s.selectEventRejectedStmt, selectEventRejectedSQL},
|
||||||
|
{&s.selectRoomsWithEventTypeNIDStmt, selectRoomsWithEventTypeNIDSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -682,3 +694,25 @@ func (s *eventStatements) SelectEventRejected(
|
|||||||
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
|
err = stmt.QueryRowContext(ctx, roomNID, eventID).Scan(&rejected)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *eventStatements) SelectRoomsWithEventTypeNID(
|
||||||
|
ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID,
|
||||||
|
) ([]types.RoomNID, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithEventTypeNIDStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, eventTypeNID)
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomsWithEventTypeNID: rows.close() failed")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return roomNIDs, rows.Err()
|
||||||
|
}
|
||||||
|
@ -65,9 +65,6 @@ const selectRoomVersionsForRoomNIDsSQL = "" +
|
|||||||
const selectRoomInfoSQL = "" +
|
const selectRoomInfoSQL = "" +
|
||||||
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
"SELECT room_version, room_nid, state_snapshot_nid, latest_event_nids FROM roomserver_rooms WHERE room_id = $1"
|
||||||
|
|
||||||
const selectRoomIDsSQL = "" +
|
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE latest_event_nids != '[]'"
|
|
||||||
|
|
||||||
const bulkSelectRoomIDsSQL = "" +
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
@ -87,7 +84,6 @@ type roomStatements struct {
|
|||||||
updateLatestEventNIDsStmt *sql.Stmt
|
updateLatestEventNIDsStmt *sql.Stmt
|
||||||
//selectRoomVersionForRoomNIDStmt *sql.Stmt
|
//selectRoomVersionForRoomNIDStmt *sql.Stmt
|
||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
selectRoomIDsStmt *sql.Stmt
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateRoomsTable(db *sql.DB) error {
|
func CreateRoomsTable(db *sql.DB) error {
|
||||||
@ -108,29 +104,10 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||||||
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
{&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL},
|
||||||
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
|
//{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL},
|
||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
|
||||||
{&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL},
|
{&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
|
||||||
rows, err := stmt.QueryContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
|
|
||||||
var roomIDs []string
|
|
||||||
var roomID string
|
|
||||||
for rows.Next() {
|
|
||||||
if err = rows.Scan(&roomID); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
roomIDs = append(roomIDs, roomID)
|
|
||||||
}
|
|
||||||
return roomIDs, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||||
var info types.RoomInfo
|
var info types.RoomInfo
|
||||||
var latestNIDsJSON string
|
var latestNIDsJSON string
|
||||||
|
@ -2,6 +2,7 @@ package tables_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
@ -147,3 +148,38 @@ func Test_EventsTable(t *testing.T) {
|
|||||||
assert.Equal(t, int64(len(room.Events())+1), maxDepth)
|
assert.Equal(t, int64(len(room.Events())+1), maxDepth)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoomsWithACL(t *testing.T) {
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
eventStateKeys, closeEventStateKeys := mustCreateEventTypesTable(t, dbType)
|
||||||
|
defer closeEventStateKeys()
|
||||||
|
|
||||||
|
eventsTable, closeEventsTable := mustCreateEventsTable(t, dbType)
|
||||||
|
defer closeEventsTable()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// insert the m.room.server_acl event type
|
||||||
|
eventTypeNID, err := eventStateKeys.InsertEventTypeNID(ctx, nil, "m.room.server_acl")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// Create ACL'd rooms
|
||||||
|
var wantRoomNIDs []types.RoomNID
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
_, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
wantRoomNIDs = append(wantRoomNIDs, types.RoomNID(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create non-ACL'd rooms (eventTypeNID+1)
|
||||||
|
for i := 10; i < 20; i++ {
|
||||||
|
_, _, err = eventsTable.InsertEvent(ctx, nil, types.RoomNID(i), eventTypeNID+1, types.EmptyStateKeyNID, fmt.Sprintf("$1337+%d", i), nil, 0, false)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotRoomNIDs, err := eventsTable.SelectRoomsWithEventTypeNID(ctx, nil, eventTypeNID)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, wantRoomNIDs, gotRoomNIDs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -69,6 +69,8 @@ type Events interface {
|
|||||||
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
|
||||||
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
|
||||||
SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error)
|
SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error)
|
||||||
|
|
||||||
|
SelectRoomsWithEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypeNID types.EventTypeNID) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Rooms interface {
|
type Rooms interface {
|
||||||
@ -80,7 +82,6 @@ type Rooms interface {
|
|||||||
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
|
||||||
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
|
||||||
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
|
||||||
SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error)
|
|
||||||
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
@ -74,11 +74,6 @@ func TestRoomsTable(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, roomInfo)
|
assert.Nil(t, roomInfo)
|
||||||
|
|
||||||
// There are no rooms with latestEventNIDs yet
|
|
||||||
roomIDs, err := tab.SelectRoomIDsWithEvents(ctx, nil)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 0, len(roomIDs))
|
|
||||||
|
|
||||||
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
roomVersions, err := tab.SelectRoomVersionsForRoomNIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
|
assert.Equal(t, roomVersions[wantRoomNID], room.Version)
|
||||||
@ -86,7 +81,7 @@ func TestRoomsTable(t *testing.T) {
|
|||||||
_, ok := roomVersions[1337]
|
_, ok := roomVersions[1337]
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
|
|
||||||
roomIDs, err = tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
roomIDs, err := tab.BulkSelectRoomIDs(ctx, nil, []types.RoomNID{wantRoomNID, 1337})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, []string{room.ID}, roomIDs)
|
assert.Equal(t, []string{room.ID}, roomIDs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user