diff --git a/jrpc2/client.go b/jrpc2/client.go index a3c0aaa..135f658 100644 --- a/jrpc2/client.go +++ b/jrpc2/client.go @@ -47,6 +47,8 @@ func New(url string) *Client { pollDuration: time.Second, url: url, lcache: NumHash{maxreads: 20}, + bcache: cache{maxreads: 20}, + hcache: cache{maxreads: 20}, } } @@ -66,6 +68,8 @@ type Client struct { func (c *Client) WithMaxReads(n int) *Client { c.lcache.maxreads = n + c.bcache.maxreads = n + c.hcache.maxreads = n return c } @@ -422,18 +426,28 @@ type blockResp struct { type segment struct { sync.Mutex - done bool - d []eth.Block + nreads int + done bool + d []eth.Block } type cache struct { sync.Mutex + maxreads int segments map[key]*segment } type getter func(ctx context.Context, start, limit uint64) ([]eth.Block, error) -func (c *cache) prune() { +func (c *cache) pruneMaxRead() { + for k, v := range c.segments { + if v.nreads >= c.maxreads { + delete(c.segments, k) + } + } +} + +func (c *cache) pruneSegments() { const size = 5 if len(c.segments) <= size { return @@ -458,16 +472,18 @@ func (c *cache) get(nocache bool, ctx context.Context, start, limit uint64, f ge if c.segments == nil { c.segments = make(map[key]*segment) } + c.pruneMaxRead() seg, ok := c.segments[key{start, limit}] if !ok { seg = &segment{} c.segments[key{start, limit}] = seg } - c.prune() + c.pruneSegments() c.Unlock() seg.Lock() defer seg.Unlock() + seg.nreads++ if seg.done { return seg.d, nil } diff --git a/jrpc2/client_test.go b/jrpc2/client_test.go index 11e543f..ee7dcb5 100644 --- a/jrpc2/client_test.go +++ b/jrpc2/client_test.go @@ -18,6 +18,7 @@ import ( "github.com/indexsupply/x/eth" "github.com/indexsupply/x/shovel/glf" + "github.com/indexsupply/x/tc" "golang.org/x/sync/errgroup" "kr.dev/diff" ) @@ -44,7 +45,7 @@ func (tg *testGetter) get(ctx context.Context, start, limit uint64) ([]eth.Block func TestCache_Prune(t *testing.T) { ctx := context.Background() tg := testGetter{} - c := cache{} + c := cache{maxreads: 2} blocks, err := c.get(false, ctx, 1, 1, tg.get) diff.Test(t, t.Fatalf, nil, err) diff.Test(t, t.Errorf, 1, len(blocks)) @@ -76,6 +77,25 @@ func TestCache_Prune(t *testing.T) { }) } +func TestCache_MaxReads(t *testing.T) { + var ( + ctx = context.Background() + tg = testGetter{} + c = cache{maxreads: 2} + ) + _, err := c.get(false, ctx, 1, 1, tg.get) + tc.NoErr(t, err) + tc.WantGot(t, 1, tg.callCount) + + _, err = c.get(false, ctx, 1, 1, tg.get) + tc.NoErr(t, err) + tc.WantGot(t, 1, tg.callCount) + + _, err = c.get(false, ctx, 1, 1, tg.get) + tc.NoErr(t, err) + tc.WantGot(t, 2, tg.callCount) +} + var ( //go:embed testdata/block-18000000.json block18000000JSON string @@ -356,6 +376,39 @@ func TestGet_Cached(t *testing.T) { eg.Wait() } +// Test that a block cache removes its segments after +// they've been read N times. Once N is reached, subsequent +// calls to Get should make new requests. +func TestGet_Cached_Pruned(t *testing.T) { + var n int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + diff.Test(t, t.Fatalf, nil, err) + switch { + case strings.Contains(string(body), "eth_getBlockByNumber"): + atomic.AddInt32(&n, 1) + _, err := w.Write([]byte(block18000000JSON)) + diff.Test(t, t.Fatalf, nil, err) + } + })) + defer ts.Close() + var ( + ctx = context.Background() + c = New(ts.URL).WithMaxReads(2) + ) + _, err := c.Get(ctx, &glf.Filter{UseHeaders: true}, 18000000, 1) + diff.Test(t, t.Errorf, nil, err) + diff.Test(t, t.Errorf, n, int32(1)) + _, err = c.Get(ctx, &glf.Filter{UseHeaders: true}, 18000000, 1) + diff.Test(t, t.Errorf, nil, err) + diff.Test(t, t.Errorf, n, int32(1)) + + //maxreads should have been reached with last 2 calls + _, err = c.Get(ctx, &glf.Filter{UseHeaders: true}, 18000000, 1) + diff.Test(t, t.Errorf, nil, err) + diff.Test(t, t.Errorf, n, int32(2)) +} + func TestNoLogs(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body)