Skip to content

Commit

Permalink
bug: SIKFromAddress data race (#1199)
Browse files Browse the repository at this point in the history
* use treeTxWithMutex to prevent data races on sik operations
  • Loading branch information
lucasmenendez authored Nov 23, 2023
1 parent cc54c3b commit 101d965
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 22 deletions.
46 changes: 26 additions & 20 deletions vochain/state/sik.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ type SIK []byte
// SIKFromAddress function return the current SIK value associated to the provided
// address.
func (v *State) SIKFromAddress(address common.Address) (SIK, error) {
sik, err := v.mainTreeViewer(false).DeepGet(address.Bytes(), StateTreeCfg(TreeSIK))
v.tx.RLock()
defer v.tx.RUnlock()

siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrSIKSubTree, err)
}
sik, err := siksTree.Get(address.Bytes())
if err != nil {
if errors.Is(err, arbo.ErrKeyNotFound) {
return nil, fmt.Errorf("%w: %w", ErrSIKNotFound, err)
Expand All @@ -62,24 +69,21 @@ func (v *State) SIKFromAddress(address common.Address) (SIK, error) {
// - If it exists but it is not valid, overwrite the stored value with the
// provided one.
func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error {
v.tx.Lock()
defer v.tx.Unlock()
siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
return fmt.Errorf("%w: %w", ErrSIKSubTree, err)
}
// check if exists a registered sik for the provided address, query also for
// no committed tree version
v.tx.Lock()
rawSIK, err := siksTree.Get(address.Bytes())
v.tx.Unlock()
if errors.Is(err, arbo.ErrKeyNotFound) {
// if not exists create it
log.Debugw("setSIK (create)",
"address", address.String(),
"sik", newSIK.String())
v.tx.Lock()
err = siksTree.Add(address.Bytes(), newSIK)
v.tx.Unlock()
if err != nil {
if err := siksTree.Add(address.Bytes(), newSIK); err != nil {
return fmt.Errorf("%w: %w", ErrSIKSet, err)
}
return nil
Expand All @@ -95,10 +99,7 @@ func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error {
"address", address.String(),
"sik", SIK(rawSIK).String())
// if the hysteresis is reached update the sik for the address
v.tx.Lock()
err = siksTree.Set(address.Bytes(), newSIK)
v.tx.Unlock()
if err != nil {
if err := siksTree.Set(address.Bytes(), newSIK); err != nil {
return fmt.Errorf("%w: %w", ErrSIKSet, err)
}
return nil
Expand All @@ -110,20 +111,23 @@ func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error {
// prevent it from being updated until all processes created before that height
// have finished.
func (v *State) InvalidateSIK(address common.Address) error {
v.tx.Lock()
defer v.tx.Unlock()
siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
return fmt.Errorf("%w: %w", ErrSIKSubTree, err)
}
// if the sik does not exists or something fails querying return the error
rawSIK, err := v.mainTreeViewer(false).DeepGet(address.Bytes(), StateTreeCfg(TreeSIK))
rawSIK, err := siksTree.Get(address.Bytes())
if err != nil {
return fmt.Errorf("%w: %w", ErrSIKGet, err)
}
// if the stored sik is already invalidated return an error
if !SIK(rawSIK).Valid() {
return ErrSIKAlreadyInvalid
}
v.tx.Lock()
invalidatedSIK := make(SIK, sikLeafValueLen).InvalidateAt(v.CurrentHeight())
err = v.tx.DeepSet(address.Bytes(), invalidatedSIK, StateTreeCfg(TreeSIK))
v.tx.Unlock()
if err != nil {
if err := siksTree.Set(address.Bytes(), invalidatedSIK); err != nil {
return fmt.Errorf("%w: %w", ErrSIKDelete, err)
}
return nil
Expand Down Expand Up @@ -187,6 +191,8 @@ func (v *State) UpdateSIKRoots() error {
currentBlock := v.CurrentHeight()

// get sik roots key-value database associated to the siks tree
v.tx.RLock()
defer v.tx.RUnlock()
siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
return fmt.Errorf("%w: %w", ErrSIKSubTree, err)
Expand Down Expand Up @@ -370,8 +376,8 @@ func (v *State) PurgeSIKsByElection(pid []byte) error {
// SIKGenProof returns the proof of the provided address in the SIKs tree.
// The first returned value is the leaf value and the second the proof siblings.
func (v *State) SIKGenProof(address common.Address) ([]byte, []byte, error) {
v.tx.Lock()
defer v.tx.Unlock()
v.tx.RLock()
defer v.tx.RUnlock()
siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
return nil, nil, fmt.Errorf("%w: %w", ErrSIKSubTree, err)
Expand All @@ -382,8 +388,8 @@ func (v *State) SIKGenProof(address common.Address) ([]byte, []byte, error) {

// SIKRoot returns the last root hash of the SIK merkle tree.
func (v *State) SIKRoot() ([]byte, error) {
v.tx.Lock()
defer v.tx.Unlock()
v.tx.RLock()
defer v.tx.RUnlock()
siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK))
if err != nil {
v.tx.Unlock()
Expand Down
50 changes: 50 additions & 0 deletions vochain/state/sik_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package state
import (
"bytes"
"encoding/hex"
"errors"
"sync"
"testing"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -205,3 +207,51 @@ func Test_validSIK(t *testing.T) {
input, _ = hex.DecodeString("F3668000B66c61aAa08aBC559a8C78Ae7E007C2e")
qt.Assert(t, SIK(input).Valid(), qt.IsTrue)
}

func TestSIKDataRace(t *testing.T) {
c := qt.New(t)
// create a state for testing
dir := t.TempDir()
s, err := NewState(db.TypePebble, dir)
qt.Assert(t, err, qt.IsNil)

// create some siks
addrs := []common.Address{}
siks := map[common.Address]SIK{}
for i := 0; i < 10000; i++ {
s := ethereum.NewSignKeys()
c.Assert(s.Generate(), qt.IsNil)
sik, err := s.AccountSIK(nil)
c.Assert(err, qt.IsNil)
addrs = append(addrs, s.Address())
siks[s.Address()] = sik
}

wg := &sync.WaitGroup{}
iterations := len(addrs) * 10

wg.Add(2)
go func() {
defer wg.Done()

for i := 0; i < iterations; i++ {
idx := util.RandomInt(0, len(addrs)-1)
addr := addrs[idx]
if err := s.SetAddressSIK(addr, siks[addr]); err != nil {
c.Assert(errors.Is(err, ErrRegisteredValidSIK), qt.IsTrue)
}
}
}()
go func() {
defer wg.Done()

idx := util.RandomInt(0, len(addrs)-1)
addr := addrs[idx]
if _, err := s.SIKFromAddress(addr); err != nil {
c.Assert(errors.Is(err, ErrSIKNotFound), qt.IsTrue)
return
}
c.Assert(s.InvalidateSIK(addr), qt.IsNil)
}()
wg.Wait()
}
4 changes: 2 additions & 2 deletions vochain/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,12 @@ func (v *State) Save() ([]byte, error) {
// Commit the statedb tx
// Note that we need to commit the tx after calling listeners, because
// the listeners may need to get the previous (not committed) state.
v.tx.Lock()
defer v.tx.Unlock()
// Update the SIK merkle-tree roots
if err := v.UpdateSIKRoots(); err != nil {
return nil, fmt.Errorf("cannot update SIK roots: %w", err)
}
v.tx.Lock()
defer v.tx.Unlock()
err := func() error {
var err error
if err := v.tx.SaveWithoutCommit(); err != nil {
Expand Down

0 comments on commit 101d965

Please sign in to comment.