diff --git a/eth/types.go b/eth/types.go index 7799370..5faf9db 100644 --- a/eth/types.go +++ b/eth/types.go @@ -135,8 +135,6 @@ func (b Block) String() string { } func (b *Block) Tx(idx uint64) *Tx { - b.Lock() - defer b.Unlock() for i := range b.Txs { if uint64(b.Txs[i].Idx) == idx { return &b.Txs[i] diff --git a/jrpc2/client.go b/jrpc2/client.go index c15a8e4..75f2b17 100644 --- a/jrpc2/client.go +++ b/jrpc2/client.go @@ -363,9 +363,7 @@ type key struct { a, b uint64 } -type ( - blockmap map[uint64]*eth.Block -) +type blockmap map[uint64]*eth.Block func (c *Client) Get( ctx context.Context, @@ -426,7 +424,7 @@ func (c *Client) Get( return nil, fmt.Errorf("getting traces: %w", err) } } - return blocks, validate("Get", start, limit, blocks) + return blocks, nil } type blockResp struct { @@ -451,9 +449,11 @@ type getter func(ctx context.Context, start, limit uint64) ([]eth.Block, error) func (c *cache) pruneMaxRead() { for k, v := range c.segments { + v.Lock() if v.nreads >= c.maxreads { delete(c.segments, k) } + v.Unlock() } } @@ -546,22 +546,14 @@ func validate(caller string, start, limit uint64, blocks []eth.Block) error { if len(blocks) == 0 { return fmt.Errorf("%s: no blocks", caller) } - - first, last := blocks[0], blocks[len(blocks)-1] - if uint64(first.Num()) != start { + first, last := blocks[0].Num(), blocks[len(blocks)-1].Num() + if uint64(first) != start { const tag = "%s: rpc response contains invalid data. requested first: %d got: %d" - return fmt.Errorf(tag, caller, start, first.Num()) + return fmt.Errorf(tag, caller, start, first) } - if uint64(last.Num()) != start+limit-1 { + if uint64(last) != start+limit-1 { const tag = "%s: rpc response contains invalid data. requested last: %d got: %d" - return fmt.Errorf(tag, caller, start+limit-1, last.Num()) - } - - // some rpc responses will not return a parent hash - // so there is nothing we can do to validate the hash - // chain - if len(blocks) <= 1 || len(blocks[0].Header.Parent) < 32 { - return nil + return fmt.Errorf(tag, caller, start+limit-1, last) } for i := 1; i < len(blocks); i++ { prev, curr := blocks[i-1], blocks[i] @@ -663,7 +655,12 @@ func (c *Client) receipts(ctx context.Context, bm blockmap, start, limit uint64) if len(resps[i].Result) == 0 { return fmt.Errorf("no rpc error but empty result") } - b, ok := bm[uint64(resps[i].Result[0].BlockNum)] + blockNum := uint64(resps[i].Result[0].BlockNum) + if blockNum < start || blockNum > start+limit { + const tag = "eth_getBlockReceipts out of range block. num=%d start=%d lim=%d" + return fmt.Errorf(tag, blockNum, start, limit) + } + b, ok := bm[blockNum] if !ok { return fmt.Errorf("block not found") } @@ -727,7 +724,15 @@ func (c *Client) logs(ctx context.Context, filter *glf.Filter, bm blockmap, star } var logsByTx = map[key][]logResult{} for i := range lresp.Result { - k := key{uint64(lresp.Result[i].BlockNum), uint64(lresp.Result[i].TxIdx)} + var ( + blockNum = uint64(lresp.Result[i].BlockNum) + txIdx = uint64(lresp.Result[i].TxIdx) + k = key{blockNum, txIdx} + ) + if blockNum < start || blockNum > start+limit { + const tag = "eth_getLogs out of range block. num=%d start=%d lim=%d" + return fmt.Errorf(tag, blockNum, start, limit) + } if logs, ok := logsByTx[k]; ok { logsByTx[k] = append(logs, lresp.Result[i]) continue @@ -740,12 +745,14 @@ func (c *Client) logs(ctx context.Context, filter *glf.Filter, bm blockmap, star if !ok { return fmt.Errorf("block not found") } + b.Lock() b.Header.Hash.Write(logs[0].BlockHash) tx := b.Tx(k.b) tx.PrecompHash.Write(logs[0].TxHash) for i := range logs { tx.Logs.Add(logs[i].Log) } + b.Unlock() } slog.Debug("http get logs", "start", start, diff --git a/jrpc2/client_test.go b/jrpc2/client_test.go index ee7dcb5..5222c77 100644 --- a/jrpc2/client_test.go +++ b/jrpc2/client_test.go @@ -352,21 +352,24 @@ func TestGet_Cached(t *testing.T) { var ( ctx = context.Background() c = New(ts.URL) - findTx = func(b eth.Block, idx uint64) (eth.Tx, error) { + findTx = func(b *eth.Block, idx uint64) (*eth.Tx, error) { for i := range b.Txs { if b.Txs[i].Idx == eth.Uint64(idx) { - return b.Txs[i], nil + return &b.Txs[i], nil } } - return eth.Tx{}, fmt.Errorf("no tx at idx %d", idx) + return nil, fmt.Errorf("no tx at idx %d", idx) } getcall = func() error { blocks, err := c.Get(ctx, &glf.Filter{UseHeaders: true, UseLogs: true}, 18000000, 1) diff.Test(t, t.Errorf, nil, err) + + blocks[0].Lock() diff.Test(t, t.Errorf, len(blocks[0].Txs), 65) - tx, err := findTx(blocks[0], 0) + tx, err := findTx(&blocks[0], 0) diff.Test(t, t.Errorf, nil, err) diff.Test(t, t.Errorf, len(tx.Logs), 1) + blocks[0].Unlock() return nil } )