From 101d965c84fa105d54e3f52f6c552d62e06daa27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Men=C3=A9ndez?= Date: Thu, 23 Nov 2023 18:42:49 +0100 Subject: [PATCH] bug: SIKFromAddress data race (#1199) * use treeTxWithMutex to prevent data races on sik operations --- vochain/state/sik.go | 46 +++++++++++++++++++---------------- vochain/state/sik_test.go | 50 +++++++++++++++++++++++++++++++++++++++ vochain/state/state.go | 4 ++-- 3 files changed, 78 insertions(+), 22 deletions(-) diff --git a/vochain/state/sik.go b/vochain/state/sik.go index 194f6996d..ac6d70723 100644 --- a/vochain/state/sik.go +++ b/vochain/state/sik.go @@ -43,7 +43,14 @@ type SIK []byte // SIKFromAddress function return the current SIK value associated to the provided // address. func (v *State) SIKFromAddress(address common.Address) (SIK, error) { - sik, err := v.mainTreeViewer(false).DeepGet(address.Bytes(), StateTreeCfg(TreeSIK)) + v.tx.RLock() + defer v.tx.RUnlock() + + siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrSIKSubTree, err) + } + sik, err := siksTree.Get(address.Bytes()) if err != nil { if errors.Is(err, arbo.ErrKeyNotFound) { return nil, fmt.Errorf("%w: %w", ErrSIKNotFound, err) @@ -62,24 +69,21 @@ func (v *State) SIKFromAddress(address common.Address) (SIK, error) { // - If it exists but it is not valid, overwrite the stored value with the // provided one. func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error { + v.tx.Lock() + defer v.tx.Unlock() siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) if err != nil { return fmt.Errorf("%w: %w", ErrSIKSubTree, err) } // check if exists a registered sik for the provided address, query also for // no committed tree version - v.tx.Lock() rawSIK, err := siksTree.Get(address.Bytes()) - v.tx.Unlock() if errors.Is(err, arbo.ErrKeyNotFound) { // if not exists create it log.Debugw("setSIK (create)", "address", address.String(), "sik", newSIK.String()) - v.tx.Lock() - err = siksTree.Add(address.Bytes(), newSIK) - v.tx.Unlock() - if err != nil { + if err := siksTree.Add(address.Bytes(), newSIK); err != nil { return fmt.Errorf("%w: %w", ErrSIKSet, err) } return nil @@ -95,10 +99,7 @@ func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error { "address", address.String(), "sik", SIK(rawSIK).String()) // if the hysteresis is reached update the sik for the address - v.tx.Lock() - err = siksTree.Set(address.Bytes(), newSIK) - v.tx.Unlock() - if err != nil { + if err := siksTree.Set(address.Bytes(), newSIK); err != nil { return fmt.Errorf("%w: %w", ErrSIKSet, err) } return nil @@ -110,8 +111,14 @@ func (v *State) SetAddressSIK(address common.Address, newSIK SIK) error { // prevent it from being updated until all processes created before that height // have finished. func (v *State) InvalidateSIK(address common.Address) error { + v.tx.Lock() + defer v.tx.Unlock() + siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) + if err != nil { + return fmt.Errorf("%w: %w", ErrSIKSubTree, err) + } // if the sik does not exists or something fails querying return the error - rawSIK, err := v.mainTreeViewer(false).DeepGet(address.Bytes(), StateTreeCfg(TreeSIK)) + rawSIK, err := siksTree.Get(address.Bytes()) if err != nil { return fmt.Errorf("%w: %w", ErrSIKGet, err) } @@ -119,11 +126,8 @@ func (v *State) InvalidateSIK(address common.Address) error { if !SIK(rawSIK).Valid() { return ErrSIKAlreadyInvalid } - v.tx.Lock() invalidatedSIK := make(SIK, sikLeafValueLen).InvalidateAt(v.CurrentHeight()) - err = v.tx.DeepSet(address.Bytes(), invalidatedSIK, StateTreeCfg(TreeSIK)) - v.tx.Unlock() - if err != nil { + if err := siksTree.Set(address.Bytes(), invalidatedSIK); err != nil { return fmt.Errorf("%w: %w", ErrSIKDelete, err) } return nil @@ -187,6 +191,8 @@ func (v *State) UpdateSIKRoots() error { currentBlock := v.CurrentHeight() // get sik roots key-value database associated to the siks tree + v.tx.RLock() + defer v.tx.RUnlock() siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) if err != nil { return fmt.Errorf("%w: %w", ErrSIKSubTree, err) @@ -370,8 +376,8 @@ func (v *State) PurgeSIKsByElection(pid []byte) error { // SIKGenProof returns the proof of the provided address in the SIKs tree. // The first returned value is the leaf value and the second the proof siblings. func (v *State) SIKGenProof(address common.Address) ([]byte, []byte, error) { - v.tx.Lock() - defer v.tx.Unlock() + v.tx.RLock() + defer v.tx.RUnlock() siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) if err != nil { return nil, nil, fmt.Errorf("%w: %w", ErrSIKSubTree, err) @@ -382,8 +388,8 @@ func (v *State) SIKGenProof(address common.Address) ([]byte, []byte, error) { // SIKRoot returns the last root hash of the SIK merkle tree. func (v *State) SIKRoot() ([]byte, error) { - v.tx.Lock() - defer v.tx.Unlock() + v.tx.RLock() + defer v.tx.RUnlock() siksTree, err := v.tx.DeepSubTree(StateTreeCfg(TreeSIK)) if err != nil { v.tx.Unlock() diff --git a/vochain/state/sik_test.go b/vochain/state/sik_test.go index bf4d323b4..77a335cfc 100644 --- a/vochain/state/sik_test.go +++ b/vochain/state/sik_test.go @@ -3,6 +3,8 @@ package state import ( "bytes" "encoding/hex" + "errors" + "sync" "testing" "github.com/ethereum/go-ethereum/common" @@ -205,3 +207,51 @@ func Test_validSIK(t *testing.T) { input, _ = hex.DecodeString("F3668000B66c61aAa08aBC559a8C78Ae7E007C2e") qt.Assert(t, SIK(input).Valid(), qt.IsTrue) } + +func TestSIKDataRace(t *testing.T) { + c := qt.New(t) + // create a state for testing + dir := t.TempDir() + s, err := NewState(db.TypePebble, dir) + qt.Assert(t, err, qt.IsNil) + + // create some siks + addrs := []common.Address{} + siks := map[common.Address]SIK{} + for i := 0; i < 10000; i++ { + s := ethereum.NewSignKeys() + c.Assert(s.Generate(), qt.IsNil) + sik, err := s.AccountSIK(nil) + c.Assert(err, qt.IsNil) + addrs = append(addrs, s.Address()) + siks[s.Address()] = sik + } + + wg := &sync.WaitGroup{} + iterations := len(addrs) * 10 + + wg.Add(2) + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + idx := util.RandomInt(0, len(addrs)-1) + addr := addrs[idx] + if err := s.SetAddressSIK(addr, siks[addr]); err != nil { + c.Assert(errors.Is(err, ErrRegisteredValidSIK), qt.IsTrue) + } + } + }() + go func() { + defer wg.Done() + + idx := util.RandomInt(0, len(addrs)-1) + addr := addrs[idx] + if _, err := s.SIKFromAddress(addr); err != nil { + c.Assert(errors.Is(err, ErrSIKNotFound), qt.IsTrue) + return + } + c.Assert(s.InvalidateSIK(addr), qt.IsNil) + }() + wg.Wait() +} diff --git a/vochain/state/state.go b/vochain/state/state.go index 152f25bfc..e8f098dbd 100644 --- a/vochain/state/state.go +++ b/vochain/state/state.go @@ -348,12 +348,12 @@ func (v *State) Save() ([]byte, error) { // Commit the statedb tx // Note that we need to commit the tx after calling listeners, because // the listeners may need to get the previous (not committed) state. - v.tx.Lock() - defer v.tx.Unlock() // Update the SIK merkle-tree roots if err := v.UpdateSIKRoots(); err != nil { return nil, fmt.Errorf("cannot update SIK roots: %w", err) } + v.tx.Lock() + defer v.tx.Unlock() err := func() error { var err error if err := v.tx.SaveWithoutCommit(); err != nil {