diff --git a/cache.go b/cache.go index 59fb830..e0b4929 100644 --- a/cache.go +++ b/cache.go @@ -66,68 +66,33 @@ func (c *Cache[K, V]) Close() { // cleaner handles removing expired keys. The chainAdd and chainDel channels are // handled here to prevent race conditions. This ensures the expiry timer can be // stopped before modifying the chain. -// -// The cleaner is stopped whenever the chain is empty due to there being no chain -// to manage. func (c *Cache[K, V]) cleaner() { - // cleaner is always called from Set or Delete methods with a value sent on chainAdd or chainDel - select { - case node := <-c.chainAdd: - c.chainInsert(node) - case key := <-c.chainDel: - c.chainSplice(key) - default: - // skip if chainAdd or chainDel isn't ready - } - - // at this point if the chain is empty then exit - if c.chain == nil { - return - } - - // create a timer for the next expiry - t := time.NewTimer(timeUntil(c.chain.expires)) - for { select { case <-c.close: // exit the cleaner goroutine return case node := <-c.chainAdd: - // stop the timer safely - if !t.Stop() { - <-t.C - } - // the chain will not be empty after this insert so no check is required c.chainInsert(node) case key := <-c.chainDel: - // stop the timer safely - if !t.Stop() { - <-t.C - } c.chainSplice(key) - case <-t.C: - // if there is no chain then kill the expiry scheduler - if c.chain == nil { - return - } - + case <-c.nextExpiry(): // remove all expired entries for c.chain != nil && c.chain.HasExpired() { c.items.CompareAndDelete(c.chain.data, c.chain.item) c.chain = c.chain.next } } - - // if there is no chain then kill the expiry scheduler - if c.chain == nil { - return - } - - t.Reset(timeUntil(c.chain.expires)) } } +func (c *Cache[K, V]) nextExpiry() <-chan time.Time { + if c.chain == nil { + return make(chan time.Time) + } + return time.After(timeUntil(c.chain.expires)) +} + func (c *Cache[K, V]) chainInsert(node keyed[K]) { // quick path for an empty chain if c.chain == nil { diff --git a/cache_test.go b/cache_test.go index 910941b..d0ea62b 100644 --- a/cache_test.go +++ b/cache_test.go @@ -152,3 +152,21 @@ func TestCache_UpdateExpiry(t *testing.T) { assert.True(t, b) assert.Equal(t, "b", get) } + +func TestCache_ClearerDeath(t *testing.T) { + timeNow = func() time.Time { return time.Now() } + + c := New[string, string]() + + time.Sleep(10 * time.Millisecond) + + var added bool + go func() { + c.chainAdd <- keyed[string]{item: item[string]{data: "a"}} + c.chainAdd <- keyed[string]{item: item[string]{data: "b"}} + added = true + }() + + time.Sleep(10 * time.Millisecond) + assert.True(t, added) +}