Fix concurrent map reads/writes on t.hadEvents (#1902)

* Fix concurrent map reads/writes on t.hadEvents

* Add hadEvent function
This commit is contained in:
Neil Alexander 2021-07-07 18:55:44 +01:00 committed by GitHub
parent 5a09290c32
commit f2974721d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -233,6 +233,7 @@ type txnReq struct {
servers federationAPI.ServersInRoomProvider servers federationAPI.ServersInRoomProvider
// a list of events from the auth and prev events which we already had // a list of events from the auth and prev events which we already had
hadEvents map[string]bool hadEvents map[string]bool
hadEventsMutex sync.Mutex
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
// which the roomserver is unaware of. // which the roomserver is unaware of.
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
@ -240,6 +241,12 @@ type txnReq struct {
work string // metrics work string // metrics
} }
func (t *txnReq) hadEvent(eventID string, had bool) {
t.hadEventsMutex.Lock()
defer t.hadEventsMutex.Unlock()
t.hadEvents[eventID] = had
}
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface { type txnFederationClient interface {
LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
@ -595,10 +602,10 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e
// Prepare a map of all the events we already had before this point, so // Prepare a map of all the events we already had before this point, so
// that we don't send them to the roomserver again. // that we don't send them to the roomserver again.
for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) { for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) {
t.hadEvents[eventID] = true t.hadEvent(eventID, true)
} }
for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) { for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) {
t.hadEvents[eventID] = false t.hadEvent(eventID, false)
} }
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
@ -673,7 +680,7 @@ withNextEvent:
); err != nil { ); err != nil {
return fmt.Errorf("api.SendEvents: %w", err) return fmt.Errorf("api.SendEvents: %w", err)
} }
t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now t.hadEvent(ev.EventID(), true) // if the roomserver didn't know about the event before, it does now
t.cacheAndReturn(ev.Headered(stateResp.RoomVersion)) t.cacheAndReturn(ev.Headered(stateResp.RoomVersion))
delete(missingAuthEvents, missingAuthEventID) delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent continue withNextEvent
@ -801,14 +808,23 @@ func (t *txnReq) processEventWithMissingState(
// First of all, send the backward extremity into the roomserver with the // First of all, send the backward extremity into the roomserver with the
// newly resolved state. This marks the "oldest" point in the backfill and // newly resolved state. This marks the "oldest" point in the backfill and
// sets the baseline state for any new events after this. // sets the baseline state for any new events after this. We'll make a
// copy of the hadEvents map so that it can be taken downstream without
// worrying about concurrent map reads/writes, since t.hadEvents is meant
// to be protected by a mutex.
hadEvents := map[string]bool{}
t.hadEventsMutex.Lock()
for k, v := range t.hadEvents {
hadEvents[k] = v
}
t.hadEventsMutex.Unlock()
err = api.SendEventWithState( err = api.SendEventWithState(
context.Background(), context.Background(),
t.rsAPI, t.rsAPI,
api.KindOld, api.KindOld,
resolvedState, resolvedState,
backwardsExtremity.Headered(roomVersion), backwardsExtremity.Headered(roomVersion),
t.hadEvents, hadEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("api.SendEventWithState: %w", err) return fmt.Errorf("api.SendEventWithState: %w", err)
@ -904,7 +920,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
// processEvent request, which is better for memory. // processEvent request, which is better for memory.
stateEvents[i] = t.cacheAndReturn(ev) stateEvents[i] = t.cacheAndReturn(ev)
t.hadEvents[ev.EventID()] = true t.hadEvent(ev.EventID(), true)
} }
// we should never access res.StateEvents again so we delete it here to make GC faster // we should never access res.StateEvents again so we delete it here to make GC faster
res.StateEvents = nil res.StateEvents = nil
@ -939,7 +955,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
} }
for i, ev := range queryRes.Events { for i, ev := range queryRes.Events {
authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap())
t.hadEvents[ev.EventID()] = true t.hadEvent(ev.EventID(), true)
} }
queryRes.Events = nil queryRes.Events = nil
} }
@ -1016,7 +1032,7 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
latestEvents := make([]string, len(res.LatestEvents)) latestEvents := make([]string, len(res.LatestEvents))
for i, ev := range res.LatestEvents { for i, ev := range res.LatestEvents {
latestEvents[i] = res.LatestEvents[i].EventID latestEvents[i] = res.LatestEvents[i].EventID
t.hadEvents[ev.EventID] = true t.hadEvent(ev.EventID, true)
} }
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
@ -1152,7 +1168,7 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
} }
for i, ev := range queryRes.Events { for i, ev := range queryRes.Events {
queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i])
t.hadEvents[ev.EventID()] = true t.hadEvent(ev.EventID(), true)
evID := queryRes.Events[i].EventID() evID := queryRes.Events[i].EventID()
if missing[evID] { if missing[evID] {
delete(missing, evID) delete(missing, evID)