Skip to content

Commit

Permalink
Merge pull request from GHSA-p228-4mrh-ww7r
Browse files Browse the repository at this point in the history
SearchFirst call when Peek does not found data for smart contract results
  • Loading branch information
iulianpascalau authored Nov 20, 2022
2 parents 49512c4 + 32abcc9 commit 8971ba7
Show file tree
Hide file tree
Showing 8 changed files with 540 additions and 100 deletions.
8 changes: 6 additions & 2 deletions process/block/preprocess/basePreProcess.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ func (bpp *basePreProcess) computeExistingAndRequestMissing(
}

txShardInfoObject := &txShardInfo{senderShardID: miniBlock.SenderShardID, receiverShardID: miniBlock.ReceiverShardID}
searchFirst := miniBlock.Type == block.InvalidBlock
method := process.SearchMethodJustPeek
if miniBlock.Type == block.InvalidBlock {
method = process.SearchMethodSearchFirst
}

for j := 0; j < len(miniBlock.TxHashes); j++ {
txHash := miniBlock.TxHashes[j]

Expand All @@ -338,7 +342,7 @@ func (bpp *basePreProcess) computeExistingAndRequestMissing(
miniBlock.ReceiverShardID,
txHash,
txPool,
searchFirst)
method)

if err != nil {
txHashes = append(txHashes, txHash)
Expand Down
2 changes: 1 addition & 1 deletion process/block/preprocess/rewardTxPreProcessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func (rtp *rewardTxPreprocessor) computeMissingRewardTxsForMiniBlock(miniBlock *
miniBlock.ReceiverShardID,
txHash,
rtp.rewardTxPool,
false,
process.SearchMethodJustPeek,
)

if tx == nil {
Expand Down
11 changes: 9 additions & 2 deletions process/block/preprocess/smartContractResults.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (scr *smartContractResults) computeMissingScrsForMiniBlock(miniBlock *block
miniBlock.ReceiverShardID,
txHash,
scr.scrPool,
false)
process.SearchMethodPeekWithFallbackSearchFirst)

if check.IfNil(tx) {
missingSmartContractResults = append(missingSmartContractResults, txHash)
Expand Down Expand Up @@ -484,7 +484,14 @@ func (scr *smartContractResults) getAllScrsFromMiniBlock(

tmp, _ := txCache.Peek(txHash)
if tmp == nil {
return nil, nil, process.ErrNilSmartContractResult
tmp, _ = scr.scrPool.SearchFirstData(txHash)
if tmp == nil {
return nil, nil, process.ErrNilSmartContractResult
}

log.Debug("scr hash not found with Peek method but found with SearchFirstData",
"scr hash", txHash,
"strCache", strCache)
}

tx, ok := tmp.(*smartContractResult.SmartContractResult)
Expand Down
97 changes: 96 additions & 1 deletion process/block/preprocess/smartContractResults_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,13 @@ func TestScrsPreProcessor_GetTransactionFromPool(t *testing.T) {
)

txHash := []byte("tx1_hash")
tx, _ := process.GetTransactionHandlerFromPool(1, 1, txHash, tdp.UnsignedTransactions(), false)
tx, _ := process.GetTransactionHandlerFromPool(
1,
1,
txHash,
tdp.UnsignedTransactions(),
process.SearchMethodPeekWithFallbackSearchFirst,
)
assert.NotNil(t, txs)
assert.NotNil(t, tx)
assert.Equal(t, uint64(10), tx.(*smartContractResult.SmartContractResult).Nonce)
Expand Down Expand Up @@ -752,6 +758,95 @@ func TestScrsPreprocessor_GetAllTxsFromMiniBlockShouldWork(t *testing.T) {
}
}

func TestScrsPreprocessor_GetAllTxsFromMiniBlockShouldWorkEvenIfScrIsMisplaced(t *testing.T) {
t.Parallel()

hasher := &hashingMocks.HasherMock{}
marshalizer := &mock.MarshalizerMock{}
dataPool := dataRetrieverMock.NewPoolsHolderMock()
senderShardId := uint32(0)
destinationShardId := uint32(1)

txsSlice := []*smartContractResult.SmartContractResult{
{Nonce: 1},
{Nonce: 2},
{Nonce: 3},
}
transactionsHashes := make([][]byte, len(txsSlice))

// add defined transactions to sender-destination cacher
for idx, tx := range txsSlice {
transactionsHashes[idx] = computeHash(tx, marshalizer, hasher)

if idx < len(txsSlice)-1 {
// place the first scrs correctly in pool
dataPool.UnsignedTransactions().AddData(
transactionsHashes[idx],
tx,
tx.Size(),
process.ShardCacherIdentifier(senderShardId, destinationShardId),
)
} else {
// misplace the last one
dataPool.UnsignedTransactions().AddData(
transactionsHashes[idx],
tx,
tx.Size(),
process.ShardCacherIdentifier(senderShardId, senderShardId), // only in shard 0
)
}
}

// add some random data
txRandom := &smartContractResult.SmartContractResult{Nonce: 4}
dataPool.UnsignedTransactions().AddData(
computeHash(txRandom, marshalizer, hasher),
txRandom,
txRandom.Size(),
process.ShardCacherIdentifier(3, 4),
)

requestTransaction := func(shardID uint32, txHashes [][]byte) {}
txs, _ := NewSmartContractResultPreprocessor(
dataPool.UnsignedTransactions(),
&mock.ChainStorerMock{},
&hashingMocks.HasherMock{},
&mock.MarshalizerMock{},
&testscommon.TxProcessorMock{},
mock.NewMultiShardsCoordinatorMock(3),
&stateMock.AccountsStub{},
requestTransaction,
&testscommon.GasHandlerStub{},
feeHandlerMock(),
createMockPubkeyConverter(),
&testscommon.BlockSizeComputationStub{},
&testscommon.BalanceComputationStub{},
&epochNotifier.EpochNotifierStub{},
2,
&testscommon.ProcessedMiniBlocksTrackerStub{},
)

mb := &block.MiniBlock{
SenderShardID: senderShardId,
ReceiverShardID: destinationShardId,
TxHashes: transactionsHashes,
Type: block.SmartContractResultBlock,
}

txsRetrieved, txHashesRetrieved, err := txs.getAllScrsFromMiniBlock(mb, haveTimeTrue)

assert.Nil(t, err)
assert.Equal(t, len(txsSlice), len(txsRetrieved))
assert.Equal(t, len(txsSlice), len(txHashesRetrieved))

for idx, tx := range txsSlice {
// txReceived should be all txs in the same order
assert.Equal(t, txsRetrieved[idx], tx)
// verify corresponding transaction hashes
assert.Equal(t, txHashesRetrieved[idx], computeHash(tx, marshalizer, hasher))
}
}

func TestScrsPreprocessor_RemoveBlockDataFromPoolsNilBlockShouldErr(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 10 additions & 4 deletions process/block/preprocess/transactions.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,10 @@ func (txs *transactions) AddTxsFromMiniBlocks(miniBlocks block.MiniBlockSlice) {
}

txShardInfoToSet := &txShardInfo{senderShardID: mb.SenderShardID, receiverShardID: mb.ReceiverShardID}
searchFirst := mb.Type == block.InvalidBlock
method := process.SearchMethodJustPeek
if mb.Type == block.InvalidBlock {
method = process.SearchMethodSearchFirst
}

for _, txHash := range mb.TxHashes {
tx, err := process.GetTransactionHandler(
Expand All @@ -810,7 +813,7 @@ func (txs *transactions) AddTxsFromMiniBlocks(miniBlocks block.MiniBlockSlice) {
txs.txPool,
txs.storage,
txs.marshalizer,
searchFirst,
method,
)
if err != nil {
log.Debug("transactions.AddTxsFromMiniBlocks: GetTransactionHandler", "tx hash", txHash, "error", err.Error())
Expand Down Expand Up @@ -948,15 +951,18 @@ func (txs *transactions) computeMissingTxsForMiniBlock(miniBlock *block.MiniBloc
}

missingTransactions := make([][]byte, 0, len(miniBlock.TxHashes))
searchFirst := txs.blockType == block.InvalidBlock
method := process.SearchMethodJustPeek
if txs.blockType == block.InvalidBlock {
method = process.SearchMethodSearchFirst
}

for _, txHash := range miniBlock.TxHashes {
tx, _ := process.GetTransactionHandlerFromPool(
miniBlock.SenderShardID,
miniBlock.ReceiverShardID,
txHash,
txs.txPool,
searchFirst)
method)

if tx == nil || tx.IsInterfaceNil() {
missingTransactions = append(missingTransactions, txHash)
Expand Down
7 changes: 6 additions & 1 deletion process/block/preprocess/transactions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,12 @@ func TestTxsPreProcessor_GetTransactionFromPool(t *testing.T) {
dataPool := initDataPool()
txs := createGoodPreprocessor(dataPool)
txHash := []byte("tx2_hash")
tx, _ := process.GetTransactionHandlerFromPool(1, 1, txHash, dataPool.Transactions(), false)
tx, _ := process.GetTransactionHandlerFromPool(
1,
1,
txHash,
dataPool.Transactions(),
process.SearchMethodJustPeek)
assert.NotNil(t, txs)
assert.NotNil(t, tx)
assert.Equal(t, uint64(10), tx.(*transaction.Transaction).Nonce)
Expand Down
84 changes: 69 additions & 15 deletions process/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,35 @@ import (

var log = logger.GetOrCreate("process")

// ShardedCacheSearchMethod defines the algorithm for searching through a sharded cache
type ShardedCacheSearchMethod byte

const (
// SearchMethodJustPeek will make the algorithm invoke just Peek method
SearchMethodJustPeek ShardedCacheSearchMethod = iota

// SearchMethodSearchFirst will make the algorithm invoke just SearchFirst method
SearchMethodSearchFirst

// SearchMethodPeekWithFallbackSearchFirst will first try a Peek method. If the data is not found will fall back
// to SearchFirst method
SearchMethodPeekWithFallbackSearchFirst
)

// ToString converts the ShardedCacheSearchMethod to its string representation
func (method ShardedCacheSearchMethod) ToString() string {
switch method {
case SearchMethodJustPeek:
return "just peek"
case SearchMethodSearchFirst:
return "search first"
case SearchMethodPeekWithFallbackSearchFirst:
return "peek with fallback to search first"
default:
return fmt.Sprintf("unknown method %d", method)
}
}

// GetShardHeader gets the header, which is associated with the given hash, from pool or storage
func GetShardHeader(
hash []byte,
Expand Down Expand Up @@ -361,15 +390,15 @@ func GetTransactionHandler(
shardedDataCacherNotifier dataRetriever.ShardedDataCacherNotifier,
storageService dataRetriever.StorageService,
marshalizer marshal.Marshalizer,
searchFirst bool,
method ShardedCacheSearchMethod,
) (data.TransactionHandler, error) {

err := checkGetTransactionParamsForNil(shardedDataCacherNotifier, storageService, marshalizer)
if err != nil {
return nil, err
}

tx, err := GetTransactionHandlerFromPool(senderShardID, destShardID, txHash, shardedDataCacherNotifier, searchFirst)
tx, err := GetTransactionHandlerFromPool(senderShardID, destShardID, txHash, shardedDataCacherNotifier, method)
if err != nil {
tx, err = GetTransactionHandlerFromStorage(txHash, storageService, marshalizer)
if err != nil {
Expand All @@ -386,30 +415,55 @@ func GetTransactionHandlerFromPool(
destShardID uint32,
txHash []byte,
shardedDataCacherNotifier dataRetriever.ShardedDataCacherNotifier,
searchFirst bool,
method ShardedCacheSearchMethod,
) (data.TransactionHandler, error) {

if shardedDataCacherNotifier == nil {
if check.IfNil(shardedDataCacherNotifier) {
return nil, ErrNilShardedDataCacherNotifier
}

return getTransactionHandlerFromPool(senderShardID, destShardID, txHash, shardedDataCacherNotifier, method)
}

func getTransactionHandlerFromPool(
senderShardID uint32,
destShardID uint32,
txHash []byte,
shardedDataCacherNotifier dataRetriever.ShardedDataCacherNotifier,
method ShardedCacheSearchMethod,
) (data.TransactionHandler, error) {
var val interface{}
ok := false
if searchFirst {
var ok bool

if method == SearchMethodSearchFirst {
val, ok = shardedDataCacherNotifier.SearchFirstData(txHash)
if !ok {
return nil, ErrTxNotFound
}
} else {
strCache := ShardCacherIdentifier(senderShardID, destShardID)
txStore := shardedDataCacherNotifier.ShardDataStore(strCache)
if txStore == nil {
return nil, ErrNilStorage
}

return castDataFromCacheAsTransactionHandler(val, ok)
}

strCache := ShardCacherIdentifier(senderShardID, destShardID)
txStore := shardedDataCacherNotifier.ShardDataStore(strCache)
if txStore == nil {
return nil, ErrNilStorage
}

switch method {
case SearchMethodJustPeek:
val, ok = txStore.Peek(txHash)
case SearchMethodPeekWithFallbackSearchFirst:
val, ok = txStore.Peek(txHash)
if !ok {
val, ok = shardedDataCacherNotifier.SearchFirstData(txHash)
}
default:
return nil, fmt.Errorf("%w for provided method: %s in getTransactionHandlerFromPool",
ErrInvalidValue, method.ToString())
}

return castDataFromCacheAsTransactionHandler(val, ok)
}

func castDataFromCacheAsTransactionHandler(val interface{}, ok bool) (data.TransactionHandler, error) {
if !ok {
return nil, ErrTxNotFound
}
Expand Down
Loading

0 comments on commit 8971ba7

Please sign in to comment.