Don't prematurely stop trying to join using servers (#1041)

* Don't prematurely stop trying to join using servers

* Factor out performJoinUsingServer
This commit is contained in:
Neil Alexander 2020-05-15 13:55:14 +01:00 committed by GitHub
parent f4f032381b
commit 5f6f8adaa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,13 +52,46 @@ func (r *FederationSenderInternalAPI) PerformJoin(
// Try each server that we were provided until we land on one that
// successfully completes the make-join send-join dance.
for _, serverName := range request.ServerNames {
if err := r.performJoinUsingServer(
ctx,
request.RoomID,
request.UserID,
request.Content,
serverName,
supportedVersions,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to join room through server")
continue
}
// We're all good.
return nil
}
// If we reach here then we didn't complete a join for some reason.
return fmt.Errorf(
"failed to join user %q to room %q through %d server(s)",
request.UserID, request.RoomID, len(request.ServerNames),
)
}
func (r *FederationSenderInternalAPI) performJoinUsingServer(
ctx context.Context,
roomID, userID string,
content map[string]interface{},
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
// Try to perform a make_join using the information supplied in the
// request.
respMakeJoin, err := r.federation.MakeJoin(
ctx,
serverName,
request.RoomID,
request.UserID,
roomID,
userID,
supportedVersions,
)
if err != nil {
@ -66,19 +99,20 @@ func (r *FederationSenderInternalAPI) PerformJoin(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
// Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd"
respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember
respMakeJoin.JoinEvent.Sender = request.UserID
respMakeJoin.JoinEvent.StateKey = &request.UserID
respMakeJoin.JoinEvent.RoomID = request.RoomID
respMakeJoin.JoinEvent.Sender = userID
respMakeJoin.JoinEvent.StateKey = &userID
respMakeJoin.JoinEvent.RoomID = roomID
respMakeJoin.JoinEvent.Redacts = ""
if request.Content == nil {
request.Content = map[string]interface{}{}
if content == nil {
content = map[string]interface{}{}
}
request.Content["membership"] = "join"
if err = respMakeJoin.JoinEvent.SetContent(request.Content); err != nil {
content["membership"] = "join"
if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
}
if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil {
@ -114,18 +148,17 @@ func (r *FederationSenderInternalAPI) PerformJoin(
respMakeJoin.RoomVersion,
)
if err != nil {
logrus.WithError(err).Warnf("r.federation.SendJoin failed")
r.statistics.ForServer(serverName).Failure()
continue
return fmt.Errorf("r.federation.SendJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
// Check that the send_join response was valid.
joinCtx := perform.JoinContext(r.federation, r.keyRing)
if err = joinCtx.CheckSendJoinResponse(
ctx, event, serverName, respMakeJoin, respSendJoin,
); err != nil {
logrus.WithError(err).Warnf("joinCtx.CheckSendJoinResponse failed")
continue
return fmt.Errorf("joinCtx.CheckSendJoinResponse: %w", err)
}
// If we successfully performed a send_join above then the other
@ -136,22 +169,12 @@ func (r *FederationSenderInternalAPI) PerformJoin(
respSendJoin.ToRespState(),
event.Headered(respMakeJoin.RoomVersion),
); err != nil {
logrus.WithError(err).Warnf("r.producer.SendEventWithState failed")
continue
return fmt.Errorf("r.producer.SendEventWithState: %w", err)
}
// We're all good.
r.statistics.ForServer(serverName).Success()
return nil
}
// If we reach here then we didn't complete a join for some reason.
return fmt.Errorf(
"failed to join user %q to room %q through %d server(s)",
request.UserID, request.RoomID, len(request.ServerNames),
)
}
// PerformLeaveRequest implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformLeave(
ctx context.Context,