From 738b829a23d4e50e68f98acb72f7d10a16009f8b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Sep 2020 13:40:29 +0100 Subject: [PATCH] Fetch missing auth events, implement QueryMissingAuthPrevEvents, try other servers in room for /event and /get_missing_events (#1450) * Try to ask other servers in the room for missing events if the origin won't provide them * Logging * More logging * Implement QueryMissingAuthPrevEvents * Try to get missing auth events badly * Use processEvent * Logging * Update QueryMissingAuthPrevEvents * Try to find missing auth events * Patchy fix for test * Logging tweaks * Send auth events as outliers * Update check in QueryMissingAuthPrevEvents * Error responses * More return codes * Don't return error on reject/soft-fail since it was ultimately handled * More tweaks * More error tweaks --- federationapi/routing/send.go | 130 +++++++++++++----- federationapi/routing/send_test.go | 98 +++++++++---- roomserver/api/api.go | 7 + roomserver/api/api_trace.go | 10 ++ roomserver/api/query.go | 23 ++++ roomserver/internal/helpers/auth.go | 6 +- roomserver/internal/input/input_events.go | 2 +- roomserver/internal/perform/perform_invite.go | 19 +-- roomserver/internal/perform/perform_join.go | 10 +- roomserver/internal/query/query.go | 39 +++++- roomserver/inthttp/client.go | 14 ++ roomserver/inthttp/server.go | 14 ++ 12 files changed, 291 insertions(+), 81 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index 5f20b2d8..4a30f8d7 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -206,10 +206,10 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res return nil, &jsonErr } else { // Auth errors mean the event is 'rejected' which have to be silent to appease sytest + errMsg := "" _, rejected := err.(*gomatrixserverlib.NotAllowed) - errMsg := err.Error() - if rejected { - errMsg = "" + if !rejected { + errMsg = err.Error() } util.GetLogger(ctx).WithError(err).WithField("event_id", e.EventID()).WithField("rejected", rejected).Warn( "Failed to process incoming federation event, skipping", @@ -345,17 +345,17 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli } func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, isInboundTxn bool) error { - prevEventIDs := e.PrevEventIDs() + logger := util.GetLogger(ctx).WithField("event_id", e.EventID()).WithField("room_id", e.RoomID()) - // Fetch the state needed to authenticate the event. - needed := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{e}) - stateReq := api.QueryStateAfterEventsRequest{ + // Work out if the roomserver knows everything it needs to know to auth + // the event. + stateReq := api.QueryMissingAuthPrevEventsRequest{ RoomID: e.RoomID(), - PrevEventIDs: prevEventIDs, - StateToFetch: needed.Tuples(), + AuthEventIDs: e.AuthEventIDs(), + PrevEventIDs: e.PrevEventIDs(), } - var stateResp api.QueryStateAfterEventsResponse - if err := t.rsAPI.QueryStateAfterEvents(ctx, &stateReq, &stateResp); err != nil { + var stateResp api.QueryMissingAuthPrevEventsResponse + if err := t.rsAPI.QueryMissingAuthPrevEvents(ctx, &stateReq, &stateResp); err != nil { return err } @@ -369,7 +369,53 @@ func (t *txnReq) processEvent(ctx context.Context, e gomatrixserverlib.Event, is return roomNotFoundError{e.RoomID()} } - if !stateResp.PrevEventsExist { + if len(stateResp.MissingAuthEventIDs) > 0 { + logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) + + servers := []gomatrixserverlib.ServerName{t.Origin} + serverReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + serverRes := &api.QueryServerJoinedToRoomResponse{} + if err := t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { + servers = append(servers, serverRes.ServerNames...) + logger.Infof("Found %d server(s) to query for missing events", len(servers)) + } + + getAuthEvent: + for _, missingAuthEventID := range stateResp.MissingAuthEventIDs { + for _, server := range servers { + logger.Infof("Retrieving missing auth event %q from %q", missingAuthEventID, server) + tx, err := t.federation.GetEvent(ctx, server, missingAuthEventID) + if err != nil { + continue // try the next server + } + ev, err := gomatrixserverlib.NewEventFromUntrustedJSON(tx.PDUs[0], stateResp.RoomVersion) + if err != nil { + logger.WithError(err).Errorf("Failed to unmarshal auth event %q", missingAuthEventID) + continue // try the next server + } + if err = api.SendInputRoomEvents( + context.Background(), + t.rsAPI, + []api.InputRoomEvent{ + { + Kind: api.KindOutlier, + Event: ev.Headered(stateResp.RoomVersion), + AuthEventIDs: ev.AuthEventIDs(), + SendAsServer: api.DoNotSendToOtherServers, + }, + }, + ); err != nil { + logger.WithError(err).Errorf("Failed to send auth event %q to roomserver", missingAuthEventID) + continue getAuthEvent // move onto the next event + } + } + } + } + + if len(stateResp.MissingPrevEventIDs) > 0 { + logger.Infof("Event refers to %d unknown prev_events", len(stateResp.MissingPrevEventIDs)) return t.processEventWithMissingState(ctx, e, stateResp.RoomVersion, isInboundTxn) } @@ -611,6 +657,7 @@ retryAllowedState: // begin from. Returns an error only if we should terminate the transaction which initiated /get_missing_events // This function recursively calls txnReq.processEvent with the missing events, which will be processed before this function returns. // This means that we may recursively call this function, as we spider back up prev_events to the min depth. +// nolint:gocyclo func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, isInboundTxn bool) (backwardsExtremity *gomatrixserverlib.Event, err error) { if !isInboundTxn { // we've recursed here, so just take a state snapshot please! @@ -637,15 +684,46 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event if minDepth < 0 { minDepth = 0 } - missingResp, err := t.federation.LookupMissingEvents(ctx, t.Origin, e.RoomID(), gomatrixserverlib.MissingEvents{ - Limit: 20, - // synapse uses the min depth they've ever seen in that room - MinDepth: minDepth, - // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. - EarliestEvents: latestEvents, - // The event IDs to retrieve the previous events for. - LatestEvents: []string{e.EventID()}, - }, roomVersion) + + servers := []gomatrixserverlib.ServerName{t.Origin} + serverReq := &api.QueryServerJoinedToRoomRequest{ + RoomID: e.RoomID(), + } + serverRes := &api.QueryServerJoinedToRoomResponse{} + if err = t.rsAPI.QueryServerJoinedToRoom(ctx, serverReq, serverRes); err == nil { + servers = append(servers, serverRes.ServerNames...) + logger.Infof("Found %d server(s) to query for missing events", len(servers)) + } + + var missingResp *gomatrixserverlib.RespMissingEvents + for _, server := range servers { + var m gomatrixserverlib.RespMissingEvents + if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + Limit: 20, + // synapse uses the min depth they've ever seen in that room + MinDepth: minDepth, + // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. + EarliestEvents: latestEvents, + // The event IDs to retrieve the previous events for. + LatestEvents: []string{e.EventID()}, + }, roomVersion); err == nil { + missingResp = &m + break + } else { + logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.Origin, server) + } + } + + if missingResp == nil { + logger.WithError(err).Errorf( + "%s pushed us an event but %d server(s) couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", + t.Origin, len(servers), + ) + return nil, missingPrevEventsError{ + eventID: e.EventID(), + err: err, + } + } // security: how we handle failures depends on whether or not this event will become the new forward extremity for the room. // There's 2 scenarios to consider: @@ -658,16 +736,6 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.Event // https://github.com/matrix-org/synapse/pull/3456 // https://github.com/matrix-org/synapse/blob/229eb81498b0fe1da81e9b5b333a0285acde9446/synapse/handlers/federation.py#L335 // For now, we do not allow Case B, so reject the event. - if err != nil { - logger.WithError(err).Errorf( - "%s pushed us an event but couldn't give us details about prev_events via /get_missing_events - dropping this event until it can", - t.Origin, - ) - return nil, missingPrevEventsError{ - eventID: e.EventID(), - err: err, - } - } logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) // topologically sort and sanity check that we are making forward progress diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index e1211ffe..ba653c1e 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -77,10 +77,11 @@ func (p *testEDUProducer) InputSendToDeviceEvent( } type testRoomserverAPI struct { - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {} @@ -162,6 +163,20 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( + ctx context.Context, + request *api.QueryMissingAuthPrevEventsRequest, + response *api.QueryMissingAuthPrevEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryMissingAuthPrevEvents(request) + response.RoomExists = res.RoomExists + response.MissingAuthEventIDs = res.MissingAuthEventIDs + response.MissingPrevEventIDs = res.MissingPrevEventIDs + return nil +} + // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -453,11 +468,11 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []gomatr // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - RoomExists: true, - StateEvents: fromStateTuples(req.StateToFetch, nil), + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: []string{}, } }, } @@ -473,14 +488,11 @@ func TestBasicTransaction(t *testing.T) { // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { rsAPI := &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - RoomExists: true, - // omit the create event so auth checks fail - StateEvents: fromStateTuples(req.StateToFetch, []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomCreate, StateKey: ""}, - }), + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{"create_event"}, + MissingPrevEventIDs: []string{}, } }, } @@ -504,28 +516,24 @@ func TestTransactionFetchMissingPrevEvents(t *testing.T) { var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - // we expect this to be called three times: - // - first with input event to realise there's a gap - // - second with the prevEvent to realise there is no gap - // - third with the input event to realise there is no longer a gap - prevEventsExist := false + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + missingPrevEvent := []string{"missing_prev_event"} if len(req.PrevEventIDs) == 1 { switch req.PrevEventIDs[0] { case haveEvent.EventID(): - prevEventsExist = true + missingPrevEvent = []string{} case prevEvent.EventID(): // we only have this event if we've been send prevEvent if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - prevEventsExist = true + missingPrevEvent = []string{} } } } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventsExist, - RoomExists: true, - StateEvents: fromStateTuples(req.StateToFetch, nil), + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: missingPrevEvent, } }, queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { @@ -626,6 +634,38 @@ func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { StateEvents: stateEvents, } }, + + queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { + askingForEvent := req.PrevEventIDs[0] + haveEventB := false + haveEventC := false + for _, ev := range rsAPI.inputRoomEvents { + switch ev.Event.EventID() { + case eventB.EventID(): + haveEventB = true + case eventC.EventID(): + haveEventC = true + } + } + prevEventExists := false + if askingForEvent == eventC.EventID() { + prevEventExists = haveEventC + } else if askingForEvent == eventB.EventID() { + prevEventExists = haveEventB + } + + var missingPrevEvent []string + if !prevEventExists { + missingPrevEvent = []string{"test"} + } + + return api.QueryMissingAuthPrevEventsResponse{ + RoomExists: true, + MissingAuthEventIDs: []string{}, + MissingPrevEventIDs: missingPrevEvent, + } + }, + queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { omitTuples := []gomatrixserverlib.StateKeyTuple{ {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 159c1829..043f7222 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -68,6 +68,13 @@ type RoomserverInternalAPI interface { response *QueryStateAfterEventsResponse, ) error + // Query whether the roomserver is missing any auth or prev events. + QueryMissingAuthPrevEvents( + ctx context.Context, + request *QueryMissingAuthPrevEventsRequest, + response *QueryMissingAuthPrevEventsResponse, + ) error + // Query a list of events by event ID. QueryEventsByID( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 5fabbc21..f4eaddc1 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -104,6 +104,16 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( return err } +func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents( + ctx context.Context, + req *QueryMissingAuthPrevEventsRequest, + res *QueryMissingAuthPrevEventsResponse, +) error { + err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 5d61e862..aff6ee07 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -82,6 +82,27 @@ type QueryStateAfterEventsResponse struct { StateEvents []gomatrixserverlib.HeaderedEvent `json:"state_events"` } +type QueryMissingAuthPrevEventsRequest struct { + // The room ID to query the state in. + RoomID string `json:"room_id"` + // The list of auth events to check the existence of. + AuthEventIDs []string `json:"auth_event_ids"` + // The list of previous events to check the existence of. + PrevEventIDs []string `json:"prev_event_ids"` +} + +type QueryMissingAuthPrevEventsResponse struct { + // Does the room exist on this roomserver? + // If the room doesn't exist all other fields will be empty. + RoomExists bool `json:"room_exists"` + // The room version of the room. + RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` + // The event IDs of the auth events that we don't know locally. + MissingAuthEventIDs []string `json:"missing_auth_event_ids"` + // The event IDs of the previous events that we don't know locally. + MissingPrevEventIDs []string `json:"missing_prev_event_ids"` +} + // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { // The event IDs to look up. @@ -154,6 +175,8 @@ type QueryServerJoinedToRoomResponse struct { RoomExists bool `json:"room_exists"` // True if we still believe that we are participating in the room IsInRoom bool `json:"is_in_room"` + // List of servers that are also in the room + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } // QueryServerAllowedToSeeEventRequest is a request to QueryServerAllowedToSeeEvent diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 834bc0c6..0fa89d9c 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -83,7 +83,7 @@ func CheckForSoftFail( // Check if the event is allowed. if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { // return true, nil - return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err) + return true, err } return false, nil } @@ -99,7 +99,7 @@ func CheckAuthEvents( // Grab the numeric IDs for the supplied auth state events from the database. authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs) if err != nil { - return nil, err + return nil, fmt.Errorf("db.StateEntriesForEventIDs: %w", err) } authStateEntries = types.DeduplicateStateEntries(authStateEntries) @@ -109,7 +109,7 @@ func CheckAuthEvents( // Load the actual auth events from the database. authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries) if err != nil { - return nil, err + return nil, fmt.Errorf("loadAuthEvents: %w", err) } // Check if the event is allowed. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index f953a925..3d44f048 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -49,7 +49,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if rejectionErr != nil { - logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event, rejecting event") + logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("helpers.CheckAuthEvents failed for event, rejecting event") isRejected = true } diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index d6a64e7e..734e73d4 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -136,14 +136,10 @@ func (r *Inviter) PerformInvite( log.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", event.AuthEventIDs()).Error( "processInviteEvent.checkAuthEvents failed for event", ) - if _, ok := err.(*gomatrixserverlib.NotAllowed); ok { - res.Error = &api.PerformError{ - Msg: err.Error(), - Code: api.PerformErrorNotAllowed, - } - return nil, nil + res.Error = &api.PerformError{ + Msg: err.Error(), + Code: api.PerformErrorNotAllowed, } - return nil, fmt.Errorf("checkAuthEvents: %w", err) } // If the invite originated from us and the target isn't local then we @@ -160,7 +156,7 @@ func (r *Inviter) PerformInvite( if err = r.FSAPI.PerformInvite(ctx, fsReq, fsRes); err != nil { res.Error = &api.PerformError{ Msg: err.Error(), - Code: api.PerformErrorNoOperation, + Code: api.PerformErrorNotAllowed, } log.WithError(err).WithField("event_id", event.EventID()).Error("r.FSAPI.PerformInvite failed") return nil, nil @@ -185,7 +181,12 @@ func (r *Inviter) PerformInvite( inputRes := &api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(context.Background(), inputReq, inputRes) if err = inputRes.Err(); err != nil { - return nil, fmt.Errorf("r.InputRoomEvents: %w", err) + res.Error = &api.PerformError{ + Msg: fmt.Sprintf("r.InputRoomEvents: %s", err.Error()), + Code: api.PerformErrorNotAllowed, + } + log.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") + return nil, nil } } else { // The invite originated over federation. Process the membership diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index e9aebb83..56ae6d0b 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -249,14 +249,10 @@ func (r *Joiner) performJoinRoomByID( inputRes := api.InputRoomEventsResponse{} r.Inputer.InputRoomEvents(ctx, &inputReq, &inputRes) if err = inputRes.Err(); err != nil { - var notAllowed *gomatrixserverlib.NotAllowed - if errors.As(err, ¬Allowed) { - return "", &api.PerformError{ - Code: api.PerformErrorNotAllowed, - Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), - } + return "", &api.PerformError{ + Code: api.PerformErrorNotAllowed, + Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err), } - return "", fmt.Errorf("r.InputRoomEvents: %w", err) } } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 58cb4493..73660421 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -98,6 +98,38 @@ func (r *Queryer) QueryStateAfterEvents( return nil } +// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI +func (r *Queryer) QueryMissingAuthPrevEvents( + ctx context.Context, + request *api.QueryMissingAuthPrevEventsRequest, + response *api.QueryMissingAuthPrevEventsResponse, +) error { + info, err := r.DB.RoomInfo(ctx, request.RoomID) + if err != nil { + return err + } + if info == nil { + return errors.New("room doesn't exist") + } + + response.RoomExists = !info.IsStub + response.RoomVersion = info.RoomVersion + + for _, authEventID := range request.AuthEventIDs { + if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { + response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID) + } + } + + for _, prevEventID := range request.PrevEventIDs { + if nids, err := r.DB.EventNIDs(ctx, []string{prevEventID}); err != nil || len(nids) == 0 { + response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) + } + } + + return nil +} + // QueryEventsByID implements api.RoomserverInternalAPI func (r *Queryer) QueryEventsByID( ctx context.Context, @@ -255,19 +287,24 @@ func (r *Queryer) QueryServerJoinedToRoom( return fmt.Errorf("r.DB.Events: %w", err) } + servers := map[gomatrixserverlib.ServerName]struct{}{} for _, e := range events { if e.Type() == gomatrixserverlib.MRoomMember && e.StateKey() != nil { _, serverName, err := gomatrixserverlib.SplitID('@', *e.StateKey()) if err != nil { continue } + servers[serverName] = struct{}{} if serverName == request.ServerName { response.IsInRoom = true - break } } } + for server := range servers { + response.ServerNames = append(response.ServerNames, server) + } + return nil } diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 3dd3edaf..24a82adf 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -35,6 +35,7 @@ const ( // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" + RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents" RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" @@ -262,6 +263,19 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +// QueryStateAfterEvents implements RoomserverQueryAPI +func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents( + ctx context.Context, + request *api.QueryMissingAuthPrevEventsRequest, + response *api.QueryMissingAuthPrevEventsResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents") + defer span.Finish() + + apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + // QueryEventsByID implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index c7e541dd..9c9d4d4a 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -125,6 +125,20 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle( + RoomserverQueryMissingAuthPrevEventsPath, + httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse { + var request api.QueryMissingAuthPrevEventsRequest + var response api.QueryMissingAuthPrevEventsResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {