[pseudoID] More pseudo ID fixes (#3167)

Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
This commit is contained in:
Sam Wedgwood 2023-08-15 12:37:04 +01:00 committed by GitHub
parent fa6c7ba456
commit 9a12420428
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 472 additions and 237 deletions

View File

@ -33,23 +33,36 @@ func GetJoinedRooms(
device *userapi.Device, device *userapi.Device,
rsAPI api.ClientRoomserverAPI, rsAPI api.ClientRoomserverAPI,
) util.JSONResponse { ) util.JSONResponse {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(req.Context(), &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, util.GetLogger(req.Context()).WithError(err).Error("Invalid device user ID")
WantMembership: "join", return util.JSONResponse{
}, &res) Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
rooms, err := rsAPI.QueryRoomsForUser(req.Context(), *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(req.Context()).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{}, JSON: spec.Unknown("internal server error"),
} }
} }
if res.RoomIDs == nil {
res.RoomIDs = []string{} var roomIDStrs []string
if rooms == nil {
roomIDStrs = []string{}
} else {
roomIDStrs = make([]string, len(rooms))
for i, roomID := range rooms {
roomIDStrs[i] = roomID.String()
} }
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: getJoinedRoomsResponse{res.RoomIDs}, JSON: getJoinedRoomsResponse{roomIDStrs},
} }
} }

View File

@ -251,11 +251,15 @@ func updateProfile(
profile *authtypes.Profile, profile *authtypes.Profile,
userID string, evTime time.Time, userID string, evTime time.Time,
) (util.JSONResponse, error) { ) (util.JSONResponse, error) {
var res api.QueryRoomsForUserResponse deviceUserID, err := spec.NewUserID(device.UserID, true)
err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ if err != nil {
UserID: device.UserID, return util.JSONResponse{
WantMembership: "join", Code: http.StatusInternalServerError,
}, &res) JSON: spec.Unknown("internal server error"),
}, err
}
rooms, err := rsAPI.QueryRoomsForUser(ctx, *deviceUserID, "join")
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed") util.GetLogger(ctx).WithError(err).Error("QueryRoomsForUser failed")
return util.JSONResponse{ return util.JSONResponse{
@ -264,6 +268,11 @@ func updateProfile(
}, err }, err
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
_, domain, err := gomatrixserverlib.SplitID('@', userID) _, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil { if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
@ -274,7 +283,7 @@ func updateProfile(
} }
events, err := buildMembershipEvents( events, err := buildMembershipEvents(
ctx, res.RoomIDs, *profile, userID, evTime, rsAPI, ctx, roomIDStrs, *profile, userID, evTime, rsAPI,
) )
switch e := err.(type) { switch e := err.(type) {
case nil: case nil:

View File

@ -316,10 +316,17 @@ func generateSendEvent(
} }
} }
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID) senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
if err != nil || senderID == nil { if err != nil {
return nil, &util.JSONResponse{ return nil, &util.JSONResponse{
Code: http.StatusNotFound, Code: http.StatusInternalServerError,
JSON: spec.NotFound("Unable to find senderID for user"), JSON: spec.NotFound("internal server error"),
}
} else if senderID == nil {
// TODO: is it always the case that lack of a sender ID means they're not joined?
// And should this logic be deferred to the roomserver somehow?
return nil, &util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("not joined to room"),
} }
} }

View File

@ -94,34 +94,42 @@ func SendServerNotice(
} }
} }
userID, err := spec.NewUserID(r.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid user ID"),
}
}
// get rooms for specified user // get rooms for specified user
allUserRooms := []string{} allUserRooms := []spec.RoomID{}
userRooms := api.QueryRoomsForUserResponse{}
// Get rooms the user is either joined, invited or has left. // Get rooms the user is either joined, invited or has left.
for _, membership := range []string{"join", "invite", "leave"} { for _, membership := range []string{"join", "invite", "leave"} {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ userRooms, queryErr := rsAPI.QueryRoomsForUser(ctx, *userID, membership)
UserID: r.UserID, if queryErr != nil {
WantMembership: membership,
}, &userRooms); err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
allUserRooms = append(allUserRooms, userRooms.RoomIDs...) allUserRooms = append(allUserRooms, userRooms...)
} }
// get rooms of the sender // get rooms of the sender
senderUserID := fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName) senderUserID, err := spec.NewUserID(fmt.Sprintf("@%s:%s", cfgNotices.LocalPart, cfgClient.Matrix.ServerName), true)
senderRooms := api.QueryRoomsForUserResponse{} if err != nil {
if err := rsAPI.QueryRoomsForUser(ctx, &api.QueryRoomsForUserRequest{ return util.JSONResponse{
UserID: senderUserID, Code: http.StatusInternalServerError,
WantMembership: "join", JSON: spec.Unknown("internal server error"),
}, &senderRooms); err != nil { }
}
senderRooms, err := rsAPI.QueryRoomsForUser(ctx, *senderUserID, "join")
if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
// check if we have rooms in common // check if we have rooms in common
commonRooms := []string{} commonRooms := []spec.RoomID{}
for _, userRoomID := range allUserRooms { for _, userRoomID := range allUserRooms {
for _, senderRoomID := range senderRooms.RoomIDs { for _, senderRoomID := range senderRooms {
if userRoomID == senderRoomID { if userRoomID == senderRoomID {
commonRooms = append(commonRooms, senderRoomID) commonRooms = append(commonRooms, senderRoomID)
} }
@ -139,7 +147,7 @@ func SendServerNotice(
// create a new room for the user // create a new room for the user
if len(commonRooms) == 0 { if len(commonRooms) == 0 {
powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID) powerLevelContent := eventutil.InitialPowerLevelsContent(senderUserID.String())
powerLevelContent.Users[r.UserID] = -10 // taken from Synapse powerLevelContent.Users[r.UserID] = -10 // taken from Synapse
pl, err := json.Marshal(powerLevelContent) pl, err := json.Marshal(powerLevelContent)
if err != nil { if err != nil {
@ -195,7 +203,7 @@ func SendServerNotice(
} }
} }
roomID = commonRooms[0] roomID = commonRooms[0].String()
membershipRes := api.QueryMembershipForUserResponse{} membershipRes := api.QueryMembershipForUserResponse{}
err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes) err = rsAPI.QueryMembershipForUser(ctx, &api.QueryMembershipForUserRequest{UserID: *deviceUserID, RoomID: roomID}, &membershipRes)
if err != nil { if err != nil {

View File

@ -117,19 +117,27 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse userID, err := spec.NewUserID(m.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: m.UserID, sentry.CaptureException(err)
WantMembership: "join", logger.WithError(err).Error("invalid user ID")
}, &queryRes) return true
}
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *userID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined rooms for user") logger.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, room := range roomIDs {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
@ -179,18 +187,27 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
} }
logger := logrus.WithField("user_id", output.UserID) logger := logrus.WithField("user_id", output.UserID)
var queryRes roomserverAPI.QueryRoomsForUserResponse outputUserID, err := spec.NewUserID(output.UserID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: output.UserID, sentry.CaptureException(err)
WantMembership: "join", logrus.WithError(err).Errorf("invalid user ID")
}, &queryRes) return true
}
rooms, err := t.rsAPI.QueryRoomsForUser(t.ctx, *outputUserID, "join")
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(rooms))
for i, room := range rooms {
roomIDStrs[i] = room.String()
}
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")

View File

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -94,16 +95,23 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
return true return true
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse parsedUserID, err := spec.NewUserID(userID, true)
err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ if err != nil {
UserID: userID, util.GetLogger(ctx).WithError(err).WithField("user_id", userID).Error("invalid user ID")
WantMembership: "join", return true
}, &queryRes) }
roomIDs, err := t.rsAPI.QueryRoomsForUser(t.ctx, *parsedUserID, "join")
if err != nil { if err != nil {
log.WithError(err).Error("failed to calculate joined rooms for user") log.WithError(err).Error("failed to calculate joined rooms for user")
return true return true
} }
roomIDStrs := make([]string, len(roomIDs))
for i, roomID := range roomIDs {
roomIDStrs[i] = roomID.String()
}
presence := msg.Header.Get("presence") presence := msg.Header.Get("presence")
ts, err := strconv.Atoi(msg.Header.Get("last_active_ts")) ts, err := strconv.Atoi(msg.Header.Get("last_active_ts"))
@ -112,7 +120,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
} }
// send this presence to all servers who share rooms with this user. // send this presence to all servers who share rooms with this user.
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) joined, err := t.db.GetJoinedHostsForRooms(t.ctx, roomIDStrs, true, true)
if err != nil { if err != nil {
log.WithError(err).Error("failed to get joined hosts") log.WithError(err).Error("failed to get joined hosts")
return true return true

View File

@ -33,7 +33,7 @@ import (
type fedRoomserverAPI struct { type fedRoomserverAPI struct {
rsapi.FederationRoomserverAPI rsapi.FederationRoomserverAPI
inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) inputRoomEvents func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse)
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error queryRoomsForUser func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
} }
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
@ -54,11 +54,11 @@ func (f *fedRoomserverAPI) InputRoomEvents(ctx context.Context, req *rsapi.Input
} }
// keychange consumer calls this // keychange consumer calls this
func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { func (f *fedRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if f.queryRoomsForUser == nil { if f.queryRoomsForUser == nil {
return nil return nil, nil
} }
return f.queryRoomsForUser(ctx, req, res) return f.queryRoomsForUser(ctx, userID, desiredMembership)
} }
// TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate // TODO: This struct isn't generic, only works for TestFederationAPIJoinThenKeyUpdate
@ -199,18 +199,22 @@ func testFederationAPIJoinThenKeyUpdate(t *testing.T, dbType test.DBType) {
fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID) fmt.Printf("creator: %v joining user: %v\n", creator.ID, joiningUser.ID)
room := test.NewRoom(t, creator) room := test.NewRoom(t, creator)
roomID, err := spec.NewRoomID(room.ID)
if err != nil {
t.Fatalf("Invalid room ID: %q", roomID)
}
rsapi := &fedRoomserverAPI{ rsapi := &fedRoomserverAPI{
inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) { inputRoomEvents: func(ctx context.Context, req *rsapi.InputRoomEventsRequest, res *rsapi.InputRoomEventsResponse) {
if req.Asynchronous { if req.Asynchronous {
t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous") t.Errorf("InputRoomEvents from PerformJoin MUST be synchronous")
} }
}, },
queryRoomsForUser: func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error { queryRoomsForUser: func(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
if req.UserID == joiningUser.ID && req.WantMembership == "join" { if userID.String() == joiningUser.ID && desiredMembership == "join" {
res.RoomIDs = []string{room.ID} return []spec.RoomID{*roomID}, nil
return nil
} }
return fmt.Errorf("unexpected queryRoomsForUser: %+v", *req) return nil, fmt.Errorf("unexpected queryRoomsForUser: %v, %v", userID, desiredMembership)
}, },
} }
fc := &fedClient{ fc := &fedClient{

View File

@ -141,11 +141,28 @@ type QueryRoomHierarchyAPI interface {
QueryNextRoomHierarchyPage(ctx context.Context, walker RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *RoomHierarchyWalker, error) QueryNextRoomHierarchyPage(ctx context.Context, walker RoomHierarchyWalker, limit int) ([]fclient.RoomHierarchyRoom, *RoomHierarchyWalker, error)
} }
type QueryMembershipAPI interface {
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to *types.HeaderedEvent of membership events.
QueryMembershipAtEvent(
ctx context.Context,
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*types.HeaderedEvent, error)
}
// API functions required by the syncapi // API functions required by the syncapi
type SyncRoomserverAPI interface { type SyncRoomserverAPI interface {
QueryLatestEventsAndStateAPI QueryLatestEventsAndStateAPI
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryMembershipAPI
// QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user.
QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
@ -155,12 +172,6 @@ type SyncRoomserverAPI interface {
req *QueryEventsByIDRequest, req *QueryEventsByIDRequest,
res *QueryEventsByIDResponse, res *QueryEventsByIDResponse,
) error ) error
// Query the membership event for an user for a room.
QueryMembershipForUser(
ctx context.Context,
req *QueryMembershipForUserRequest,
res *QueryMembershipForUserResponse,
) error
// Query the state after a list of events in a room from the room server. // Query the state after a list of events in a room from the room server.
QueryStateAfterEvents( QueryStateAfterEvents(
@ -175,14 +186,6 @@ type SyncRoomserverAPI interface {
req *PerformBackfillRequest, req *PerformBackfillRequest,
res *PerformBackfillResponse, res *PerformBackfillResponse,
) error ) error
// QueryMembershipAtEvent queries the memberships at the given events.
// Returns a map from eventID to a slice of types.HeaderedEvent.
QueryMembershipAtEvent(
ctx context.Context,
request *QueryMembershipAtEventRequest,
response *QueryMembershipAtEventResponse,
) error
} }
type AppserviceRoomserverAPI interface { type AppserviceRoomserverAPI interface {
@ -219,7 +222,7 @@ type ClientRoomserverAPI interface {
DefaultRoomVersionAPI DefaultRoomVersionAPI
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error QueryStateAfterEvents(ctx context.Context, req *QueryStateAfterEventsRequest, res *QueryStateAfterEventsResponse) error
// QueryKnownUsers returns a list of users that we know about from our joined rooms. // QueryKnownUsers returns a list of users that we know about from our joined rooms.
QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error QueryKnownUsers(ctx context.Context, req *QueryKnownUsersRequest, res *QueryKnownUsersResponse) error
@ -278,15 +281,12 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI QueryBulkStateContentAPI
QuerySenderIDAPI QuerySenderIDAPI
QueryRoomHierarchyAPI QueryRoomHierarchyAPI
QueryMembershipAPI
UserRoomPrivateKeyCreator UserRoomPrivateKeyCreator
AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error) AssignRoomNID(ctx context.Context, roomID spec.RoomID, roomVersion gomatrixserverlib.RoomVersion) (roomNID types.RoomNID, err error)
SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error) SigningIdentityFor(ctx context.Context, roomID spec.RoomID, senderID spec.UserID) (fclient.SigningIdentity, error)
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipForSenderID(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID. // which room to use by querying the first events roomID.
@ -300,7 +300,7 @@ type FederationRoomserverAPI interface {
QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error QueryMissingEvents(ctx context.Context, req *QueryMissingEventsRequest, res *QueryMissingEventsResponse) error
// Query whether a server is allowed to see an event // Query whether a server is allowed to see an event
QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error) QueryServerAllowedToSeeEvent(ctx context.Context, serverName spec.ServerName, eventID string, roomID string) (allowed bool, err error)
QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error)
QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error)
PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error PerformInboundPeek(ctx context.Context, req *PerformInboundPeekRequest, res *PerformInboundPeekResponse) error
HandleInvite(ctx context.Context, event *types.HeaderedEvent) error HandleInvite(ctx context.Context, event *types.HeaderedEvent) error

View File

@ -132,6 +132,8 @@ type QueryMembershipForUserResponse struct {
// True if the user asked to forget this room. // True if the user asked to forget this room.
IsRoomForgotten bool `json:"is_room_forgotten"` IsRoomForgotten bool `json:"is_room_forgotten"`
RoomExists bool `json:"room_exists"` RoomExists bool `json:"room_exists"`
// The sender ID of the user in the room, if it exists
SenderID *spec.SenderID
} }
// QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom // QueryMembershipsForRoomRequest is a request to QueryMembershipsForRoom
@ -289,16 +291,6 @@ type QuerySharedUsersResponse struct {
UserIDsToCount map[string]int UserIDsToCount map[string]int
} }
type QueryRoomsForUserRequest struct {
UserID string
// The desired membership of the user. If this is the empty string then no rooms are returned.
WantMembership string
}
type QueryRoomsForUserResponse struct {
RoomIDs []string
}
type QueryBulkStateContentRequest struct { type QueryBulkStateContentRequest struct {
// Returns state events in these rooms // Returns state events in these rooms
RoomIDs []string RoomIDs []string
@ -414,22 +406,6 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error {
return nil return nil
} }
// QueryMembershipAtEventRequest requests the membership event for a user
// for a list of eventIDs.
type QueryMembershipAtEventRequest struct {
RoomID string
EventIDs []string
UserID string
}
// QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest.
type QueryMembershipAtEventResponse struct {
// Membership is a map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
Membership map[string]*types.HeaderedEvent `json:"membership"`
}
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
// a room with anymore. This is used to cleanup stale device list entries, where we would // a room with anymore. This is used to cleanup stale device list entries, where we would
// otherwise keep on trying to get device lists. // otherwise keep on trying to get device lists.

View File

@ -161,12 +161,12 @@ func (r *Admin) PerformAdminEvacuateUser(
return nil, fmt.Errorf("can only evacuate local users using this endpoint") return nil, fmt.Errorf("can only evacuate local users using this endpoint")
} }
roomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Join) roomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Join)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, userID, spec.Invite) inviteRoomIDs, err := r.DB.GetRoomsByMembership(ctx, *fullUserID, spec.Invite)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return nil, err return nil, err
} }

View File

@ -230,6 +230,33 @@ func (r *Queryer) QueryMembershipForSenderID(
senderID spec.SenderID, senderID spec.SenderID,
response *api.QueryMembershipForUserResponse, response *api.QueryMembershipForUserResponse,
) error { ) error {
return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response)
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response)
}
// Query membership information for provided sender ID and room ID
//
// If sender ID is nil, then act as if the provided sender is not a member of the room.
func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
response.SenderID = senderID
info, err := r.DB.RoomInfo(ctx, roomID.String()) info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil { if err != nil {
return err return err
@ -240,7 +267,11 @@ func (r *Queryer) QueryMembershipForSenderID(
} }
response.RoomExists = true response.RoomExists = true
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID) if senderID == nil {
return nil
}
membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID)
if err != nil { if err != nil {
return err return err
} }
@ -268,70 +299,55 @@ func (r *Queryer) QueryMembershipForSenderID(
return err return err
} }
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *Queryer) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
return r.QueryMembershipForSenderID(ctx, *roomID, *senderID, response)
}
// QueryMembershipAtEvent returns the known memberships at a given event. // QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned // If the state before an event is not known, an empty list will be returned
// for that event instead. // for that event instead.
//
// Returned map from eventID to membership event. Events that
// do not have known state will return a nil event, resulting in a "leave" membership
// when calculating history visibility.
func (r *Queryer) QueryMembershipAtEvent( func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context, ctx context.Context,
request *api.QueryMembershipAtEventRequest, roomID spec.RoomID,
response *api.QueryMembershipAtEventResponse, eventIDs []string,
) error { senderID spec.SenderID,
response.Membership = make(map[string]*types.HeaderedEvent) ) (map[string]*types.HeaderedEvent, error) {
info, err := r.DB.RoomInfo(ctx, roomID.String())
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil { if err != nil {
return fmt.Errorf("unable to get roomInfo: %w", err) return nil, fmt.Errorf("unable to get roomInfo: %w", err)
} }
if info == nil { if info == nil {
return fmt.Errorf("no roomInfo found") return nil, fmt.Errorf("no roomInfo found")
} }
// get the users stateKeyNID // get the users stateKeyNID
stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID}) stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil { if err != nil {
return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err) return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
} }
if _, ok := stateKeyNIDs[request.UserID]; !ok { if _, ok := stateKeyNIDs[string(senderID)]; !ok {
return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
} }
response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
switch err { switch err {
case nil: case nil:
return nil return eventIDMembershipMap, nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default: default:
return err return eventIDMembershipMap, err
} }
response.Membership = make(map[string]*types.HeaderedEvent) eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
if err != nil { if err != nil {
return fmt.Errorf("unable to get state before event: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
} }
// If we only have one or less state entries, we can short circuit the below // If we only have one or less state entries, we can short circuit the below
// loop and avoid hitting the database // loop and avoid hitting the database
allStateEventNIDs := make(map[types.EventNID]types.StateEntry) allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry := stateEntries[eventID] stateEntry := stateEntries[eventID]
for _, s := range stateEntry { for _, s := range stateEntry {
allStateEventNIDs[s.EventNID] = s allStateEventNIDs[s.EventNID] = s
@ -344,10 +360,10 @@ func (r *Queryer) QueryMembershipAtEvent(
} }
var memberships []types.Event var memberships []types.Event
for _, eventID := range request.EventIDs { for _, eventID := range eventIDs {
stateEntry, ok := stateEntries[eventID] stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 { if !ok || len(stateEntry) == 0 {
response.Membership[eventID] = nil eventIDMembershipMap[eventID] = nil
continue continue
} }
@ -361,7 +377,7 @@ func (r *Queryer) QueryMembershipAtEvent(
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false) memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
} }
if err != nil { if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err) return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
} }
// Iterate over all membership events we got. Given we only query the membership for // Iterate over all membership events we got. Given we only query the membership for
@ -369,13 +385,13 @@ func (r *Queryer) QueryMembershipAtEvent(
// a given event, overwrite any other existing membership events. // a given event, overwrite any other existing membership events.
for i := range memberships { for i := range memberships {
ev := memberships[i] ev := memberships[i]
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU} eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
} }
} }
} }
return nil return eventIDMembershipMap, nil
} }
// QueryMembershipsForRoom implements api.RoomserverInternalAPI // QueryMembershipsForRoom implements api.RoomserverInternalAPI
@ -830,13 +846,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
return nil return nil
} }
func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership) roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
if err != nil { if err != nil {
return err return nil, err
} }
res.RoomIDs = roomIDs roomIDs := make([]spec.RoomID, len(roomIDStrs))
return nil for i, roomIDStr := range roomIDStrs {
roomID, err := spec.NewRoomID(roomIDStr)
if err != nil {
return nil, err
}
roomIDs[i] = *roomID
}
return roomIDs, nil
} }
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error { func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
@ -879,7 +902,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
} }
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") parsedUserID, err := spec.NewUserID(req.UserID, true)
if err != nil {
return err
}
roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
if err != nil { if err != nil {
return err return err
} }

View File

@ -158,7 +158,7 @@ type Database interface {
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error) GetStateEventsWithEventType(ctx context.Context, roomID, evType string) ([]*types.HeaderedEvent, error)
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error)
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)

View File

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectUserNIDsStmt *sql.Stmt selectUserNIDsStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
} }
func CreateUserRoomKeysTable(db *sql.DB) error { func CreateUserRoomKeysTable(db *sql.DB) error {
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectUserNIDsStmt, selectUserNIDsSQL}, {&s.selectUserNIDsStmt, selectUserNIDsSQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -150,3 +154,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View File

@ -1347,7 +1347,7 @@ func (d *Database) GetStateEventsWithEventType(ctx context.Context, roomID, evTy
} }
// GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key). // GetRoomsByMembership returns a list of room IDs matching the provided membership and user ID (as state_key).
func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership string) ([]string, error) { func (d *Database) GetRoomsByMembership(ctx context.Context, userID spec.UserID, membership string) ([]string, error) {
var membershipState tables.MembershipState var membershipState tables.MembershipState
switch membership { switch membership {
case "join": case "join":
@ -1361,17 +1361,73 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
default: default:
return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership) return nil, fmt.Errorf("GetRoomsByMembership: invalid membership %s", membership)
} }
stateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID)
// Convert provided user ID to NID
userNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, nil, userID.String())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} else {
return nil, fmt.Errorf("SelectEventStateKeyNID: cannot map user ID to state key NIDs: %w", err)
} }
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
} }
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
// Use this NID to fetch all associated room keys (for pseudo ID rooms)
roomKeyMap, err := d.UserRoomKeyTable.SelectAllPublicKeysForUser(ctx, nil, userNID)
if err != nil {
if err == sql.ErrNoRows {
roomKeyMap = map[types.RoomNID]ed25519.PublicKey{}
} else {
return nil, fmt.Errorf("SelectAllPublicKeysForUser: could not select user room public keys for user: %w", err)
}
}
var eventStateKeyNIDs []types.EventStateKeyNID
// If there are room keys (i.e. this user is in pseudo ID rooms), then gather the appropriate NIDs
if len(roomKeyMap) != 0 {
// Convert keys to string representation
userRoomKeys := make([]string, len(roomKeyMap))
i := 0
for _, key := range roomKeyMap {
userRoomKeys[i] = spec.Base64Bytes(key).Encode()
i += 1
}
// Convert the string representation to its NID
pseudoIDStateKeys, sqlErr := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userRoomKeys)
if sqlErr != nil {
if sqlErr == sql.ErrNoRows {
pseudoIDStateKeys = map[string]types.EventStateKeyNID{}
} else {
return nil, fmt.Errorf("BulkSelectEventStateKeyNID: could not select state keys for public room keys: %w", err)
}
}
// Collect all NIDs together
eventStateKeyNIDs = make([]types.EventStateKeyNID, len(pseudoIDStateKeys)+1)
eventStateKeyNIDs[0] = userNID
i = 1
for _, nid := range pseudoIDStateKeys {
eventStateKeyNIDs[i] = nid
i += 1
}
} else {
// If there are no room keys (so no pseudo ID rooms), we only need to care about the user ID NID.
eventStateKeyNIDs = []types.EventStateKeyNID{userNID}
}
// Fetch rooms that match membership for each NID
roomNIDs := []types.RoomNID{}
for _, nid := range eventStateKeyNIDs {
var roomNIDsChunk []types.RoomNID
roomNIDsChunk, err = d.MembershipTable.SelectRoomsWithMembership(ctx, nil, nid, membershipState)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
} }
roomNIDs = append(roomNIDs, roomNIDsChunk...)
}
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs) roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil { if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err) return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)

View File

@ -56,12 +56,15 @@ const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_use
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)` const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
const selectAllUserRoomPublicKeyForUserSQL = `SELECT room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1`
type userRoomKeysStatements struct { type userRoomKeysStatements struct {
db *sql.DB db *sql.DB
insertUserRoomPrivateKeyStmt *sql.Stmt insertUserRoomPrivateKeyStmt *sql.Stmt
insertUserRoomPublicKeyStmt *sql.Stmt insertUserRoomPublicKeyStmt *sql.Stmt
selectUserRoomKeyStmt *sql.Stmt selectUserRoomKeyStmt *sql.Stmt
selectUserRoomPublicKeyStmt *sql.Stmt selectUserRoomPublicKeyStmt *sql.Stmt
selectAllUserRoomPublicKeysForUser *sql.Stmt
//selectUserNIDsStmt *sql.Stmt //prepared at runtime //selectUserNIDsStmt *sql.Stmt //prepared at runtime
} }
@ -77,6 +80,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL}, {&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL}, {&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL}, {&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
{&s.selectAllUserRoomPublicKeysForUser, selectAllUserRoomPublicKeyForUserSQL},
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime //{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
}.Prepare(db) }.Prepare(db)
} }
@ -165,3 +169,24 @@ func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sq
} }
return result, rows.Err() return result, rows.Err()
} }
func (s *userRoomKeysStatements) SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error) {
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectAllUserRoomPublicKeysForUser)
rows, err := stmt.QueryContext(ctx, userNID)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
resultMap := make(map[types.RoomNID]ed25519.PublicKey)
var roomNID types.RoomNID
var pubkey ed25519.PublicKey
for rows.Next() {
if err = rows.Scan(&roomNID, &pubkey); err != nil {
return nil, err
}
resultMap[roomNID] = pubkey
}
return resultMap, err
}

View File

@ -198,6 +198,8 @@ type UserRoomKeys interface {
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair. // BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
// If a senderKey can't be found, it is omitted in the result. // If a senderKey can't be found, it is omitted in the result.
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
// SelectAllPublicKeysForUser returns all known public keys for a user. Returns a map from room NID -> public key
SelectAllPublicKeysForUser(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID) (map[types.RoomNID]ed25519.PublicKey, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.

View File

@ -16,6 +16,7 @@ package internal
import ( import (
"context" "context"
"fmt"
"math" "math"
"time" "time"
@ -101,13 +102,15 @@ func (ev eventVisibility) allowed() (allowed bool) {
// ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents. // ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents.
// Returns the filtered events and an error, if any. // Returns the filtered events and an error, if any.
//
// This function assumes that all provided events are from the same room.
func ApplyHistoryVisibilityFilter( func ApplyHistoryVisibilityFilter(
ctx context.Context, ctx context.Context,
syncDB storage.DatabaseTransaction, syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{}, alwaysIncludeEventIDs map[string]struct{},
userID, endpoint string, userID spec.UserID, endpoint string,
) ([]*types.HeaderedEvent, error) { ) ([]*types.HeaderedEvent, error) {
if len(events) == 0 { if len(events) == 0 {
return events, nil return events, nil
@ -115,15 +118,29 @@ func ApplyHistoryVisibilityFilter(
start := time.Now() start := time.Now()
// try to get the current membership of the user // try to get the current membership of the user
membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64) membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID.String(), math.MaxInt64)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Get the mapping from eventID -> eventVisibility // Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*types.HeaderedEvent, 0, len(events)) eventsFiltered := make([]*types.HeaderedEvent, 0, len(events))
visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) firstEvRoomID, err := spec.NewRoomID(events[0].RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *firstEvRoomID, userID)
if err != nil {
return nil, err
}
visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, *firstEvRoomID)
for _, ev := range events { for _, ev := range events {
// Validate same room assumption
if ev.RoomID() != firstEvRoomID.String() {
return nil, fmt.Errorf("events from different rooms supplied to ApplyHistoryVisibilityFilter")
}
evVis := visibilities[ev.EventID()] evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses // Always include specific state events for /sync responses
@ -133,23 +150,15 @@ func ApplyHistoryVisibilityFilter(
continue continue
} }
} }
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
user, err := spec.NewUserID(userID, true) // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if err != nil { if senderID != nil {
return nil, err
}
roomID, err := spec.NewRoomID(ev.RoomID())
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
if err == nil && senderID != nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) { if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) {
eventsFiltered = append(eventsFiltered, ev) eventsFiltered = append(eventsFiltered, ev)
continue continue
} }
} }
// Always allow history evVis events on boundaries. This is done // Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive // by setting the effective evVis to the least restrictive
// of the old vs new. // of the old vs new.
@ -178,13 +187,13 @@ func ApplyHistoryVisibilityFilter(
} }
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership // visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
// of `userID` at the given event. // of `senderID` at the given event. If provided sender ID is nil, assume that membership is Leave
// Returns an error if the roomserver can't calculate the memberships. // Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents( func visibilityForEvents(
ctx context.Context, ctx context.Context,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent, events []*types.HeaderedEvent,
userID, roomID string, senderID *spec.SenderID, roomID spec.RoomID,
) map[string]eventVisibility { ) map[string]eventVisibility {
eventIDs := make([]string, len(events)) eventIDs := make([]string, len(events))
for i := range events { for i := range events {
@ -194,16 +203,14 @@ func visibilityForEvents(
result := make(map[string]eventVisibility, len(eventIDs)) result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs // get the membership events for all eventIDs
membershipResp := &api.QueryMembershipAtEventResponse{} var err error
membershipEvents := make(map[string]*types.HeaderedEvent)
err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ if senderID != nil {
RoomID: roomID, membershipEvents, err = rsAPI.QueryMembershipAtEvent(ctx, roomID, eventIDs, *senderID)
EventIDs: eventIDs,
UserID: userID,
}, membershipResp)
if err != nil { if err != nil {
logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'") logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
} }
}
// Create a map from eventID -> eventVisibility // Create a map from eventID -> eventVisibility
for _, event := range events { for _, event := range events {
@ -212,7 +219,7 @@ func visibilityForEvents(
membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility, visibility: event.Visibility,
} }
ev, ok := membershipResp.Membership[eventID] ev, ok := membershipEvents[eventID]
if !ok || ev == nil { if !ok || ev == nil {
result[eventID] = vis result[eventID] = vis
continue continue

View File

@ -69,8 +69,8 @@ func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spe
} }
// QueryRoomsForUser retrieves a list of room IDs matching the given query. // QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error { func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
return nil return nil, nil
} }
// QueryBulkStateContent does a bulk query for state event content in the given rooms. // QueryBulkStateContent does a bulk query for state event content in the given rooms.

View File

@ -138,7 +138,7 @@ func Context(
// verify the user is allowed to see the context for this room/event // verify the user is allowed to see the context for this room/event
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, *userID, "context")
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -176,7 +176,7 @@ func Context(
} }
startTime = time.Now() startTime = time.Now()
eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID) eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, *userID)
if err != nil { if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter") logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{ return util.JSONResponse{
@ -257,7 +257,7 @@ func Context(
func applyHistoryVisibilityOnContextEvents( func applyHistoryVisibilityOnContextEvents(
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI, ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*rstypes.HeaderedEvent, eventsBefore, eventsAfter []*rstypes.HeaderedEvent,
userID string, userID spec.UserID,
) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) { ) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore)) eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter)) eventIDsAfter := make(map[string]struct{}, len(eventsAfter))

View File

@ -37,7 +37,7 @@ import (
func GetEvent( func GetEvent(
req *http.Request, req *http.Request,
device *userapi.Device, device *userapi.Device,
roomID string, rawRoomID string,
eventID string, eventID string,
cfg *config.SyncAPI, cfg *config.SyncAPI,
syncDB storage.Database, syncDB storage.Database,
@ -47,7 +47,7 @@ func GetEvent(
db, err := syncDB.NewDatabaseTransaction(ctx) db, err := syncDB.NewDatabaseTransaction(ctx)
logger := util.GetLogger(ctx).WithFields(logrus.Fields{ logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": eventID, "event_id": eventID,
"room_id": roomID, "room_id": rawRoomID,
}) })
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed") logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed")
@ -57,6 +57,14 @@ func GetEvent(
} }
} }
roomID, err := spec.NewRoomID(rawRoomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
events, err := db.Events(ctx, []string{eventID}) events, err := db.Events(ctx, []string{eventID})
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.Events failed") logger.WithError(err).Error("GetEvent: syncDB.Events failed")
@ -76,13 +84,22 @@ func GetEvent(
} }
// If the request is coming from an appservice, get the user from the request // If the request is coming from an appservice, get the user from the request
userID := device.UserID rawUserID := device.UserID
if asUserID := req.FormValue("user_id"); device.AppserviceID != "" && asUserID != "" { if asUserID := req.FormValue("user_id"); device.AppserviceID != "" && asUserID != "" {
userID = asUserID rawUserID = asUserID
}
userID, err := spec.NewUserID(rawUserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("invalid device.UserID")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
} }
// Apply history visibility to determine if the user is allowed to view the event // Apply history visibility to determine if the user is allowed to view the event
events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, userID, "event") events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, *userID, "event")
if err != nil { if err != nil {
logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed") logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed")
return util.JSONResponse{ return util.JSONResponse{
@ -101,18 +118,14 @@ func GetEvent(
} }
} }
sender := spec.UserID{} senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID())
validRoomID, err := spec.NewRoomID(roomID) if err != nil || senderUserID == nil {
if err != nil { util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusInternalServerError,
JSON: spec.BadJSON("roomID is invalid"), JSON: spec.Unknown("internal server error"),
} }
} }
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
if err == nil && senderUserID != nil {
sender = *senderUserID
}
sk := events[0].StateKey() sk := events[0].StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
@ -131,6 +144,6 @@ func GetEvent(
} }
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusOK, Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk), JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, *senderUserID, sk),
} }
} }

View File

@ -50,6 +50,7 @@ type messagesReq struct {
from *types.TopologyToken from *types.TopologyToken
to *types.TopologyToken to *types.TopologyToken
device *userapi.Device device *userapi.Device
deviceUserID spec.UserID
wasToProvided bool wasToProvided bool
backwardOrdering bool backwardOrdering bool
filter *synctypes.RoomEventFilter filter *synctypes.RoomEventFilter
@ -77,6 +78,15 @@ func OnIncomingMessagesRequest(
) util.JSONResponse { ) util.JSONResponse {
var err error var err error
deviceUserID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
// NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we // NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
// expect to be able to write to the database in response to a /messages // expect to be able to write to the database in response to a /messages
// request that requires backfilling from the roomserver or federation. // request that requires backfilling from the roomserver or federation.
@ -240,6 +250,7 @@ func OnIncomingMessagesRequest(
filter: filter, filter: filter,
backwardOrdering: backwardOrdering, backwardOrdering: backwardOrdering,
device: device, device: device,
deviceUserID: *deviceUserID,
} }
clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI) clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
@ -359,7 +370,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
// Apply room history visibility filter // Apply room history visibility filter
startTime := time.Now() startTime := time.Now()
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.deviceUserID, "messages")
if err != nil { if err != nil {
return []synctypes.ClientEvent{}, *r.from, *r.to, nil return []synctypes.ClientEvent{}, *r.from, *r.to, nil
} }

View File

@ -43,9 +43,25 @@ func Relations(
req *http.Request, device *userapi.Device, req *http.Request, device *userapi.Device,
syncDB storage.Database, syncDB storage.Database,
rsAPI api.SyncRoomserverAPI, rsAPI api.SyncRoomserverAPI,
roomID, eventID, relType, eventType string, rawRoomID, eventID, relType, eventType string,
) util.JSONResponse { ) util.JSONResponse {
var err error roomID, err := spec.NewRoomID(rawRoomID)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("invalid room ID"),
}
}
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.Unknown("internal server error"),
}
}
var from, to types.StreamPosition var from, to types.StreamPosition
var limit int var limit int
dir := req.URL.Query().Get("dir") dir := req.URL.Query().Get("dir")
@ -93,7 +109,7 @@ func Relations(
} }
var events []types.StreamEvent var events []types.StreamEvent
events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor( events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor(
req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit, req.Context(), roomID.String(), eventID, relType, eventType, from, to, dir == "b", limit,
) )
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
@ -105,12 +121,7 @@ func Relations(
} }
// Apply history visibility to the result events. // Apply history visibility to the result events.
filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations") filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, *userID, "relations")
if err != nil {
return util.ErrorResponse(err)
}
validRoomID, err := spec.NewRoomID(roomID)
if err != nil { if err != nil {
return util.ErrorResponse(err) return util.ErrorResponse(err)
} }
@ -120,14 +131,14 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents)) res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents { for _, event := range filteredEvents {
sender := spec.UserID{} sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID()) userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
if err == nil && userID != nil { if err == nil && userID != nil {
sender = *userID sender = *userID
} }
sk := event.StateKey() sk := event.StateKey()
if sk != nil && *sk != "" { if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey())) skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil { if err == nil && skUserID != nil {
skString := skUserID.String() skString := skUserID.String()
sk = &skString sk = &skString

View File

@ -562,8 +562,13 @@ func applyHistoryVisibilityFilter(
} }
} }
parsedUserID, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
startTime := time.Now() startTime := time.Now()
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, *parsedUserID, "sync")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -44,6 +44,11 @@ func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spe
return spec.NewUserID(string(senderID), true) return spec.NewUserID(string(senderID), true)
} }
func (s *syncRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
senderID := spec.SenderID(userID.String())
return &senderID, nil
}
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room var room *test.Room
for _, r := range s.rooms { for _, r := range s.rooms {
@ -74,8 +79,13 @@ func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsa
return nil return nil
} }
func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error { func (s *syncRoomserverAPI) QueryMembershipAtEvent(
return nil ctx context.Context,
roomID spec.RoomID,
eventIDs []string,
senderID spec.SenderID,
) (map[string]*rstypes.HeaderedEvent, error) {
return map[string]*rstypes.HeaderedEvent{}, nil
} }
type syncUserAPI struct { type syncUserAPI struct {