From 33b8143a9597ff8c6b75ea47a588d50dc6e72259 Mon Sep 17 00:00:00 2001 From: Kegsay Date: Thu, 3 Sep 2020 18:27:02 +0100 Subject: [PATCH] Implement more CSS storage functions in roomserver (#1388) --- .../storage/postgres/membership_table.go | 59 +++++++ roomserver/storage/postgres/rooms_table.go | 26 +++ roomserver/storage/shared/storage.go | 150 +++++++++++++++--- .../storage/sqlite3/membership_table.go | 58 +++++++ roomserver/storage/sqlite3/rooms_table.go | 25 +++ roomserver/storage/tables/interface.go | 5 + 6 files changed, 303 insertions(+), 20 deletions(-) diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0799647e..5164f654 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -18,7 +18,9 @@ package postgres import ( "context" "database/sql" + "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -102,6 +108,16 @@ const updateMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid = ANY(" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -112,6 +128,8 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership( } return roomNIDs, nil } + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + roomIDarray := make([]int64, len(roomNIDs)) + for i := range roomNIDs { + roomIDarray[i] = int64(roomNIDs[i]) + } + rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 9d359146..ef1b7891 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -81,6 +81,9 @@ const selectRoomIDsSQL = "" + const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt @@ -91,6 +94,7 @@ type roomStatements struct { selectRoomInfoStmt *sql.Stmt selectRoomIDsStmt *sql.Stmt bulkSelectRoomIDsStmt *sql.Stmt + bulkSelectRoomNIDsStmt *sql.Stmt } func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { @@ -109,6 +113,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) { {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, {&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL}, + {&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL}, }.Prepare(db) } @@ -245,3 +250,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types } return roomIDs, nil } + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + var array pq.StringArray + for _, roomID := range roomIDs { + array = append(array, roomID) + } + rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 5c447d66..a3b33a4f 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" + "sort" csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables" "github.com/matrix-org/dendrite/internal/caching" @@ -13,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -717,25 +719,42 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event { // If no event could be found, returns nil // If there was an issue during the retrieval, returns an error func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) { - /* - roomInfo, err := d.RoomInfo(ctx, roomID) - if err != nil { - return nil, err + roomInfo, err := d.RoomInfo(ctx, roomID) + if err != nil { + return nil, err + } + eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) + if err != nil { + return nil, err + } + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) + if err != nil { + return nil, err + } + entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + if err != nil { + return nil, err + } + // return the event requested + for _, e := range entries { + if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID { + data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID}) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID) + } + ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion) + if err != nil { + return nil, err + } + h := ev.Headered(roomInfo.RoomVersion) + return &h, nil } - eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) - if err != nil { - return nil, err - } - stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey) - if err != nil { - return nil, err - } - blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) - if err != nil { - return nil, err - } - */ - return nil, nil + } + + return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID) } // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). @@ -779,15 +798,106 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear. func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) { - return nil, fmt.Errorf("not implemented yet") + roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs) + if err != nil { + return nil, err + } + userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs) + if err != nil { + return nil, err + } + stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount)) + i := 0 + for nid := range userNIDToCount { + stateKeyNIDs[i] = nid + i++ + } + nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs) + if err != nil { + return nil, err + } + if len(nidToUserID) != len(userNIDToCount) { + return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID)) + } + result := make(map[string]int, len(userNIDToCount)) + for nid, count := range userNIDToCount { + result[nidToUserID[nid]] = count + } + return result, nil } // GetKnownUsers searches all users that userID knows about. func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) { - return nil, fmt.Errorf("not implemented yet") + stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID) + if err != nil { + return nil, err + } + return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit) } // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDs(ctx) } + +// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops +// it should live in this package! + +func (d *Database) loadStateAtSnapshot( + ctx context.Context, stateNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + if err != nil { + return nil, err + } + // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. + stateBlockNIDList := stateBlockNIDLists[0] + + stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) + if err != nil { + return nil, err + } + stateEntriesMap := stateEntryListMap(stateEntryLists) + + // Combine all the state entries for this snapshot. + // The order of state block NIDs in the list tells us the order to combine them in. + var fullState []types.StateEntry + for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + entries, ok := stateEntriesMap.lookup(stateBlockNID) + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) + } + fullState = append(fullState, entries...) + } + + // Stable sort so that the most recent entry for each state key stays + // remains later in the list than the older entries for the same state key. + sort.Stable(stateEntryByStateKeySorter(fullState)) + // Unique returns the last entry and hence the most recent entry for each state key. + fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] + return fullState, nil +} + +type stateEntryListMap []types.StateEntryList + +func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { + list := []types.StateEntryList(m) + i := sort.Search(len(list), func(i int) bool { + return list[i].StateBlockNID >= stateBlockNID + }) + if i < len(list) && list[i].StateBlockNID == stateBlockNID { + ok = true + stateEntries = list[i].StateEntries + } + return +} + +type stateEntryByStateKeySorter []types.StateEntry + +func (s stateEntryByStateKeySorter) Len() int { return len(s) } +func (s stateEntryByStateKeySorter) Less(i, j int) bool { + return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) +} +func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index e850c80b..0d5ce516 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -18,6 +18,8 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -38,6 +40,10 @@ const membershipSchema = ` ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -78,6 +84,16 @@ const updateMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid IN (" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -89,6 +105,7 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt updateMembershipStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -110,6 +127,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -227,3 +245,43 @@ func (s *membershipStatements) SelectRoomsWithMembership( } return roomNIDs, nil } + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + iRoomNIDs := make([]interface{}, len(roomNIDs)) + for i, v := range roomNIDs { + iRoomNIDs[i] = v + } + query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1) + rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index daacf86f..b4564aff 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -72,6 +72,9 @@ const selectRoomIDsSQL = "" + const bulkSelectRoomIDsSQL = "" + "SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)" +const bulkSelectRoomNIDsSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt @@ -252,3 +255,25 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types } return roomIDs, nil } + +func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) { + iRoomIDs := make([]interface{}, len(roomIDs)) + for i, v := range roomIDs { + iRoomIDs[i] = v + } + sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed") + var roomNIDs []types.RoomNID + for rows.Next() { + var roomNID types.RoomNID + if err = rows.Scan(&roomNID); err != nil { + return nil, err + } + roomNIDs = append(roomNIDs, roomNID) + } + return roomNIDs, nil +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 126c27b5..a142f2b1 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -67,6 +67,7 @@ type Rooms interface { SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) SelectRoomIDs(ctx context.Context) ([]string, error) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) + BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) } type Transactions interface { @@ -123,6 +124,10 @@ type Membership interface { SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error) UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error) + // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the + // counts of how many rooms they are joined. + SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) + SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) } type Published interface {