Skip to content

Commit

Permalink
shovel/jrpc2: fix concurrency problems unearthed by -race
Browse files Browse the repository at this point in the history
- Moved the mutex out of the Block.Tx method and instead rely on the
  callers to lock the block prior to mutation

- Removed the call to validate in Get and instead have each rpc related
  method implement its own validation

- Pruning maxread data requires locking the segemnt & cache
  • Loading branch information
ryandotsmith committed May 1, 2024
1 parent 2969ab4 commit d08c9fc
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
2 changes: 0 additions & 2 deletions eth/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ func (b Block) String() string {
}

func (b *Block) Tx(idx uint64) *Tx {
b.Lock()
defer b.Unlock()
for i := range b.Txs {
if uint64(b.Txs[i].Idx) == idx {
return &b.Txs[i]
Expand Down
45 changes: 26 additions & 19 deletions jrpc2/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,7 @@ type key struct {
a, b uint64
}

type (
blockmap map[uint64]*eth.Block
)
type blockmap map[uint64]*eth.Block

func (c *Client) Get(
ctx context.Context,
Expand Down Expand Up @@ -426,7 +424,7 @@ func (c *Client) Get(
return nil, fmt.Errorf("getting traces: %w", err)
}
}
return blocks, validate("Get", start, limit, blocks)
return blocks, nil
}

type blockResp struct {
Expand All @@ -451,9 +449,11 @@ type getter func(ctx context.Context, start, limit uint64) ([]eth.Block, error)

func (c *cache) pruneMaxRead() {
for k, v := range c.segments {
v.Lock()
if v.nreads >= c.maxreads {
delete(c.segments, k)
}
v.Unlock()
}
}

Expand Down Expand Up @@ -546,22 +546,14 @@ func validate(caller string, start, limit uint64, blocks []eth.Block) error {
if len(blocks) == 0 {
return fmt.Errorf("%s: no blocks", caller)
}

first, last := blocks[0], blocks[len(blocks)-1]
if uint64(first.Num()) != start {
first, last := blocks[0].Num(), blocks[len(blocks)-1].Num()
if uint64(first) != start {
const tag = "%s: rpc response contains invalid data. requested first: %d got: %d"
return fmt.Errorf(tag, caller, start, first.Num())
return fmt.Errorf(tag, caller, start, first)
}
if uint64(last.Num()) != start+limit-1 {
if uint64(last) != start+limit-1 {
const tag = "%s: rpc response contains invalid data. requested last: %d got: %d"
return fmt.Errorf(tag, caller, start+limit-1, last.Num())
}

// some rpc responses will not return a parent hash
// so there is nothing we can do to validate the hash
// chain
if len(blocks) <= 1 || len(blocks[0].Header.Parent) < 32 {
return nil
return fmt.Errorf(tag, caller, start+limit-1, last)
}
for i := 1; i < len(blocks); i++ {
prev, curr := blocks[i-1], blocks[i]
Expand Down Expand Up @@ -663,7 +655,12 @@ func (c *Client) receipts(ctx context.Context, bm blockmap, start, limit uint64)
if len(resps[i].Result) == 0 {
return fmt.Errorf("no rpc error but empty result")
}
b, ok := bm[uint64(resps[i].Result[0].BlockNum)]
blockNum := uint64(resps[i].Result[0].BlockNum)
if blockNum < start || blockNum > start+limit {
const tag = "eth_getBlockReceipts out of range block. num=%d start=%d lim=%d"
return fmt.Errorf(tag, blockNum, start, limit)
}
b, ok := bm[blockNum]
if !ok {
return fmt.Errorf("block not found")
}
Expand Down Expand Up @@ -727,7 +724,15 @@ func (c *Client) logs(ctx context.Context, filter *glf.Filter, bm blockmap, star
}
var logsByTx = map[key][]logResult{}
for i := range lresp.Result {
k := key{uint64(lresp.Result[i].BlockNum), uint64(lresp.Result[i].TxIdx)}
var (
blockNum = uint64(lresp.Result[i].BlockNum)
txIdx = uint64(lresp.Result[i].TxIdx)
k = key{blockNum, txIdx}
)
if blockNum < start || blockNum > start+limit {
const tag = "eth_getLogs out of range block. num=%d start=%d lim=%d"
return fmt.Errorf(tag, blockNum, start, limit)
}
if logs, ok := logsByTx[k]; ok {
logsByTx[k] = append(logs, lresp.Result[i])
continue
Expand All @@ -740,12 +745,14 @@ func (c *Client) logs(ctx context.Context, filter *glf.Filter, bm blockmap, star
if !ok {
return fmt.Errorf("block not found")
}
b.Lock()
b.Header.Hash.Write(logs[0].BlockHash)
tx := b.Tx(k.b)
tx.PrecompHash.Write(logs[0].TxHash)
for i := range logs {
tx.Logs.Add(logs[i].Log)
}
b.Unlock()
}
slog.Debug("http get logs",
"start", start,
Expand Down
11 changes: 7 additions & 4 deletions jrpc2/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,21 +352,24 @@ func TestGet_Cached(t *testing.T) {
var (
ctx = context.Background()
c = New(ts.URL)
findTx = func(b eth.Block, idx uint64) (eth.Tx, error) {
findTx = func(b *eth.Block, idx uint64) (*eth.Tx, error) {
for i := range b.Txs {
if b.Txs[i].Idx == eth.Uint64(idx) {
return b.Txs[i], nil
return &b.Txs[i], nil
}
}
return eth.Tx{}, fmt.Errorf("no tx at idx %d", idx)
return nil, fmt.Errorf("no tx at idx %d", idx)
}
getcall = func() error {
blocks, err := c.Get(ctx, &glf.Filter{UseHeaders: true, UseLogs: true}, 18000000, 1)
diff.Test(t, t.Errorf, nil, err)

blocks[0].Lock()
diff.Test(t, t.Errorf, len(blocks[0].Txs), 65)
tx, err := findTx(blocks[0], 0)
tx, err := findTx(&blocks[0], 0)
diff.Test(t, t.Errorf, nil, err)
diff.Test(t, t.Errorf, len(tx.Logs), 1)
blocks[0].Unlock()
return nil
}
)
Expand Down

0 comments on commit d08c9fc

Please sign in to comment.