diff --git a/cache.go b/cache.go index 4753e17..ca0152c 100644 --- a/cache.go +++ b/cache.go @@ -3,10 +3,11 @@ package agecache import ( "container/list" - "errors" "math/rand" "sync" "time" + + "github.com/pkg/errors" ) // Stats hold cache statistics. @@ -80,6 +81,8 @@ type Config struct { OnEviction func(key, value interface{}) // Optional callback invoked when an item expired OnExpiration func(key, value interface{}) + // Optional callback invoked when a key is missing/expired on `Get()` + OnMiss func(key interface{}) (interface{}, error) } // Entry pointed to by each list.Element @@ -99,6 +102,7 @@ type Cache struct { expirationInterval time.Duration onEviction func(key, value interface{}) onExpiration func(key, value interface{}) + onMiss func(key interface{}) (interface{}, error) // Cache statistics sets int64 @@ -154,6 +158,7 @@ func New(config Config) *Cache { expirationInterval: interval, onEviction: config.OnEviction, onExpiration: config.OnExpiration, + onMiss: config.OnMiss, items: make(map[interface{}]*list.Element), evictionList: list.New(), rand: rand.New(seed), @@ -176,6 +181,10 @@ func (cache *Cache) Set(key, value interface{}) bool { cache.mutex.Lock() defer cache.mutex.Unlock() + return cache.set(key, value) +} + +func (cache *Cache) set(key, value interface{}) bool { cache.sets++ timestamp := cache.getTimestamp() @@ -198,10 +207,15 @@ func (cache *Cache) Set(key, value interface{}) bool { return evict } -// Get returns the value stored at `key`. The boolean value reports whether or -// not the value was found. The OnExpiration callback is invoked if the value -// had expired on access -func (cache *Cache) Get(key interface{}) (interface{}, bool) { +var ( + ErrNotFound = errors.New("not found") +) + +// Get returns the value stored at `key`. Error is set to ErrNotFound if +// key not found or expired. If OnMiss is set, value will be fetched, set and returned. +// If fetch failed, error will be returned. The OnExpiration callback is invoked if the value +// had expired on access. +func (cache *Cache) Get(key interface{}) (interface{}, error) { cache.mutex.Lock() defer cache.mutex.Unlock() @@ -212,20 +226,27 @@ func (cache *Cache) Get(key interface{}) (interface{}, bool) { if cache.maxAge == 0 || time.Since(entry.timestamp) <= cache.maxAge { cache.evictionList.MoveToFront(element) cache.hits++ - return entry.value, true + return entry.value, nil } // Entry expired cache.deleteElement(element) - cache.misses++ if cache.onExpiration != nil { cache.onExpiration(entry.key, entry.value) } - return nil, false } cache.misses++ - return nil, false + if cache.onMiss != nil { + value, err := cache.onMiss(key) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch values for key %v", key) + } + cache.set(key, value) + return value, nil + } + + return nil, ErrNotFound } // Has returns whether or not the `key` is in the cache without updating @@ -383,6 +404,14 @@ func (cache *Cache) OnExpiration(callback func(key, value interface{})) { cache.onExpiration = callback } +// OnMiss sets the callback to fetch value on miss. +func (cache *Cache) OnMiss(callback func(key interface{}) (interface{}, error)) { + cache.mutex.Lock() + defer cache.mutex.Unlock() + + cache.onMiss = callback +} + // Stats returns cache stats. func (cache *Cache) Stats() Stats { cache.mutex.RLock() diff --git a/cache_test.go b/cache_test.go index eed0038..74b0662 100644 --- a/cache_test.go +++ b/cache_test.go @@ -39,12 +39,12 @@ func TestBasicSetGet(t *testing.T) { cache.Set("foo", 1) cache.Set("bar", 2) - val, ok := cache.Get("foo") - assert.True(t, ok) + val, err := cache.Get("foo") + assert.Nil(t, err) assert.Equal(t, 1, val) - val, ok = cache.Get("bar") - assert.True(t, ok) + val, err = cache.Get("bar") + assert.Nil(t, err) assert.Equal(t, 2, val) } @@ -52,10 +52,10 @@ func TestBasicSetOverwrite(t *testing.T) { cache := New(Config{Capacity: 2}) cache.Set("foo", 1) evict := cache.Set("foo", 2) - val, ok := cache.Get("foo") + val, err := cache.Get("foo") assert.False(t, evict) - assert.True(t, ok) + assert.Nil(t, err) assert.Equal(t, 2, val) } @@ -73,10 +73,10 @@ func TestEviction(t *testing.T) { cache.Set("foo", 1) cache.Set("bar", 2) evict := cache.Set("baz", 3) - val, ok := cache.Get("foo") + val, err := cache.Get("foo") assert.True(t, evict) - assert.False(t, ok) + assert.Equal(t, ErrNotFound, err) assert.Nil(t, val) assert.Equal(t, "foo", k) assert.Equal(t, 1, v) @@ -101,8 +101,8 @@ func TestExpiration(t *testing.T) { cache.Set("foo", 1) <-time.After(time.Millisecond * 2) - val, ok := cache.Get("foo") - assert.False(t, ok) + val, err := cache.Get("foo") + assert.Equal(t, ErrNotFound, err) assert.Nil(t, val) assert.Equal(t, "foo", k) assert.Equal(t, 1, v) @@ -128,12 +128,12 @@ func TestJitter(t *testing.T) { cache.Set("foo", "bar") <-time.After(time.Millisecond * 2) - _, ok := cache.Get("foo") - assert.True(t, ok) + _, err := cache.Get("foo") + assert.Nil(t, err) <-time.After(time.Millisecond * 3) - _, ok = cache.Get("foo") - assert.False(t, ok) + _, err = cache.Get("foo") + assert.Equal(t, ErrNotFound, err) } func TestHas(t *testing.T) { @@ -171,8 +171,8 @@ func TestRemove(t *testing.T) { assert.True(t, ok) assert.False(t, eviction) - val, ok := cache.Get("foo") - assert.False(t, ok) + val, err := cache.Get("foo") + assert.Equal(t, ErrNotFound, err) assert.Nil(t, val) } @@ -192,8 +192,8 @@ func TestEvictOldest(t *testing.T) { assert.True(t, ok) assert.True(t, eviction) - val, ok := cache.Get("foo") - assert.False(t, ok) + val, err := cache.Get("foo") + assert.Equal(t, ErrNotFound, err) assert.Nil(t, val) eviction = false @@ -222,8 +222,8 @@ func TestClear(t *testing.T) { cache.Clear() for i := 0; i <= 9; i++ { - _, ok := cache.Get(i) - assert.False(t, ok) + _, err := cache.Get(i) + assert.Equal(t, ErrNotFound, err) } assert.Equal(t, 0, cache.Len()) }