Take write lock for rate limit map (#1532)

* Take write lock for rate limit map

* Fix potential race condition
This commit is contained in:
Neil Alexander 2020-10-16 15:44:39 +01:00 committed by GitHub
parent 4a7fb9c045
commit 640e8c50ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,6 +13,7 @@ import (
type rateLimits struct { type rateLimits struct {
limits map[string]chan struct{} limits map[string]chan struct{}
limitsMutex sync.RWMutex limitsMutex sync.RWMutex
cleanMutex sync.RWMutex
enabled bool enabled bool
requestThreshold int64 requestThreshold int64
cooloffDuration time.Duration cooloffDuration time.Duration
@ -38,6 +39,7 @@ func (l *rateLimits) clean() {
// empty. If they are then we will close and delete them, // empty. If they are then we will close and delete them,
// freeing up memory. // freeing up memory.
time.Sleep(time.Second * 30) time.Sleep(time.Second * 30)
l.cleanMutex.Lock()
l.limitsMutex.Lock() l.limitsMutex.Lock()
for k, c := range l.limits { for k, c := range l.limits {
if len(c) == 0 { if len(c) == 0 {
@ -46,6 +48,7 @@ func (l *rateLimits) clean() {
} }
} }
l.limitsMutex.Unlock() l.limitsMutex.Unlock()
l.cleanMutex.Unlock()
} }
} }
@ -55,12 +58,12 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
return nil return nil
} }
// Lock the map long enough to check for rate limiting. We hold it // Take a read lock out on the cleaner mutex. The cleaner expects to
// for longer here than we really need to but it makes sure that we // be able to take a write lock, which isn't possible while there are
// also don't conflict with the cleaner goroutine which might clean // readers, so this has the effect of blocking the cleaner goroutine
// up a channel after we have retrieved it otherwise. // from doing its work until there are no requests in flight.
l.limitsMutex.RLock() l.cleanMutex.RLock()
defer l.limitsMutex.RUnlock() defer l.cleanMutex.RUnlock()
// First of all, work out if X-Forwarded-For was sent to us. If not // First of all, work out if X-Forwarded-For was sent to us. If not
// then we'll just use the IP address of the caller. // then we'll just use the IP address of the caller.
@ -69,12 +72,19 @@ func (l *rateLimits) rateLimit(req *http.Request) *util.JSONResponse {
caller = forwardedFor caller = forwardedFor
} }
// Look up the caller's channel, if they have one. If they don't then // Look up the caller's channel, if they have one.
// let's create one. l.limitsMutex.RLock()
rateLimit, ok := l.limits[caller] rateLimit, ok := l.limits[caller]
l.limitsMutex.RUnlock()
// If the caller doesn't have a channel, create one and write it
// back to the map.
if !ok { if !ok {
l.limits[caller] = make(chan struct{}, l.requestThreshold) rateLimit = make(chan struct{}, l.requestThreshold)
rateLimit = l.limits[caller]
l.limitsMutex.Lock()
l.limits[caller] = rateLimit
l.limitsMutex.Unlock()
} }
// Check if the user has got free resource slots for this request. // Check if the user has got free resource slots for this request.