mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-22 11:41:38 +00:00
Implement more CSS storage functions in roomserver (#1388)
This commit is contained in:
parent
b20386123e
commit
33b8143a95
@ -18,7 +18,9 @@ package postgres
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||||
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
// SELECT FOR UPDATE
|
// SELECT FOR UPDATE
|
||||||
const insertMembershipSQL = "" +
|
const insertMembershipSQL = "" +
|
||||||
@ -102,6 +108,16 @@ const updateMembershipSQL = "" +
|
|||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
// only return users that the user would ordinarily be able to see anyway.
|
||||||
|
var selectKnownUsersSQL = "" +
|
||||||
|
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||||
|
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
" WHERE room_nid = ANY(" +
|
||||||
|
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
selectMembershipForUpdateStmt *sql.Stmt
|
selectMembershipForUpdateStmt *sql.Stmt
|
||||||
@ -112,6 +128,8 @@ type membershipStatements struct {
|
|||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
selectRoomsWithMembershipStmt *sql.Stmt
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
|
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||||
|
selectKnownUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
|
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
|
||||||
|
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||||||
}
|
}
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
|
roomIDarray := make([]int64, len(roomNIDs))
|
||||||
|
for i := range roomNIDs {
|
||||||
|
roomIDarray[i] = int64(roomNIDs[i])
|
||||||
|
}
|
||||||
|
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
|
result := make(map[types.EventStateKeyNID]int)
|
||||||
|
for rows.Next() {
|
||||||
|
var userID types.EventStateKeyNID
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[userID] = count
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
|
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
@ -81,6 +81,9 @@ const selectRoomIDsSQL = "" +
|
|||||||
const bulkSelectRoomIDsSQL = "" +
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
|
const bulkSelectRoomNIDsSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
selectRoomNIDStmt *sql.Stmt
|
selectRoomNIDStmt *sql.Stmt
|
||||||
@ -91,6 +94,7 @@ type roomStatements struct {
|
|||||||
selectRoomInfoStmt *sql.Stmt
|
selectRoomInfoStmt *sql.Stmt
|
||||||
selectRoomIDsStmt *sql.Stmt
|
selectRoomIDsStmt *sql.Stmt
|
||||||
bulkSelectRoomIDsStmt *sql.Stmt
|
bulkSelectRoomIDsStmt *sql.Stmt
|
||||||
|
bulkSelectRoomNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
||||||
@ -109,6 +113,7 @@ func NewPostgresRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||||||
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
{&s.selectRoomInfoStmt, selectRoomInfoSQL},
|
||||||
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
{&s.selectRoomIDsStmt, selectRoomIDsSQL},
|
||||||
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
{&s.bulkSelectRoomIDsStmt, bulkSelectRoomIDsSQL},
|
||||||
|
{&s.bulkSelectRoomNIDsStmt, bulkSelectRoomNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -245,3 +250,24 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||||||
}
|
}
|
||||||
return roomIDs, nil
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
|
var array pq.StringArray
|
||||||
|
for _, roomID := range roomIDs {
|
||||||
|
array = append(array, roomID)
|
||||||
|
}
|
||||||
|
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
csstables "github.com/matrix-org/dendrite/currentstateserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
@ -13,6 +14,7 @@ import (
|
|||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -717,25 +719,42 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
|
|||||||
// If no event could be found, returns nil
|
// If no event could be found, returns nil
|
||||||
// If there was an issue during the retrieval, returns an error
|
// If there was an issue during the retrieval, returns an error
|
||||||
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||||
/*
|
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||||
roomInfo, err := d.RoomInfo(ctx, roomID)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
|
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// return the event requested
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
|
||||||
|
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil, fmt.Errorf("GetStateEvent: no json for event nid %d", e.EventNID)
|
||||||
|
}
|
||||||
|
ev, err := gomatrixserverlib.NewEventFromTrustedJSON(data[0].EventJSON, false, roomInfo.RoomVersion)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
h := ev.Headered(roomInfo.RoomVersion)
|
||||||
|
return &h, nil
|
||||||
}
|
}
|
||||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, fmt.Errorf("GetStateEvent: no event type '%s' with key '%s' exists in room %s", evType, stateKey, roomID)
|
||||||
}
|
|
||||||
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, stateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
blockNIDs, err := d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
|
||||||
@ -779,15 +798,106 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
|||||||
|
|
||||||
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
|
||||||
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
|
||||||
return nil, fmt.Errorf("not implemented yet")
|
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateKeyNIDs := make([]types.EventStateKeyNID, len(userNIDToCount))
|
||||||
|
i := 0
|
||||||
|
for nid := range userNIDToCount {
|
||||||
|
stateKeyNIDs[i] = nid
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nidToUserID) != len(userNIDToCount) {
|
||||||
|
return nil, fmt.Errorf("found %d users but only have state key nids for %d of them", len(userNIDToCount), len(nidToUserID))
|
||||||
|
}
|
||||||
|
result := make(map[string]int, len(userNIDToCount))
|
||||||
|
for nid, count := range userNIDToCount {
|
||||||
|
result[nidToUserID[nid]] = count
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKnownUsers searches all users that userID knows about.
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
|
func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) {
|
||||||
return nil, fmt.Errorf("not implemented yet")
|
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.RoomsTable.SelectRoomIDs(ctx)
|
return d.RoomsTable.SelectRoomIDs(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
|
||||||
|
// it should live in this package!
|
||||||
|
|
||||||
|
func (d *Database) loadStateAtSnapshot(
|
||||||
|
ctx context.Context, stateNID types.StateSnapshotNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
||||||
|
stateBlockNIDList := stateBlockNIDLists[0]
|
||||||
|
|
||||||
|
stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
||||||
|
|
||||||
|
// Combine all the state entries for this snapshot.
|
||||||
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
|
var fullState []types.StateEntry
|
||||||
|
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
||||||
|
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
||||||
|
if !ok {
|
||||||
|
// This should only get hit if the database is corrupt.
|
||||||
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
|
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
||||||
|
}
|
||||||
|
fullState = append(fullState, entries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
|
// remains later in the list than the older entries for the same state key.
|
||||||
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
|
// Unique returns the last entry and hence the most recent entry for each state key.
|
||||||
|
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
||||||
|
return fullState, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateEntryListMap []types.StateEntryList
|
||||||
|
|
||||||
|
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
|
||||||
|
list := []types.StateEntryList(m)
|
||||||
|
i := sort.Search(len(list), func(i int) bool {
|
||||||
|
return list[i].StateBlockNID >= stateBlockNID
|
||||||
|
})
|
||||||
|
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
|
||||||
|
ok = true
|
||||||
|
stateEntries = list[i].StateEntries
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type stateEntryByStateKeySorter []types.StateEntry
|
||||||
|
|
||||||
|
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
|
||||||
|
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
|
||||||
|
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
|
||||||
|
}
|
||||||
|
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
@ -18,6 +18,8 @@ package sqlite3
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
@ -38,6 +40,10 @@ const membershipSchema = `
|
|||||||
);
|
);
|
||||||
`
|
`
|
||||||
|
|
||||||
|
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||||
|
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
|
||||||
|
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
|
||||||
|
|
||||||
// Insert a row in to membership table so that it can be locked by the
|
// Insert a row in to membership table so that it can be locked by the
|
||||||
// SELECT FOR UPDATE
|
// SELECT FOR UPDATE
|
||||||
const insertMembershipSQL = "" +
|
const insertMembershipSQL = "" +
|
||||||
@ -78,6 +84,16 @@ const updateMembershipSQL = "" +
|
|||||||
const selectRoomsWithMembershipSQL = "" +
|
const selectRoomsWithMembershipSQL = "" +
|
||||||
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
|
||||||
|
// joined to. Since this information is used to populate the user directory, we will
|
||||||
|
// only return users that the user would ordinarily be able to see anyway.
|
||||||
|
var selectKnownUsersSQL = "" +
|
||||||
|
"SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
|
||||||
|
"roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
|
" WHERE room_nid IN (" +
|
||||||
|
" SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
|
||||||
|
") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
@ -89,6 +105,7 @@ type membershipStatements struct {
|
|||||||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||||
selectRoomsWithMembershipStmt *sql.Stmt
|
selectRoomsWithMembershipStmt *sql.Stmt
|
||||||
updateMembershipStmt *sql.Stmt
|
updateMembershipStmt *sql.Stmt
|
||||||
|
selectKnownUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
@ -110,6 +127,7 @@ func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||||||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||||
|
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,3 +245,43 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||||||
}
|
}
|
||||||
return roomNIDs, nil
|
return roomNIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||||
|
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||||
|
for i, v := range roomNIDs {
|
||||||
|
iRoomNIDs[i] = v
|
||||||
|
}
|
||||||
|
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
|
||||||
|
result := make(map[types.EventStateKeyNID]int)
|
||||||
|
for rows.Next() {
|
||||||
|
var userID types.EventStateKeyNID
|
||||||
|
var count int
|
||||||
|
if err := rows.Scan(&userID, &count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result[userID] = count
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||||
|
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := []string{}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
|
||||||
|
for rows.Next() {
|
||||||
|
var userID string
|
||||||
|
if err := rows.Scan(&userID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, userID)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
@ -72,6 +72,9 @@ const selectRoomIDsSQL = "" +
|
|||||||
const bulkSelectRoomIDsSQL = "" +
|
const bulkSelectRoomIDsSQL = "" +
|
||||||
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
"SELECT room_id FROM roomserver_rooms WHERE room_nid IN ($1)"
|
||||||
|
|
||||||
|
const bulkSelectRoomNIDsSQL = "" +
|
||||||
|
"SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)"
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertRoomNIDStmt *sql.Stmt
|
insertRoomNIDStmt *sql.Stmt
|
||||||
@ -252,3 +255,25 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||||||
}
|
}
|
||||||
return roomIDs, nil
|
return roomIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||||
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
|
for i, v := range roomIDs {
|
||||||
|
iRoomIDs[i] = v
|
||||||
|
}
|
||||||
|
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||||
|
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
|
||||||
|
var roomNIDs []types.RoomNID
|
||||||
|
for rows.Next() {
|
||||||
|
var roomNID types.RoomNID
|
||||||
|
if err = rows.Scan(&roomNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
roomNIDs = append(roomNIDs, roomNID)
|
||||||
|
}
|
||||||
|
return roomNIDs, nil
|
||||||
|
}
|
||||||
|
@ -67,6 +67,7 @@ type Rooms interface {
|
|||||||
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
|
||||||
SelectRoomIDs(ctx context.Context) ([]string, error)
|
SelectRoomIDs(ctx context.Context) ([]string, error)
|
||||||
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
|
||||||
|
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Transactions interface {
|
type Transactions interface {
|
||||||
@ -123,6 +124,10 @@ type Membership interface {
|
|||||||
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
|
||||||
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID) error
|
||||||
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
|
||||||
|
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
|
||||||
|
// counts of how many rooms they are joined.
|
||||||
|
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
|
||||||
|
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Published interface {
|
type Published interface {
|
||||||
|
Loading…
Reference in New Issue
Block a user