Move SendJoin logic to GMSL (#3084)

Moves the core matrix logic for handling the send_join endpoint over to
gmsl.
This commit is contained in:
devonh 2023-05-19 16:27:01 +00:00 committed by GitHub
parent 027a9b8ce0
commit 2eae8dc489
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 233 deletions

View File

@ -16,7 +16,6 @@ package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sort"
@ -160,7 +159,6 @@ func MakeJoin(
BuildEventTemplate: createJoinTemplate,
}
response, internalErr := gomatrixserverlib.HandleMakeJoin(input)
if internalErr != nil {
switch e := internalErr.(type) {
case nil:
case spec.InternalServerError:
@ -178,7 +176,7 @@ func MakeJoin(
case spec.ErrorNotFound:
code = http.StatusNotFound
case spec.ErrorUnableToAuthoriseJoin:
code = http.StatusBadRequest
fallthrough // http.StatusBadRequest
case spec.ErrorBadJSON:
code = http.StatusBadRequest
}
@ -200,7 +198,6 @@ func MakeJoin(
JSON: spec.Unknown("unknown error"),
}
}
}
if response == nil {
util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response")
@ -219,6 +216,25 @@ func MakeJoin(
}
}
type MembershipQuerier struct {
roomserver api.FederationRoomserverAPI
}
func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) {
req := api.QueryMembershipForUserRequest{
RoomID: roomID.String(),
UserID: userID.String(),
}
res := api.QueryMembershipForUserResponse{}
err := mq.roomserver.QueryMembershipForUser(ctx, &req, &res)
membership := ""
if err == nil {
membership = res.Membership
}
return membership, err
}
// SendJoin implements the /send_join API
// The make-join send-join dance makes much more sense as a single
// flow so the cyclomatic complexity is high:
@ -229,9 +245,10 @@ func SendJoin(
cfg *config.FederationAPI,
rsAPI api.FederationRoomserverAPI,
keys gomatrixserverlib.JSONVerifier,
roomID, eventID string,
roomID spec.RoomID,
eventID string,
) util.JSONResponse {
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID)
roomVersion, err := rsAPI.QueryRoomVersionForRoom(httpReq.Context(), roomID.String())
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("rsAPI.QueryRoomVersionForRoom failed")
return util.JSONResponse{
@ -239,132 +256,71 @@ func SendJoin(
JSON: spec.InternalServerError{},
}
}
verImpl, err := gomatrixserverlib.GetRoomVersion(roomVersion)
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.UnsupportedRoomVersion(
fmt.Sprintf("QueryRoomVersionForRoom returned unknown room version: %s", roomVersion),
),
}
}
event, err := verImpl.NewEventFromUntrustedJSON(request.Content())
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("The request body could not be decoded into valid JSON: " + err.Error()),
input := gomatrixserverlib.HandleSendJoinInput{
Context: httpReq.Context(),
RoomID: roomID,
EventID: eventID,
JoinEvent: request.Content(),
RoomVersion: roomVersion,
RequestOrigin: request.Origin(),
LocalServerName: cfg.Matrix.ServerName,
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
Verifier: keys,
MembershipQuerier: &MembershipQuerier{roomserver: rsAPI},
}
}
// Check that a state key is provided.
if event.StateKey() == nil || event.StateKeyEquals("") {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("No state key was provided in the join event."),
}
}
if !event.StateKeyEquals(event.Sender()) {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Event state key must match the event sender."),
}
}
// Check that the sender belongs to the server that is sending us
// the request. By this point we've already asserted that the sender
// and the state key are equal so we don't need to check both.
var serverName spec.ServerName
if _, serverName, err = gomatrixserverlib.SplitID('@', event.Sender()); err != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender of the join is invalid"),
}
} else if serverName != request.Origin() {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("The sender does not match the server that originated the request"),
}
}
// Check that the room ID is correct.
if event.RoomID() != roomID {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(
fmt.Sprintf(
"The room ID in the request path (%q) must match the room ID in the join event JSON (%q)",
roomID, event.RoomID(),
),
),
}
}
// Check that the event ID is correct.
if event.EventID() != eventID {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(
fmt.Sprintf(
"The event ID in the request path (%q) must match the event ID in the join event JSON (%q)",
eventID, event.EventID(),
),
),
}
}
// Check that this is in fact a join event
membership, err := event.Membership()
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("missing content.membership key"),
}
}
if membership != spec.Join {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("membership must be 'join'"),
}
}
// Check that the event is signed by the server sending the request.
redacted, err := verImpl.RedactEventJSON(event.JSON())
if err != nil {
logrus.WithError(err).Errorf("XXX: join.go")
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("The event JSON could not be redacted"),
}
}
verifyRequests := []gomatrixserverlib.VerifyJSONRequest{{
ServerName: serverName,
Message: redacted,
AtTS: event.OriginServerTS(),
StrictValidityChecking: true,
}}
verifyResults, err := keys.VerifyJSONs(httpReq.Context(), verifyRequests)
if err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("keys.VerifyJSONs failed")
response, joinErr := gomatrixserverlib.HandleSendJoin(input)
switch e := joinErr.(type) {
case nil:
case spec.InternalServerError:
util.GetLogger(httpReq.Context()).WithError(joinErr)
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
case spec.MatrixError:
util.GetLogger(httpReq.Context()).WithError(joinErr)
code := http.StatusInternalServerError
switch e.ErrCode {
case spec.ErrorForbidden:
code = http.StatusForbidden
case spec.ErrorNotFound:
code = http.StatusNotFound
case spec.ErrorUnsupportedRoomVersion:
code = http.StatusInternalServerError
case spec.ErrorBadJSON:
code = http.StatusBadRequest
}
if verifyResults[0].Error != nil {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("Signature check failed: " + verifyResults[0].Error.Error()),
Code: code,
JSON: e,
}
default:
util.GetLogger(httpReq.Context()).WithError(joinErr)
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Unknown("unknown error"),
}
}
if response == nil {
util.GetLogger(httpReq.Context()).Error("gmsl.HandleMakeJoin returned invalid response")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
}
}
// Fetch the state and auth chain. We do this before we send the events
// on, in case this fails.
var stateAndAuthChainResponse api.QueryStateAndAuthChainResponse
err = rsAPI.QueryStateAndAuthChain(httpReq.Context(), &api.QueryStateAndAuthChainRequest{
PrevEventIDs: event.PrevEventIDs(),
AuthEventIDs: event.AuthEventIDs(),
RoomID: roomID,
PrevEventIDs: response.JoinEvent.PrevEventIDs(),
AuthEventIDs: response.JoinEvent.AuthEventIDs(),
RoomID: roomID.String(),
ResolveState: true,
}, &stateAndAuthChainResponse)
if err != nil {
@ -388,84 +344,27 @@ func SendJoin(
}
}
// Check if the user is already in the room. If they're already in then
// there isn't much point in sending another join event into the room.
// Also check to see if they are banned: if they are then we reject them.
alreadyJoined := false
isBanned := false
for _, se := range stateAndAuthChainResponse.StateEvents {
if !se.StateKeyEquals(*event.StateKey()) {
continue
}
if membership, merr := se.Membership(); merr == nil {
alreadyJoined = (membership == spec.Join)
isBanned = (membership == spec.Ban)
break
}
}
if isBanned {
return util.JSONResponse{
Code: http.StatusForbidden,
JSON: spec.Forbidden("user is banned"),
}
}
// If the membership content contains a user ID for a server that is not
// ours then we should kick it back.
var memberContent gomatrixserverlib.MemberContent
if err := json.Unmarshal(event.Content(), &memberContent); err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(err.Error()),
}
}
if memberContent.AuthorisedVia != "" {
_, domain, err := gomatrixserverlib.SplitID('@', memberContent.AuthorisedVia)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q is invalid.", memberContent.AuthorisedVia)),
}
}
if domain != cfg.Matrix.ServerName {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON(fmt.Sprintf("The authorising username %q does not belong to this server.", memberContent.AuthorisedVia)),
}
}
}
// Sign the membership event. This is required for restricted joins to work
// in the case that the authorised via user is one of our own users. It also
// doesn't hurt to do it even if it isn't a restricted join.
signed := event.Sign(
string(cfg.Matrix.ServerName),
cfg.Matrix.KeyID,
cfg.Matrix.PrivateKey,
)
// Send the events to the room server.
// We are responsible for notifying other servers that the user has joined
// the room, so set SendAsServer to cfg.Matrix.ServerName
if !alreadyJoined {
var response api.InputRoomEventsResponse
if !response.AlreadyJoined {
var rsResponse api.InputRoomEventsResponse
rsAPI.InputRoomEvents(httpReq.Context(), &api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{
{
Kind: api.KindNew,
Event: &types.HeaderedEvent{PDU: signed},
Event: &types.HeaderedEvent{PDU: response.JoinEvent},
SendAsServer: string(cfg.Matrix.ServerName),
TransactionID: nil,
},
},
}, &response)
if response.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, response.ErrMsg).Error("SendEvents failed")
if response.NotAllowed {
}, &rsResponse)
if rsResponse.ErrMsg != "" {
util.GetLogger(httpReq.Context()).WithField(logrus.ErrorKey, rsResponse.ErrMsg).Error("SendEvents failed")
if rsResponse.NotAllowed {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.Forbidden(response.ErrMsg),
JSON: spec.Forbidden(rsResponse.ErrMsg),
}
}
return util.JSONResponse{
@ -488,7 +387,7 @@ func SendJoin(
StateEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.StateEvents),
AuthEvents: types.NewEventJSONsFromHeaderedEvents(stateAndAuthChainResponse.AuthChainEvents),
Origin: cfg.Matrix.ServerName,
Event: signed.JSON(),
Event: response.JoinEvent.JSON(),
},
}
}

View File

@ -331,14 +331,14 @@ func Setup(
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid UserID"),
JSON: spec.InvalidParam("Invalid UserID"),
}
}
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.BadJSON("Invalid RoomID"),
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
@ -358,10 +358,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"]
eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
res := SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID,
httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
// not all responses get wrapped in [code, body]
var body interface{}
@ -390,10 +397,17 @@ func Setup(
JSON: spec.Forbidden("Forbidden by server ACLs"),
}
}
roomID := vars["roomID"]
eventID := vars["eventID"]
roomID, err := spec.NewRoomID(vars["roomID"])
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Invalid RoomID"),
}
}
return SendJoin(
httpReq, request, cfg, rsAPI, keys, roomID, eventID,
httpReq, request, cfg, rsAPI, keys, *roomID, eventID,
)
},
)).Methods(http.MethodPut)

4
go.mod
View File

@ -22,7 +22,7 @@ require (
github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530
github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2
github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66
github.com/mattn/go-sqlite3 v1.14.16
@ -34,7 +34,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
github.com/sirupsen/logrus v1.9.1
github.com/sirupsen/logrus v1.9.2
github.com/stretchr/testify v1.8.2
github.com/tidwall/gjson v1.14.4
github.com/tidwall/sjson v1.2.5

8
go.sum
View File

@ -323,8 +323,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw
github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U=
github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2 h1:V36yCWt2CoSfI1xr6WYJ9Mb3eyl95SknMRLGFvEuYak=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230517000105-1ff06fc8a6a2/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c h1:EF04pmshcDmBQOrBQbzT5htyTivetfyvR70gX2hB9AM=
github.com/matrix-org/gomatrixserverlib v0.0.0-20230519160810-b92e84b02a7c/go.mod h1:H9V9N3Uqn1bBJqYJNGK1noqtgJTaCEhtTdcH/mp50uU=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a h1:awrPDf9LEFySxTLKYBMCiObelNx/cBuv/wzllvCCH3A=
github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a/go.mod h1:HchJX9oKMXaT2xYFs0Ha/6Zs06mxLU8k6F1ODnrGkeQ=
github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y=
@ -444,8 +444,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.1 h1:Ou41VVR3nMWWmTiEUnj0OlsgOSCUFgsPAOl6jRIcVtQ=
github.com/sirupsen/logrus v1.9.1/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sirupsen/logrus v1.9.2 h1:oxx1eChJGI6Uks2ZC4W1zpLlVgqB8ner4EuQwV4Ik1Y=
github.com/sirupsen/logrus v1.9.2/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=

View File

@ -202,6 +202,7 @@ type FederationRoomserverAPI interface {
QueryBulkStateContentAPI
// QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs.
QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error
QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error
QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error
QueryRoomVersionForRoom(ctx context.Context, roomID string) (gomatrixserverlib.RoomVersion, error)
GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error

View File

@ -868,7 +868,7 @@ func (r *Queryer) QueryRoomInfo(ctx context.Context, roomID spec.RoomID) (*types
}
func (r *Queryer) CurrentStateEvent(ctx context.Context, roomID spec.RoomID, eventType string, stateKey string) (gomatrixserverlib.PDU, error) {
res, err := r.DB.GetStateEvent(ctx, roomID.String(), string(eventType), "")
res, err := r.DB.GetStateEvent(ctx, roomID.String(), eventType, "")
if res == nil {
return nil, err
}