mirror of
https://github.com/1f349/dendrite.git
synced 2024-11-25 05:01:41 +00:00
Merge SenderID & Per Room User Key work (#3109)
This commit is contained in:
parent
7a2e325d10
commit
e4665979bf
@ -181,7 +181,7 @@ func (s *OutputRoomEventConsumer) sendEvents(
|
|||||||
// Create the transaction body.
|
// Create the transaction body.
|
||||||
transaction, err := json.Marshal(
|
transaction, err := json.Marshal(
|
||||||
ApplicationServiceTransaction{
|
ApplicationServiceTransaction{
|
||||||
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
@ -236,7 +236,11 @@ func (s *appserviceState) backoffAndPause(err error) error {
|
|||||||
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
// TODO: This should be cached, see https://github.com/matrix-org/dendrite/issues/1682
|
||||||
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
|
func (s *OutputRoomEventConsumer) appserviceIsInterestedInEvent(ctx context.Context, event *types.HeaderedEvent, appservice *config.ApplicationService) bool {
|
||||||
user := ""
|
user := ""
|
||||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = userID.String()
|
user = userID.String()
|
||||||
}
|
}
|
||||||
|
@ -233,11 +233,18 @@ func RemoveLocalAlias(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomIDRes.RoomID, *userID)
|
validRoomID, err := spec.NewRoomID(roomIDRes.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusInternalServerError,
|
Code: http.StatusNotFound,
|
||||||
JSON: spec.InternalServerError{Err: "Could not find SenderID for this device"},
|
JSON: spec.NotFound("The alias does not exist."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
deviceSenderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *userID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusNotFound,
|
||||||
|
JSON: spec.NotFound("The alias does not exist."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,7 +328,15 @@ func SetVisibility(
|
|||||||
JSON: spec.BadJSON("userID for this device is invalid"),
|
JSON: spec.BadJSON("userID for this device is invalid"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("RoomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
|
@ -64,7 +64,14 @@ func SendBan(
|
|||||||
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
|
JSON: spec.Forbidden("You don't have permission to ban this user, bad userID"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("RoomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
@ -155,7 +162,14 @@ func SendKick(
|
|||||||
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
JSON: spec.Forbidden("You don't have permission to kick this user, bad userID"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("RoomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
@ -428,7 +442,11 @@ func buildMembershipEvent(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -437,7 +455,7 @@ func buildMembershipEvent(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *targetID)
|
targetSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *targetID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -368,7 +368,11 @@ func buildMembershipEvents(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
for _, roomID := range roomIDs {
|
for _, roomID := range roomIDs {
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,14 @@ func SendRedaction(
|
|||||||
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
JSON: spec.Forbidden("userID doesn't have power level to redact"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), roomID, *deviceUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("RoomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderID, queryErr := rsAPI.QuerySenderIDForUser(req.Context(), *validRoomID, *deviceUserID)
|
||||||
if queryErr != nil {
|
if queryErr != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
@ -103,8 +110,8 @@ func SendRedaction(
|
|||||||
JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."),
|
JSON: spec.Forbidden("You don't have permission to redact this event, no power_levels event in this room."),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pl, err := plEvent.PowerLevels()
|
pl, plErr := plEvent.PowerLevels()
|
||||||
if err != nil {
|
if plErr != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: 403,
|
Code: 403,
|
||||||
JSON: spec.Forbidden(
|
JSON: spec.Forbidden(
|
||||||
@ -134,7 +141,7 @@ func SendRedaction(
|
|||||||
Type: spec.MRoomRedaction,
|
Type: spec.MRoomRedaction,
|
||||||
Redacts: eventID,
|
Redacts: eventID,
|
||||||
}
|
}
|
||||||
err := proto.SetContent(r)
|
err = proto.SetContent(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed")
|
util.GetLogger(req.Context()).WithError(err).Error("proto.SetContent failed")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
|
@ -273,7 +273,14 @@ func generateSendEvent(
|
|||||||
JSON: spec.BadJSON("Bad userID"),
|
JSON: spec.BadJSON("Bad userID"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("RoomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusNotFound,
|
Code: http.StatusNotFound,
|
||||||
@ -344,8 +351,8 @@ func generateSendEvent(
|
|||||||
stateEvents[i] = queryRes.StateEvents[i].PDU
|
stateEvents[i] = queryRes.StateEvents[i].PDU
|
||||||
}
|
}
|
||||||
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
|
provider := gomatrixserverlib.NewAuthEvents(gomatrixserverlib.ToPDUs(stateEvents))
|
||||||
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(e.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, *validRoomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, &util.JSONResponse{
|
return nil, &util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
|
@ -150,7 +150,7 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||||||
for _, ev := range stateRes.StateEvents {
|
for _, ev := range stateRes.StateEvents {
|
||||||
stateEvents = append(
|
stateEvents = append(
|
||||||
stateEvents,
|
stateEvents,
|
||||||
synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, ev),
|
}, ev),
|
||||||
)
|
)
|
||||||
@ -173,14 +173,19 @@ func OnIncomingStateRequest(ctx context.Context, device *userapi.Device, rsAPI a
|
|||||||
}
|
}
|
||||||
for _, ev := range stateAfterRes.StateEvents {
|
for _, ev := range stateAfterRes.StateEvents {
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), ev.SenderID())
|
evRoomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Error("Event roomID is invalid")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, ev.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := ev.StateKey()
|
sk := ev.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*ev.StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
@ -367,7 +372,7 @@ func OnIncomingStateTypeRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateEvent := stateEventInStateResp{
|
stateEvent := stateEventInStateResp{
|
||||||
ClientEvent: synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
ClientEvent: synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, event),
|
}, event),
|
||||||
}
|
}
|
||||||
|
@ -359,7 +359,11 @@ func emit3PIDInviteEvent(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sender, err := rsAPI.QuerySenderIDForUser(ctx, roomID, *userID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sender, err := rsAPI.QuerySenderIDForUser(ctx, *validRoomID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -11,11 +11,13 @@ import (
|
|||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/dendrite/setup"
|
"github.com/matrix-org/dendrite/setup"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
@ -66,10 +68,14 @@ func main() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
natsInstance := &jetstream.NATSInstance{}
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm,
|
||||||
|
natsInstance, caching.NewRistrettoCache(128*1024*1024, time.Hour, true), false)
|
||||||
|
|
||||||
roomInfo := &types.RoomInfo{
|
roomInfo := &types.RoomInfo{
|
||||||
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
|
RoomVersion: gomatrixserverlib.RoomVersion(*roomVersion),
|
||||||
}
|
}
|
||||||
stateres := state.NewStateResolution(roomserverDB, roomInfo)
|
stateres := state.NewStateResolution(roomserverDB, roomInfo, rsAPI)
|
||||||
|
|
||||||
if *difference {
|
if *difference {
|
||||||
if len(snapshotNIDs) != 2 {
|
if len(snapshotNIDs) != 2 {
|
||||||
@ -183,8 +189,8 @@ func main() {
|
|||||||
fmt.Println("Resolving state")
|
fmt.Println("Resolving state")
|
||||||
var resolved Events
|
var resolved Events
|
||||||
resolved, err = gomatrixserverlib.ResolveConflicts(
|
resolved, err = gomatrixserverlib.ResolveConflicts(
|
||||||
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
gomatrixserverlib.RoomVersion(*roomVersion), events, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return roomserverDB.GetUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -36,11 +36,11 @@ type fedRoomserverAPI struct {
|
|||||||
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
queryRoomsForUser func(ctx context.Context, req *rsapi.QueryRoomsForUserRequest, res *rsapi.QueryRoomsForUserResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (f *fedRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
func (f *fedRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
return spec.SenderID(userID.String()), nil
|
return spec.SenderID(userID.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -154,14 +154,9 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, roomID, *user)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
joinInput := gomatrixserverlib.PerformJoinInput{
|
joinInput := gomatrixserverlib.PerformJoinInput{
|
||||||
UserID: user,
|
UserID: user,
|
||||||
SenderID: senderID,
|
|
||||||
RoomID: room,
|
RoomID: room,
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
Content: content,
|
Content: content,
|
||||||
@ -169,12 +164,20 @@ func (r *FederationInternalAPI) performJoinUsingServer(
|
|||||||
PrivateKey: r.cfg.Matrix.PrivateKey,
|
PrivateKey: r.cfg.Matrix.PrivateKey,
|
||||||
KeyID: r.cfg.Matrix.KeyID,
|
KeyID: r.cfg.Matrix.KeyID,
|
||||||
KeyRing: r.keyRing,
|
KeyRing: r.keyRing,
|
||||||
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
EventProvider: federatedEventProvider(ctx, r.federation, r.keyRing, user.Domain(), serverName, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}),
|
}),
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
|
SenderIDCreator: func(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (spec.SenderID, error) {
|
||||||
|
key, keyErr := r.rsAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||||
|
if keyErr != nil {
|
||||||
|
return "", keyErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
response, joinErr := gomatrixserverlib.PerformJoin(ctx, r, joinInput)
|
||||||
|
|
||||||
@ -368,7 +371,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
|||||||
|
|
||||||
// authenticate the state returned (check its auth events etc)
|
// authenticate the state returned (check its auth events etc)
|
||||||
// the equivalent of CheckSendJoinResponse()
|
// the equivalent of CheckSendJoinResponse()
|
||||||
userIDProvider := func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
userIDProvider := func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}
|
}
|
||||||
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(
|
||||||
@ -459,7 +462,11 @@ func (r *FederationInternalAPI) PerformLeave(
|
|||||||
|
|
||||||
// Set all the fields to be what they should be, this should be a no-op
|
// Set all the fields to be what they should be, this should be a no-op
|
||||||
// but it's possible that the remote server returned us something "odd"
|
// but it's possible that the remote server returned us something "odd"
|
||||||
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, request.RoomID, *userID)
|
roomID, err := spec.NewRoomID(request.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
senderID, err := r.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -527,7 +534,11 @@ func (r *FederationInternalAPI) SendInvite(
|
|||||||
event gomatrixserverlib.PDU,
|
event gomatrixserverlib.PDU,
|
||||||
strippedState []gomatrixserverlib.InviteStrippedState,
|
strippedState []gomatrixserverlib.InviteStrippedState,
|
||||||
) (gomatrixserverlib.PDU, error) {
|
) (gomatrixserverlib.PDU, error) {
|
||||||
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inviter, err := r.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,7 @@ func InviteV2(
|
|||||||
StateQuerier: rsAPI.StateQuerier(),
|
StateQuerier: rsAPI.StateQuerier(),
|
||||||
InviteEvent: inviteReq.Event(),
|
InviteEvent: inviteReq.Event(),
|
||||||
StrippedState: inviteReq.InviteRoomState(),
|
StrippedState: inviteReq.InviteRoomState(),
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -188,7 +188,7 @@ func InviteV1(
|
|||||||
StateQuerier: rsAPI.StateQuerier(),
|
StateQuerier: rsAPI.StateQuerier(),
|
||||||
InviteEvent: event,
|
InviteEvent: event,
|
||||||
StrippedState: strippedState,
|
StrippedState: strippedState,
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,7 @@ func MakeJoin(
|
|||||||
Roomserver: rsAPI,
|
Roomserver: rsAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
|
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
|
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
@ -118,7 +118,7 @@ func MakeJoin(
|
|||||||
LocalServerName: cfg.Matrix.ServerName,
|
LocalServerName: cfg.Matrix.ServerName,
|
||||||
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
||||||
RoomQuerier: &roomQuerier,
|
RoomQuerier: &roomQuerier,
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
||||||
},
|
},
|
||||||
BuildEventTemplate: createJoinTemplate,
|
BuildEventTemplate: createJoinTemplate,
|
||||||
@ -215,7 +215,7 @@ func SendJoin(
|
|||||||
PrivateKey: cfg.Matrix.PrivateKey,
|
PrivateKey: cfg.Matrix.PrivateKey,
|
||||||
Verifier: keys,
|
Verifier: keys,
|
||||||
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
|
MembershipQuerier: &api.MembershipQuerier{Roomserver: rsAPI},
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ func MakeLeave(
|
|||||||
return event, stateEvents, nil
|
return event, stateEvents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID.String(), userID)
|
senderID, err := rsAPI.QuerySenderIDForUser(httpReq.Context(), roomID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
|
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QuerySenderIDForUser failed")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
@ -105,7 +105,7 @@ func MakeLeave(
|
|||||||
LocalServerName: cfg.Matrix.ServerName,
|
LocalServerName: cfg.Matrix.ServerName,
|
||||||
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
LocalServerInRoom: res.RoomExists && res.IsInRoom,
|
||||||
BuildEventTemplate: createLeaveTemplate,
|
BuildEventTemplate: createLeaveTemplate,
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, senderID)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -236,7 +236,14 @@ func SendLeave(
|
|||||||
// Check that the sender belongs to the server that is sending us
|
// Check that the sender belongs to the server that is sending us
|
||||||
// the request. By this point we've already asserted that the sender
|
// the request. By this point we've already asserted that the sender
|
||||||
// and the state key are equal so we don't need to check both.
|
// and the state key are equal so we don't need to check both.
|
||||||
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("Room ID is invalid."),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sender, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusForbidden,
|
Code: http.StatusForbidden,
|
||||||
|
@ -140,7 +140,14 @@ func ExchangeThirdPartyInvite(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(proto.SenderID))
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("Invalid room ID"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(proto.SenderID))
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
@ -150,7 +157,7 @@ func ExchangeThirdPartyInvite(
|
|||||||
senderDomain := userID.Domain()
|
senderDomain := userID.Domain()
|
||||||
|
|
||||||
// Check that the state key is correct.
|
// Check that the state key is correct.
|
||||||
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), roomID, spec.SenderID(*proto.StateKey))
|
targetUserID, err := rsAPI.QueryUserIDForSender(httpReq.Context(), *validRoomID, spec.SenderID(*proto.StateKey))
|
||||||
if err != nil || targetUserID == nil {
|
if err != nil || targetUserID == nil {
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusBadRequest,
|
Code: http.StatusBadRequest,
|
||||||
|
10
go.mod
10
go.mod
@ -22,7 +22,7 @@ require (
|
|||||||
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
|
||||||
github.com/mattn/go-sqlite3 v1.14.16
|
github.com/mattn/go-sqlite3 v1.14.16
|
||||||
@ -42,11 +42,11 @@ require (
|
|||||||
github.com/uber/jaeger-lib v2.4.1+incompatible
|
github.com/uber/jaeger-lib v2.4.1+incompatible
|
||||||
github.com/yggdrasil-network/yggdrasil-go v0.4.6
|
github.com/yggdrasil-network/yggdrasil-go v0.4.6
|
||||||
go.uber.org/atomic v1.10.0
|
go.uber.org/atomic v1.10.0
|
||||||
golang.org/x/crypto v0.9.0
|
golang.org/x/crypto v0.10.0
|
||||||
golang.org/x/image v0.5.0
|
golang.org/x/image v0.5.0
|
||||||
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
|
golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e
|
||||||
golang.org/x/sync v0.1.0
|
golang.org/x/sync v0.1.0
|
||||||
golang.org/x/term v0.8.0
|
golang.org/x/term v0.9.0
|
||||||
gopkg.in/h2non/bimg.v1 v1.1.9
|
gopkg.in/h2non/bimg.v1 v1.1.9
|
||||||
gopkg.in/yaml.v2 v2.4.0
|
gopkg.in/yaml.v2 v2.4.0
|
||||||
gotest.tools/v3 v3.4.0
|
gotest.tools/v3 v3.4.0
|
||||||
@ -127,8 +127,8 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||||
golang.org/x/mod v0.8.0 // indirect
|
golang.org/x/mod v0.8.0 // indirect
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.10.0 // indirect
|
||||||
golang.org/x/sys v0.8.0 // indirect
|
golang.org/x/sys v0.9.0 // indirect
|
||||||
golang.org/x/text v0.9.0 // indirect
|
golang.org/x/text v0.10.0 // indirect
|
||||||
golang.org/x/time v0.3.0 // indirect
|
golang.org/x/time v0.3.0 // indirect
|
||||||
golang.org/x/tools v0.6.0 // indirect
|
golang.org/x/tools v0.6.0 // indirect
|
||||||
google.golang.org/protobuf v1.28.1 // indirect
|
google.golang.org/protobuf v1.28.1 // indirect
|
||||||
|
20
go.sum
20
go.sum
@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
|
|||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077 h1:AmKkAUjy9rZA2K+qHXm/O/dPEPnUYfRE2I6SL+Dj+LU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1 h1:k75Fy0iQVbDjvddip/x898+BdyopBNAfL1BMNx0awA0=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20230612110349-8e7766804077/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20230614140620-4dea2171c8f1/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
|
||||||
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
|
||||||
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
|
||||||
@ -511,8 +511,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
|||||||
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
|
||||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
@ -669,12 +669,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/term v0.8.0 h1:n5xxQn2i3PC0yLAbjTpNT85q/Kgzcr2gIoX9OrJUols=
|
golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
|
||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
@ -683,8 +683,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
|||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
@ -115,7 +115,11 @@ func ruleMatches(rule *Rule, kind Kind, event gomatrixserverlib.PDU, ec Evaluati
|
|||||||
|
|
||||||
case SenderKind:
|
case SenderKind:
|
||||||
userID := ""
|
userID := ""
|
||||||
sender, err := userIDForSender(event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
sender, err := userIDForSender(*validRoomID, event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
userID = sender.String()
|
userID = sender.String()
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ func TestRuleMatches(t *testing.T) {
|
|||||||
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
|
{"emptyOverride", OverrideKind, emptyRule, `{}`, true},
|
||||||
{"emptyContent", ContentKind, emptyRule, `{}`, false},
|
{"emptyContent", ContentKind, emptyRule, `{}`, false},
|
||||||
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
|
{"emptyRoom", RoomKind, emptyRule, `{}`, true},
|
||||||
{"emptySender", SenderKind, emptyRule, `{}`, true},
|
{"emptySender", SenderKind, emptyRule, `{"room_id":"!room:example.com"}`, true},
|
||||||
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
|
{"emptyUnderride", UnderrideKind, emptyRule, `{}`, true},
|
||||||
|
|
||||||
{"disabled", OverrideKind, Rule{}, `{}`, false},
|
{"disabled", OverrideKind, Rule{}, `{}`, false},
|
||||||
@ -90,8 +90,8 @@ func TestRuleMatches(t *testing.T) {
|
|||||||
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
|
{"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!room:example.com"}`, true},
|
||||||
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
|
{"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room:example.com"}, `{"room_id":"!otherroom:example.com"}`, false},
|
||||||
|
|
||||||
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com"}`, true},
|
{"senderMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@user:example.com","room_id":"!room:example.com"}`, true},
|
||||||
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com"}`, false},
|
{"senderNoMatch", SenderKind, Rule{Enabled: true, RuleID: "@user:example.com"}, `{"sender":"@otheruser:example.com","room_id":"!room:example.com"}`, false},
|
||||||
}
|
}
|
||||||
for _, tst := range tsts {
|
for _, tst := range tsts {
|
||||||
t.Run(tst.Name, func(t *testing.T) {
|
t.Run(tst.Name, func(t *testing.T) {
|
||||||
|
@ -167,7 +167,7 @@ func (t *TxnReq) ProcessTransaction(ctx context.Context) (*fclient.RespSend, *ut
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return t.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
|
||||||
|
@ -70,7 +70,7 @@ type FakeRsAPI struct {
|
|||||||
bannedFromRoom bool
|
bannedFromRoom bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (r *FakeRsAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ type testRoomserverAPI struct {
|
|||||||
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
|
queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (t *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +51,7 @@ type RoomserverInternalAPI interface {
|
|||||||
UserRoomserverAPI
|
UserRoomserverAPI
|
||||||
FederationRoomserverAPI
|
FederationRoomserverAPI
|
||||||
QuerySenderIDAPI
|
QuerySenderIDAPI
|
||||||
|
UserRoomPrivateKeyCreator
|
||||||
|
|
||||||
// needed to avoid chicken and egg scenario when setting up the
|
// needed to avoid chicken and egg scenario when setting up the
|
||||||
// interdependencies between the roomserver and other input APIs
|
// interdependencies between the roomserver and other input APIs
|
||||||
@ -67,7 +68,9 @@ type RoomserverInternalAPI interface {
|
|||||||
req *QueryAuthChainRequest,
|
req *QueryAuthChainRequest,
|
||||||
res *QueryAuthChainResponse,
|
res *QueryAuthChainResponse,
|
||||||
) error
|
) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserRoomPrivateKeyCreator interface {
|
||||||
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
// GetOrCreateUserRoomPrivateKey gets the user room key for the specified user. If no key exists yet, a new one is created.
|
||||||
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
GetOrCreateUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (ed25519.PrivateKey, error)
|
||||||
}
|
}
|
||||||
@ -81,8 +84,8 @@ type InputRoomEventsAPI interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type QuerySenderIDAPI interface {
|
type QuerySenderIDAPI interface {
|
||||||
QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
|
QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error)
|
||||||
QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query the latest events and state for a room from the room server.
|
// Query the latest events and state for a room from the room server.
|
||||||
@ -228,6 +231,7 @@ type FederationRoomserverAPI interface {
|
|||||||
QueryLatestEventsAndStateAPI
|
QueryLatestEventsAndStateAPI
|
||||||
QueryBulkStateContentAPI
|
QueryBulkStateContentAPI
|
||||||
QuerySenderIDAPI
|
QuerySenderIDAPI
|
||||||
|
UserRoomPrivateKeyCreator
|
||||||
|
|
||||||
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
|
||||||
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
|
||||||
|
@ -15,7 +15,7 @@ package auth
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
@ -25,7 +25,7 @@ import (
|
|||||||
// IsServerAllowed returns true if the server is allowed to see events in the room
|
// IsServerAllowed returns true if the server is allowed to see events in the room
|
||||||
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
// at this particular state. This function implements https://matrix.org/docs/spec/client_server/r0.6.0#id87
|
||||||
func IsServerAllowed(
|
func IsServerAllowed(
|
||||||
ctx context.Context, db storage.RoomDatabase,
|
ctx context.Context, querier api.QuerySenderIDAPI,
|
||||||
serverName spec.ServerName,
|
serverName spec.ServerName,
|
||||||
serverCurrentlyInRoom bool,
|
serverCurrentlyInRoom bool,
|
||||||
authEvents []gomatrixserverlib.PDU,
|
authEvents []gomatrixserverlib.PDU,
|
||||||
@ -41,7 +41,7 @@ func IsServerAllowed(
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// 2. If the user's membership was join, allow.
|
// 2. If the user's membership was join, allow.
|
||||||
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Join)
|
joinedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Join)
|
||||||
if joinedUserExists {
|
if joinedUserExists {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ func IsServerAllowed(
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
|
// 4. If the user's membership was invite, and the history_visibility was set to invited, allow.
|
||||||
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, db, serverName, authEvents, spec.Invite)
|
invitedUserExists := IsAnyUserOnServerWithMembership(ctx, querier, serverName, authEvents, spec.Invite)
|
||||||
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
|
if invitedUserExists && historyVisibility == gomatrixserverlib.HistoryVisibilityInvited {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@ -74,7 +74,7 @@ func HistoryVisibilityForRoom(authEvents []gomatrixserverlib.PDU) gomatrixserver
|
|||||||
return visibility
|
return visibility
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabase, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
func IsAnyUserOnServerWithMembership(ctx context.Context, querier api.QuerySenderIDAPI, serverName spec.ServerName, authEvents []gomatrixserverlib.PDU, wantMembership string) bool {
|
||||||
for _, ev := range authEvents {
|
for _, ev := range authEvents {
|
||||||
if ev.Type() != spec.MRoomMember {
|
if ev.Type() != spec.MRoomMember {
|
||||||
continue
|
continue
|
||||||
@ -89,7 +89,11 @@ func IsAnyUserOnServerWithMembership(ctx context.Context, db storage.RoomDatabas
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := db.GetUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*stateKey))
|
validRoomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*stateKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -4,17 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FakeStorageDB struct {
|
type FakeQuerier struct {
|
||||||
storage.RoomDatabase
|
api.QuerySenderIDAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *FakeStorageDB) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (f *FakeQuerier) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ func TestIsServerAllowed(t *testing.T) {
|
|||||||
authEvents = append(authEvents, ev.PDU)
|
authEvents = append(authEvents, ev.PDU)
|
||||||
}
|
}
|
||||||
|
|
||||||
if got := IsServerAllowed(context.Background(), &FakeStorageDB{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
if got := IsServerAllowed(context.Background(), &FakeQuerier{}, tt.serverName, tt.serverCurrentlyInRoom, authEvents); got != tt.want {
|
||||||
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
|
t.Errorf("IsServerAllowed() = %v, want %v", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -113,6 +113,7 @@ func (r *RoomserverInternalAPI) GetAliasesForRoomID(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocyclo
|
||||||
// RemoveRoomAlias implements alias.RoomserverInternalAPI
|
// RemoveRoomAlias implements alias.RoomserverInternalAPI
|
||||||
func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@ -129,7 +130,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, request.SenderID)
|
||||||
if err != nil || sender == nil {
|
if err != nil || sender == nil {
|
||||||
return fmt.Errorf("r.QueryUserIDForSender: %w", err)
|
return fmt.Errorf("r.QueryUserIDForSender: %w", err)
|
||||||
}
|
}
|
||||||
@ -177,7 +183,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
|||||||
if request.SenderID != ev.SenderID() {
|
if request.SenderID != ev.SenderID() {
|
||||||
senderID = ev.SenderID()
|
senderID = ev.SenderID()
|
||||||
}
|
}
|
||||||
sender, err := r.QueryUserIDForSender(ctx, roomID, senderID)
|
sender, err := r.QueryUserIDForSender(ctx, *validRoomID, senderID)
|
||||||
if err != nil || sender == nil {
|
if err != nil || sender == nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -206,7 +212,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateRes := &api.QueryLatestEventsAndStateResponse{}
|
stateRes := &api.QueryLatestEventsAndStateResponse{}
|
||||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
|
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r, &api.QueryLatestEventsAndStateRequest{RoomID: roomID, StateToFetch: eventsNeeded.Tuples()}, stateRes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -177,6 +177,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio
|
|||||||
IsLocalServerName: r.Cfg.Global.IsLocalServerName,
|
IsLocalServerName: r.Cfg.Global.IsLocalServerName,
|
||||||
DB: r.DB,
|
DB: r.DB,
|
||||||
FSAPI: r.fsAPI,
|
FSAPI: r.fsAPI,
|
||||||
|
Querier: r.Queryer,
|
||||||
KeyRing: r.KeyRing,
|
KeyRing: r.KeyRing,
|
||||||
// Perspective servers are trusted to not lie about server keys, so we will also
|
// Perspective servers are trusted to not lie about server keys, so we will also
|
||||||
// prefer these servers when backfilling (assuming they are in the room) rather
|
// prefer these servers when backfilling (assuming they are in the room) rather
|
||||||
|
@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/state"
|
"github.com/matrix-org/dendrite/roomserver/state"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
@ -36,6 +37,7 @@ func CheckForSoftFail(
|
|||||||
roomInfo *types.RoomInfo,
|
roomInfo *types.RoomInfo,
|
||||||
event *types.HeaderedEvent,
|
event *types.HeaderedEvent,
|
||||||
stateEventIDs []string,
|
stateEventIDs []string,
|
||||||
|
querier api.QuerySenderIDAPI,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
rewritesState := len(stateEventIDs) > 1
|
rewritesState := len(stateEventIDs) > 1
|
||||||
|
|
||||||
@ -49,7 +51,7 @@ func CheckForSoftFail(
|
|||||||
} else {
|
} else {
|
||||||
// Then get the state entries for the current state snapshot.
|
// Then get the state entries for the current state snapshot.
|
||||||
// We'll use this to check if the event is allowed right now.
|
// We'll use this to check if the event is allowed right now.
|
||||||
roomState := state.NewStateResolution(db, roomInfo)
|
roomState := state.NewStateResolution(db, roomInfo, querier)
|
||||||
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
|
authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
|
return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err)
|
||||||
@ -76,8 +78,8 @@ func CheckForSoftFail(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the event is allowed.
|
// Check if the event is allowed.
|
||||||
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return db.GetUserIDForSender(ctx, roomID, senderID)
|
return querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
// return true, nil
|
// return true, nil
|
||||||
return true, err
|
return true, err
|
||||||
|
@ -68,7 +68,7 @@ func UpdateToInviteMembership(
|
|||||||
// memberships. If the servername is not supplied then the local server will be
|
// memberships. If the servername is not supplied then the local server will be
|
||||||
// checked instead using a faster code path.
|
// checked instead using a faster code path.
|
||||||
// TODO: This should probably be replaced by an API call.
|
// TODO: This should probably be replaced by an API call.
|
||||||
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName spec.ServerName, roomID string) (bool, error) {
|
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, serverName spec.ServerName, roomID string) (bool, error) {
|
||||||
info, err := db.RoomInfo(ctx, roomID)
|
info, err := db.RoomInfo(ctx, roomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
@ -94,7 +94,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
|
|||||||
for i := range events {
|
for i := range events {
|
||||||
gmslEvents[i] = events[i].PDU
|
gmslEvents[i] = events[i].PDU
|
||||||
}
|
}
|
||||||
return auth.IsAnyUserOnServerWithMembership(ctx, db, serverName, gmslEvents, spec.Join), nil
|
return auth.IsAnyUserOnServerWithMembership(ctx, querier, serverName, gmslEvents, spec.Join), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsInvitePending(
|
func IsInvitePending(
|
||||||
@ -211,8 +211,8 @@ func GetMembershipsAtState(
|
|||||||
return events, nil
|
return events, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
|
func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventNID types.EventNID, querier api.QuerySenderIDAPI) ([]types.StateEntry, error) {
|
||||||
roomState := state.NewStateResolution(db, info)
|
roomState := state.NewStateResolution(db, info, querier)
|
||||||
// Lookup the event NID
|
// Lookup the event NID
|
||||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -229,8 +229,8 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
|
|||||||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
|
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID, querier api.QuerySenderIDAPI) (map[string][]types.StateEntry, error) {
|
||||||
roomState := state.NewStateResolution(db, info)
|
roomState := state.NewStateResolution(db, info, querier)
|
||||||
// Fetch the state as it was when this event was fired
|
// Fetch the state as it was when this event was fired
|
||||||
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
|
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
|
||||||
}
|
}
|
||||||
@ -264,7 +264,7 @@ func LoadStateEvents(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CheckServerAllowedToSeeEvent(
|
func CheckServerAllowedToSeeEvent(
|
||||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool,
|
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID string, eventID string, serverName spec.ServerName, isServerInRoom bool, querier api.QuerySenderIDAPI,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
|
stateAtEvent, err := db.GetHistoryVisibilityState(ctx, info, eventID, string(serverName))
|
||||||
switch err {
|
switch err {
|
||||||
@ -273,7 +273,7 @@ func CheckServerAllowedToSeeEvent(
|
|||||||
case tables.OptimisationNotSupportedError:
|
case tables.OptimisationNotSupportedError:
|
||||||
// The database engine didn't support this optimisation, so fall back to using
|
// The database engine didn't support this optimisation, so fall back to using
|
||||||
// the old and slow method
|
// the old and slow method
|
||||||
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName)
|
stateAtEvent, err = slowGetHistoryVisibilityState(ctx, db, info, roomID, eventID, serverName, querier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -288,13 +288,13 @@ func CheckServerAllowedToSeeEvent(
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return auth.IsServerAllowed(ctx, db, serverName, isServerInRoom, stateAtEvent), nil
|
return auth.IsServerAllowed(ctx, querier, serverName, isServerInRoom, stateAtEvent), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func slowGetHistoryVisibilityState(
|
func slowGetHistoryVisibilityState(
|
||||||
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName,
|
ctx context.Context, db storage.Database, info *types.RoomInfo, roomID, eventID string, serverName spec.ServerName, querier api.QuerySenderIDAPI,
|
||||||
) ([]gomatrixserverlib.PDU, error) {
|
) ([]gomatrixserverlib.PDU, error) {
|
||||||
roomState := state.NewStateResolution(db, info)
|
roomState := state.NewStateResolution(db, info, querier)
|
||||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
@ -318,9 +318,13 @@ func slowGetHistoryVisibilityState(
|
|||||||
// If the event state key doesn't match the given servername
|
// If the event state key doesn't match the given servername
|
||||||
// then we'll filter it out. This does preserve state keys that
|
// then we'll filter it out. This does preserve state keys that
|
||||||
// are "" since these will contain history visibility etc.
|
// are "" since these will contain history visibility etc.
|
||||||
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
for nid, key := range stateKeys {
|
for nid, key := range stateKeys {
|
||||||
if key != "" {
|
if key != "" {
|
||||||
userID, err := db.GetUserIDForSender(ctx, roomID, spec.SenderID(key))
|
userID, err := querier.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(key))
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
if userID.Domain() != serverName {
|
if userID.Domain() != serverName {
|
||||||
delete(stateKeys, nid)
|
delete(stateKeys, nid)
|
||||||
@ -349,7 +353,7 @@ func slowGetHistoryVisibilityState(
|
|||||||
// TODO: Remove this when we have tests to assert correctness of this function
|
// TODO: Remove this when we have tests to assert correctness of this function
|
||||||
func ScanEventTree(
|
func ScanEventTree(
|
||||||
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
|
ctx context.Context, db storage.Database, info *types.RoomInfo, front []string, visited map[string]bool, limit int,
|
||||||
serverName spec.ServerName,
|
serverName spec.ServerName, querier api.QuerySenderIDAPI,
|
||||||
) ([]types.EventNID, map[string]struct{}, error) {
|
) ([]types.EventNID, map[string]struct{}, error) {
|
||||||
var resultNIDs []types.EventNID
|
var resultNIDs []types.EventNID
|
||||||
var err error
|
var err error
|
||||||
@ -392,7 +396,7 @@ BFSLoop:
|
|||||||
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
||||||
// only talk in event IDs, no room IDs at all (!!!)
|
// only talk in event IDs, no room IDs at all (!!!)
|
||||||
ev := events[0]
|
ev := events[0]
|
||||||
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
|
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, querier, serverName, ev.RoomID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
||||||
}
|
}
|
||||||
@ -415,7 +419,7 @@ BFSLoop:
|
|||||||
// hasn't been seen before.
|
// hasn't been seen before.
|
||||||
if !visited[pre] {
|
if !visited[pre] {
|
||||||
visited[pre] = true
|
visited[pre] = true
|
||||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom)
|
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, ev.RoomID(), pre, serverName, isServerInRoom, querier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||||
"Error checking if allowed to see event",
|
"Error checking if allowed to see event",
|
||||||
@ -444,7 +448,7 @@ BFSLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func QueryLatestEventsAndState(
|
func QueryLatestEventsAndState(
|
||||||
ctx context.Context, db storage.Database,
|
ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI,
|
||||||
request *api.QueryLatestEventsAndStateRequest,
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
response *api.QueryLatestEventsAndStateResponse,
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
) error {
|
) error {
|
||||||
@ -457,7 +461,7 @@ func QueryLatestEventsAndState(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
roomState := state.NewStateResolution(db, roomInfo)
|
roomState := state.NewStateResolution(db, roomInfo, querier)
|
||||||
response.RoomExists = true
|
response.RoomExists = true
|
||||||
response.RoomVersion = roomInfo.RoomVersion
|
response.RoomVersion = roomInfo.RoomVersion
|
||||||
|
|
||||||
|
@ -128,7 +128,11 @@ func (r *Inputer) processRoomEvent(
|
|||||||
if roomInfo == nil && !isCreateEvent {
|
if roomInfo == nil && !isCreateEvent {
|
||||||
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
|
||||||
}
|
}
|
||||||
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sender, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
|
return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
|
||||||
}
|
}
|
||||||
@ -282,8 +286,8 @@ func (r *Inputer) processRoomEvent(
|
|||||||
|
|
||||||
// Check if the event is allowed by its auth events. If it isn't then
|
// Check if the event is allowed by its auth events. If it isn't then
|
||||||
// we consider the event to be "rejected" — it will still be persisted.
|
// we consider the event to be "rejected" — it will still be persisted.
|
||||||
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
isRejected = true
|
isRejected = true
|
||||||
rejectionErr = err
|
rejectionErr = err
|
||||||
@ -321,7 +325,7 @@ func (r *Inputer) processRoomEvent(
|
|||||||
if input.Kind == api.KindNew && !isCreateEvent {
|
if input.Kind == api.KindNew && !isCreateEvent {
|
||||||
// Check that the event passes authentication checks based on the
|
// Check that the event passes authentication checks based on the
|
||||||
// current room state.
|
// current room state.
|
||||||
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs)
|
softfail, err = helpers.CheckForSoftFail(ctx, r.DB, roomInfo, headered, input.StateEventIDs, r.Queryer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("Error authing soft-failed event")
|
logger.WithError(err).Warn("Error authing soft-failed event")
|
||||||
}
|
}
|
||||||
@ -401,7 +405,7 @@ func (r *Inputer) processRoomEvent(
|
|||||||
redactedEvent gomatrixserverlib.PDU
|
redactedEvent gomatrixserverlib.PDU
|
||||||
)
|
)
|
||||||
if !isRejected && !isCreateEvent {
|
if !isRejected && !isCreateEvent {
|
||||||
resolver := state.NewStateResolution(r.DB, roomInfo)
|
resolver := state.NewStateResolution(r.DB, roomInfo, r.Queryer)
|
||||||
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver)
|
redactionEvent, redactedEvent, err = r.DB.MaybeRedactEvent(ctx, roomInfo, eventNID, event, &resolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -587,8 +591,8 @@ func (r *Inputer) processStateBefore(
|
|||||||
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
|
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
|
||||||
gomatrixserverlib.ToPDUs(stateBeforeEvent),
|
gomatrixserverlib.ToPDUs(stateBeforeEvent),
|
||||||
)
|
)
|
||||||
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); rejectionErr != nil {
|
}); rejectionErr != nil {
|
||||||
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
|
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
|
||||||
return
|
return
|
||||||
@ -700,8 +704,8 @@ nextAuthEvent:
|
|||||||
// Check the signatures of the event. If this fails then we'll simply
|
// Check the signatures of the event. If this fails then we'll simply
|
||||||
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
// skip it, because gomatrixserverlib.Allowed() will notice a problem
|
||||||
// if a critical event is missing anyway.
|
// if a critical event is missing anyway.
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
continue nextAuthEvent
|
continue nextAuthEvent
|
||||||
}
|
}
|
||||||
@ -718,8 +722,8 @@ nextAuthEvent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the auth event should be rejected.
|
// Check if the auth event should be rejected.
|
||||||
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
if isRejected = err != nil; isRejected {
|
if isRejected = err != nil; isRejected {
|
||||||
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
|
||||||
@ -783,7 +787,7 @@ func (r *Inputer) calculateAndSetState(
|
|||||||
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
return fmt.Errorf("r.DB.GetRoomUpdater: %w", err)
|
||||||
}
|
}
|
||||||
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
|
||||||
roomState := state.NewStateResolution(updater, roomInfo)
|
roomState := state.NewStateResolution(updater, roomInfo, r.Queryer)
|
||||||
|
|
||||||
if input.HasState {
|
if input.HasState {
|
||||||
// We've been told what the state at the event is so we don't need to calculate it.
|
// We've been told what the state at the event is so we don't need to calculate it.
|
||||||
@ -836,13 +840,18 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
prevEvents := latestRes.LatestEvents
|
prevEvents := latestRes.LatestEvents
|
||||||
for _, memberEvent := range memberEvents {
|
for _, memberEvent := range memberEvents {
|
||||||
if memberEvent.StateKey() == nil {
|
if memberEvent.StateKey() == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, memberEvent.RoomID(), spec.SenderID(*memberEvent.StateKey()))
|
memberUserID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*memberEvent.StateKey()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Finally check that the event is NOT allowed
|
// Finally check that the event is NOT allowed
|
||||||
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}); err == nil {
|
}); err == nil {
|
||||||
t.Fatalf("event should not be allowed, but it was")
|
t.Fatalf("event should not be allowed, but it was")
|
||||||
|
@ -213,7 +213,7 @@ func (u *latestEventsUpdater) latestState() error {
|
|||||||
defer trace.EndRegion()
|
defer trace.EndRegion()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
roomState := state.NewStateResolution(u.updater, u.roomInfo)
|
roomState := state.NewStateResolution(u.updater, u.roomInfo, u.api.Queryer)
|
||||||
|
|
||||||
// Work out if the state at the extremities has actually changed
|
// Work out if the state at the extremities has actually changed
|
||||||
// or not. If they haven't then we won't bother doing all of the
|
// or not. If they haven't then we won't bother doing all of the
|
||||||
|
@ -139,7 +139,11 @@ func (r *Inputer) updateMembership(
|
|||||||
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
|
func (r *Inputer) isLocalTarget(ctx context.Context, event *types.Event) bool {
|
||||||
isTargetLocalUser := false
|
isTargetLocalUser := false
|
||||||
if statekey := event.StateKey(); statekey != nil {
|
if statekey := event.StateKey(); statekey != nil {
|
||||||
userID, err := r.Queryer.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*statekey))
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return isTargetLocalUser
|
||||||
|
}
|
||||||
|
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*statekey))
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
return isTargetLocalUser
|
return isTargetLocalUser
|
||||||
}
|
}
|
||||||
|
@ -383,7 +383,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
|
|||||||
defer trace.EndRegion()
|
defer trace.EndRegion()
|
||||||
|
|
||||||
var res parsedRespState
|
var res parsedRespState
|
||||||
roomState := state.NewStateResolution(t.db, t.roomInfo)
|
roomState := state.NewStateResolution(t.db, t.roomInfo, t.inputer.Queryer)
|
||||||
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
|
stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.log.WithError(err).Warnf("failed to get state after %s locally", eventID)
|
t.log.WithError(err).Warnf("failed to get state after %s locally", eventID)
|
||||||
@ -473,8 +473,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
|
|||||||
stateEventList = append(stateEventList, state.StateEvents...)
|
stateEventList = append(stateEventList, state.StateEvents...)
|
||||||
}
|
}
|
||||||
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
|
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
|
||||||
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -482,8 +482,8 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
|
|||||||
}
|
}
|
||||||
// apply the current event
|
// apply the current event
|
||||||
retryAllowedState:
|
retryAllowedState:
|
||||||
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
switch missing := err.(type) {
|
switch missing := err.(type) {
|
||||||
case gomatrixserverlib.MissingAuthEventError:
|
case gomatrixserverlib.MissingAuthEventError:
|
||||||
@ -569,8 +569,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
|
|||||||
// will be added and duplicates will be removed.
|
// will be added and duplicates will be removed.
|
||||||
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
|
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
|
||||||
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -660,8 +660,8 @@ func (t *missingStateReq) lookupMissingStateViaState(
|
|||||||
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
|
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
|
||||||
StateEvents: state.GetStateEvents(),
|
StateEvents: state.GetStateEvents(),
|
||||||
AuthEvents: state.GetAuthEvents(),
|
AuthEvents: state.GetAuthEvents(),
|
||||||
}, roomVersion, t.keys, nil, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
}, roomVersion, t.keys, nil, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -897,8 +897,8 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
|
|||||||
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
|
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
|
||||||
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
|
||||||
}
|
}
|
||||||
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return t.db.GetUserIDForSender(ctx, roomID, senderID)
|
return t.inputer.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
|
||||||
return nil, verifySigError{event.EventID(), err}
|
return nil, verifySigError{event.EventID(), err}
|
||||||
|
@ -74,6 +74,10 @@ func (r *Admin) PerformAdminEvacuateRoom(
|
|||||||
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
prevEvents := latestRes.LatestEvents
|
prevEvents := latestRes.LatestEvents
|
||||||
var senderDomain spec.ServerName
|
var senderDomain spec.ServerName
|
||||||
@ -100,7 +104,7 @@ func (r *Admin) PerformAdminEvacuateRoom(
|
|||||||
PrevEvents: prevEvents,
|
PrevEvents: prevEvents,
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := r.Queryer.QueryUserIDForSender(ctx, roomID, spec.SenderID(fledglingEvent.SenderID))
|
userID, err := r.Queryer.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(fledglingEvent.SenderID))
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -264,16 +268,16 @@ func (r *Admin) PerformAdminDownloadState(
|
|||||||
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
|
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
|
||||||
}
|
}
|
||||||
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
authEventMap[authEvent.EventID()] = authEvent
|
authEventMap[authEvent.EventID()] = authEvent
|
||||||
}
|
}
|
||||||
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
|
||||||
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Queryer.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -293,7 +297,11 @@ func (r *Admin) PerformAdminDownloadState(
|
|||||||
stateIDs = append(stateIDs, stateEvent.EventID())
|
stateIDs = append(stateIDs, stateEvent.EventID())
|
||||||
}
|
}
|
||||||
|
|
||||||
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, roomID, *fullUserID)
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
senderID, err := r.Queryer.QuerySenderIDForUser(ctx, *validRoomID, *fullUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -42,6 +42,7 @@ type Backfiller struct {
|
|||||||
DB storage.Database
|
DB storage.Database
|
||||||
FSAPI federationAPI.RoomserverFederationAPI
|
FSAPI federationAPI.RoomserverFederationAPI
|
||||||
KeyRing gomatrixserverlib.JSONVerifier
|
KeyRing gomatrixserverlib.JSONVerifier
|
||||||
|
Querier api.QuerySenderIDAPI
|
||||||
|
|
||||||
// The servers which should be preferred above other servers when backfilling
|
// The servers which should be preferred above other servers when backfilling
|
||||||
PreferServers []spec.ServerName
|
PreferServers []spec.ServerName
|
||||||
@ -79,7 +80,7 @@ func (r *Backfiller) PerformBackfill(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Scan the event tree for events to send back.
|
// Scan the event tree for events to send back.
|
||||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r.Querier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -113,7 +114,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||||||
if info == nil || info.IsStub() {
|
if info == nil || info.IsStub() {
|
||||||
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
|
return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID)
|
||||||
}
|
}
|
||||||
requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
requester := newBackfillRequester(r.DB, r.FSAPI, r.Querier, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers)
|
||||||
// Request 100 items regardless of what the query asks for.
|
// Request 100 items regardless of what the query asks for.
|
||||||
// We don't want to go much higher than this.
|
// We don't want to go much higher than this.
|
||||||
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
|
||||||
@ -121,8 +122,8 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||||||
// Specifically the test "Outbound federation can backfill events"
|
// Specifically the test "Outbound federation can backfill events"
|
||||||
events, err := gomatrixserverlib.RequestBackfill(
|
events, err := gomatrixserverlib.RequestBackfill(
|
||||||
ctx, req.VirtualHost, requester,
|
ctx, req.VirtualHost, requester,
|
||||||
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
// Only return an error if we really couldn't get any events.
|
// Only return an error if we really couldn't get any events.
|
||||||
@ -135,7 +136,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
|||||||
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
|
||||||
|
|
||||||
// persist these new events - auth checks have already been done
|
// persist these new events - auth checks have already been done
|
||||||
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
|
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, r.Querier, events)
|
||||||
|
|
||||||
for _, ev := range backfilledEventMap {
|
for _, ev := range backfilledEventMap {
|
||||||
// now add state for these events
|
// now add state for these events
|
||||||
@ -212,8 +213,8 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
|
||||||
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Warn("failed to load and verify event")
|
logger.WithError(err).Warn("failed to load and verify event")
|
||||||
@ -246,13 +247,14 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
|
||||||
persistEvents(ctx, r.DB, newEvents)
|
persistEvents(ctx, r.DB, r.Querier, newEvents)
|
||||||
}
|
}
|
||||||
|
|
||||||
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
// backfillRequester implements gomatrixserverlib.BackfillRequester
|
||||||
type backfillRequester struct {
|
type backfillRequester struct {
|
||||||
db storage.Database
|
db storage.Database
|
||||||
fsAPI federationAPI.RoomserverFederationAPI
|
fsAPI federationAPI.RoomserverFederationAPI
|
||||||
|
querier api.QuerySenderIDAPI
|
||||||
virtualHost spec.ServerName
|
virtualHost spec.ServerName
|
||||||
isLocalServerName func(spec.ServerName) bool
|
isLocalServerName func(spec.ServerName) bool
|
||||||
preferServer map[spec.ServerName]bool
|
preferServer map[spec.ServerName]bool
|
||||||
@ -268,6 +270,7 @@ type backfillRequester struct {
|
|||||||
|
|
||||||
func newBackfillRequester(
|
func newBackfillRequester(
|
||||||
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
|
db storage.Database, fsAPI federationAPI.RoomserverFederationAPI,
|
||||||
|
querier api.QuerySenderIDAPI,
|
||||||
virtualHost spec.ServerName,
|
virtualHost spec.ServerName,
|
||||||
isLocalServerName func(spec.ServerName) bool,
|
isLocalServerName func(spec.ServerName) bool,
|
||||||
bwExtrems map[string][]string, preferServers []spec.ServerName,
|
bwExtrems map[string][]string, preferServers []spec.ServerName,
|
||||||
@ -279,6 +282,7 @@ func newBackfillRequester(
|
|||||||
return &backfillRequester{
|
return &backfillRequester{
|
||||||
db: db,
|
db: db,
|
||||||
fsAPI: fsAPI,
|
fsAPI: fsAPI,
|
||||||
|
querier: querier,
|
||||||
virtualHost: virtualHost,
|
virtualHost: virtualHost,
|
||||||
isLocalServerName: isLocalServerName,
|
isLocalServerName: isLocalServerName,
|
||||||
eventIDToBeforeStateIDs: make(map[string][]string),
|
eventIDToBeforeStateIDs: make(map[string][]string),
|
||||||
@ -460,14 +464,14 @@ FindSuccessor:
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID)
|
stateEntries, err := helpers.StateBeforeEvent(ctx, b.db, info, NIDs[eventID].EventNID, b.querier)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// possibly return all joined servers depending on history visiblity
|
// possibly return all joined servers depending on history visiblity
|
||||||
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, info, stateEntries, b.virtualHost)
|
memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, b.querier, info, stateEntries, b.virtualHost)
|
||||||
b.historyVisiblity = visibility
|
b.historyVisiblity = visibility
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
|
||||||
@ -488,7 +492,11 @@ FindSuccessor:
|
|||||||
// Store the server names in a temporary map to avoid duplicates.
|
// Store the server names in a temporary map to avoid duplicates.
|
||||||
serverSet := make(map[spec.ServerName]bool)
|
serverSet := make(map[spec.ServerName]bool)
|
||||||
for _, event := range memberEvents {
|
for _, event := range memberEvents {
|
||||||
if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sender, err := b.querier.QueryUserIDForSender(ctx, *validRoomID, event.SenderID()); err == nil {
|
||||||
serverSet[sender.Domain()] = true
|
serverSet[sender.Domain()] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -554,7 +562,7 @@ func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion,
|
|||||||
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
|
||||||
// pull all events and then filter by that table.
|
// pull all events and then filter by that table.
|
||||||
func joinEventsFromHistoryVisibility(
|
func joinEventsFromHistoryVisibility(
|
||||||
ctx context.Context, db storage.RoomDatabase, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
ctx context.Context, db storage.RoomDatabase, querier api.QuerySenderIDAPI, roomInfo *types.RoomInfo, stateEntries []types.StateEntry,
|
||||||
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
|
thisServer spec.ServerName) ([]types.Event, gomatrixserverlib.HistoryVisibility, error) {
|
||||||
|
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
@ -582,7 +590,7 @@ func joinEventsFromHistoryVisibility(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Can we see events in the room?
|
// Can we see events in the room?
|
||||||
canSeeEvents := auth.IsServerAllowed(ctx, db, thisServer, true, events)
|
canSeeEvents := auth.IsServerAllowed(ctx, querier, thisServer, true, events)
|
||||||
visibility := auth.HistoryVisibilityForRoom(events)
|
visibility := auth.HistoryVisibilityForRoom(events)
|
||||||
if !canSeeEvents {
|
if !canSeeEvents {
|
||||||
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
logrus.Infof("ServersAtEvent history not visible to us: %s", visibility)
|
||||||
@ -597,7 +605,7 @@ func joinEventsFromHistoryVisibility(
|
|||||||
return evs, visibility, err
|
return evs, visibility, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
func persistEvents(ctx context.Context, db storage.Database, querier api.QuerySenderIDAPI, events []gomatrixserverlib.PDU) (types.RoomNID, map[string]types.Event) {
|
||||||
var roomNID types.RoomNID
|
var roomNID types.RoomNID
|
||||||
var eventNID types.EventNID
|
var eventNID types.EventNID
|
||||||
backfilledEventMap := make(map[string]types.Event)
|
backfilledEventMap := make(map[string]types.Event)
|
||||||
@ -639,7 +647,7 @@ func persistEvents(ctx context.Context, db storage.Database, events []gomatrixse
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
resolver := state.NewStateResolution(db, roomInfo)
|
resolver := state.NewStateResolution(db, roomInfo, querier)
|
||||||
|
|
||||||
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
|
_, redactedEvent, err := db.MaybeRedactEvent(ctx, roomInfo, eventNID, ev, &resolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -63,13 +63,20 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
senderID, err := c.DB.GetSenderIDForUser(ctx, roomID.String(), userID)
|
var senderID spec.SenderID
|
||||||
if err != nil {
|
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
||||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
// create user room key if needed
|
||||||
return "", &util.JSONResponse{
|
key, keyErr := c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
||||||
Code: http.StatusInternalServerError,
|
if keyErr != nil {
|
||||||
JSON: spec.InternalServerError{},
|
util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||||
|
return "", &util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
|
||||||
|
} else {
|
||||||
|
senderID = spec.SenderID(userID.String())
|
||||||
}
|
}
|
||||||
createContent["creator"] = senderID
|
createContent["creator"] = senderID
|
||||||
createContent["room_version"] = createRequest.RoomVersion
|
createContent["room_version"] = createRequest.RoomVersion
|
||||||
@ -323,8 +330,8 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return c.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return c.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
|
||||||
return "", &util.JSONResponse{
|
return "", &util.JSONResponse{
|
||||||
@ -364,18 +371,6 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// create user room key if needed
|
|
||||||
if createRequest.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
|
||||||
_, err = c.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, userID, roomID)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
|
|
||||||
return "", &util.JSONResponse{
|
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
JSON: spec.InternalServerError{},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// send the remaining events
|
// send the remaining events
|
||||||
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil {
|
if err = api.SendInputRoomEvents(ctx, c.RSAPI, userID.Domain(), inputs[1:], false); err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
|
util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed")
|
||||||
@ -455,7 +450,7 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
|
|||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID.String(), *inviteeUserID)
|
inviteeSenderID, queryErr := c.RSAPI.QuerySenderIDForUser(ctx, roomID, *inviteeUserID)
|
||||||
if queryErr != nil {
|
if queryErr != nil {
|
||||||
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
|
util.GetLogger(ctx).WithError(queryErr).Error("rsapi.QuerySenderIDForUser failed")
|
||||||
return "", &util.JSONResponse{
|
return "", &util.JSONResponse{
|
||||||
|
@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek(
|
|||||||
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
|
response.LatestEvent = &types.HeaderedEvent{PDU: sortedLatestEvents[0]}
|
||||||
|
|
||||||
// XXX: do we actually need to do a state resolution here?
|
// XXX: do we actually need to do a state resolution here?
|
||||||
roomState := state.NewStateResolution(r.DB, info)
|
roomState := state.NewStateResolution(r.DB, info, r.Inputer.Queryer)
|
||||||
|
|
||||||
var stateEntries []types.StateEntry
|
var stateEntries []types.StateEntry
|
||||||
stateEntries, err = roomState.LoadStateAtSnapshot(
|
stateEntries, err = roomState.LoadStateAtSnapshot(
|
||||||
|
@ -34,6 +34,7 @@ import (
|
|||||||
|
|
||||||
type QueryState struct {
|
type QueryState struct {
|
||||||
storage.Database
|
storage.Database
|
||||||
|
querier api.QuerySenderIDAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
|
func (q *QueryState) GetAuthEvents(ctx context.Context, event gomatrixserverlib.PDU) (gomatrixserverlib.AuthEventProvider, error) {
|
||||||
@ -46,7 +47,7 @@ func (q *QueryState) GetState(ctx context.Context, roomID spec.RoomID, stateWant
|
|||||||
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
return nil, fmt.Errorf("failed to load RoomInfo: %w", err)
|
||||||
}
|
}
|
||||||
if info != nil {
|
if info != nil {
|
||||||
roomState := state.NewStateResolution(q.Database, info)
|
roomState := state.NewStateResolution(q.Database, info, q.querier)
|
||||||
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
|
||||||
ctx, info.StateSnapshotNID(), stateWanted,
|
ctx, info.StateSnapshotNID(), stateWanted,
|
||||||
)
|
)
|
||||||
@ -98,7 +99,11 @@ func (r *Inviter) ProcessInviteMembership(
|
|||||||
var outputUpdates []api.OutputEvent
|
var outputUpdates []api.OutputEvent
|
||||||
var updater *shared.MembershipUpdater
|
var updater *shared.MembershipUpdater
|
||||||
|
|
||||||
userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
|
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
userID, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
|
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
|
||||||
}
|
}
|
||||||
@ -126,7 +131,12 @@ func (r *Inviter) PerformInvite(
|
|||||||
) error {
|
) error {
|
||||||
event := req.Event
|
event := req.Event
|
||||||
|
|
||||||
sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sender, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return spec.InvalidParam("The sender user ID is invalid")
|
return spec.InvalidParam("The sender user ID is invalid")
|
||||||
}
|
}
|
||||||
@ -137,18 +147,13 @@ func (r *Inviter) PerformInvite(
|
|||||||
if event.StateKey() == nil || *event.StateKey() == "" {
|
if event.StateKey() == nil || *event.StateKey() == "" {
|
||||||
return fmt.Errorf("invite must be a state event")
|
return fmt.Errorf("invite must be a state event")
|
||||||
}
|
}
|
||||||
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
invitedUser, err := r.RSAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if err != nil || invitedUser == nil {
|
if err != nil || invitedUser == nil {
|
||||||
return spec.InvalidParam("Could not find the matching senderID for this user")
|
return spec.InvalidParam("Could not find the matching senderID for this user")
|
||||||
}
|
}
|
||||||
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
|
isTargetLocal := r.Cfg.Matrix.IsLocalServerName(invitedUser.Domain())
|
||||||
|
|
||||||
validRoomID, err := spec.NewRoomID(event.RoomID())
|
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, *validRoomID, *invitedUser)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
invitedSenderID, err := r.RSAPI.QuerySenderIDForUser(ctx, event.RoomID(), *invitedUser)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed looking up senderID for invited user")
|
return fmt.Errorf("failed looking up senderID for invited user")
|
||||||
}
|
}
|
||||||
@ -161,9 +166,9 @@ func (r *Inviter) PerformInvite(
|
|||||||
IsTargetLocal: isTargetLocal,
|
IsTargetLocal: isTargetLocal,
|
||||||
StrippedState: req.InviteRoomState,
|
StrippedState: req.InviteRoomState,
|
||||||
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
|
||||||
StateQuerier: &QueryState{r.DB},
|
StateQuerier: &QueryState{r.DB, r.RSAPI},
|
||||||
UserIDQuerier: func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
UserIDQuerier: func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.RSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
|
||||||
|
@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||||
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
@ -174,44 +175,6 @@ func (r *Joiner) performJoinRoomByID(
|
|||||||
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
req.ServerNames = append(req.ServerNames, roomID.Domain())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the template for the join event.
|
|
||||||
userID, err := spec.NewUserID(req.UserID, true)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
|
||||||
}
|
|
||||||
senderID, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomIDOrAlias, *userID)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
|
||||||
}
|
|
||||||
senderIDString := string(senderID)
|
|
||||||
userDomain := userID.Domain()
|
|
||||||
proto := gomatrixserverlib.ProtoEvent{
|
|
||||||
Type: spec.MRoomMember,
|
|
||||||
SenderID: senderIDString,
|
|
||||||
StateKey: &senderIDString,
|
|
||||||
RoomID: req.RoomIDOrAlias,
|
|
||||||
Redacts: "",
|
|
||||||
}
|
|
||||||
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
|
||||||
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// It is possible for the request to include some "content" for the
|
|
||||||
// event. We'll always overwrite the "membership" key, but the rest,
|
|
||||||
// like "display_name" or "avatar_url", will be kept if supplied.
|
|
||||||
if req.Content == nil {
|
|
||||||
req.Content = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
req.Content["membership"] = spec.Join
|
|
||||||
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
|
||||||
return "", "", aerr
|
|
||||||
} else if authorisedVia != "" {
|
|
||||||
req.Content["join_authorised_via_users_server"] = authorisedVia
|
|
||||||
}
|
|
||||||
if err = proto.SetContent(req.Content); err != nil {
|
|
||||||
return "", "", fmt.Errorf("eb.SetContent: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Force a federated join if we aren't in the room and we've been
|
// Force a federated join if we aren't in the room and we've been
|
||||||
// given some server names to try joining by.
|
// given some server names to try joining by.
|
||||||
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
|
inRoomReq := &rsAPI.QueryServerJoinedToRoomRequest{
|
||||||
@ -224,29 +187,63 @@ func (r *Joiner) performJoinRoomByID(
|
|||||||
serverInRoom := inRoomRes.IsInRoom
|
serverInRoom := inRoomRes.IsInRoom
|
||||||
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
|
forceFederatedJoin := len(req.ServerNames) > 0 && !serverInRoom
|
||||||
|
|
||||||
|
userID, err := spec.NewUserID(req.UserID, true)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", rsAPI.ErrInvalidID{Err: fmt.Errorf("user ID %q is invalid: %w", req.UserID, err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look up the room NID for the supplied room ID.
|
||||||
|
var senderID spec.SenderID
|
||||||
|
checkInvitePending := false
|
||||||
|
info, err := r.DB.RoomInfo(ctx, req.RoomIDOrAlias)
|
||||||
|
if err == nil && info != nil {
|
||||||
|
switch info.RoomVersion {
|
||||||
|
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||||
|
senderID, err = r.Queryer.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||||
|
if err == nil {
|
||||||
|
checkInvitePending = true
|
||||||
|
} else {
|
||||||
|
// create user room key if needed
|
||||||
|
key, keyErr := r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
|
||||||
|
if keyErr != nil {
|
||||||
|
util.GetLogger(ctx).WithError(keyErr).Error("GetOrCreateUserRoomPrivateKey failed")
|
||||||
|
return "", "", fmt.Errorf("GetOrCreateUserRoomPrivateKey failed: %w", keyErr)
|
||||||
|
}
|
||||||
|
senderID = spec.SenderID(spec.Base64Bytes(key).Encode())
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
checkInvitePending = true
|
||||||
|
senderID = spec.SenderID(userID.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
userDomain := userID.Domain()
|
||||||
|
|
||||||
// Force a federated join if we're dealing with a pending invite
|
// Force a federated join if we're dealing with a pending invite
|
||||||
// and we aren't in the room.
|
// and we aren't in the room.
|
||||||
isInvitePending, inviteSender, _, inviteEvent, err := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
if checkInvitePending {
|
||||||
if err == nil && !serverInRoom && isInvitePending {
|
isInvitePending, inviteSender, _, inviteEvent, inviteErr := helpers.IsInvitePending(ctx, r.DB, req.RoomIDOrAlias, senderID)
|
||||||
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomIDOrAlias, inviteSender)
|
if inviteErr == nil && !serverInRoom && isInvitePending {
|
||||||
if queryErr != nil {
|
inviter, queryErr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, inviteSender)
|
||||||
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
if queryErr != nil {
|
||||||
}
|
return "", "", fmt.Errorf("r.RSAPI.QueryUserIDForSender: %w", queryErr)
|
||||||
|
}
|
||||||
|
|
||||||
// If we were invited by someone from another server then we can
|
// If we were invited by someone from another server then we can
|
||||||
// assume they are in the room so we can join via them.
|
// assume they are in the room so we can join via them.
|
||||||
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
if inviter != nil && !r.Cfg.Matrix.IsLocalServerName(inviter.Domain()) {
|
||||||
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
req.ServerNames = append(req.ServerNames, inviter.Domain())
|
||||||
forceFederatedJoin = true
|
forceFederatedJoin = true
|
||||||
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
memberEvent := gjson.Parse(string(inviteEvent.JSON()))
|
||||||
// only set unsigned if we've got a content.membership, which we _should_
|
// only set unsigned if we've got a content.membership, which we _should_
|
||||||
if memberEvent.Get("content.membership").Exists() {
|
if memberEvent.Get("content.membership").Exists() {
|
||||||
req.Unsigned = map[string]interface{}{
|
req.Unsigned = map[string]interface{}{
|
||||||
"prev_sender": memberEvent.Get("sender").Str,
|
"prev_sender": memberEvent.Get("sender").Str,
|
||||||
"prev_content": map[string]interface{}{
|
"prev_content": map[string]interface{}{
|
||||||
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
"is_direct": memberEvent.Get("content.is_direct").Bool(),
|
||||||
"membership": memberEvent.Get("content.membership").Str,
|
"membership": memberEvent.Get("content.membership").Str,
|
||||||
},
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -274,6 +271,7 @@ func (r *Joiner) performJoinRoomByID(
|
|||||||
// If we should do a forced federated join then do that.
|
// If we should do a forced federated join then do that.
|
||||||
var joinedVia spec.ServerName
|
var joinedVia spec.ServerName
|
||||||
if forceFederatedJoin {
|
if forceFederatedJoin {
|
||||||
|
// TODO : pseudoIDs - pass through userID here since we don't know what the senderID should be yet
|
||||||
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
|
joinedVia, err = r.performFederatedJoinRoomByID(ctx, req)
|
||||||
return req.RoomIDOrAlias, joinedVia, err
|
return req.RoomIDOrAlias, joinedVia, err
|
||||||
}
|
}
|
||||||
@ -289,19 +287,40 @@ func (r *Joiner) performJoinRoomByID(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("error joining local room: %q", err)
|
return "", "", fmt.Errorf("error joining local room: %q", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
senderIDString := string(senderID)
|
||||||
|
|
||||||
|
// Prepare the template for the join event.
|
||||||
|
proto := gomatrixserverlib.ProtoEvent{
|
||||||
|
Type: spec.MRoomMember,
|
||||||
|
SenderID: senderIDString,
|
||||||
|
StateKey: &senderIDString,
|
||||||
|
RoomID: req.RoomIDOrAlias,
|
||||||
|
Redacts: "",
|
||||||
|
}
|
||||||
|
if err = proto.SetUnsigned(struct{}{}); err != nil {
|
||||||
|
return "", "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It is possible for the request to include some "content" for the
|
||||||
|
// event. We'll always overwrite the "membership" key, but the rest,
|
||||||
|
// like "display_name" or "avatar_url", will be kept if supplied.
|
||||||
|
if req.Content == nil {
|
||||||
|
req.Content = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
req.Content["membership"] = spec.Join
|
||||||
|
if authorisedVia, aerr := r.populateAuthorisedViaUserForRestrictedJoin(ctx, req, senderID); aerr != nil {
|
||||||
|
return "", "", aerr
|
||||||
|
} else if authorisedVia != "" {
|
||||||
|
req.Content["join_authorised_via_users_server"] = authorisedVia
|
||||||
|
}
|
||||||
|
if err = proto.SetContent(req.Content); err != nil {
|
||||||
|
return "", "", fmt.Errorf("eb.SetContent: %w", err)
|
||||||
|
}
|
||||||
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes)
|
event, err := eventutil.QueryAndBuildEvent(ctx, &proto, identity, time.Now(), r.RSAPI, &buildRes)
|
||||||
|
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
// create user room key if needed
|
|
||||||
if buildRes.RoomVersion == gomatrixserverlib.RoomVersionPseudoIDs {
|
|
||||||
_, err = r.RSAPI.GetOrCreateUserRoomPrivateKey(ctx, *userID, *roomID)
|
|
||||||
if err != nil {
|
|
||||||
logrus.WithError(err).Error("GetOrCreateUserRoomPrivateKey failed")
|
|
||||||
return "", "", fmt.Errorf("failed to get user room private key: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The room join is local. Send the new join event into the
|
// The room join is local. Send the new join event into the
|
||||||
// roomserver. First of all check that the user isn't already
|
// roomserver. First of all check that the user isn't already
|
||||||
// a member of the room. This is best-effort (as in we won't
|
// a member of the room. This is best-effort (as in we won't
|
||||||
|
@ -78,7 +78,11 @@ func (r *Leaver) performLeaveRoomByID(
|
|||||||
req *api.PerformLeaveRequest,
|
req *api.PerformLeaveRequest,
|
||||||
res *api.PerformLeaveResponse, // nolint:unparam
|
res *api.PerformLeaveResponse, // nolint:unparam
|
||||||
) ([]api.OutputEvent, error) {
|
) ([]api.OutputEvent, error) {
|
||||||
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, req.RoomID, req.Leaver)
|
roomID, err := spec.NewRoomID(req.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
leaver, err := r.RSAPI.QuerySenderIDForUser(ctx, *roomID, req.Leaver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
return nil, fmt.Errorf("leaver %s has no matching senderID in this room", req.Leaver.String())
|
||||||
}
|
}
|
||||||
@ -87,7 +91,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||||||
// that.
|
// that.
|
||||||
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
|
isInvitePending, senderUser, eventID, _, err := helpers.IsInvitePending(ctx, r.DB, req.RoomID, leaver)
|
||||||
if err == nil && isInvitePending {
|
if err == nil && isInvitePending {
|
||||||
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, req.RoomID, senderUser)
|
sender, serr := r.RSAPI.QueryUserIDForSender(ctx, *roomID, senderUser)
|
||||||
if serr != nil || sender == nil {
|
if serr != nil || sender == nil {
|
||||||
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
|
return nil, fmt.Errorf("sender %q has no matching userID", senderUser)
|
||||||
}
|
}
|
||||||
@ -133,7 +137,7 @@ func (r *Leaver) performLeaveRoomByID(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
latestRes := api.QueryLatestEventsAndStateResponse{}
|
latestRes := api.QueryLatestEventsAndStateResponse{}
|
||||||
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, &latestReq, &latestRes); err != nil {
|
if err = helpers.QueryLatestEventsAndState(ctx, r.DB, r.RSAPI, &latestReq, &latestRes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !latestRes.RoomExists {
|
if !latestRes.RoomExists {
|
||||||
|
@ -54,7 +54,11 @@ func (r *Upgrader) performRoomUpgrade(
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, roomID, userID)
|
fullRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
senderID, err := r.URSAPI.QuerySenderIDForUser(ctx, *fullRoomID, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
util.GetLogger(ctx).WithError(err).Error("Failed getting senderID for user")
|
||||||
return "", err
|
return "", err
|
||||||
@ -488,7 +492,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, send
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
|
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
|
||||||
@ -569,7 +573,7 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, send
|
|||||||
stateEvents[i] = queryRes.StateEvents[i].PDU
|
stateEvents[i] = queryRes.StateEvents[i].PDU
|
||||||
}
|
}
|
||||||
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
|
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
|
||||||
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
|
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
|
||||||
|
@ -16,6 +16,7 @@ package query
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState(
|
|||||||
request *api.QueryLatestEventsAndStateRequest,
|
request *api.QueryLatestEventsAndStateRequest,
|
||||||
response *api.QueryLatestEventsAndStateResponse,
|
response *api.QueryLatestEventsAndStateResponse,
|
||||||
) error {
|
) error {
|
||||||
return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response)
|
return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
// QueryStateAfterEvents implements api.RoomserverInternalAPI
|
||||||
@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
roomState := state.NewStateResolution(r.DB, info)
|
roomState := state.NewStateResolution(r.DB, info, r)
|
||||||
response.RoomExists = true
|
response.RoomExists = true
|
||||||
response.RoomVersion = info.RoomVersion
|
response.RoomVersion = info.RoomVersion
|
||||||
|
|
||||||
@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||||||
}
|
}
|
||||||
|
|
||||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser(
|
|||||||
request *api.QueryMembershipForUserRequest,
|
request *api.QueryMembershipForUserRequest,
|
||||||
response *api.QueryMembershipForUserResponse,
|
response *api.QueryMembershipForUserResponse,
|
||||||
) error {
|
) error {
|
||||||
senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
roomID, err := spec.NewRoomID(request.RoomID)
|
roomID, err := spec.NewRoomID(request.RoomID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
|
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
response.Membership = make(map[string]*types.HeaderedEvent)
|
response.Membership = make(map[string]*types.HeaderedEvent)
|
||||||
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
|
stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get state before event: %w", err)
|
return fmt.Errorf("unable to get state before event: %w", err)
|
||||||
}
|
}
|
||||||
@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||||||
return fmt.Errorf("r.DB.Events: %w", err)
|
return fmt.Errorf("r.DB.Events: %w", err)
|
||||||
}
|
}
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, event)
|
}, event)
|
||||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||||
@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||||||
|
|
||||||
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
|
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
|
||||||
} else {
|
} else {
|
||||||
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
|
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||||
return err
|
return err
|
||||||
@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, event)
|
}, event)
|
||||||
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
response.JoinEvents = append(response.JoinEvents, clientEvent)
|
||||||
@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return helpers.CheckServerAllowedToSeeEvent(
|
return helpers.CheckServerAllowedToSeeEvent(
|
||||||
ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
|
ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents(
|
|||||||
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
|
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||||||
|
|
||||||
if request.ResolveState {
|
if request.ResolveState {
|
||||||
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
stateEvents, err = gomatrixserverlib.ResolveConflicts(
|
||||||
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
return r.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||||||
|
|
||||||
// first bool: is rejected, second bool: state missing
|
// first bool: is rejected, second bool: state missing
|
||||||
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
|
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
|
||||||
roomState := state.NewStateResolution(r.DB, roomInfo)
|
roomState := state.NewStateResolution(r.DB, roomInfo, r)
|
||||||
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err.(type) {
|
switch err.(type) {
|
||||||
@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
|
|||||||
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
|
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
return r.DB.GetSenderIDForUser(ctx, roomID, userID)
|
version, err := r.DB.GetRoomVersion(ctx, roomID.String())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case gomatrixserverlib.RoomVersionPseudoIDs:
|
||||||
|
key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
|
||||||
|
default:
|
||||||
|
return spec.SenderID(userID.String()), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return r.DB.GetUserIDForSender(ctx, roomID, senderID)
|
userID, err := spec.NewUserID(string(senderID), true)
|
||||||
|
if err == nil {
|
||||||
|
return userID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes := spec.Base64Bytes{}
|
||||||
|
err = bytes.Decode(string(senderID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}}
|
||||||
|
result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if userKeys, ok := result[roomID]; ok {
|
||||||
|
if userID, ok := userKeys[string(senderID)]; ok {
|
||||||
|
return spec.NewUserID(userID, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -516,6 +516,9 @@ func TestRedaction(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
natsInstance := &jetstream.NATSInstance{}
|
||||||
|
rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
authEvents := []types.EventNID{}
|
authEvents := []types.EventNID{}
|
||||||
@ -551,7 +554,7 @@ func TestRedaction(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the snapshotNID etc.
|
// Calculate the snapshotNID etc.
|
||||||
plResolver := state.NewStateResolution(db, roomInfo)
|
plResolver := state.NewStateResolution(db, roomInfo, rsAPI)
|
||||||
stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false)
|
stateAtEvent.BeforeStateSnapshotNID, err = plResolver.CalculateAndStoreStateBeforeEvent(ctx, ev.PDU, false)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
@ -29,6 +29,7 @@ import (
|
|||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,20 +45,21 @@ type StateResolutionStorage interface {
|
|||||||
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
|
||||||
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error)
|
||||||
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type StateResolution struct {
|
type StateResolution struct {
|
||||||
db StateResolutionStorage
|
db StateResolutionStorage
|
||||||
roomInfo *types.RoomInfo
|
roomInfo *types.RoomInfo
|
||||||
events map[types.EventNID]gomatrixserverlib.PDU
|
events map[types.EventNID]gomatrixserverlib.PDU
|
||||||
|
Querier api.QuerySenderIDAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo) StateResolution {
|
func NewStateResolution(db StateResolutionStorage, roomInfo *types.RoomInfo, querier api.QuerySenderIDAPI) StateResolution {
|
||||||
return StateResolution{
|
return StateResolution{
|
||||||
db: db,
|
db: db,
|
||||||
roomInfo: roomInfo,
|
roomInfo: roomInfo,
|
||||||
events: make(map[types.EventNID]gomatrixserverlib.PDU),
|
events: make(map[types.EventNID]gomatrixserverlib.PDU),
|
||||||
|
Querier: querier,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -947,8 +949,8 @@ func (v *StateResolution) resolveConflictsV1(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the conflicts.
|
// Resolve the conflicts.
|
||||||
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return v.db.GetUserIDForSender(ctx, roomID, senderID)
|
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Map from the full events back to numeric state entries.
|
// Map from the full events back to numeric state entries.
|
||||||
@ -1061,8 +1063,8 @@ func (v *StateResolution) resolveConflictsV2(
|
|||||||
conflictedEvents,
|
conflictedEvents,
|
||||||
nonConflictedEvents,
|
nonConflictedEvents,
|
||||||
authEvents,
|
authEvents,
|
||||||
func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return v.db.GetUserIDForSender(ctx, roomID, senderID)
|
return v.Querier.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
@ -169,10 +169,6 @@ type Database interface {
|
|||||||
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error)
|
||||||
// GetKnownUsers searches all users that userID knows about.
|
// GetKnownUsers searches all users that userID knows about.
|
||||||
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error)
|
||||||
// GetKnownUsers tries to obtain the current mxid for a given user.
|
|
||||||
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
|
||||||
// GetKnownUsers tries to obtain the current senderID for a given user.
|
|
||||||
GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error)
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
GetKnownRooms(ctx context.Context) ([]string, error)
|
GetKnownRooms(ctx context.Context) ([]string, error)
|
||||||
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
// ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room
|
||||||
@ -190,6 +186,7 @@ type Database interface {
|
|||||||
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
|
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
|
||||||
) (map[string]*types.HeaderedEvent, error)
|
) (map[string]*types.HeaderedEvent, error)
|
||||||
GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
|
GetOrCreateRoomInfo(ctx context.Context, event gomatrixserverlib.PDU) (*types.RoomInfo, error)
|
||||||
|
GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
|
||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
MaybeRedactEvent(
|
MaybeRedactEvent(
|
||||||
@ -205,8 +202,12 @@ type UserRoomKeys interface {
|
|||||||
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
|
InsertUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID, key ed25519.PublicKey) (result ed25519.PublicKey, err error)
|
||||||
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
||||||
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
|
SelectUserRoomPrivateKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PrivateKey, err error)
|
||||||
|
// SelectUserRoomPublicKey selects the public key for the given user and room combination
|
||||||
|
SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error)
|
||||||
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
|
// SelectUserIDsForPublicKeys selects all userIDs for the requested senderKeys. Returns a map from roomID -> map from publicKey to userID.
|
||||||
// If a senderKey can't be found, it is omitted in the result.
|
// If a senderKey can't be found, it is omitted in the result.
|
||||||
|
// TODO: Why is the result map indexed by string not public key?
|
||||||
|
// TODO: Shouldn't the input & result map be changed to be indexed by string instead of the RoomID struct?
|
||||||
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
|
SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (map[spec.RoomID]map[string]string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,7 +234,6 @@ type RoomDatabase interface {
|
|||||||
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
|
||||||
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
|
||||||
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
|
GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error)
|
||||||
GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type EventDatabase interface {
|
type EventDatabase interface {
|
||||||
|
@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
|
|||||||
|
|
||||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
|
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid = ANY($1) AND pseudo_id_pub_key = ANY($2)`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserNIDsStmt *sql.Stmt
|
selectUserNIDsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
|||||||
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
|
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomPrivateKeySQL},
|
||||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
{&s.selectUserNIDsStmt, selectUserNIDsSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
userNID types.EventStateKeyNID,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
) (ed25519.PublicKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
|
||||||
|
var result ed25519.PublicKey
|
||||||
|
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
||||||
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserNIDsStmt)
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
)
|
)
|
||||||
@ -251,7 +250,3 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
|
|||||||
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
|
||||||
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
|
||||||
return u.d.GetUserIDForSender(ctx, roomID, senderID)
|
|
||||||
}
|
|
||||||
|
@ -721,6 +721,22 @@ func (d *Database) GetOrCreateRoomInfo(ctx context.Context, event gomatrixserver
|
|||||||
}, err
|
}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Database) GetRoomVersion(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error) {
|
||||||
|
cachedRoomVersion, versionOK := d.Cache.GetRoomVersion(roomID)
|
||||||
|
if versionOK {
|
||||||
|
return cachedRoomVersion, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
roomInfo, err := d.RoomInfo(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return roomInfo.RoomVersion, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
|
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
|
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
|
||||||
@ -1550,16 +1566,6 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
|
|||||||
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
|
||||||
// TODO: Use real logic once DB for pseudoIDs is in place
|
|
||||||
return spec.NewUserID(string(senderID), true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
|
||||||
// TODO: Use real logic once DB for pseudoIDs is in place
|
|
||||||
return spec.SenderID(userID.String()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKnownRooms returns a list of all rooms we know about.
|
// GetKnownRooms returns a list of all rooms we know about.
|
||||||
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
|
||||||
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil)
|
||||||
@ -1718,6 +1724,35 @@ func (d *Database) SelectUserRoomPrivateKey(ctx context.Context, userID spec.Use
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SelectUserRoomPublicKey queries the users room public key.
|
||||||
|
// If no key exists, returns no key and no error. Otherwise returns
|
||||||
|
// the key and a database error, if any.
|
||||||
|
func (d *Database) SelectUserRoomPublicKey(ctx context.Context, userID spec.UserID, roomID spec.RoomID) (key ed25519.PublicKey, err error) {
|
||||||
|
uID := userID.String()
|
||||||
|
stateKeyNIDMap, sErr := d.eventStateKeyNIDs(ctx, nil, []string{uID})
|
||||||
|
if sErr != nil {
|
||||||
|
return nil, sErr
|
||||||
|
}
|
||||||
|
stateKeyNID := stateKeyNIDMap[uID]
|
||||||
|
|
||||||
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
roomInfo, rErr := d.roomInfo(ctx, txn, roomID.String())
|
||||||
|
if rErr != nil {
|
||||||
|
return rErr
|
||||||
|
}
|
||||||
|
if roomInfo == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
key, sErr = d.UserRoomKeyTable.SelectUserRoomPublicKey(ctx, txn, stateKeyNID, roomInfo.RoomNID)
|
||||||
|
if !errors.Is(sErr, sql.ErrNoRows) {
|
||||||
|
return sErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
|
// SelectUserIDsForPublicKeys returns a map from roomID -> map from senderKey -> userID
|
||||||
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
|
func (d *Database) SelectUserIDsForPublicKeys(ctx context.Context, publicKeys map[spec.RoomID][]ed25519.PublicKey) (result map[spec.RoomID]map[string]string, err error) {
|
||||||
result = make(map[spec.RoomID]map[string]string, len(publicKeys))
|
result = make(map[spec.RoomID]map[string]string, len(publicKeys))
|
||||||
|
@ -163,12 +163,17 @@ func TestUserRoomKeys(t *testing.T) {
|
|||||||
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
|
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *roomID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key, gotKey)
|
assert.Equal(t, key, gotKey)
|
||||||
|
pubKey, err := db.SelectUserRoomPublicKey(context.Background(), *userID, *roomID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, key.Public(), pubKey)
|
||||||
|
|
||||||
// Key doesn't exist, we shouldn't get anything back
|
// Key doesn't exist, we shouldn't get anything back
|
||||||
assert.NoError(t, err)
|
|
||||||
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
|
gotKey, err = db.SelectUserRoomPrivateKey(context.Background(), *userID, *doesNotExist)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, gotKey)
|
assert.Nil(t, gotKey)
|
||||||
|
pubKey, err = db.SelectUserRoomPublicKey(context.Background(), *userID, *doesNotExist)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, pubKey)
|
||||||
|
|
||||||
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
|
queryUserIDs := map[spec.RoomID][]ed25519.PublicKey{
|
||||||
*roomID: {key.Public().(ed25519.PublicKey)},
|
*roomID: {key.Public().(ed25519.PublicKey)},
|
||||||
|
@ -51,12 +51,15 @@ const insertUserRoomPublicKeySQL = `
|
|||||||
|
|
||||||
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
const selectUserRoomKeySQL = `SELECT pseudo_id_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
|
const selectUserRoomPublicKeySQL = `SELECT pseudo_id_pub_key FROM roomserver_user_room_keys WHERE user_nid = $1 AND room_nid = $2`
|
||||||
|
|
||||||
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
const selectUserNIDsSQL = `SELECT user_nid, room_nid, pseudo_id_pub_key FROM roomserver_user_room_keys WHERE room_nid IN ($1) AND pseudo_id_pub_key IN ($2)`
|
||||||
|
|
||||||
type userRoomKeysStatements struct {
|
type userRoomKeysStatements struct {
|
||||||
insertUserRoomPrivateKeyStmt *sql.Stmt
|
insertUserRoomPrivateKeyStmt *sql.Stmt
|
||||||
insertUserRoomPublicKeyStmt *sql.Stmt
|
insertUserRoomPublicKeyStmt *sql.Stmt
|
||||||
selectUserRoomKeyStmt *sql.Stmt
|
selectUserRoomKeyStmt *sql.Stmt
|
||||||
|
selectUserRoomPublicKeyStmt *sql.Stmt
|
||||||
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
//selectUserNIDsStmt *sql.Stmt //prepared at runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,6 +74,7 @@ func PrepareUserRoomKeysTable(db *sql.DB) (tables.UserRoomKeys, error) {
|
|||||||
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
|
{&s.insertUserRoomPrivateKeyStmt, insertUserRoomKeySQL},
|
||||||
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
{&s.insertUserRoomPublicKeyStmt, insertUserRoomPublicKeySQL},
|
||||||
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
{&s.selectUserRoomKeyStmt, selectUserRoomKeySQL},
|
||||||
|
{&s.selectUserRoomPublicKeyStmt, selectUserRoomPublicKeySQL},
|
||||||
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
//{&s.selectUserNIDsStmt, selectUserNIDsSQL}, //prepared at runtime
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
@ -102,6 +106,21 @@ func (s *userRoomKeysStatements) SelectUserRoomPrivateKey(
|
|||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *userRoomKeysStatements) SelectUserRoomPublicKey(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
userNID types.EventStateKeyNID,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
) (ed25519.PublicKey, error) {
|
||||||
|
stmt := sqlutil.TxStmtContext(ctx, txn, s.selectUserRoomPublicKeyStmt)
|
||||||
|
var result ed25519.PublicKey
|
||||||
|
err := stmt.QueryRowContext(ctx, userNID, roomNID).Scan(&result)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
func (s *userRoomKeysStatements) BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error) {
|
||||||
|
|
||||||
roomNIDs := make([]any, 0, len(senderKeys))
|
roomNIDs := make([]any, 0, len(senderKeys))
|
||||||
|
@ -193,6 +193,8 @@ type UserRoomKeys interface {
|
|||||||
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
|
InsertUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID, key ed25519.PublicKey) (ed25519.PublicKey, error)
|
||||||
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
// SelectUserRoomPrivateKey selects the private key for the given user and room combination
|
||||||
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
|
SelectUserRoomPrivateKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PrivateKey, error)
|
||||||
|
// SelectUserRoomPublicKey selects the public key for the given user and room combination
|
||||||
|
SelectUserRoomPublicKey(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomNID types.RoomNID) (ed25519.PublicKey, error)
|
||||||
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
// BulkSelectUserNIDs selects all userIDs for the requested senderKeys. Returns a map from publicKey -> types.UserRoomKeyPair.
|
||||||
// If a senderKey can't be found, it is omitted in the result.
|
// If a senderKey can't be found, it is omitted in the result.
|
||||||
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
BulkSelectUserNIDs(ctx context.Context, txn *sql.Tx, senderKeys map[types.RoomNID][]ed25519.PublicKey) (map[string]types.UserRoomKeyPair, error)
|
||||||
|
@ -50,6 +50,7 @@ func TestUserRoomKeysTable(t *testing.T) {
|
|||||||
|
|
||||||
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
|
||||||
var gotKey, key2, key3 ed25519.PrivateKey
|
var gotKey, key2, key3 ed25519.PrivateKey
|
||||||
|
var pubKey ed25519.PublicKey
|
||||||
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
|
gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, gotKey, key)
|
assert.Equal(t, gotKey, key)
|
||||||
@ -71,6 +72,9 @@ func TestUserRoomKeysTable(t *testing.T) {
|
|||||||
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
|
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, key, gotKey)
|
assert.Equal(t, key, gotKey)
|
||||||
|
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, key.Public(), pubKey)
|
||||||
|
|
||||||
// try to update an existing key, this should only be done for users NOT on this homeserver
|
// try to update an existing key, this should only be done for users NOT on this homeserver
|
||||||
var gotPubKey ed25519.PublicKey
|
var gotPubKey ed25519.PublicKey
|
||||||
@ -82,6 +86,9 @@ func TestUserRoomKeysTable(t *testing.T) {
|
|||||||
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
|
gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Nil(t, gotKey)
|
assert.Nil(t, gotKey)
|
||||||
|
pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, pubKey)
|
||||||
|
|
||||||
// query user NIDs for senderKeys
|
// query user NIDs for senderKeys
|
||||||
var gotKeys map[string]types.UserRoomKeyPair
|
var gotKeys map[string]types.UserRoomKeyPair
|
||||||
|
@ -94,7 +94,7 @@ type MSC2836EventRelationshipsResponse struct {
|
|||||||
|
|
||||||
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
|
func toClientResponse(ctx context.Context, res *MSC2836EventRelationshipsResponse, rsAPI roomserver.RoomserverInternalAPI) *EventRelationshipResponse {
|
||||||
out := &EventRelationshipResponse{
|
out := &EventRelationshipResponse{
|
||||||
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
Events: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(res.ParsedEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}),
|
}),
|
||||||
Limited: res.Limited,
|
Limited: res.Limited,
|
||||||
|
@ -525,11 +525,11 @@ type testRoomserverAPI struct {
|
|||||||
events map[string]*types.HeaderedEvent
|
events map[string]*types.HeaderedEvent
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (r *testRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
|
func (r *testRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
|
||||||
return spec.SenderID(userID.String()), nil
|
return spec.SenderID(userID.String()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -377,7 +377,11 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
|
|||||||
return sp, fmt.Errorf("unexpected nil state_key")
|
return sp, fmt.Errorf("unexpected nil state_key")
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
validRoomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return sp, err
|
||||||
|
}
|
||||||
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*ev.StateKey()))
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
|
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
|
||||||
}
|
}
|
||||||
@ -404,7 +408,11 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
|
validRoomID, err := spec.NewRoomID(msg.Event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*msg.Event.StateKey()))
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -454,7 +462,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
|||||||
|
|
||||||
// Notify any active sync requests that the invite has been retired.
|
// Notify any active sync requests that the invite has been retired.
|
||||||
s.inviteStream.Advance(pduPos)
|
s.inviteStream.Advance(pduPos)
|
||||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
|
validRoomID, err := spec.NewRoomID(msg.RoomID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{
|
||||||
|
"event_id": msg.EventID,
|
||||||
|
"room_id": msg.RoomID,
|
||||||
|
log.ErrorKey: err,
|
||||||
|
}).Errorf("roomID is invalid")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, msg.TargetSenderID)
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
"event_id": msg.EventID,
|
"event_id": msg.EventID,
|
||||||
|
@ -139,7 +139,11 @@ func ApplyHistoryVisibilityFilter(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
|
roomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
|
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
|
||||||
eventsFiltered = append(eventsFiltered, ev)
|
eventsFiltered = append(eventsFiltered, ev)
|
||||||
|
@ -170,11 +170,15 @@ func TrackChangedUsers(
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
for roomID, state := range stateRes.Rooms {
|
for roomID, state := range stateRes.Rooms {
|
||||||
|
validRoomID, roomErr := spec.NewRoomID(roomID)
|
||||||
|
if roomErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for tuple, membership := range state {
|
for tuple, membership := range state {
|
||||||
if membership != spec.Join {
|
if membership != spec.Join {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
user, queryErr := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
|
||||||
if queryErr != nil || user == nil {
|
if queryErr != nil || user == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -216,13 +220,17 @@ func TrackChangedUsers(
|
|||||||
return nil, left, err
|
return nil, left, err
|
||||||
}
|
}
|
||||||
for roomID, state := range stateRes.Rooms {
|
for roomID, state := range stateRes.Rooms {
|
||||||
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for tuple, membership := range state {
|
for tuple, membership := range state {
|
||||||
if membership != spec.Join {
|
if membership != spec.Join {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// new user who we weren't previously sharing rooms with
|
// new user who we weren't previously sharing rooms with
|
||||||
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
||||||
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
user, err := rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(tuple.StateKey))
|
||||||
if err != nil || user == nil {
|
if err != nil || user == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -64,7 +64,7 @@ type mockRoomserverAPI struct {
|
|||||||
roomIDToJoinedMembers map[string][]string
|
roomIDToJoinedMembers map[string][]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,13 +101,20 @@ func (n *Notifier) OnNewEvent(
|
|||||||
n._removeEmptyUserStreams()
|
n._removeEmptyUserStreams()
|
||||||
|
|
||||||
if ev != nil {
|
if ev != nil {
|
||||||
|
validRoomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||||
|
"Notifier.OnNewEvent: RoomID is invalid",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
// Map this event's room_id to a list of joined users, and wake them up.
|
// Map this event's room_id to a list of joined users, and wake them up.
|
||||||
usersToNotify := n._joinedUsers(ev.RoomID())
|
usersToNotify := n._joinedUsers(ev.RoomID())
|
||||||
// Map this event's room_id to a list of peeking devices, and wake them up.
|
// Map this event's room_id to a list of peeking devices, and wake them up.
|
||||||
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
|
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
|
||||||
// If this is an invite, also add in the invitee to this list.
|
// If this is an invite, also add in the invitee to this list.
|
||||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||||
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), *validRoomID, spec.SenderID(*ev.StateKey()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||||
"Notifier.OnNewEvent: Failed to find the userID for this event",
|
"Notifier.OnNewEvent: Failed to find the userID for this event",
|
||||||
|
@ -109,7 +109,7 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
|
|||||||
|
|
||||||
type TestRoomServer struct{ api.SyncRoomserverAPI }
|
type TestRoomServer struct{ api.SyncRoomserverAPI }
|
||||||
|
|
||||||
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,10 +200,10 @@ func Context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ func Context(
|
|||||||
if filter.LazyLoadMembers {
|
if filter.LazyLoadMembers {
|
||||||
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
|
allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...)
|
||||||
allEvents = append(allEvents, &requestedEvent)
|
allEvents = append(allEvents, &requestedEvent)
|
||||||
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
|
newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache)
|
||||||
@ -224,14 +224,14 @@ func Context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
ev := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}, requestedEvent)
|
}, requestedEvent)
|
||||||
response := ContextRespsonse{
|
response := ContextRespsonse{
|
||||||
Event: &ev,
|
Event: &ev,
|
||||||
EventsAfter: eventsAfterClient,
|
EventsAfter: eventsAfterClient,
|
||||||
EventsBefore: eventsBeforeClient,
|
EventsBefore: eventsBeforeClient,
|
||||||
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
|
@ -102,14 +102,28 @@ func GetEvent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), roomID, events[0].SenderID())
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("roomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
|
||||||
if err == nil && senderUserID != nil {
|
if err == nil && senderUserID != nil {
|
||||||
sender = *senderUserID
|
sender = *senderUserID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := events[0].StateKey()
|
sk := events[0].StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
|
evRoomID, err := spec.NewRoomID(events[0].RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
JSON: spec.BadJSON("roomID is invalid"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
skUserID, err := rsAPI.QueryUserIDForSender(ctx, *evRoomID, spec.SenderID(*events[0].StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
|
@ -152,7 +152,15 @@ func GetMemberships(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
|
validRoomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
util.GetLogger(req.Context()).WithError(err).Error("roomID is invalid")
|
||||||
|
return util.JSONResponse{
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
JSON: spec.InternalServerError{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
|
||||||
if err != nil || userID == nil {
|
if err != nil || userID == nil {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
|
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed")
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
@ -175,7 +183,7 @@ func GetMemberships(
|
|||||||
}
|
}
|
||||||
return util.JSONResponse{
|
return util.JSONResponse{
|
||||||
Code: http.StatusOK,
|
Code: http.StatusOK,
|
||||||
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
})},
|
})},
|
||||||
}
|
}
|
||||||
|
@ -273,7 +273,7 @@ func OnIncomingMessagesRequest(
|
|||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
})...)
|
})...)
|
||||||
}
|
}
|
||||||
@ -389,7 +389,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
|
|||||||
"events_before": len(events),
|
"events_before": len(events),
|
||||||
"events_after": len(filteredEvents),
|
"events_after": len(filteredEvents),
|
||||||
}).Debug("applied history visibility (messages)")
|
}).Debug("applied history visibility (messages)")
|
||||||
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
}), start, end, err
|
}), start, end, err
|
||||||
}
|
}
|
||||||
|
@ -110,19 +110,24 @@ func Relations(
|
|||||||
return util.ErrorResponse(err)
|
return util.ErrorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return util.ErrorResponse(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Convert the events into client events, and optionally filter based on the event
|
// Convert the events into client events, and optionally filter based on the event
|
||||||
// type if it was specified.
|
// type if it was specified.
|
||||||
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
|
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
|
||||||
for _, event := range filteredEvents {
|
for _, event := range filteredEvents {
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
|
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := event.StateKey()
|
sk := event.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
|
@ -205,9 +205,14 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||||||
|
|
||||||
profileInfos := make(map[string]ProfileInfoResponse)
|
profileInfos := make(map[string]ProfileInfoResponse)
|
||||||
for _, ev := range append(eventsBefore, eventsAfter...) {
|
for _, ev := range append(eventsBefore, eventsAfter...) {
|
||||||
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID())
|
validRoomID, roomErr := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(roomErr).WithField("room_id", ev.RoomID()).Warn("failed to query userprofile")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, queryErr := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, ev.SenderID())
|
||||||
if queryErr != nil {
|
if queryErr != nil {
|
||||||
logrus.WithError(queryErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile")
|
logrus.WithError(queryErr).WithField("sender_id", ev.SenderID()).Warn("failed to query userprofile")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,14 +236,19 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||||||
}
|
}
|
||||||
|
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), event.SenderID())
|
validRoomID, roomErr := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithError(roomErr).WithField("room_id", event.RoomID()).Warn("failed to query userprofile")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := event.StateKey()
|
sk := event.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
@ -248,10 +258,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||||||
Context: SearchContextResponse{
|
Context: SearchContextResponse{
|
||||||
Start: startToken.String(),
|
Start: startToken.String(),
|
||||||
End: endToken.String(),
|
End: endToken.String(),
|
||||||
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
}),
|
}),
|
||||||
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
}),
|
}),
|
||||||
ProfileInfo: profileInfos,
|
ProfileInfo: profileInfos,
|
||||||
@ -272,7 +282,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||||||
JSON: spec.InternalServerError{},
|
JSON: spec.InternalServerError{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ import (
|
|||||||
|
|
||||||
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
|
type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI }
|
||||||
|
|
||||||
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +114,14 @@ func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Dev
|
|||||||
}).WithError(err).Warnf("Failed to add transaction ID to event")
|
}).WithError(err).Warnf("Failed to add transaction ID to event")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID)
|
roomID, err := spec.NewRoomID(in[i].RoomID())
|
||||||
|
if err != nil {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"event_id": out[i].EventID(),
|
||||||
|
}).WithError(err).Warnf("Room ID is invalid")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithFields(logrus.Fields{
|
logrus.WithFields(logrus.Fields{
|
||||||
"event_id": out[i].EventID(),
|
"event_id": out[i].EventID(),
|
||||||
@ -515,7 +522,11 @@ func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userI
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
|
roomID, err := spec.NewRoomID(ev.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *fullUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
@ -65,14 +65,18 @@ func (p *InviteStreamProvider) IncrementalSync(
|
|||||||
|
|
||||||
for roomID, inviteEvent := range invites {
|
for roomID, inviteEvent := range invites {
|
||||||
user := spec.UserID{}
|
user := spec.UserID{}
|
||||||
sender, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), inviteEvent.SenderID())
|
validRoomID, err := spec.NewRoomID(inviteEvent.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sender, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, inviteEvent.SenderID())
|
||||||
if err == nil && sender != nil {
|
if err == nil && sender != nil {
|
||||||
user = *sender
|
user = *sender
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := inviteEvent.StateKey()
|
sk := inviteEvent.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
|
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*inviteEvent.StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
|
@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
|
jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
req.Response.Rooms.Join[delta.RoomID] = jr
|
req.Response.Rooms.Join[delta.RoomID] = jr
|
||||||
@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||||||
jr := types.NewJoinResponse()
|
jr := types.NewJoinResponse()
|
||||||
jr.Timeline.PrevBatch = &prevBatch
|
jr.Timeline.PrevBatch = &prevBatch
|
||||||
// TODO: Apply history visibility on peeked rooms
|
// TODO: Apply history visibility on peeked rooms
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
jr.Timeline.Limited = limited
|
jr.Timeline.Limited = limited
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
req.Response.Rooms.Peek[delta.RoomID] = jr
|
req.Response.Rooms.Peek[delta.RoomID] = jr
|
||||||
@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
|||||||
case spec.Ban:
|
case spec.Ban:
|
||||||
lr := types.NewLeaveResponse()
|
lr := types.NewLeaveResponse()
|
||||||
lr.Timeline.PrevBatch = &prevBatch
|
lr.Timeline.PrevBatch = &prevBatch
|
||||||
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
lr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
||||||
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
req.Response.Rooms.Leave[delta.RoomID] = lr
|
req.Response.Rooms.Leave[delta.RoomID] = lr
|
||||||
@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
jr.Timeline.PrevBatch = prevBatch
|
jr.Timeline.PrevBatch = prevBatch
|
||||||
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
// If we are limited by the filter AND the history visibility filter
|
// If we are limited by the filter AND the history visibility filter
|
||||||
// didn't "remove" events, return that the response is limited.
|
// didn't "remove" events, return that the response is limited.
|
||||||
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
jr.Timeline.Limited = limited && len(events) == len(recentEvents)
|
||||||
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
return jr, nil
|
return jr, nil
|
||||||
@ -585,6 +585,10 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
|||||||
if len(timelineEvents) == 0 {
|
if len(timelineEvents) == 0 {
|
||||||
return stateEvents, nil
|
return stateEvents, nil
|
||||||
}
|
}
|
||||||
|
validRoomID, err := spec.NewRoomID(roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
// Work out which memberships to include
|
// Work out which memberships to include
|
||||||
timelineUsers := make(map[string]struct{})
|
timelineUsers := make(map[string]struct{})
|
||||||
if !incremental {
|
if !incremental {
|
||||||
@ -606,8 +610,8 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
|||||||
isGappedIncremental := limited && incremental
|
isGappedIncremental := limited && incremental
|
||||||
// We want this users membership event, keep it in the list
|
// We want this users membership event, keep it in the list
|
||||||
userID := ""
|
userID := ""
|
||||||
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
|
stateKeyUserID, queryErr := p.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if err == nil && stateKeyUserID != nil {
|
if queryErr == nil && stateKeyUserID != nil {
|
||||||
userID = stateKeyUserID.String()
|
userID = stateKeyUserID.String()
|
||||||
}
|
}
|
||||||
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
|
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
|
||||||
|
@ -40,7 +40,7 @@ type syncRoomserverAPI struct {
|
|||||||
rooms []*test.Room
|
rooms []*test.Room
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,14 +52,18 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
|
|||||||
continue // TODO: shouldn't happen?
|
continue // TODO: shouldn't happen?
|
||||||
}
|
}
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := userIDForSender(se.RoomID(), se.SenderID())
|
validRoomID, err := spec.NewRoomID(se.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID, err := userIDForSender(*validRoomID, se.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := se.StateKey()
|
sk := se.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
|
skUserID, err := userIDForSender(*validRoomID, spec.SenderID(*sk))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
@ -95,14 +99,18 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
|
|||||||
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
|
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
|
||||||
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
|
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := userIDQuery(event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return ClientEvent{}
|
||||||
|
}
|
||||||
|
userID, err := userIDQuery(*validRoomID, event.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := event.StateKey()
|
sk := event.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
|
skUserID, err := userIDQuery(*validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if err == nil && skUserID != nil {
|
if err == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
|
@ -39,7 +39,7 @@ var (
|
|||||||
roomIDCounter = int64(0)
|
roomIDCounter = int64(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
func UserIDForSender(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func UserIDForSender(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,14 +302,18 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *rst
|
|||||||
switch {
|
switch {
|
||||||
case event.Type() == spec.MRoomMember:
|
case event.Type() == spec.MRoomMember:
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, roomErr := spec.NewRoomID(event.RoomID())
|
||||||
|
if roomErr != nil {
|
||||||
|
return roomErr
|
||||||
|
}
|
||||||
|
userID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if queryErr == nil && userID != nil {
|
if queryErr == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := event.StateKey()
|
sk := event.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if queryErr == nil && skUserID != nil {
|
if queryErr == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
@ -544,14 +548,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
|||||||
}
|
}
|
||||||
|
|
||||||
sender := spec.UserID{}
|
sender := spec.UserID{}
|
||||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
userID, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err == nil && userID != nil {
|
if err == nil && userID != nil {
|
||||||
sender = *userID
|
sender = *userID
|
||||||
}
|
}
|
||||||
|
|
||||||
sk := event.StateKey()
|
sk := event.StateKey()
|
||||||
if sk != nil && *sk != "" {
|
if sk != nil && *sk != "" {
|
||||||
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), spec.SenderID(*event.StateKey()))
|
skUserID, queryErr := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, spec.SenderID(*event.StateKey()))
|
||||||
if queryErr == nil && skUserID != nil {
|
if queryErr == nil && skUserID != nil {
|
||||||
skString := skUserID.String()
|
skString := skUserID.String()
|
||||||
sk = &skString
|
sk = &skString
|
||||||
@ -644,7 +652,11 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *rstype
|
|||||||
// user. Returns actions (including dont_notify).
|
// user. Returns actions (including dont_notify).
|
||||||
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *rstypes.HeaderedEvent, mem *localMembership, roomSize int) ([]*pushrules.Action, error) {
|
||||||
user := ""
|
user := ""
|
||||||
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
user = sender.String()
|
user = sender.String()
|
||||||
}
|
}
|
||||||
@ -682,7 +694,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event *
|
|||||||
roomSize: roomSize,
|
roomSize: roomSize,
|
||||||
}
|
}
|
||||||
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
|
eval := pushrules.NewRuleSetEvaluator(ec, &ruleSets.Global)
|
||||||
rule, err := eval.MatchEvent(event.PDU, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
rule, err := eval.MatchEvent(event.PDU, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
return s.rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -790,7 +802,11 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
|
|||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
sender, err := s.rsAPI.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
|
validRoomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sender, err := s.rsAPI.QueryUserIDForSender(ctx, *validRoomID, event.SenderID())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
|
logger.WithError(err).Errorf("Failed to get userID for sender %s", event.SenderID())
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -818,7 +834,13 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *rstypes
|
|||||||
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
|
logger.WithError(err).Errorf("Failed to convert local user to userID %s", localpart)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, event.RoomID(), *userID)
|
roomID, err := spec.NewRoomID(event.RoomID())
|
||||||
|
if err != nil {
|
||||||
|
logger.WithError(err).Errorf("event roomID is invalid %s", event.RoomID())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
localSender, err := s.rsAPI.QuerySenderIDForUser(ctx, *roomID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
|
logger.WithError(err).Errorf("Failed to get local user senderID for room %s: %s", userID.String(), event.RoomID())
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -47,7 +47,7 @@ func mustCreateEvent(t *testing.T, content string) *types.HeaderedEvent {
|
|||||||
|
|
||||||
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
|
type FakeUserRoomserverAPI struct{ rsapi.UserRoomserverAPI }
|
||||||
|
|
||||||
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
func (f *FakeUserRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
|
||||||
return spec.NewUserID(string(senderID), true)
|
return spec.NewUserID(string(senderID), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,13 +68,13 @@ func Test_evaluatePushRules(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "m.receipt doesn't notify",
|
name: "m.receipt doesn't notify",
|
||||||
eventContent: `{"type":"m.receipt"}`,
|
eventContent: `{"type":"m.receipt","room_id":"!room:example.com"}`,
|
||||||
wantAction: pushrules.UnknownAction,
|
wantAction: pushrules.UnknownAction,
|
||||||
wantActions: nil,
|
wantActions: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "m.reaction doesn't notify",
|
name: "m.reaction doesn't notify",
|
||||||
eventContent: `{"type":"m.reaction"}`,
|
eventContent: `{"type":"m.reaction","room_id":"!room:example.com"}`,
|
||||||
wantAction: pushrules.DontNotifyAction,
|
wantAction: pushrules.DontNotifyAction,
|
||||||
wantActions: []*pushrules.Action{
|
wantActions: []*pushrules.Action{
|
||||||
{
|
{
|
||||||
@ -84,7 +84,7 @@ func Test_evaluatePushRules(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "m.room.message notifies",
|
name: "m.room.message notifies",
|
||||||
eventContent: `{"type":"m.room.message"}`,
|
eventContent: `{"type":"m.room.message","room_id":"!room:example.com"}`,
|
||||||
wantNotify: true,
|
wantNotify: true,
|
||||||
wantAction: pushrules.NotifyAction,
|
wantAction: pushrules.NotifyAction,
|
||||||
wantActions: []*pushrules.Action{
|
wantActions: []*pushrules.Action{
|
||||||
@ -93,7 +93,7 @@ func Test_evaluatePushRules(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "m.room.message highlights",
|
name: "m.room.message highlights",
|
||||||
eventContent: `{"type":"m.room.message", "content": {"body": "test"}}`,
|
eventContent: `{"type":"m.room.message", "content": {"body": "test"},"room_id":"!room:example.com"}`,
|
||||||
wantNotify: true,
|
wantNotify: true,
|
||||||
wantAction: pushrules.NotifyAction,
|
wantAction: pushrules.NotifyAction,
|
||||||
wantActions: []*pushrules.Action{
|
wantActions: []*pushrules.Action{
|
||||||
|
Loading…
Reference in New Issue
Block a user