package perform import ( "context" "fmt" "github.com/matrix-org/gomatrixserverlib" ) // This file contains helpers for the PerformJoin function. type joinContext struct { federation *gomatrixserverlib.FederationClient keyRing *gomatrixserverlib.KeyRing } // Returns a new join context. func JoinContext(f *gomatrixserverlib.FederationClient, k *gomatrixserverlib.KeyRing) *joinContext { return &joinContext{ federation: f, keyRing: k, } } // checkSendJoinResponse checks that all of the signatures are correct // and that the join is allowed by the supplied state. func (r joinContext) CheckSendJoinResponse( ctx context.Context, event gomatrixserverlib.Event, server gomatrixserverlib.ServerName, respMakeJoin gomatrixserverlib.RespMakeJoin, respSendJoin gomatrixserverlib.RespSendJoin, ) error { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. retries := map[string]bool{} retryCheck: // TODO: Can we expand Check here to return a list of missing auth // events rather than failing one at a time? if err := respSendJoin.Check(ctx, r.keyRing, event); err != nil { switch e := err.(type) { case gomatrixserverlib.MissingAuthEventError: // Check that we haven't already retried for this event, prevents // us from ending up in endless loops if !retries[e.AuthEventID] { // Ask the server that we're talking to right now for the event tx, txerr := r.federation.GetEvent(ctx, server, e.AuthEventID) if txerr != nil { return fmt.Errorf("r.federation.GetEvent: %w", txerr) } // For each event returned, add it to the auth events. for _, pdu := range tx.PDUs { ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, respMakeJoin.RoomVersion) if everr != nil { return fmt.Errorf("gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr) } respSendJoin.AuthEvents = append(respSendJoin.AuthEvents, ev) } // Mark the event as retried and then give the check another go. retries[e.AuthEventID] = true goto retryCheck } return fmt.Errorf("respSendJoin (after retries): %w", e) default: return fmt.Errorf("respSendJoin: %w", err) } } return nil }