Skip to content

Commit

Permalink
Add OnMiss() callback to fetch missing key
Browse files Browse the repository at this point in the history
  • Loading branch information
tamalsaha committed May 10, 2018
1 parent c15de7c commit be22292
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
47 changes: 38 additions & 9 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package agecache

import (
"container/list"
"errors"
"math/rand"
"sync"
"time"

"github.com/pkg/errors"
)

// Stats hold cache statistics.
Expand Down Expand Up @@ -80,6 +81,8 @@ type Config struct {
OnEviction func(key, value interface{})
// Optional callback invoked when an item expired
OnExpiration func(key, value interface{})
// Optional callback invoked when a key is missing/expired on `Get()`
OnMiss func(key interface{}) (interface{}, error)
}

// Entry pointed to by each list.Element
Expand All @@ -99,6 +102,7 @@ type Cache struct {
expirationInterval time.Duration
onEviction func(key, value interface{})
onExpiration func(key, value interface{})
onMiss func(key interface{}) (interface{}, error)

// Cache statistics
sets int64
Expand Down Expand Up @@ -154,6 +158,7 @@ func New(config Config) *Cache {
expirationInterval: interval,
onEviction: config.OnEviction,
onExpiration: config.OnExpiration,
onMiss: config.OnMiss,
items: make(map[interface{}]*list.Element),
evictionList: list.New(),
rand: rand.New(seed),
Expand All @@ -176,6 +181,10 @@ func (cache *Cache) Set(key, value interface{}) bool {
cache.mutex.Lock()
defer cache.mutex.Unlock()

return cache.set(key, value)
}

func (cache *Cache) set(key, value interface{}) bool {
cache.sets++
timestamp := cache.getTimestamp()

Expand All @@ -198,10 +207,15 @@ func (cache *Cache) Set(key, value interface{}) bool {
return evict
}

// Get returns the value stored at `key`. The boolean value reports whether or
// not the value was found. The OnExpiration callback is invoked if the value
// had expired on access
func (cache *Cache) Get(key interface{}) (interface{}, bool) {
var (
ErrNotFound = errors.New("not found")
)

// Get returns the value stored at `key`. Error is set to ErrNotFound if
// key not found or expired. If OnMiss is set, value will be fetched, set and returned.
// If fetch failed, error will be returned. The OnExpiration callback is invoked if the value
// had expired on access.
func (cache *Cache) Get(key interface{}) (interface{}, error) {
cache.mutex.Lock()
defer cache.mutex.Unlock()

Expand All @@ -212,20 +226,27 @@ func (cache *Cache) Get(key interface{}) (interface{}, bool) {
if cache.maxAge == 0 || time.Since(entry.timestamp) <= cache.maxAge {
cache.evictionList.MoveToFront(element)
cache.hits++
return entry.value, true
return entry.value, nil
}

// Entry expired
cache.deleteElement(element)
cache.misses++
if cache.onExpiration != nil {
cache.onExpiration(entry.key, entry.value)
}
return nil, false
}

cache.misses++
return nil, false
if cache.onMiss != nil {
value, err := cache.onMiss(key)
if err != nil {
return nil, errors.Wrapf(err, "failed to fetch values for key %v", key)
}
cache.set(key, value)
return value, nil
}

return nil, ErrNotFound
}

// Has returns whether or not the `key` is in the cache without updating
Expand Down Expand Up @@ -383,6 +404,14 @@ func (cache *Cache) OnExpiration(callback func(key, value interface{})) {
cache.onExpiration = callback
}

// OnMiss sets the callback to fetch value on miss.
func (cache *Cache) OnMiss(callback func(key interface{}) (interface{}, error)) {
cache.mutex.Lock()
defer cache.mutex.Unlock()

cache.onMiss = callback
}

// Stats returns cache stats.
func (cache *Cache) Stats() Stats {
cache.mutex.RLock()
Expand Down
40 changes: 20 additions & 20 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,23 @@ func TestBasicSetGet(t *testing.T) {
cache.Set("foo", 1)
cache.Set("bar", 2)

val, ok := cache.Get("foo")
assert.True(t, ok)
val, err := cache.Get("foo")
assert.Nil(t, err)
assert.Equal(t, 1, val)

val, ok = cache.Get("bar")
assert.True(t, ok)
val, err = cache.Get("bar")
assert.Nil(t, err)
assert.Equal(t, 2, val)
}

func TestBasicSetOverwrite(t *testing.T) {
cache := New(Config{Capacity: 2})
cache.Set("foo", 1)
evict := cache.Set("foo", 2)
val, ok := cache.Get("foo")
val, err := cache.Get("foo")

assert.False(t, evict)
assert.True(t, ok)
assert.Nil(t, err)
assert.Equal(t, 2, val)
}

Expand All @@ -73,10 +73,10 @@ func TestEviction(t *testing.T) {
cache.Set("foo", 1)
cache.Set("bar", 2)
evict := cache.Set("baz", 3)
val, ok := cache.Get("foo")
val, err := cache.Get("foo")

assert.True(t, evict)
assert.False(t, ok)
assert.Equal(t, ErrNotFound, err)
assert.Nil(t, val)
assert.Equal(t, "foo", k)
assert.Equal(t, 1, v)
Expand All @@ -101,8 +101,8 @@ func TestExpiration(t *testing.T) {
cache.Set("foo", 1)
<-time.After(time.Millisecond * 2)

val, ok := cache.Get("foo")
assert.False(t, ok)
val, err := cache.Get("foo")
assert.Equal(t, ErrNotFound, err)
assert.Nil(t, val)
assert.Equal(t, "foo", k)
assert.Equal(t, 1, v)
Expand All @@ -128,12 +128,12 @@ func TestJitter(t *testing.T) {
cache.Set("foo", "bar")

<-time.After(time.Millisecond * 2)
_, ok := cache.Get("foo")
assert.True(t, ok)
_, err := cache.Get("foo")
assert.Nil(t, err)

<-time.After(time.Millisecond * 3)
_, ok = cache.Get("foo")
assert.False(t, ok)
_, err = cache.Get("foo")
assert.Equal(t, ErrNotFound, err)
}

func TestHas(t *testing.T) {
Expand Down Expand Up @@ -171,8 +171,8 @@ func TestRemove(t *testing.T) {
assert.True(t, ok)
assert.False(t, eviction)

val, ok := cache.Get("foo")
assert.False(t, ok)
val, err := cache.Get("foo")
assert.Equal(t, ErrNotFound, err)
assert.Nil(t, val)
}

Expand All @@ -192,8 +192,8 @@ func TestEvictOldest(t *testing.T) {
assert.True(t, ok)
assert.True(t, eviction)

val, ok := cache.Get("foo")
assert.False(t, ok)
val, err := cache.Get("foo")
assert.Equal(t, ErrNotFound, err)
assert.Nil(t, val)

eviction = false
Expand Down Expand Up @@ -222,8 +222,8 @@ func TestClear(t *testing.T) {
cache.Clear()

for i := 0; i <= 9; i++ {
_, ok := cache.Get(i)
assert.False(t, ok)
_, err := cache.Get(i)
assert.Equal(t, ErrNotFound, err)
}
assert.Equal(t, 0, cache.Len())
}
Expand Down

0 comments on commit be22292

Please sign in to comment.