From eab87ef07dd9461e38b716b7c4b4c95cedee0ca8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 2 Aug 2022 10:22:17 +0100 Subject: [PATCH] Stronger checks for `/send_join` (#2604) --- federationapi/routing/join.go | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 41004cf5..30406a15 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -202,6 +202,14 @@ func SendJoin( } } + // Check that the event is from the server sending the request. + if event.Origin() != request.Origin() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"), + } + } + // Check that a state key is provided. if event.StateKey() == nil || event.StateKeyEquals("") { return util.JSONResponse{ @@ -216,6 +224,22 @@ func SendJoin( } } + // Check that the sender belongs to the server that is sending us + // the request. By this point we've already asserted that the sender + // and the state key are equal so we don't need to check both. + var domain gomatrixserverlib.ServerName + if _, domain, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender of the join is invalid"), + } + } else if domain != request.Origin() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The sender of the join must belong to the origin server"), + } + } + // Check that the room ID is correct. if event.RoomID() != roomID { return util.JSONResponse{ @@ -242,14 +266,6 @@ func SendJoin( } } - // Check that the event is from the server sending the request. - if event.Origin() != request.Origin() { - return util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The join must be sent by the server it originated on"), - } - } - // Check that this is in fact a join event membership, err := event.Membership() if err != nil {