diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 0e68a73266d..688f688b7e2 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -951,3 +951,14 @@ # MaxRoundsOfInactivityAccepted defines the number of rounds missed by a main or higher level backup machine before # the current machine will take over and propose/sign blocks. Used in both single-key and multi-key modes. MaxRoundsOfInactivityAccepted = 3 + +# ConsensusGradualBroadcast defines how validators will broadcast the aggregated final info, based on their consensus index +[ConsensusGradualBroadcast] + GradualIndexBroadcastDelay = [ + # All validators will broadcast the message right away + { EndIndex = 0, DelayInMilliseconds = 0 }, + ] + +[InterceptedDataVerifier] + CacheSpanInSec = 30 + CacheExpiryInSec = 30 diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index 91d5963fcda..87f4a6b9a09 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -318,6 +318,12 @@ # CryptoOpcodesV2EnableEpoch represents the epoch when BLSMultiSig, Secp256r1 and other opcodes are enabled CryptoOpcodesV2EnableEpoch = 4 + # EquivalentMessagesEnableEpoch represents the epoch when the equivalent messages are enabled + EquivalentMessagesEnableEpoch = 4 + + # FixedOrderInConsensusEnableEpoch represents the epoch when the fixed order in consensus is enabled + FixedOrderInConsensusEnableEpoch = 4 + # BLSMultiSignerEnableEpoch represents the activation epoch for different types of BLS multi-signers BLSMultiSignerEnableEpoch = [ { EnableEpoch = 0, Type = "no-KOSK" }, diff --git a/cmd/node/factory/interface.go b/cmd/node/factory/interface.go index 21c74696087..8f90ce3ee89 100644 --- a/cmd/node/factory/interface.go +++ b/cmd/node/factory/interface.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -14,6 +15,8 @@ type HeaderSigVerifierHandler interface { VerifyLeaderSignature(header data.HeaderHandler) error VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } diff --git a/common/common.go b/common/common.go new file mode 100644 index 00000000000..d5624d7777a --- /dev/null +++ b/common/common.go @@ -0,0 +1,21 @@ +package common + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" +) + +// IsEpochChangeBlockForFlagActivation returns true if the provided header is the first one after the specified flag's activation +func IsEpochChangeBlockForFlagActivation(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { + isStartOfEpochBlock := header.IsStartOfEpochBlock() + isBlockInActivationEpoch := header.GetEpoch() == enableEpochsHandler.GetActivationEpoch(flag) + + return isStartOfEpochBlock && isBlockInActivationEpoch +} + +// IsFlagEnabledAfterEpochsStartBlock returns true if the flag is enabled for the header, but it is not the epoch start block +func IsFlagEnabledAfterEpochsStartBlock(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { + isFlagEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(flag, header.GetEpoch()) + isEpochStartBlock := IsEpochChangeBlockForFlagActivation(header, enableEpochsHandler, flag) + return isFlagEnabled && !isEpochStartBlock +} diff --git a/common/constants.go b/common/constants.go index def14b6e316..4f9ac681316 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ const ConnectionTopic = "connection" // ValidatorInfoTopic is the topic used for validatorInfo signaling const ValidatorInfoTopic = "validatorInfo" +// EquivalentProofsTopic is the topic used for equivalent proofs +const EquivalentProofsTopic = "equivalentProofs" + // MetricCurrentRound is the metric for monitoring the current round of a node const MetricCurrentRound = "erd_current_round" @@ -840,8 +843,10 @@ const ( ChainParametersOrder // NodesCoordinatorOrder defines the order in which NodesCoordinator is notified of a start of epoch event NodesCoordinatorOrder - // ConsensusOrder defines the order in which Consensus is notified of a start of epoch event - ConsensusOrder + // ConsensusHandlerOrder defines the order in which ConsensusHandler is notified of a start of epoch event + ConsensusHandlerOrder + // ConsensusStartRoundOrder defines the order in which Consensus StartRound subround is notified of a start of epoch event + ConsensusStartRoundOrder // NetworkShardingOrder defines the order in which the network sharding subsystem is notified of a start of epoch event NetworkShardingOrder // IndexerOrder defines the order in which indexer is notified of a start of epoch event @@ -1222,5 +1227,7 @@ const ( DynamicESDTFlag core.EnableEpochFlag = "DynamicEsdtFlag" EGLDInESDTMultiTransferFlag core.EnableEpochFlag = "EGLDInESDTMultiTransferFlag" CryptoOpcodesV2Flag core.EnableEpochFlag = "CryptoOpcodesV2Flag" + EquivalentMessagesFlag core.EnableEpochFlag = "EquivalentMessagesFlag" + FixedOrderInConsensusFlag core.EnableEpochFlag = "FixedOrderInConsensusFlag" // all new flags must be added to createAllFlagsMap method, as part of enableEpochsHandler allFlagsDefined ) diff --git a/common/enablers/enableEpochsHandler.go b/common/enablers/enableEpochsHandler.go index 24962e09030..dac7e1aba6b 100644 --- a/common/enablers/enableEpochsHandler.go +++ b/common/enablers/enableEpochsHandler.go @@ -750,6 +750,18 @@ func (handler *enableEpochsHandler) createAllFlagsMap() { }, activationEpoch: handler.enableEpochsConfig.CryptoOpcodesV2EnableEpoch, }, + common.EquivalentMessagesFlag: { + isActiveInEpoch: func(epoch uint32) bool { + return epoch >= handler.enableEpochsConfig.EquivalentMessagesEnableEpoch + }, + activationEpoch: handler.enableEpochsConfig.EquivalentMessagesEnableEpoch, + }, + common.FixedOrderInConsensusFlag: { + isActiveInEpoch: func(epoch uint32) bool { + return epoch >= handler.enableEpochsConfig.FixedOrderInConsensusEnableEpoch + }, + activationEpoch: handler.enableEpochsConfig.FixedOrderInConsensusEnableEpoch, + }, } } diff --git a/common/enablers/enableEpochsHandler_test.go b/common/enablers/enableEpochsHandler_test.go index 524fe924771..a1b47200647 100644 --- a/common/enablers/enableEpochsHandler_test.go +++ b/common/enablers/enableEpochsHandler_test.go @@ -115,10 +115,12 @@ func createEnableEpochsConfig() config.EnableEpochs { StakingV4Step2EnableEpoch: 98, StakingV4Step3EnableEpoch: 99, AlwaysMergeContextsInEEIEnableEpoch: 100, - CleanupAuctionOnLowWaitingListEnableEpoch: 101, + CleanupAuctionOnLowWaitingListEnableEpoch: 101, DynamicESDTEnableEpoch: 102, EGLDInMultiTransferEnableEpoch: 103, CryptoOpcodesV2EnableEpoch: 104, + EquivalentMessagesEnableEpoch: 105, + FixedOrderInConsensusEnableEpoch: 106, } } @@ -319,6 +321,8 @@ func TestEnableEpochsHandler_IsFlagEnabled(t *testing.T) { require.True(t, handler.IsFlagEnabled(common.StakingV4StartedFlag)) require.True(t, handler.IsFlagEnabled(common.AlwaysMergeContextsInEEIFlag)) require.True(t, handler.IsFlagEnabled(common.DynamicESDTFlag)) + require.True(t, handler.IsFlagEnabled(common.EquivalentMessagesFlag)) + require.True(t, handler.IsFlagEnabled(common.FixedOrderInConsensusFlag)) } func TestEnableEpochsHandler_GetActivationEpoch(t *testing.T) { @@ -438,6 +442,8 @@ func TestEnableEpochsHandler_GetActivationEpoch(t *testing.T) { require.Equal(t, cfg.DynamicESDTEnableEpoch, handler.GetActivationEpoch(common.DynamicESDTFlag)) require.Equal(t, cfg.EGLDInMultiTransferEnableEpoch, handler.GetActivationEpoch(common.EGLDInESDTMultiTransferFlag)) require.Equal(t, cfg.CryptoOpcodesV2EnableEpoch, handler.GetActivationEpoch(common.CryptoOpcodesV2Flag)) + require.Equal(t, cfg.EquivalentMessagesEnableEpoch, handler.GetActivationEpoch(common.EquivalentMessagesFlag)) + require.Equal(t, cfg.FixedOrderInConsensusEnableEpoch, handler.GetActivationEpoch(common.FixedOrderInConsensusFlag)) } func TestEnableEpochsHandler_IsInterfaceNil(t *testing.T) { diff --git a/config/config.go b/config/config.go index 8d33b87830f..9607c9dc330 100644 --- a/config/config.go +++ b/config/config.go @@ -225,9 +225,12 @@ type Config struct { Requesters RequesterConfig VMOutputCacher CacheConfig - PeersRatingConfig PeersRatingConfig - PoolsCleanersConfig PoolsCleanersConfig - Redundancy RedundancyConfig + PeersRatingConfig PeersRatingConfig + PoolsCleanersConfig PoolsCleanersConfig + Redundancy RedundancyConfig + ConsensusGradualBroadcast ConsensusGradualBroadcastConfig + + InterceptedDataVerifier InterceptedDataVerifierConfig } // PeersRatingConfig will hold settings related to peers rating @@ -667,3 +670,20 @@ type ChainParametersByEpochConfig struct { MetachainMinNumNodes uint32 Adaptivity bool } + +// IndexBroadcastDelay holds a pair of starting consensus index and the delay the nodes should wait before broadcasting final info +type IndexBroadcastDelay struct { + EndIndex int + DelayInMilliseconds uint64 +} + +// ConsensusGradualBroadcastConfig holds the configuration for the consensus final info gradual broadcast +type ConsensusGradualBroadcastConfig struct { + GradualIndexBroadcastDelay []IndexBroadcastDelay +} + +// InterceptedDataVerifierConfig holds the configuration for the intercepted data verifier +type InterceptedDataVerifierConfig struct { + CacheSpanInSec uint64 + CacheExpiryInSec uint64 +} diff --git a/config/epochConfig.go b/config/epochConfig.go index f53f2078f6d..c197eb2e614 100644 --- a/config/epochConfig.go +++ b/config/epochConfig.go @@ -117,6 +117,8 @@ type EnableEpochs struct { DynamicESDTEnableEpoch uint32 EGLDInMultiTransferEnableEpoch uint32 CryptoOpcodesV2EnableEpoch uint32 + EquivalentMessagesEnableEpoch uint32 + FixedOrderInConsensusEnableEpoch uint32 BLSMultiSignerEnableEpoch []MultiSignerConfig } diff --git a/config/tomlConfig_test.go b/config/tomlConfig_test.go index 98a30eb9431..89f6fedbb8d 100644 --- a/config/tomlConfig_test.go +++ b/config/tomlConfig_test.go @@ -162,6 +162,14 @@ func TestTomlParser(t *testing.T) { Redundancy: RedundancyConfig{ MaxRoundsOfInactivityAccepted: 3, }, + ConsensusGradualBroadcast: ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []IndexBroadcastDelay{ + { + EndIndex: 0, + DelayInMilliseconds: 0, + }, + }, + }, } testString := ` [GeneralSettings] @@ -268,6 +276,13 @@ func TestTomlParser(t *testing.T) { # MaxRoundsOfInactivityAccepted defines the number of rounds missed by a main or higher level backup machine before # the current machine will take over and propose/sign blocks. Used in both single-key and multi-key modes. MaxRoundsOfInactivityAccepted = 3 + +# ConsensusGradualBroadcast defines how validators will broadcast the aggregated final info, based on their consensus index +[ConsensusGradualBroadcast] + GradualIndexBroadcastDelay = [ + # All validators will broadcast the message right away + { EndIndex = 0, DelayInMilliseconds = 0 }, + ] ` cfg := Config{} @@ -890,6 +905,12 @@ func TestEnableEpochConfig(t *testing.T) { # CryptoOpcodesV2EnableEpoch represents the epoch when BLSMultiSig, Secp256r1 and other opcodes are enabled CryptoOpcodesV2EnableEpoch = 98 + # EquivalentMessagesEnableEpoch represents the epoch when the equivalent messages are enabled + EquivalentMessagesEnableEpoch = 99 + + # FixedOrderInConsensusEnableEpoch represents the epoch when the fixed order in consensus is enabled + FixedOrderInConsensusEnableEpoch = 100 + # MaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MaxNodesChangeEnableEpoch = [ { EpochEnable = 44, MaxNumNodes = 2169, NodesToShufflePerShard = 80 }, @@ -1006,6 +1027,8 @@ func TestEnableEpochConfig(t *testing.T) { DynamicESDTEnableEpoch: 96, EGLDInMultiTransferEnableEpoch: 97, CryptoOpcodesV2EnableEpoch: 98, + EquivalentMessagesEnableEpoch: 99, + FixedOrderInConsensusEnableEpoch: 100, MaxNodesChangeEnableEpoch: []MaxNodesChangeConfig{ { EpochEnable: 44, diff --git a/consensus/broadcast/commonMessenger.go b/consensus/broadcast/commonMessenger.go index 60c59e01145..a584897e50f 100644 --- a/consensus/broadcast/commonMessenger.go +++ b/consensus/broadcast/commonMessenger.go @@ -7,41 +7,29 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/partitioning" - "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("consensus/broadcast") -// delayedBroadcaster exposes functionality for handling the consensus members broadcasting of delay data -type delayedBroadcaster interface { - SetLeaderData(data *delayedBroadcastData) error - SetValidatorData(data *delayedBroadcastData) error - SetHeaderForValidator(vData *validatorHeaderBroadcastData) error - SetBroadcastHandlers( - mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, - txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, - headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, - ) error - Close() -} - type commonMessenger struct { marshalizer marshal.Marshalizer hasher hashing.Hasher messenger consensus.P2PMessenger shardCoordinator sharding.Coordinator peerSignatureHandler crypto.PeerSignatureHandler - delayedBlockBroadcaster delayedBroadcaster + delayedBlockBroadcaster DelayedBroadcaster keysHandler consensus.KeysHandler } @@ -58,6 +46,7 @@ type CommonMessengerArgs struct { MaxValidatorDelayCacheSize uint32 AlarmScheduler core.TimersScheduler KeysHandler consensus.KeysHandler + DelayedBroadcaster DelayedBroadcaster } func checkCommonMessengerNilParameters( @@ -93,6 +82,9 @@ func checkCommonMessengerNilParameters( if check.IfNil(args.KeysHandler) { return ErrNilKeysHandler } + if check.IfNil(args.DelayedBroadcaster) { + return ErrNilDelayedBroadcaster + } return nil } @@ -195,6 +187,18 @@ func (cm *commonMessenger) BroadcastBlockData( } } +// PrepareBroadcastEquivalentProof sets the proof into the delayed block broadcaster +func (cm *commonMessenger) PrepareBroadcastEquivalentProof( + proof *block.HeaderProof, + consensusIndex int, + pkBytes []byte, +) { + err := cm.delayedBlockBroadcaster.SetFinalProofForValidator(proof, consensusIndex, pkBytes) + if err != nil { + log.Error("commonMessenger.PrepareBroadcastEquivalentProof", "error", err) + } +} + func (cm *commonMessenger) extractMetaMiniBlocksAndTransactions( miniBlocks map[uint32][]byte, transactions map[string][][]byte, @@ -241,3 +245,18 @@ func (cm *commonMessenger) broadcast(topic string, data []byte, pkBytes []byte) cm.messenger.BroadcastUsingPrivateKey(topic, data, pid, skBytes) } + +func (cm *commonMessenger) broadcastEquivalentProof(proof *block.HeaderProof, pkBytes []byte, topic string) error { + if check.IfNilReflect(proof) { + return spos.ErrNilHeaderProof + } + + msgProof, err := cm.marshalizer.Marshal(proof) + if err != nil { + return err + } + + cm.broadcast(topic, msgProof, pkBytes) + + return nil +} diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index 955a81f0f73..a1c94cf33d7 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -11,8 +11,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" @@ -23,6 +26,7 @@ import ( const prefixHeaderAlarm = "header_" const prefixDelayDataAlarm = "delay_" +const prefixConsensusMessageAlarm = "message_" const sizeHeadersCache = 1000 // 1000 hashes in cache // ArgsDelayedBlockBroadcaster holds the arguments to create a delayed block broadcaster @@ -33,25 +37,7 @@ type ArgsDelayedBlockBroadcaster struct { LeaderCacheSize uint32 ValidatorCacheSize uint32 AlarmScheduler timersScheduler -} - -type validatorHeaderBroadcastData struct { - headerHash []byte - header data.HeaderHandler - metaMiniBlocksData map[uint32][]byte - metaTransactionsData map[string][][]byte - order uint32 - pkBytes []byte -} - -type delayedBroadcastData struct { - headerHash []byte - header data.HeaderHandler - miniBlocksData map[uint32][]byte - miniBlockHashes map[string]map[string]struct{} - transactions map[string][][]byte - order uint32 - pkBytes []byte + Config config.ConsensusGradualBroadcastConfig } // timersScheduler exposes functionality for scheduling multiple timers @@ -67,22 +53,33 @@ type headerDataForValidator struct { prevRandSeed []byte } +type validatorProof struct { + proof *block.HeaderProof + pkBytes []byte +} + type delayedBlockBroadcaster struct { alarm timersScheduler interceptorsContainer process.InterceptorsContainer shardCoordinator sharding.Coordinator headersSubscriber consensus.HeadersPoolSubscriber - valHeaderBroadcastData []*validatorHeaderBroadcastData - valBroadcastData []*delayedBroadcastData - delayedBroadcastData []*delayedBroadcastData + valHeaderBroadcastData []*shared.ValidatorHeaderBroadcastData + valBroadcastData []*shared.DelayedBroadcastData + delayedBroadcastData []*shared.DelayedBroadcastData maxDelayCacheSize uint32 maxValidatorDelayCacheSize uint32 mutDataForBroadcast sync.RWMutex broadcastMiniblocksData func(mbData map[uint32][]byte, pkBytes []byte) error broadcastTxsData func(txData map[string][][]byte, pkBytes []byte) error broadcastHeader func(header data.HeaderHandler, pkBytes []byte) error + broadcastEquivalentProof func(proof *block.HeaderProof, pkBytes []byte) error + broadcastConsensusMessage func(message *consensus.Message) error cacheHeaders storage.Cacher mutHeadersCache sync.RWMutex + config config.ConsensusGradualBroadcastConfig + mutBroadcastFinalProof sync.RWMutex + valBroadcastFinalProof map[string]*validatorProof + cacheConsensusMessages storage.Cacher } // NewDelayedBlockBroadcaster create a new instance of a delayed block data broadcaster @@ -105,19 +102,27 @@ func NewDelayedBlockBroadcaster(args *ArgsDelayedBlockBroadcaster) (*delayedBloc return nil, err } + cacheConsensusMessages, err := cache.NewLRUCache(sizeHeadersCache) + if err != nil { + return nil, err + } + dbb := &delayedBlockBroadcaster{ alarm: args.AlarmScheduler, shardCoordinator: args.ShardCoordinator, interceptorsContainer: args.InterceptorsContainer, headersSubscriber: args.HeadersSubscriber, - valHeaderBroadcastData: make([]*validatorHeaderBroadcastData, 0), - valBroadcastData: make([]*delayedBroadcastData, 0), - delayedBroadcastData: make([]*delayedBroadcastData, 0), + valHeaderBroadcastData: make([]*shared.ValidatorHeaderBroadcastData, 0), + valBroadcastData: make([]*shared.DelayedBroadcastData, 0), + delayedBroadcastData: make([]*shared.DelayedBroadcastData, 0), + valBroadcastFinalProof: make(map[string]*validatorProof, 0), maxDelayCacheSize: args.LeaderCacheSize, maxValidatorDelayCacheSize: args.ValidatorCacheSize, mutDataForBroadcast: sync.RWMutex{}, cacheHeaders: cacheHeaders, mutHeadersCache: sync.RWMutex{}, + config: args.Config, + cacheConsensusMessages: cacheConsensusMessages, } dbb.headersSubscriber.RegisterHandler(dbb.headerReceived) @@ -135,22 +140,22 @@ func NewDelayedBlockBroadcaster(args *ArgsDelayedBlockBroadcaster) (*delayedBloc } // SetLeaderData sets the data for consensus leader delayed broadcast -func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *delayedBroadcastData) error { +func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *shared.DelayedBroadcastData) error { if broadcastData == nil { return spos.ErrNilParameter } log.Trace("delayedBlockBroadcaster.SetLeaderData: setting leader delay data", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, ) - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) dbb.mutDataForBroadcast.Lock() dbb.delayedBroadcastData = append(dbb.delayedBroadcastData, broadcastData) if len(dbb.delayedBroadcastData) > int(dbb.maxDelayCacheSize) { log.Debug("delayedBlockBroadcaster.SetLeaderData: leader broadcasts old data before alarm due to too much delay data", - "headerHash", dbb.delayedBroadcastData[0].headerHash, + "headerHash", dbb.delayedBroadcastData[0].HeaderHash, "nbDelayedData", len(dbb.delayedBroadcastData), "maxDelayCacheSize", dbb.maxDelayCacheSize, ) @@ -167,14 +172,17 @@ func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *delayedBroadcas } // SetHeaderForValidator sets the header to be broadcast by validator if leader fails to broadcast it -func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeaderBroadcastData) error { - if check.IfNil(vData.header) { +func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error { + if check.IfNil(vData.Header) { return spos.ErrNilHeader } - if len(vData.headerHash) == 0 { + if len(vData.HeaderHash) == 0 { return spos.ErrNilHeaderHash } + dbb.mutDataForBroadcast.Lock() + defer dbb.mutDataForBroadcast.Unlock() + log.Trace("delayedBlockBroadcaster.SetHeaderForValidator", "nbDelayedBroadcastData", len(dbb.delayedBroadcastData), "nbValBroadcastData", len(dbb.valBroadcastData), @@ -182,25 +190,25 @@ func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeader ) // set alarm only for validators that are aware that the block was finalized - if len(vData.header.GetSignature()) != 0 { - _, alreadyReceived := dbb.cacheHeaders.Get(vData.headerHash) + if len(vData.Header.GetSignature()) != 0 { + _, alreadyReceived := dbb.cacheHeaders.Get(vData.HeaderHash) if alreadyReceived { return nil } - duration := validatorDelayPerOrder * time.Duration(vData.order) + duration := validatorDelayPerOrder * time.Duration(vData.Order) dbb.valHeaderBroadcastData = append(dbb.valHeaderBroadcastData, vData) - alarmID := prefixHeaderAlarm + hex.EncodeToString(vData.headerHash) + alarmID := prefixHeaderAlarm + hex.EncodeToString(vData.HeaderHash) dbb.alarm.Add(dbb.headerAlarmExpired, duration, alarmID) log.Trace("delayedBlockBroadcaster.SetHeaderForValidator: header alarm has been set", - "validatorConsensusOrder", vData.order, - "headerHash", vData.headerHash, + "validatorConsensusOrder", vData.Order, + "headerHash", vData.HeaderHash, "alarmID", alarmID, "duration", duration, ) } else { log.Trace("delayedBlockBroadcaster.SetHeaderForValidator: header alarm has not been set", - "validatorConsensusOrder", vData.order, + "validatorConsensusOrder", vData.Order, ) } @@ -208,29 +216,29 @@ func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeader } // SetValidatorData sets the data for consensus validator delayed broadcast -func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *delayedBroadcastData) error { +func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *shared.DelayedBroadcastData) error { if broadcastData == nil { return spos.ErrNilParameter } alarmIDsToCancel := make([]string, 0) log.Trace("delayedBlockBroadcaster.SetValidatorData: setting validator delay data", - "headerHash", broadcastData.headerHash, - "round", broadcastData.header.GetRound(), - "prevRandSeed", broadcastData.header.GetPrevRandSeed(), + "headerHash", broadcastData.HeaderHash, + "round", broadcastData.Header.GetRound(), + "prevRandSeed", broadcastData.Header.GetPrevRandSeed(), ) dbb.mutDataForBroadcast.Lock() - broadcastData.miniBlockHashes = dbb.extractMiniBlockHashesCrossFromMe(broadcastData.header) + broadcastData.MiniBlockHashes = dbb.extractMiniBlockHashesCrossFromMe(broadcastData.Header) dbb.valBroadcastData = append(dbb.valBroadcastData, broadcastData) if len(dbb.valBroadcastData) > int(dbb.maxValidatorDelayCacheSize) { - alarmHeaderID := prefixHeaderAlarm + hex.EncodeToString(dbb.valBroadcastData[0].headerHash) - alarmDelayID := prefixDelayDataAlarm + hex.EncodeToString(dbb.valBroadcastData[0].headerHash) + alarmHeaderID := prefixHeaderAlarm + hex.EncodeToString(dbb.valBroadcastData[0].HeaderHash) + alarmDelayID := prefixDelayDataAlarm + hex.EncodeToString(dbb.valBroadcastData[0].HeaderHash) alarmIDsToCancel = append(alarmIDsToCancel, alarmHeaderID, alarmDelayID) dbb.valBroadcastData = dbb.valBroadcastData[1:] log.Debug("delayedBlockBroadcaster.SetValidatorData: canceling old alarms (header and delay data) due to too much delay data", - "headerHash", dbb.valBroadcastData[0].headerHash, + "headerHash", dbb.valBroadcastData[0].HeaderHash, "alarmID-header", alarmHeaderID, "alarmID-delay", alarmDelayID, "nbDelayData", len(dbb.valBroadcastData), @@ -246,13 +254,63 @@ func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *delayedBroad return nil } +// SetFinalProofForValidator sets the header proof to be broadcast by validator when its turn comes +func (dbb *delayedBlockBroadcaster) SetFinalProofForValidator( + proof *block.HeaderProof, + consensusIndex int, + pkBytes []byte, +) error { + if proof == nil { + return spos.ErrNilHeaderProof + } + + // set alarm only for validators that are aware that the block was finalized + isProofValid := len(proof.AggregatedSignature) > 0 && + len(proof.PubKeysBitmap) > 0 && + len(proof.HeaderHash) > 0 + if !isProofValid { + log.Trace("delayedBlockBroadcaster.SetFinalProofForValidator: consensus message alarm has not been set", + "validatorConsensusOrder", consensusIndex, + ) + + return nil + } + + if dbb.cacheConsensusMessages.Has(proof.HeaderHash) { + return nil + } + + duration := dbb.getBroadcastDelayForIndex(consensusIndex) + alarmID := prefixConsensusMessageAlarm + hex.EncodeToString(proof.HeaderHash) + + vProof := &validatorProof{ + proof: proof, + pkBytes: pkBytes, + } + dbb.mutBroadcastFinalProof.Lock() + dbb.valBroadcastFinalProof[alarmID] = vProof + dbb.mutBroadcastFinalProof.Unlock() + + dbb.alarm.Add(dbb.finalProofAlarmExpired, duration, alarmID) + log.Trace("delayedBlockBroadcaster.SetFinalProofForValidator: final proof alarm has been set", + "validatorConsensusOrder", consensusIndex, + "headerHash", proof.HeaderHash, + "alarmID", alarmID, + "duration", duration, + ) + + return nil +} + // SetBroadcastHandlers sets the broadcast handlers for miniBlocks and transactions func (dbb *delayedBlockBroadcaster) SetBroadcastHandlers( mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, ) error { - if mbBroadcast == nil || txBroadcast == nil || headerBroadcast == nil { + if mbBroadcast == nil || txBroadcast == nil || headerBroadcast == nil || consensusMessageBroadcast == nil { return spos.ErrNilParameter } @@ -262,6 +320,8 @@ func (dbb *delayedBlockBroadcaster) SetBroadcastHandlers( dbb.broadcastMiniblocksData = mbBroadcast dbb.broadcastTxsData = txBroadcast dbb.broadcastHeader = headerBroadcast + dbb.broadcastEquivalentProof = equivalentProofBroadcast + dbb.broadcastConsensusMessage = consensusMessageBroadcast return nil } @@ -319,12 +379,12 @@ func (dbb *delayedBlockBroadcaster) broadcastDataForHeaders(headerHashes [][]byt time.Sleep(common.ExtraDelayForBroadcastBlockInfo) dbb.mutDataForBroadcast.Lock() - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) OuterLoop: for i := len(dbb.delayedBroadcastData) - 1; i >= 0; i-- { for _, headerHash := range headerHashes { - if bytes.Equal(dbb.delayedBroadcastData[i].headerHash, headerHash) { + if bytes.Equal(dbb.delayedBroadcastData[i].HeaderHash, headerHash) { log.Debug("delayedBlockBroadcaster.broadcastDataForHeaders: leader broadcasts block data", "headerHash", headerHash, ) @@ -366,29 +426,29 @@ func (dbb *delayedBlockBroadcaster) scheduleValidatorBroadcast(dataForValidators log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast: registered data for broadcast") for i := range dbb.valBroadcastData { log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast", - "round", dbb.valBroadcastData[i].header.GetRound(), - "prevRandSeed", dbb.valBroadcastData[i].header.GetPrevRandSeed(), + "round", dbb.valBroadcastData[i].Header.GetRound(), + "prevRandSeed", dbb.valBroadcastData[i].Header.GetPrevRandSeed(), ) } for _, headerData := range dataForValidators { for _, broadcastData := range dbb.valBroadcastData { - sameRound := headerData.round == broadcastData.header.GetRound() - samePrevRandomness := bytes.Equal(headerData.prevRandSeed, broadcastData.header.GetPrevRandSeed()) + sameRound := headerData.round == broadcastData.Header.GetRound() + samePrevRandomness := bytes.Equal(headerData.prevRandSeed, broadcastData.Header.GetPrevRandSeed()) if sameRound && samePrevRandomness { - duration := validatorDelayPerOrder*time.Duration(broadcastData.order) + common.ExtraDelayForBroadcastBlockInfo - alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.headerHash) + duration := validatorDelayPerOrder*time.Duration(broadcastData.Order) + common.ExtraDelayForBroadcastBlockInfo + alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.HeaderHash) alarmsToAdd = append(alarmsToAdd, alarmParams{ id: alarmID, duration: duration, }) log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast: scheduling delay data broadcast for notarized header", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, "alarmID", alarmID, "round", headerData.round, "prevRandSeed", headerData.prevRandSeed, - "consensusOrder", broadcastData.order, + "consensusOrder", broadcastData.Order, ) } } @@ -411,9 +471,9 @@ func (dbb *delayedBlockBroadcaster) alarmExpired(alarmID string) { } dbb.mutDataForBroadcast.Lock() - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) for i, broadcastData := range dbb.valBroadcastData { - if bytes.Equal(broadcastData.headerHash, headerHash) { + if bytes.Equal(broadcastData.HeaderHash, headerHash) { log.Debug("delayedBlockBroadcaster.alarmExpired: validator broadcasts block data (with delay) instead of leader", "headerHash", headerHash, "alarmID", alarmID, @@ -440,9 +500,9 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { } dbb.mutDataForBroadcast.Lock() - var vHeader *validatorHeaderBroadcastData + var vHeader *shared.ValidatorHeaderBroadcastData for i, broadcastData := range dbb.valHeaderBroadcastData { - if bytes.Equal(broadcastData.headerHash, headerHash) { + if bytes.Equal(broadcastData.HeaderHash, headerHash) { vHeader = broadcastData dbb.valHeaderBroadcastData = append(dbb.valHeaderBroadcastData[:i], dbb.valHeaderBroadcastData[i+1:]...) break @@ -463,7 +523,7 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "alarmID", alarmID, ) // broadcast header - err = dbb.broadcastHeader(vHeader.header, vHeader.pkBytes) + err = dbb.broadcastHeader(vHeader.Header, vHeader.PkBytes) if err != nil { log.Warn("delayedBlockBroadcaster.headerAlarmExpired", "error", err.Error(), "headerHash", headerHash, @@ -477,15 +537,15 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "headerHash", headerHash, "alarmID", alarmID, ) - go dbb.broadcastBlockData(vHeader.metaMiniBlocksData, vHeader.metaTransactionsData, vHeader.pkBytes, common.ExtraDelayForBroadcastBlockInfo) + go dbb.broadcastBlockData(vHeader.MetaMiniBlocksData, vHeader.MetaTransactionsData, vHeader.PkBytes, common.ExtraDelayForBroadcastBlockInfo) } } -func (dbb *delayedBlockBroadcaster) broadcastDelayedData(broadcastData []*delayedBroadcastData) { +func (dbb *delayedBlockBroadcaster) broadcastDelayedData(broadcastData []*shared.DelayedBroadcastData) { for _, bData := range broadcastData { go func(miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) { dbb.broadcastBlockData(miniBlocks, transactions, pkBytes, 0) - }(bData.miniBlocksData, bData.transactions, bData.pkBytes) + }(bData.MiniBlocksData, bData.Transactions, bData.PkBytes) } } @@ -637,6 +697,19 @@ func (dbb *delayedBlockBroadcaster) interceptedHeader(_ string, headerHash []byt dbb.cacheHeaders.Put(headerHash, struct{}{}, 0) dbb.mutHeadersCache.Unlock() + // TODO: should be handled from interceptor + proof := headerHandler.GetPreviousProof() + var aggSig, bitmap []byte + if !check.IfNilReflect(proof) { + aggSig, bitmap = proof.GetAggregatedSignature(), proof.GetPubKeysBitmap() + } + + // TODO: add common check for verifying proof validity + isFinalInfo := len(aggSig) > 0 && len(bitmap) > 0 + if isFinalInfo { + dbb.cacheConsensusMessages.Put(headerHash, struct{}{}, 0) + } + log.Trace("delayedBlockBroadcaster.interceptedHeader", "headerHash", headerHash, "round", headerHandler.GetRound(), @@ -646,8 +719,8 @@ func (dbb *delayedBlockBroadcaster) interceptedHeader(_ string, headerHash []byt alarmsToCancel := make([]string, 0) dbb.mutDataForBroadcast.RLock() for i, broadcastData := range dbb.valHeaderBroadcastData { - samePrevRandSeed := bytes.Equal(broadcastData.header.GetPrevRandSeed(), headerHandler.GetPrevRandSeed()) - sameRound := broadcastData.header.GetRound() == headerHandler.GetRound() + samePrevRandSeed := bytes.Equal(broadcastData.Header.GetPrevRandSeed(), headerHandler.GetPrevRandSeed()) + sameRound := broadcastData.Header.GetRound() == headerHandler.GetRound() sameHeader := samePrevRandSeed && sameRound if sameHeader { @@ -676,24 +749,24 @@ func (dbb *delayedBlockBroadcaster) interceptedMiniBlockData(topic string, hash "topic", topic, ) - remainingValBroadcastData := make([]*delayedBroadcastData, 0) + remainingValBroadcastData := make([]*shared.DelayedBroadcastData, 0) alarmsToCancel := make([]string, 0) dbb.mutDataForBroadcast.Lock() for i, broadcastData := range dbb.valBroadcastData { - mbHashesMap := broadcastData.miniBlockHashes + mbHashesMap := broadcastData.MiniBlockHashes if len(mbHashesMap) > 0 && len(mbHashesMap[topic]) > 0 { - delete(broadcastData.miniBlockHashes[topic], string(hash)) + delete(broadcastData.MiniBlockHashes[topic], string(hash)) if len(mbHashesMap[topic]) == 0 { delete(mbHashesMap, topic) } } if len(mbHashesMap) == 0 { - alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.headerHash) + alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.HeaderHash) alarmsToCancel = append(alarmsToCancel, alarmID) log.Trace("delayedBlockBroadcaster.interceptedMiniBlockData: leader has broadcast block data, validator cancelling alarm", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, "alarmID", alarmID, ) } else { @@ -744,3 +817,51 @@ func (dbb *delayedBlockBroadcaster) extractMbsFromMeTo(header data.HeaderHandler return mbHashesForShard } + +func (dbb *delayedBlockBroadcaster) getBroadcastDelayForIndex(index int) time.Duration { + for i := 0; i < len(dbb.config.GradualIndexBroadcastDelay); i++ { + entry := dbb.config.GradualIndexBroadcastDelay[i] + if index > entry.EndIndex { + continue + } + + return time.Duration(entry.DelayInMilliseconds) * time.Millisecond + } + + return 0 +} + +func (dbb *delayedBlockBroadcaster) finalProofAlarmExpired(alarmID string) { + headerHash, err := hex.DecodeString(strings.TrimPrefix(alarmID, prefixConsensusMessageAlarm)) + if err != nil { + log.Error("delayedBlockBroadcaster.finalProofAlarmExpired", "error", err.Error(), + "headerHash", headerHash, + "alarmID", alarmID, + ) + return + } + + dbb.mutBroadcastFinalProof.Lock() + defer dbb.mutBroadcastFinalProof.Unlock() + if dbb.cacheConsensusMessages.Has(headerHash) { + delete(dbb.valBroadcastFinalProof, alarmID) + return + } + + vProof, ok := dbb.valBroadcastFinalProof[alarmID] + if !ok { + return + } + + err = dbb.broadcastEquivalentProof(vProof.proof, vProof.pkBytes) + if err != nil { + log.Error("finalProofAlarmExpired.broadcastEquivalentProof", "error", err) + } + + delete(dbb.valBroadcastFinalProof, alarmID) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dbb *delayedBlockBroadcaster) IsInterfaceNil() bool { + return dbb == nil +} diff --git a/consensus/broadcast/delayedBroadcast_test.go b/consensus/broadcast/delayedBroadcast_test.go index 0f22e8a5157..da1402bd90a 100644 --- a/consensus/broadcast/delayedBroadcast_test.go +++ b/consensus/broadcast/delayedBroadcast_test.go @@ -13,14 +13,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) type validatorDelayArgs struct { @@ -38,21 +42,24 @@ func createValidatorDelayArgs(index int) *validatorDelayArgs { iStr := strconv.Itoa(index) return &validatorDelayArgs{ headerHash: []byte("header hash" + iStr), - header: &block.Header{ - PrevRandSeed: []byte("prev rand seed" + iStr), - Round: uint64(0), - MiniBlockHeaders: []block.MiniBlockHeader{ - { - Hash: []byte("miniBlockHash0" + iStr), - SenderShardID: 0, - ReceiverShardID: 0, - }, - { - Hash: []byte("miniBlockHash1" + iStr), - SenderShardID: 0, - ReceiverShardID: 1, + header: &block.HeaderV2{ + Header: &block.Header{ + PrevRandSeed: []byte("prev rand seed" + iStr), + Round: uint64(0), + MiniBlockHeaders: []block.MiniBlockHeader{ + { + Hash: []byte("miniBlockHash0" + iStr), + SenderShardID: 0, + ReceiverShardID: 0, + }, + { + Hash: []byte("miniBlockHash1" + iStr), + SenderShardID: 0, + ReceiverShardID: 1, + }, }, }, + PreviousHeaderProof: &block.HeaderProof{}, }, miniBlocks: map[uint32][]byte{0: []byte("miniblock data sh0" + iStr), 1: []byte("miniblock data sh1" + iStr)}, miniBlockHashes: map[string]map[string]struct{}{"txBlockBodies_0": {"miniBlockHash0" + iStr: struct{}{}}, "txBlockBodies_0_1": {"miniBlockHash1" + iStr: struct{}{}}}, @@ -97,7 +104,7 @@ func createMetaBlock() *block.MetaBlock { } func createDefaultDelayedBroadcasterArgs() *broadcast.ArgsDelayedBlockBroadcaster { - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() dbbArgs := &broadcast.ArgsDelayedBlockBroadcaster{ ShardCoordinator: &mock.ShardCoordinatorMock{}, @@ -106,6 +113,9 @@ func createDefaultDelayedBroadcasterArgs() *broadcast.ArgsDelayedBlockBroadcaste LeaderCacheSize: 2, ValidatorCacheSize: 2, AlarmScheduler: alarm.NewAlarmScheduler(), + Config: config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}, + }, } return dbbArgs @@ -177,12 +187,18 @@ func TestDelayedBlockBroadcaster_HeaderReceivedNoDelayedDataRegistered(t *testin broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) metaBlock := createMetaBlock() @@ -210,12 +226,18 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForRegisteredDelayedDataShouldBro broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -256,12 +278,18 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNotRegisteredDelayedDataShould broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -301,12 +329,18 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNextRegisteredDelayedDataShoul broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -424,12 +458,18 @@ func TestDelayedBlockBroadcaster_SetHeaderForValidatorShouldSetAlarmAndBroadcast headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -484,6 +524,12 @@ func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAla headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -492,7 +538,7 @@ func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAla dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -552,6 +598,12 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarm(t *testing.T headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -560,7 +612,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarm(t *testing.T dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -621,6 +673,12 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarmForHeaderBroa headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -629,7 +687,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarmForHeaderBroa dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -689,6 +747,12 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderInvalidOrDifferentShouldIgnore headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -697,7 +761,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderInvalidOrDifferentShouldIgnore dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -802,12 +866,18 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentHeaderRoundS broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -859,12 +929,18 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentPrevRandShou broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -919,12 +995,18 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastSameRoundAndPrevRandS broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -979,12 +1061,18 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldBroadcastTheDataForRegistered broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1032,12 +1120,18 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldDoNothingForNotRegisteredData broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1180,12 +1274,18 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockForNotSetValDataShouldBroad broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1243,12 +1343,18 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockOutOfManyForSetValDataShoul broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1307,12 +1413,18 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockFinalForSetValDataShouldNot broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1371,12 +1483,18 @@ func TestDelayedBlockBroadcaster_Close(t *testing.T) { broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1413,3 +1531,189 @@ func TestDelayedBlockBroadcaster_Close(t *testing.T) { vbd = dbb.GetValidatorBroadcastData() require.Equal(t, 1, len(vbd)) } + +func TestDelayedBlockBroadcaster_SetFinalProofForValidator(t *testing.T) { + t.Parallel() + + t.Run("nil proof should error", func(t *testing.T) { + t.Parallel() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.NoError(t, err) + + err = dbb.SetFinalProofForValidator(nil, 0, []byte("pk")) + require.Equal(t, spos.ErrNilHeaderProof, err) + }) + t.Run("empty aggregated sig should work", func(t *testing.T) { + t.Parallel() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.NoError(t, err) + + proof := &block.HeaderProof{} + err = dbb.SetFinalProofForValidator(proof, 0, []byte("pk")) + require.NoError(t, err) + }) + t.Run("header already received should early exit", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + require.Fail(t, "should have not panicked") + } + }() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.NoError(t, err) + + providedHash := []byte("hdr hash") + dbb.InterceptedHeaderData("", providedHash, &block.HeaderV2{ + Header: &block.Header{}, + PreviousHeaderProof: &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("agg sig"), + HeaderHash: []byte("hash"), + }, + }) + proof := &block.HeaderProof{ + AggregatedSignature: []byte("agg sig"), + PubKeysBitmap: []byte("bitmap"), + HeaderHash: providedHash, + } + err = dbb.SetFinalProofForValidator(proof, 0, []byte("pk")) + require.NoError(t, err) + }) + t.Run("should work and fire alarm", func(t *testing.T) { + t.Parallel() + + type timestamps struct { + setTimestamp int64 + fireTimestamp int64 + } + firingMap := make(map[string]*timestamps, 3) + mutFiringMap := sync.RWMutex{} + + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + require.Fail(t, "should have not been called") + return nil + } + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + require.Fail(t, "should have not been called") + return nil + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + require.Fail(t, "should have not been called") + return nil + } + broadcastConsensusMessage := func(message *consensus.Message) error { + mutFiringMap.Lock() + defer mutFiringMap.Unlock() + firingMap[string(message.BlockHeaderHash)].fireTimestamp = time.Now().UnixMilli() + + return nil + } + broadcastEquivalentProofs := func(proof *block.HeaderProof, pkBytes []byte) error { + mutFiringMap.Lock() + defer mutFiringMap.Unlock() + firingMap[string(proof.GetHeaderHash())].fireTimestamp = time.Now().UnixMilli() + + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + delayBroadcasterArgs.Config = config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{ + { + EndIndex: 4, + DelayInMilliseconds: 0, + }, + { + EndIndex: 9, + DelayInMilliseconds: 100, + }, + { + EndIndex: 15, + DelayInMilliseconds: 200, + }, + }, + } + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastEquivalentProofs, broadcastConsensusMessage) + require.Nil(t, err) + + // idx 0 should fire the alarm after immediately + hashIdx0 := []byte("hash idx 0") + pkIdx0 := []byte("pk idx 0") + proofIdx0 := &block.HeaderProof{ + AggregatedSignature: []byte("sig"), + PubKeysBitmap: []byte("bitmap"), + HeaderHash: hashIdx0, + } + mutFiringMap.Lock() + firingMap[string(hashIdx0)] = ×tamps{ + setTimestamp: time.Now().UnixMilli(), + } + mutFiringMap.Unlock() + err = dbb.SetFinalProofForValidator(proofIdx0, 0, pkIdx0) + require.NoError(t, err) + + // idx 5 should fire the alarm after 100ms + hashIdx5 := []byte("hash idx 5") + pkIdx5 := []byte("pk idx 5") + proofIdx5 := &block.HeaderProof{ + AggregatedSignature: []byte("sig"), + PubKeysBitmap: []byte("bitmap"), + HeaderHash: hashIdx5, + } + mutFiringMap.Lock() + firingMap[string(hashIdx5)] = ×tamps{ + setTimestamp: time.Now().UnixMilli(), + } + mutFiringMap.Unlock() + err = dbb.SetFinalProofForValidator(proofIdx5, 5, pkIdx5) + require.NoError(t, err) + + // idx 10 should fire the alarm after 200ms + hashIdx10 := []byte("hash idx 10") + pkIdx10 := []byte("pk idx 10") + proofIdx10 := &block.HeaderProof{ + AggregatedSignature: []byte("sig"), + PubKeysBitmap: []byte("bitmap"), + HeaderHash: hashIdx10, + } + mutFiringMap.Lock() + firingMap[string(hashIdx10)] = ×tamps{ + setTimestamp: time.Now().UnixMilli(), + } + mutFiringMap.Unlock() + err = dbb.SetFinalProofForValidator(proofIdx10, 10, pkIdx10) + require.NoError(t, err) + + // wait all alarms to fire + time.Sleep(time.Millisecond * 250) + + mutFiringMap.RLock() + defer mutFiringMap.RUnlock() + + resultIdx0 := firingMap[string(hashIdx0)] + timeDifIdx0 := resultIdx0.fireTimestamp - resultIdx0.setTimestamp + require.Less(t, timeDifIdx0, int64(5), "idx 0 should have fired the alarm immediately, but fired after %dms", timeDifIdx0) + require.GreaterOrEqual(t, timeDifIdx0, int64(0), "idx 0 should have fired the alarm immediately, but fired after %dms", timeDifIdx0) + + resultIdx5 := firingMap[string(hashIdx5)] + timeDifIdx5 := resultIdx5.fireTimestamp - resultIdx5.setTimestamp + require.Less(t, timeDifIdx5, int64(105), "idx 5 should have fired the alarm after 100ms, but fired after %dms", timeDifIdx5) + require.GreaterOrEqual(t, timeDifIdx5, int64(100), "idx 5 should have fired the alarm after 100ms, but fired after %dms", timeDifIdx5) + + resultIdx10 := firingMap[string(hashIdx10)] + timeDifIdx10 := resultIdx10.fireTimestamp - resultIdx10.setTimestamp + require.Less(t, timeDifIdx10, int64(205), "idx 10 should have fired the alarm after 200ms, but fired after %dms", timeDifIdx10) + require.GreaterOrEqual(t, timeDifIdx10, int64(200), "idx 10 should have fired the alarm after 200ms, but fired after %dms", timeDifIdx10) + }) +} diff --git a/consensus/broadcast/errors.go b/consensus/broadcast/errors.go index 86acef6937b..c16c878bc50 100644 --- a/consensus/broadcast/errors.go +++ b/consensus/broadcast/errors.go @@ -4,3 +4,6 @@ import "errors" // ErrNilKeysHandler signals that a nil keys handler was provided var ErrNilKeysHandler = errors.New("nil keys handler") + +// ErrNilDelayedBroadcaster signals that a nil delayed broadcaster was provided +var ErrNilDelayedBroadcaster = errors.New("nil delayed broadcaster") diff --git a/consensus/broadcast/export.go b/consensus/broadcast/export.go index e7b0e4dfa80..27bc721f332 100644 --- a/consensus/broadcast/export.go +++ b/consensus/broadcast/export.go @@ -6,7 +6,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/sharding" ) @@ -32,14 +34,14 @@ func CreateDelayBroadcastDataForValidator( miniBlockHashes map[string]map[string]struct{}, transactionsData map[string][][]byte, order uint32, -) *delayedBroadcastData { - return &delayedBroadcastData{ - headerHash: headerHash, - header: header, - miniBlocksData: miniblocksData, - miniBlockHashes: miniBlockHashes, - transactions: transactionsData, - order: order, +) *shared.DelayedBroadcastData { + return &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + Header: header, + MiniBlocksData: miniblocksData, + MiniBlockHashes: miniBlockHashes, + Transactions: transactionsData, + Order: order, } } @@ -50,13 +52,13 @@ func CreateValidatorHeaderBroadcastData( metaMiniBlocksData map[uint32][]byte, metaTransactionsData map[string][][]byte, order uint32, -) *validatorHeaderBroadcastData { - return &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - metaMiniBlocksData: metaMiniBlocksData, - metaTransactionsData: metaTransactionsData, - order: order, +) *shared.ValidatorHeaderBroadcastData { + return &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + MetaMiniBlocksData: metaMiniBlocksData, + MetaTransactionsData: metaTransactionsData, + Order: order, } } @@ -65,11 +67,11 @@ func CreateDelayBroadcastDataForLeader( headerHash []byte, miniblocks map[uint32][]byte, transactions map[string][][]byte, -) *delayedBroadcastData { - return &delayedBroadcastData{ - headerHash: headerHash, - miniBlocksData: miniblocks, - transactions: transactions, +) *shared.DelayedBroadcastData { + return &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + MiniBlocksData: miniblocks, + Transactions: transactions, } } @@ -80,9 +82,9 @@ func (dbb *delayedBlockBroadcaster) HeaderReceived(headerHandler data.HeaderHand } // GetValidatorBroadcastData returns the set validator delayed broadcast data -func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*delayedBroadcastData { +func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*shared.DelayedBroadcastData { dbb.mutDataForBroadcast.RLock() - copyValBroadcastData := make([]*delayedBroadcastData, len(dbb.valBroadcastData)) + copyValBroadcastData := make([]*shared.DelayedBroadcastData, len(dbb.valBroadcastData)) copy(copyValBroadcastData, dbb.valBroadcastData) dbb.mutDataForBroadcast.RUnlock() @@ -90,9 +92,9 @@ func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*delayedBroadc } // GetValidatorHeaderBroadcastData - -func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*validatorHeaderBroadcastData { +func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*shared.ValidatorHeaderBroadcastData { dbb.mutDataForBroadcast.RLock() - copyValHeaderBroadcastData := make([]*validatorHeaderBroadcastData, len(dbb.valHeaderBroadcastData)) + copyValHeaderBroadcastData := make([]*shared.ValidatorHeaderBroadcastData, len(dbb.valHeaderBroadcastData)) copy(copyValHeaderBroadcastData, dbb.valHeaderBroadcastData) dbb.mutDataForBroadcast.RUnlock() @@ -100,9 +102,9 @@ func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*validat } // GetLeaderBroadcastData returns the set leader delayed broadcast data -func (dbb *delayedBlockBroadcaster) GetLeaderBroadcastData() []*delayedBroadcastData { +func (dbb *delayedBlockBroadcaster) GetLeaderBroadcastData() []*shared.DelayedBroadcastData { dbb.mutDataForBroadcast.RLock() - copyDelayBroadcastData := make([]*delayedBroadcastData, len(dbb.delayedBroadcastData)) + copyDelayBroadcastData := make([]*shared.DelayedBroadcastData, len(dbb.delayedBroadcastData)) copy(copyDelayBroadcastData, dbb.delayedBroadcastData) dbb.mutDataForBroadcast.RUnlock() diff --git a/consensus/broadcast/export_test.go b/consensus/broadcast/export_test.go new file mode 100644 index 00000000000..646dfa9b161 --- /dev/null +++ b/consensus/broadcast/export_test.go @@ -0,0 +1,12 @@ +package broadcast + +import ( + "github.com/multiversx/mx-chain-core-go/marshal" +) + +// SetMarshalizerMeta sets the unexported marshaller +func (mcm *metaChainMessenger) SetMarshalizerMeta( + m marshal.Marshalizer, +) { + mcm.marshalizer = m +} diff --git a/consensus/broadcast/interface.go b/consensus/broadcast/interface.go new file mode 100644 index 00000000000..4708bab7827 --- /dev/null +++ b/consensus/broadcast/interface.go @@ -0,0 +1,25 @@ +package broadcast + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" +) + +// DelayedBroadcaster exposes functionality for handling the consensus members broadcasting of delay data +type DelayedBroadcaster interface { + SetLeaderData(data *shared.DelayedBroadcastData) error + SetValidatorData(data *shared.DelayedBroadcastData) error + SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error + SetBroadcastHandlers( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, + ) error + SetFinalProofForValidator(proof *block.HeaderProof, consensusIndex int, pkBytes []byte) error + Close() + IsInterfaceNil() bool +} diff --git a/consensus/broadcast/metaChainMessenger.go b/consensus/broadcast/metaChainMessenger.go index daca3b436a5..78490fb5d01 100644 --- a/consensus/broadcast/metaChainMessenger.go +++ b/consensus/broadcast/metaChainMessenger.go @@ -5,8 +5,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process/factory" ) @@ -32,27 +34,13 @@ func NewMetaChainMessenger( return nil, err } - dbbArgs := &ArgsDelayedBlockBroadcaster{ - InterceptorsContainer: args.InterceptorsContainer, - HeadersSubscriber: args.HeadersSubscriber, - LeaderCacheSize: args.MaxDelayCacheSize, - ValidatorCacheSize: args.MaxValidatorDelayCacheSize, - ShardCoordinator: args.ShardCoordinator, - AlarmScheduler: args.AlarmScheduler, - } - - dbb, err := NewDelayedBlockBroadcaster(dbbArgs) - if err != nil { - return nil, err - } - cm := &commonMessenger{ marshalizer: args.Marshalizer, hasher: args.Hasher, messenger: args.Messenger, shardCoordinator: args.ShardCoordinator, peerSignatureHandler: args.PeerSignatureHandler, - delayedBlockBroadcaster: dbb, + delayedBlockBroadcaster: args.DelayedBroadcaster, keysHandler: args.KeysHandler, } @@ -60,7 +48,12 @@ func NewMetaChainMessenger( commonMessenger: cm, } - err = dbb.SetBroadcastHandlers(mcm.BroadcastMiniBlocks, mcm.BroadcastTransactions, mcm.BroadcastHeader) + err = mcm.delayedBlockBroadcaster.SetBroadcastHandlers( + mcm.BroadcastMiniBlocks, + mcm.BroadcastTransactions, + mcm.BroadcastHeader, + mcm.BroadcastEquivalentProof, + mcm.BroadcastConsensusMessage) if err != nil { return nil, err } @@ -124,6 +117,14 @@ func (mcm *metaChainMessenger) BroadcastHeader(header data.HeaderHandler, pkByte return nil } +// BroadcastEquivalentProof will broadcast the proof for a header on the metachain common topic +func (mcm *metaChainMessenger) BroadcastEquivalentProof(proof *block.HeaderProof, pkBytes []byte) error { + identifierMetaAll := mcm.shardCoordinator.CommunicationIdentifier(core.AllShardId) + topic := common.EquivalentProofsTopic + identifierMetaAll + + return mcm.broadcastEquivalentProof(proof, pkBytes, topic) +} + // BroadcastBlockDataLeader broadcasts the block data as consensus group leader func (mcm *metaChainMessenger) BroadcastBlockDataLeader( _ data.HeaderHandler, @@ -154,13 +155,13 @@ func (mcm *metaChainMessenger) PrepareBroadcastHeaderValidator( return } - vData := &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - metaMiniBlocksData: miniBlocks, - metaTransactionsData: transactions, - order: uint32(idx), - pkBytes: pkBytes, + vData := &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + MetaMiniBlocksData: miniBlocks, + MetaTransactionsData: transactions, + Order: uint32(idx), + PkBytes: pkBytes, } err = mcm.delayedBlockBroadcaster.SetHeaderForValidator(vData) diff --git a/consensus/broadcast/metaChainMessenger_test.go b/consensus/broadcast/metaChainMessenger_test.go index 01cbb6a151d..3e89f546b79 100644 --- a/consensus/broadcast/metaChainMessenger_test.go +++ b/consensus/broadcast/metaChainMessenger_test.go @@ -7,16 +7,22 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) var nodePkBytes = []byte("node public key bytes") @@ -27,10 +33,11 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} hasher := &hashingMocks.HasherMock{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock} - alarmScheduler := &mock.AlarmSchedulerStub{} + alarmScheduler := &testscommon.AlarmSchedulerStub{} + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{} return broadcast.MetaChainMessengerArgs{ CommonMessengerArgs: broadcast.CommonMessengerArgs{ @@ -45,6 +52,7 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { MaxDelayCacheSize: 2, AlarmScheduler: alarmScheduler, KeysHandler: &testscommon.KeysHandlerStub{}, + DelayedBroadcaster: delayedBroadcaster, }, } } @@ -94,6 +102,14 @@ func TestMetaChainMessenger_NilKeysHandlerShouldError(t *testing.T) { assert.Equal(t, broadcast.ErrNilKeysHandler, err) } +func TestMetaChainMessenger_NilDelayedBroadcasterShouldError(t *testing.T) { + args := createDefaultMetaChainArgs() + args.DelayedBroadcaster = nil + scm, err := broadcast.NewMetaChainMessenger(args) + + assert.Nil(t, scm) + assert.Equal(t, broadcast.ErrNilDelayedBroadcaster, err) +} func TestMetaChainMessenger_NewMetaChainMessengerShouldWork(t *testing.T) { args := createDefaultMetaChainArgs() mcm, err := broadcast.NewMetaChainMessenger(args) @@ -292,3 +308,115 @@ func TestMetaChainMessenger_BroadcastBlockDataLeader(t *testing.T) { assert.Equal(t, len(transactions), numBroadcast) }) } + +func TestMetaChainMessenger_Close(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + closeCalled := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + CloseCalled: func() { + closeCalled = true + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.Close() + assert.True(t, closeCalled) +} + +func TestMetaChainMessenger_PrepareBroadcastHeaderValidator(t *testing.T) { + t.Parallel() + + t.Run("Nil header", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.PrepareBroadcastHeaderValidator(nil, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + }) + t.Run("Err on core.CalculateHash", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + header := &block.Header{} + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.SetMarshalizerMeta(nil) + mcm.PrepareBroadcastHeaderValidator(header, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + }) + t.Run("Err on SetHeaderForValidator", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + checkVarModified := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + checkVarModified = true + return expectedErr + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + header := &block.Header{} + mcm.PrepareBroadcastHeaderValidator(header, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + assert.True(t, checkVarModified) + }) +} + +func TestMetaChainMessenger_BroadcastBlock(t *testing.T) { + t.Parallel() + + t.Run("Err nil blockData", func(t *testing.T) { + args := createDefaultMetaChainArgs() + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + err := mcm.BroadcastBlock(nil, nil) + assert.NotNil(t, err) + }) +} + +func TestMetaChainMessenger_NewMetaChainMessengerFailSetBroadcast(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + varModified := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetBroadcastHandlersCalled: func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofsBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error) error { + varModified = true + return expectedErr + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, err := broadcast.NewMetaChainMessenger(args) + assert.Nil(t, mcm) + assert.NotNil(t, err) + assert.True(t, varModified) +} diff --git a/consensus/broadcast/shardChainMessenger.go b/consensus/broadcast/shardChainMessenger.go index ac7485a8d1f..f479cf3bc35 100644 --- a/consensus/broadcast/shardChainMessenger.go +++ b/consensus/broadcast/shardChainMessenger.go @@ -7,8 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process/factory" ) @@ -37,35 +39,25 @@ func NewShardChainMessenger( } cm := &commonMessenger{ - marshalizer: args.Marshalizer, - hasher: args.Hasher, - messenger: args.Messenger, - shardCoordinator: args.ShardCoordinator, - peerSignatureHandler: args.PeerSignatureHandler, - keysHandler: args.KeysHandler, - } - - dbbArgs := &ArgsDelayedBlockBroadcaster{ - InterceptorsContainer: args.InterceptorsContainer, - HeadersSubscriber: args.HeadersSubscriber, - LeaderCacheSize: args.MaxDelayCacheSize, - ValidatorCacheSize: args.MaxValidatorDelayCacheSize, - ShardCoordinator: args.ShardCoordinator, - AlarmScheduler: args.AlarmScheduler, - } - - dbb, err := NewDelayedBlockBroadcaster(dbbArgs) - if err != nil { - return nil, err + marshalizer: args.Marshalizer, + hasher: args.Hasher, + messenger: args.Messenger, + shardCoordinator: args.ShardCoordinator, + peerSignatureHandler: args.PeerSignatureHandler, + keysHandler: args.KeysHandler, + delayedBlockBroadcaster: args.DelayedBroadcaster, } - cm.delayedBlockBroadcaster = dbb - scm := &shardChainMessenger{ commonMessenger: cm, } - err = dbb.SetBroadcastHandlers(scm.BroadcastMiniBlocks, scm.BroadcastTransactions, scm.BroadcastHeader) + err = scm.delayedBlockBroadcaster.SetBroadcastHandlers( + scm.BroadcastMiniBlocks, + scm.BroadcastTransactions, + scm.BroadcastHeader, + scm.BroadcastEquivalentProof, + scm.BroadcastConsensusMessage) if err != nil { return nil, err } @@ -136,6 +128,14 @@ func (scm *shardChainMessenger) BroadcastHeader(header data.HeaderHandler, pkByt return nil } +// BroadcastEquivalentProof will broadcast the proof for a header on the shard metachain common topic +func (scm *shardChainMessenger) BroadcastEquivalentProof(proof *block.HeaderProof, pkBytes []byte) error { + shardIdentifier := scm.shardCoordinator.CommunicationIdentifier(core.MetachainShardId) + topic := common.EquivalentProofsTopic + shardIdentifier + + return scm.broadcastEquivalentProof(proof, pkBytes, topic) +} + // BroadcastBlockDataLeader broadcasts the block data as consensus group leader func (scm *shardChainMessenger) BroadcastBlockDataLeader( header data.HeaderHandler, @@ -157,11 +157,11 @@ func (scm *shardChainMessenger) BroadcastBlockDataLeader( metaMiniBlocks, metaTransactions := scm.extractMetaMiniBlocksAndTransactions(miniBlocks, transactions) - broadcastData := &delayedBroadcastData{ - headerHash: headerHash, - miniBlocksData: miniBlocks, - transactions: transactions, - pkBytes: pkBytes, + broadcastData := &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + MiniBlocksData: miniBlocks, + Transactions: transactions, + PkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetLeaderData(broadcastData) @@ -192,11 +192,11 @@ func (scm *shardChainMessenger) PrepareBroadcastHeaderValidator( return } - vData := &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - order: uint32(idx), - pkBytes: pkBytes, + vData := &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + Order: uint32(idx), + PkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetHeaderForValidator(vData) @@ -228,13 +228,13 @@ func (scm *shardChainMessenger) PrepareBroadcastBlockDataValidator( return } - broadcastData := &delayedBroadcastData{ - headerHash: headerHash, - header: header, - miniBlocksData: miniBlocks, - transactions: transactions, - order: uint32(idx), - pkBytes: pkBytes, + broadcastData := &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + Header: header, + MiniBlocksData: miniBlocks, + Transactions: transactions, + Order: uint32(idx), + PkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetValidatorData(broadcastData) diff --git a/consensus/broadcast/shardChainMessenger_test.go b/consensus/broadcast/shardChainMessenger_test.go index c81d2d98c28..3f0155a05ee 100644 --- a/consensus/broadcast/shardChainMessenger_test.go +++ b/consensus/broadcast/shardChainMessenger_test.go @@ -2,13 +2,24 @@ package broadcast_test import ( "bytes" + "errors" "testing" "time" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/consensus" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/pool" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/broadcast" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/p2p" @@ -17,9 +28,10 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) +var expectedErr = errors.New("expected error") + func createDelayData(prefix string) ([]byte, *block.Header, map[uint32][]byte, map[string][][]byte) { miniblocks := make(map[uint32][]byte) receiverShardID := uint32(1) @@ -58,12 +70,13 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { messengerMock := &p2pmocks.MessengerStub{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() peerSigHandler := &mock.PeerSignatureHandler{ Signer: singleSignerMock, } - alarmScheduler := &mock.AlarmSchedulerStub{} + alarmScheduler := &testscommon.AlarmSchedulerStub{} + delayedBroadcaster := &testscommonConsensus.DelayedBroadcasterMock{} return broadcast.ShardChainMessengerArgs{ CommonMessengerArgs: broadcast.CommonMessengerArgs{ @@ -78,6 +91,20 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { MaxValidatorDelayCacheSize: 1, AlarmScheduler: alarmScheduler, KeysHandler: &testscommon.KeysHandlerStub{}, + DelayedBroadcaster: delayedBroadcaster, + }, + } +} + +func newBlockWithEmptyMiniblock() *block.Body { + return &block.Body{ + MiniBlocks: []*block.MiniBlock{ + { + TxHashes: [][]byte{}, + ReceiverShardID: 0, + SenderShardID: 0, + Type: 0, + }, }, } } @@ -85,6 +112,7 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { func TestShardChainMessenger_NewShardChainMessengerNilMarshalizerShouldFail(t *testing.T) { args := createDefaultShardChainArgs() args.Marshalizer = nil + scm, err := broadcast.NewShardChainMessenger(args) assert.Nil(t, scm) @@ -136,6 +164,15 @@ func TestShardChainMessenger_NewShardChainMessengerNilHeadersSubscriberShouldFai assert.Equal(t, spos.ErrNilHeadersSubscriber, err) } +func TestShardChainMessenger_NilDelayedBroadcasterShouldError(t *testing.T) { + args := createDefaultShardChainArgs() + args.DelayedBroadcaster = nil + scm, err := broadcast.NewShardChainMessenger(args) + + assert.Nil(t, scm) + assert.Equal(t, broadcast.ErrNilDelayedBroadcaster, err) +} + func TestShardChainMessenger_NilKeysHandlerShouldError(t *testing.T) { args := createDefaultShardChainArgs() args.KeysHandler = nil @@ -154,6 +191,26 @@ func TestShardChainMessenger_NewShardChainMessengerShouldWork(t *testing.T) { assert.False(t, scm.IsInterfaceNil()) } +func TestShardChainMessenger_NewShardChainMessengerShouldErr(t *testing.T) { + + args := createDefaultShardChainArgs() + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetBroadcastHandlersCalled: func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofsBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, + ) error { + return expectedErr + }} + + _, err := broadcast.NewShardChainMessenger(args) + + assert.Equal(t, expectedErr, err) + +} + func TestShardChainMessenger_BroadcastBlockShouldErrNilBody(t *testing.T) { args := createDefaultShardChainArgs() scm, _ := broadcast.NewShardChainMessenger(args) @@ -170,6 +227,14 @@ func TestShardChainMessenger_BroadcastBlockShouldErrNilHeader(t *testing.T) { assert.Equal(t, spos.ErrNilHeader, err) } +func TestShardChainMessenger_BroadcastBlockShouldErrMiniBlockEmpty(t *testing.T) { + args := createDefaultShardChainArgs() + scm, _ := broadcast.NewShardChainMessenger(args) + + err := scm.BroadcastBlock(newBlockWithEmptyMiniblock(), &block.Header{}) + assert.Equal(t, data.ErrMiniBlockEmpty, err) +} + func TestShardChainMessenger_BroadcastBlockShouldErrMockMarshalizer(t *testing.T) { marshalizer := mock.MarshalizerMock{ Fail: true, @@ -363,6 +428,19 @@ func TestShardChainMessenger_BroadcastHeaderNilHeaderShouldErr(t *testing.T) { assert.Equal(t, spos.ErrNilHeader, err) } +func TestShardChainMessenger_BroadcastHeaderShouldErr(t *testing.T) { + marshalizer := mock.MarshalizerMock{ + Fail: true, + } + + args := createDefaultShardChainArgs() + args.Marshalizer = marshalizer + scm, _ := broadcast.NewShardChainMessenger(args) + + err := scm.BroadcastHeader(&block.MetaBlock{Nonce: 10}, []byte("pk bytes")) + assert.Equal(t, mock.ErrMockMarshalizer, err) +} + func TestShardChainMessenger_BroadcastHeaderShouldWork(t *testing.T) { channelBroadcastCalled := make(chan bool, 1) channelBroadcastUsingPrivateKeyCalled := make(chan bool, 1) @@ -439,6 +517,41 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderNilMiniblocksShouldReturnNi assert.Nil(t, err) } +func TestShardChainMessenger_BroadcastBlockDataLeaderShouldErr(t *testing.T) { + marshalizer := mock.MarshalizerMock{ + Fail: true, + } + + args := createDefaultShardChainArgs() + args.Marshalizer = marshalizer + + scm, _ := broadcast.NewShardChainMessenger(args) + + _, header, miniblocks, transactions := createDelayData("1") + + err := scm.BroadcastBlockDataLeader(header, miniblocks, transactions, []byte("pk bytes")) + assert.Equal(t, mock.ErrMockMarshalizer, err) +} + +func TestShardChainMessenger_BroadcastBlockDataLeaderShouldErrDelayedBroadcaster(t *testing.T) { + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetLeaderDataCalled: func(data *shared.DelayedBroadcastData) error { + return expectedErr + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + _, header, miniblocks, transactions := createDelayData("1") + + err := scm.BroadcastBlockDataLeader(header, miniblocks, transactions, []byte("pk bytes")) + + assert.Equal(t, expectedErr, err) +} + func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayedMessage(t *testing.T) { broadcastWasCalled := atomic.Flag{} broadcastUsingPrivateKeyWasCalled := atomic.Flag{} @@ -457,6 +570,18 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayed return bytes.Equal(pkBytes, nodePkBytes) }, } + argsDelayedBroadcaster := broadcast.ArgsDelayedBlockBroadcaster{ + InterceptorsContainer: args.InterceptorsContainer, + HeadersSubscriber: args.HeadersSubscriber, + ShardCoordinator: args.ShardCoordinator, + LeaderCacheSize: args.MaxDelayCacheSize, + ValidatorCacheSize: args.MaxDelayCacheSize, + AlarmScheduler: args.AlarmScheduler, + } + + // Using real component in order to properly simulate the expected behavior + args.DelayedBroadcaster, _ = broadcast.NewDelayedBlockBroadcaster(&argsDelayedBroadcaster) + scm, _ := broadcast.NewShardChainMessenger(args) t.Run("original public key of the node", func(t *testing.T) { @@ -488,3 +613,190 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayed assert.True(t, broadcastUsingPrivateKeyWasCalled.IsSet()) }) } + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldFailHeaderNil(t *testing.T) { + + pkBytes := make([]byte, 32) + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(nil, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldFailCalculateHashErr(t *testing.T) { + + pkBytes := make([]byte, 32) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(headerMock, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldWork(t *testing.T) { + + pkBytes := make([]byte, 32) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + varSetHeaderForValidatorCalled := false + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + varSetHeaderForValidatorCalled = true + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, nil + }} + args.Hasher = &testscommon.HasherStub{ComputeCalled: func(s string) []byte { + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(headerMock, nil, nil, 1, pkBytes) + + assert.True(t, varSetHeaderForValidatorCalled) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailHeaderNil(t *testing.T) { + + pkBytes := make([]byte, 32) + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(nil, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailMiniBlocksLenZero(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := make(map[uint32][]byte) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailCalculateHashErr(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := map[uint32][]byte{1: {}} + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldWork(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := map[uint32][]byte{1: {}} + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + varSetValidatorDataCalled := false + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + varSetValidatorDataCalled = true + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, nil + }, + } + + args.Hasher = &testscommon.HasherStub{ + ComputeCalled: func(s string) []byte { + return nil + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) + + assert.True(t, varSetValidatorDataCalled) +} + +func TestShardChainMessenger_CloseShouldWork(t *testing.T) { + + args := createDefaultShardChainArgs() + + varCloseCalled := false + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + CloseCalled: func() { + varCloseCalled = true + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.Close() + assert.True(t, varCloseCalled) + +} diff --git a/consensus/broadcast/shared/types.go b/consensus/broadcast/shared/types.go new file mode 100644 index 00000000000..216cd5987b8 --- /dev/null +++ b/consensus/broadcast/shared/types.go @@ -0,0 +1,26 @@ +package shared + +import ( + "github.com/multiversx/mx-chain-core-go/data" +) + +// DelayedBroadcastData is exported to be accessible in delayedBroadcasterMock +type DelayedBroadcastData struct { + HeaderHash []byte + Header data.HeaderHandler + MiniBlocksData map[uint32][]byte + MiniBlockHashes map[string]map[string]struct{} + Transactions map[string][][]byte + Order uint32 + PkBytes []byte +} + +// ValidatorHeaderBroadcastData is exported to be accessible in delayedBroadcasterMock +type ValidatorHeaderBroadcastData struct { + HeaderHash []byte + Header data.HeaderHandler + MetaMiniBlocksData map[uint32][]byte + MetaTransactionsData map[string][][]byte + Order uint32 + PkBytes []byte +} diff --git a/consensus/chronology/chronology.go b/consensus/chronology/chronology.go index 1b20bc1dc03..0c195c2e31a 100644 --- a/consensus/chronology/chronology.go +++ b/consensus/chronology/chronology.go @@ -10,10 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/closing" "github.com/multiversx/mx-chain-core-go/display" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/ntp" - "github.com/multiversx/mx-chain-logger-go" ) var _ consensus.ChronologyHandler = (*chronology)(nil) diff --git a/consensus/chronology/chronology_test.go b/consensus/chronology/chronology_test.go index 978d898834c..c14a5be13e5 100644 --- a/consensus/chronology/chronology_test.go +++ b/consensus/chronology/chronology_test.go @@ -5,11 +5,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/chronology" "github.com/multiversx/mx-chain-go/consensus/mock" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func initSubroundHandlerMock() *mock.SubroundHandlerMock { @@ -115,7 +118,7 @@ func TestChronology_StartRoundShouldReturnWhenRoundIndexIsNegative(t *testing.T) t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.IndexCalled = func() int64 { return -1 } @@ -149,7 +152,7 @@ func TestChronology_StartRoundShouldReturnWhenDoWorkReturnsFalse(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock chr, _ := chronology.NewChronology(arg) @@ -166,7 +169,7 @@ func TestChronology_StartRoundShouldWork(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock chr, _ := chronology.NewChronology(arg) @@ -219,7 +222,7 @@ func TestChronology_InitRoundShouldNotSetSubroundWhenRoundIndexIsNegative(t *tes t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -240,7 +243,7 @@ func TestChronology_InitRoundShouldSetSubroundWhenRoundIndexIsPositive(t *testin t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() @@ -257,7 +260,7 @@ func TestChronology_StartRoundShouldNotUpdateRoundWhenCurrentRoundIsNotFinished( t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -271,7 +274,7 @@ func TestChronology_StartRoundShouldNotUpdateRoundWhenCurrentRoundIsNotFinished( func TestChronology_StartRoundShouldUpdateRoundWhenCurrentRoundIsFinished(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -315,9 +318,75 @@ func TestChronology_CheckIfStatusHandlerWorks(t *testing.T) { func getDefaultChronologyArg() chronology.ArgChronology { return chronology.ArgChronology{ GenesisTime: time.Now(), - RoundHandler: &mock.RoundHandlerMock{}, - SyncTimer: &mock.SyncTimerMock{}, + RoundHandler: &consensusMocks.RoundHandlerMock{}, + SyncTimer: &consensusMocks.SyncTimerMock{}, AppStatusHandler: statusHandlerMock.NewAppStatusHandlerMock(), Watchdog: &mock.WatchdogMock{}, } } + +func TestChronology_CloseWatchDogStop(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + stopCalled := false + arg.Watchdog = &mock.WatchdogMock{ + StopCalled: func(alarmID string) { + stopCalled = true + }, + } + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + chr.SetCancelFunc(nil) + + err = chr.Close() + assert.Nil(t, err) + assert.True(t, stopCalled) +} + +func TestChronology_Close(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + stopCalled := false + arg.Watchdog = &mock.WatchdogMock{ + StopCalled: func(alarmID string) { + stopCalled = true + }, + } + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + + cancelCalled := false + chr.SetCancelFunc(func() { + cancelCalled = true + }) + + err = chr.Close() + assert.Nil(t, err) + assert.True(t, stopCalled) + assert.True(t, cancelCalled) +} + +func TestChronology_StartRounds(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + doneFuncCalled := false + + ctx := &mock.ContextMock{ + DoneFunc: func() <-chan struct{} { + done := make(chan struct{}) + close(done) + doneFuncCalled = true + return done + }, + } + chr.StartRoundsTest(ctx) + assert.True(t, doneFuncCalled) +} diff --git a/consensus/chronology/export_test.go b/consensus/chronology/export_test.go index 39ff4cab99f..b3a35131597 100644 --- a/consensus/chronology/export_test.go +++ b/consensus/chronology/export_test.go @@ -3,6 +3,8 @@ package chronology import ( "context" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -37,3 +39,18 @@ func (chr *chronology) UpdateRound() { func (chr *chronology) InitRound() { chr.initRound() } + +// StartRoundsTest calls the unexported startRounds function +func (chr *chronology) StartRoundsTest(ctx context.Context) { + chr.startRounds(ctx) +} + +// SetWatchdog sets the watchdog for chronology object +func (chr *chronology) SetWatchdog(watchdog core.WatchdogTimer) { + chr.watchdog = watchdog +} + +// SetCancelFunc sets cancelFunc for chronology object +func (chr *chronology) SetCancelFunc(cancelFunc func()) { + chr.cancelFunc = cancelFunc +} diff --git a/consensus/interface.go b/consensus/interface.go index aa8d9057bc4..8dfc1018172 100644 --- a/consensus/interface.go +++ b/consensus/interface.go @@ -6,7 +6,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -61,12 +63,14 @@ type ChronologyHandler interface { type BroadcastMessenger interface { BroadcastBlock(data.BodyHandler, data.HeaderHandler) error BroadcastHeader(data.HeaderHandler, []byte) error + BroadcastEquivalentProof(proof *block.HeaderProof, pkBytes []byte) error BroadcastMiniBlocks(map[uint32][]byte, []byte) error BroadcastTransactions(map[string][][]byte, []byte) error BroadcastConsensusMessage(*Message) error BroadcastBlockDataLeader(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) error PrepareBroadcastHeaderValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) PrepareBroadcastBlockDataValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) + PrepareBroadcastEquivalentProof(proof *block.HeaderProof, consensusIndex int, pkBytes []byte) IsInterfaceNil() bool } @@ -122,11 +126,14 @@ type HeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } // FallbackHeaderValidator defines the behaviour of a component able to signal when a fallback header validation could be applied type FallbackHeaderValidator interface { + ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool IsInterfaceNil() bool } @@ -193,3 +200,21 @@ type KeysHandler interface { GetRedundancyStepInReason() string IsInterfaceNil() bool } + +// EquivalentProofsPool defines the behaviour of a proofs pool components +type EquivalentProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) error + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} + +// ProofHandler defines the interface for a proof handler +type ProofHandler interface { + GetPubKeysBitmap() []byte + GetAggregatedSignature() []byte + GetHeaderHash() []byte + GetHeaderEpoch() uint32 + GetHeaderNonce() uint64 + GetHeaderShardId() uint32 +} diff --git a/consensus/message.go b/consensus/message.go index f4396c05076..3e581673d17 100644 --- a/consensus/message.go +++ b/consensus/message.go @@ -1,7 +1,9 @@ //go:generate protoc -I=. -I=$GOPATH/src -I=$GOPATH/src/github.com/multiversx/protobuf/protobuf --gogoslick_out=. message.proto package consensus -import "github.com/multiversx/mx-chain-core-go/core" +import ( + "github.com/multiversx/mx-chain-core-go/core" +) // MessageType specifies what type of message was received type MessageType int diff --git a/consensus/mock/alarmSchedulerStub.go b/consensus/mock/alarmSchedulerStub.go deleted file mode 100644 index fe2e7597036..00000000000 --- a/consensus/mock/alarmSchedulerStub.go +++ /dev/null @@ -1,45 +0,0 @@ -package mock - -import ( - "time" -) - -type AlarmSchedulerStub struct { - AddCalled func(func(alarmID string), time.Duration, string) - CancelCalled func(string) - CloseCalled func() - ResetCalled func(string) -} - -// Add - -func (a *AlarmSchedulerStub) Add(callback func(alarmID string), duration time.Duration, alarmID string) { - if a.AddCalled != nil { - a.AddCalled(callback, duration, alarmID) - } -} - -// Cancel - -func (a *AlarmSchedulerStub) Cancel(alarmID string) { - if a.CancelCalled != nil { - a.CancelCalled(alarmID) - } -} - -// Close - -func (a *AlarmSchedulerStub) Close() { - if a.CloseCalled != nil { - a.CloseCalled() - } -} - -// Reset - -func (a *AlarmSchedulerStub) Reset(alarmID string) { - if a.ResetCalled != nil { - a.ResetCalled(alarmID) - } -} - -// IsInterfaceNil - -func (a *AlarmSchedulerStub) IsInterfaceNil() bool { - return a == nil -} diff --git a/consensus/mock/consensusStateMock.go b/consensus/mock/consensusStateMock.go deleted file mode 100644 index fb4fb708449..00000000000 --- a/consensus/mock/consensusStateMock.go +++ /dev/null @@ -1,137 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" -) - -// ConsensusStateMock - -type ConsensusStateMock struct { - ResetConsensusStateCalled func() - IsNodeLeaderInCurrentRoundCalled func(node string) bool - IsSelfLeaderInCurrentRoundCalled func() bool - GetLeaderCalled func() (string, error) - GetNextConsensusGroupCalled func(randomSource string, vgs nodesCoordinator.NodesCoordinator) ([]string, error) - IsConsensusDataSetCalled func() bool - IsConsensusDataEqualCalled func(data []byte) bool - IsJobDoneCalled func(node string, currentSubroundId int) bool - IsSelfJobDoneCalled func(currentSubroundId int) bool - IsCurrentSubroundFinishedCalled func(currentSubroundId int) bool - IsNodeSelfCalled func(node string) bool - IsBlockBodyAlreadyReceivedCalled func() bool - IsHeaderAlreadyReceivedCalled func() bool - CanDoSubroundJobCalled func(currentSubroundId int) bool - CanProcessReceivedMessageCalled func(cnsDta consensus.Message, currentRoundIndex int32, currentSubroundId int) bool - GenerateBitmapCalled func(subroundId int) []byte - ProcessingBlockCalled func() bool - SetProcessingBlockCalled func(processingBlock bool) - ConsensusGroupSizeCalled func() int - SetThresholdCalled func(subroundId int, threshold int) -} - -// ResetConsensusState - -func (cnsm *ConsensusStateMock) ResetConsensusState() { - cnsm.ResetConsensusStateCalled() -} - -// IsNodeLeaderInCurrentRound - -func (cnsm *ConsensusStateMock) IsNodeLeaderInCurrentRound(node string) bool { - return cnsm.IsNodeLeaderInCurrentRoundCalled(node) -} - -// IsSelfLeaderInCurrentRound - -func (cnsm *ConsensusStateMock) IsSelfLeaderInCurrentRound() bool { - return cnsm.IsSelfLeaderInCurrentRoundCalled() -} - -// GetLeader - -func (cnsm *ConsensusStateMock) GetLeader() (string, error) { - return cnsm.GetLeaderCalled() -} - -// GetNextConsensusGroup - -func (cnsm *ConsensusStateMock) GetNextConsensusGroup( - randomSource string, - vgs nodesCoordinator.NodesCoordinator, -) ([]string, error) { - return cnsm.GetNextConsensusGroupCalled(randomSource, vgs) -} - -// IsConsensusDataSet - -func (cnsm *ConsensusStateMock) IsConsensusDataSet() bool { - return cnsm.IsConsensusDataSetCalled() -} - -// IsConsensusDataEqual - -func (cnsm *ConsensusStateMock) IsConsensusDataEqual(data []byte) bool { - return cnsm.IsConsensusDataEqualCalled(data) -} - -// IsJobDone - -func (cnsm *ConsensusStateMock) IsJobDone(node string, currentSubroundId int) bool { - return cnsm.IsJobDoneCalled(node, currentSubroundId) -} - -// IsSelfJobDone - -func (cnsm *ConsensusStateMock) IsSelfJobDone(currentSubroundId int) bool { - return cnsm.IsSelfJobDoneCalled(currentSubroundId) -} - -// IsCurrentSubroundFinished - -func (cnsm *ConsensusStateMock) IsCurrentSubroundFinished(currentSubroundId int) bool { - return cnsm.IsCurrentSubroundFinishedCalled(currentSubroundId) -} - -// IsNodeSelf - -func (cnsm *ConsensusStateMock) IsNodeSelf(node string) bool { - return cnsm.IsNodeSelfCalled(node) -} - -// IsBlockBodyAlreadyReceived - -func (cnsm *ConsensusStateMock) IsBlockBodyAlreadyReceived() bool { - return cnsm.IsBlockBodyAlreadyReceivedCalled() -} - -// IsHeaderAlreadyReceived - -func (cnsm *ConsensusStateMock) IsHeaderAlreadyReceived() bool { - return cnsm.IsHeaderAlreadyReceivedCalled() -} - -// CanDoSubroundJob - -func (cnsm *ConsensusStateMock) CanDoSubroundJob(currentSubroundId int) bool { - return cnsm.CanDoSubroundJobCalled(currentSubroundId) -} - -// CanProcessReceivedMessage - -func (cnsm *ConsensusStateMock) CanProcessReceivedMessage( - cnsDta consensus.Message, - currentRoundIndex int32, - currentSubroundId int, -) bool { - return cnsm.CanProcessReceivedMessageCalled(cnsDta, currentRoundIndex, currentSubroundId) -} - -// GenerateBitmap - -func (cnsm *ConsensusStateMock) GenerateBitmap(subroundId int) []byte { - return cnsm.GenerateBitmapCalled(subroundId) -} - -// ProcessingBlock - -func (cnsm *ConsensusStateMock) ProcessingBlock() bool { - return cnsm.ProcessingBlockCalled() -} - -// SetProcessingBlock - -func (cnsm *ConsensusStateMock) SetProcessingBlock(processingBlock bool) { - cnsm.SetProcessingBlockCalled(processingBlock) -} - -// ConsensusGroupSize - -func (cnsm *ConsensusStateMock) ConsensusGroupSize() int { - return cnsm.ConsensusGroupSizeCalled() -} - -// SetThreshold - -func (cnsm *ConsensusStateMock) SetThreshold(subroundId int, threshold int) { - cnsm.SetThresholdCalled(subroundId, threshold) -} diff --git a/consensus/mock/contextMock.go b/consensus/mock/contextMock.go new file mode 100644 index 00000000000..0cdab606821 --- /dev/null +++ b/consensus/mock/contextMock.go @@ -0,0 +1,45 @@ +package mock + +import ( + "time" +) + +// ContextMock - +type ContextMock struct { + DoneFunc func() <-chan struct{} + DeadlineFunc func() (time.Time, bool) + ErrFunc func() error + ValueFunc func(key interface{}) interface{} +} + +// Done - +func (c *ContextMock) Done() <-chan struct{} { + if c.DoneFunc != nil { + return c.DoneFunc() + } + return nil +} + +// Deadline - +func (c *ContextMock) Deadline() (time.Time, bool) { + if c.DeadlineFunc != nil { + return c.DeadlineFunc() + } + return time.Time{}, false +} + +// Err - +func (c *ContextMock) Err() error { + if c.ErrFunc != nil { + return c.ErrFunc() + } + return nil +} + +// Value - +func (c *ContextMock) Value(key interface{}) interface{} { + if c.ValueFunc != nil { + return c.ValueFunc(key) + } + return nil +} diff --git a/consensus/mock/epochStartNotifierStub.go b/consensus/mock/epochStartNotifierStub.go deleted file mode 100644 index a671e0f2ead..00000000000 --- a/consensus/mock/epochStartNotifierStub.go +++ /dev/null @@ -1,65 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/epochStart" -) - -// EpochStartNotifierStub - -type EpochStartNotifierStub struct { - RegisterHandlerCalled func(handler epochStart.ActionHandler) - UnregisterHandlerCalled func(handler epochStart.ActionHandler) - NotifyAllCalled func(hdr data.HeaderHandler) - NotifyAllPrepareCalled func(hdr data.HeaderHandler, body data.BodyHandler) - epochStartHdls []epochStart.ActionHandler -} - -// RegisterHandler - -func (esnm *EpochStartNotifierStub) RegisterHandler(handler epochStart.ActionHandler) { - if esnm.RegisterHandlerCalled != nil { - esnm.RegisterHandlerCalled(handler) - } - - esnm.epochStartHdls = append(esnm.epochStartHdls, handler) -} - -// UnregisterHandler - -func (esnm *EpochStartNotifierStub) UnregisterHandler(handler epochStart.ActionHandler) { - if esnm.UnregisterHandlerCalled != nil { - esnm.UnregisterHandlerCalled(handler) - } - - for i, hdl := range esnm.epochStartHdls { - if hdl == handler { - esnm.epochStartHdls = append(esnm.epochStartHdls[:i], esnm.epochStartHdls[i+1:]...) - break - } - } -} - -// NotifyAllPrepare - -func (esnm *EpochStartNotifierStub) NotifyAllPrepare(metaHdr data.HeaderHandler, body data.BodyHandler) { - if esnm.NotifyAllPrepareCalled != nil { - esnm.NotifyAllPrepareCalled(metaHdr, body) - } - - for _, hdl := range esnm.epochStartHdls { - hdl.EpochStartPrepare(metaHdr, body) - } -} - -// NotifyAll - -func (esnm *EpochStartNotifierStub) NotifyAll(hdr data.HeaderHandler) { - if esnm.NotifyAllCalled != nil { - esnm.NotifyAllCalled(hdr) - } - - for _, hdl := range esnm.epochStartHdls { - hdl.EpochStartAction(hdr) - } -} - -// IsInterfaceNil - -func (esnm *EpochStartNotifierStub) IsInterfaceNil() bool { - return esnm == nil -} diff --git a/consensus/mock/forkDetectorMock.go b/consensus/mock/forkDetectorMock.go deleted file mode 100644 index 6c1a4f70d5e..00000000000 --- a/consensus/mock/forkDetectorMock.go +++ /dev/null @@ -1,93 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/process" -) - -// ForkDetectorMock - -type ForkDetectorMock struct { - AddHeaderCalled func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error - RemoveHeaderCalled func(nonce uint64, hash []byte) - CheckForkCalled func() *process.ForkInfo - GetHighestFinalBlockNonceCalled func() uint64 - GetHighestFinalBlockHashCalled func() []byte - ProbableHighestNonceCalled func() uint64 - ResetForkCalled func() - GetNotarizedHeaderHashCalled func(nonce uint64) []byte - SetRollBackNonceCalled func(nonce uint64) - RestoreToGenesisCalled func() - ResetProbableHighestNonceCalled func() - SetFinalToLastCheckpointCalled func() -} - -// RestoreToGenesis - -func (fdm *ForkDetectorMock) RestoreToGenesis() { - fdm.RestoreToGenesisCalled() -} - -// AddHeader - -func (fdm *ForkDetectorMock) AddHeader(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) -} - -// RemoveHeader - -func (fdm *ForkDetectorMock) RemoveHeader(nonce uint64, hash []byte) { - fdm.RemoveHeaderCalled(nonce, hash) -} - -// CheckFork - -func (fdm *ForkDetectorMock) CheckFork() *process.ForkInfo { - return fdm.CheckForkCalled() -} - -// GetHighestFinalBlockNonce - -func (fdm *ForkDetectorMock) GetHighestFinalBlockNonce() uint64 { - return fdm.GetHighestFinalBlockNonceCalled() -} - -// GetHighestFinalBlockHash - -func (fdm *ForkDetectorMock) GetHighestFinalBlockHash() []byte { - return fdm.GetHighestFinalBlockHashCalled() -} - -// ProbableHighestNonce - -func (fdm *ForkDetectorMock) ProbableHighestNonce() uint64 { - return fdm.ProbableHighestNonceCalled() -} - -// SetRollBackNonce - -func (fdm *ForkDetectorMock) SetRollBackNonce(nonce uint64) { - if fdm.SetRollBackNonceCalled != nil { - fdm.SetRollBackNonceCalled(nonce) - } -} - -// ResetFork - -func (fdm *ForkDetectorMock) ResetFork() { - fdm.ResetForkCalled() -} - -// GetNotarizedHeaderHash - -func (fdm *ForkDetectorMock) GetNotarizedHeaderHash(nonce uint64) []byte { - return fdm.GetNotarizedHeaderHashCalled(nonce) -} - -// ResetProbableHighestNonce - -func (fdm *ForkDetectorMock) ResetProbableHighestNonce() { - if fdm.ResetProbableHighestNonceCalled != nil { - fdm.ResetProbableHighestNonceCalled() - } -} - -// SetFinalToLastCheckpoint - -func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { - if fdm.SetFinalToLastCheckpointCalled != nil { - fdm.SetFinalToLastCheckpointCalled() - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (fdm *ForkDetectorMock) IsInterfaceNil() bool { - return fdm == nil -} diff --git a/consensus/mock/headerIntegrityVerifierStub.go b/consensus/mock/headerIntegrityVerifierStub.go deleted file mode 100644 index 3d793b89924..00000000000 --- a/consensus/mock/headerIntegrityVerifierStub.go +++ /dev/null @@ -1,32 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderIntegrityVerifierStub - -type HeaderIntegrityVerifierStub struct { - VerifyCalled func(header data.HeaderHandler) error - GetVersionCalled func(epoch uint32) string -} - -// Verify - -func (h *HeaderIntegrityVerifierStub) Verify(header data.HeaderHandler) error { - if h.VerifyCalled != nil { - return h.VerifyCalled(header) - } - - return nil -} - -// GetVersion - -func (h *HeaderIntegrityVerifierStub) GetVersion(epoch uint32) string { - if h.GetVersionCalled != nil { - return h.GetVersionCalled(epoch) - } - - return "version" -} - -// IsInterfaceNil - -func (h *HeaderIntegrityVerifierStub) IsInterfaceNil() bool { - return h == nil -} diff --git a/consensus/mock/headerSigVerifierStub.go b/consensus/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/consensus/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/consensus/mock/headersCacherStub.go b/consensus/mock/headersCacherStub.go deleted file mode 100644 index bc458a8235f..00000000000 --- a/consensus/mock/headersCacherStub.go +++ /dev/null @@ -1,105 +0,0 @@ -package mock - -import ( - "errors" - - "github.com/multiversx/mx-chain-core-go/data" -) - -// HeadersCacherStub - -type HeadersCacherStub struct { - AddCalled func(headerHash []byte, header data.HeaderHandler) - RemoveHeaderByHashCalled func(headerHash []byte) - RemoveHeaderByNonceAndShardIdCalled func(hdrNonce uint64, shardId uint32) - GetHeaderByNonceAndShardIdCalled func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) - GetHeaderByHashCalled func(hash []byte) (data.HeaderHandler, error) - ClearCalled func() - RegisterHandlerCalled func(handler func(header data.HeaderHandler, shardHeaderHash []byte)) - NoncesCalled func(shardId uint32) []uint64 - LenCalled func() int - MaxSizeCalled func() int - GetNumHeadersCalled func(shardId uint32) int -} - -// AddHeader - -func (hcs *HeadersCacherStub) AddHeader(headerHash []byte, header data.HeaderHandler) { - if hcs.AddCalled != nil { - hcs.AddCalled(headerHash, header) - } -} - -// RemoveHeaderByHash - -func (hcs *HeadersCacherStub) RemoveHeaderByHash(headerHash []byte) { - if hcs.RemoveHeaderByHashCalled != nil { - hcs.RemoveHeaderByHashCalled(headerHash) - } -} - -// RemoveHeaderByNonceAndShardId - -func (hcs *HeadersCacherStub) RemoveHeaderByNonceAndShardId(hdrNonce uint64, shardId uint32) { - if hcs.RemoveHeaderByNonceAndShardIdCalled != nil { - hcs.RemoveHeaderByNonceAndShardIdCalled(hdrNonce, shardId) - } -} - -// GetHeadersByNonceAndShardId - -func (hcs *HeadersCacherStub) GetHeadersByNonceAndShardId(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { - if hcs.GetHeaderByNonceAndShardIdCalled != nil { - return hcs.GetHeaderByNonceAndShardIdCalled(hdrNonce, shardId) - } - return nil, nil, errors.New("err") -} - -// GetHeaderByHash - -func (hcs *HeadersCacherStub) GetHeaderByHash(hash []byte) (data.HeaderHandler, error) { - if hcs.GetHeaderByHashCalled != nil { - return hcs.GetHeaderByHashCalled(hash) - } - return nil, nil -} - -// Clear - -func (hcs *HeadersCacherStub) Clear() { - if hcs.ClearCalled != nil { - hcs.ClearCalled() - } -} - -// RegisterHandler - -func (hcs *HeadersCacherStub) RegisterHandler(handler func(header data.HeaderHandler, shardHeaderHash []byte)) { - if hcs.RegisterHandlerCalled != nil { - hcs.RegisterHandlerCalled(handler) - } -} - -// Nonces - -func (hcs *HeadersCacherStub) Nonces(shardId uint32) []uint64 { - if hcs.NoncesCalled != nil { - return hcs.NoncesCalled(shardId) - } - return nil -} - -// Len - -func (hcs *HeadersCacherStub) Len() int { - return 0 -} - -// MaxSize - -func (hcs *HeadersCacherStub) MaxSize() int { - return 100 -} - -// IsInterfaceNil - -func (hcs *HeadersCacherStub) IsInterfaceNil() bool { - return hcs == nil -} - -// GetNumHeaders - -func (hcs *HeadersCacherStub) GetNumHeaders(shardId uint32) int { - if hcs.GetNumHeadersCalled != nil { - return hcs.GetNumHeadersCalled(shardId) - } - - return 0 -} diff --git a/consensus/mock/watchdogMock.go b/consensus/mock/watchdogMock.go index 15a153f50a0..1c026b4e8c4 100644 --- a/consensus/mock/watchdogMock.go +++ b/consensus/mock/watchdogMock.go @@ -6,10 +6,15 @@ import ( // WatchdogMock - type WatchdogMock struct { + SetCalled func(callback func(alarmID string), duration time.Duration, alarmID string) + StopCalled func(alarmID string) } // Set - func (w *WatchdogMock) Set(callback func(alarmID string), duration time.Duration, alarmID string) { + if w.SetCalled != nil { + w.SetCalled(callback, duration, alarmID) + } } // SetDefault - @@ -18,6 +23,9 @@ func (w *WatchdogMock) SetDefault(duration time.Duration, alarmID string) { // Stop - func (w *WatchdogMock) Stop(alarmID string) { + if w.StopCalled != nil { + w.StopCalled(alarmID) + } } // Reset - diff --git a/consensus/round/round_test.go b/consensus/round/round_test.go index ede509d7176..ec1f08ec82d 100644 --- a/consensus/round/round_test.go +++ b/consensus/round/round_test.go @@ -5,8 +5,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/round" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/assert" ) @@ -28,7 +30,7 @@ func TestRound_NewRoundShouldWork(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, err := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) @@ -41,7 +43,7 @@ func TestRound_UpdateRoundShouldNotChangeAnything(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) oldIndex := rnd.Index() @@ -61,7 +63,7 @@ func TestRound_UpdateRoundShouldAdvanceOneRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) oldIndex := rnd.Index() @@ -76,7 +78,7 @@ func TestRound_IndexShouldReturnFirstIndex(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) rnd.UpdateRound(genesisTime, genesisTime.Add(roundTimeDuration/2)) @@ -90,7 +92,7 @@ func TestRound_TimeStampShouldReturnTimeStampOfTheNextRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) rnd.UpdateRound(genesisTime, genesisTime.Add(roundTimeDuration+roundTimeDuration/2)) @@ -104,7 +106,7 @@ func TestRound_TimeDurationShouldReturnTheDurationOfOneRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) timeDuration := rnd.TimeDuration() @@ -117,7 +119,7 @@ func TestRound_RemainingTimeInCurrentRoundShouldReturnPositiveValue(t *testing.T genesisTime := time.Unix(0, 0) - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} timeElapsed := int64(roundTimeDuration - 1) @@ -138,7 +140,7 @@ func TestRound_RemainingTimeInCurrentRoundShouldReturnNegativeValue(t *testing.T genesisTime := time.Unix(0, 0) - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} timeElapsed := int64(roundTimeDuration + 1) diff --git a/consensus/spos/bls/blsWorker.go b/consensus/spos/bls/blsWorker.go index 456d4e8b1d8..b8ceffe9122 100644 --- a/consensus/spos/bls/blsWorker.go +++ b/consensus/spos/bls/blsWorker.go @@ -5,7 +5,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/spos" ) -// peerMaxMessagesPerSec defines how many messages can be propagated by a pid in a round. The value was chosen by +// PeerMaxMessagesPerSec defines how many messages can be propagated by a pid in a round. The value was chosen by // following the next premises: // 1. a leader can propagate as maximum as 3 messages per round: proposed header block + proposed body + final info; // 2. due to the fact that a delayed signature of the proposer (from previous round) can be received in the current round @@ -16,15 +16,15 @@ import ( // // Validators only send one signature message in a round, treating the edge case of a delayed message, will need at most // 2 messages per round (which is ok as it is below the set value of 5) -const peerMaxMessagesPerSec = uint32(6) +const PeerMaxMessagesPerSec = uint32(6) -// defaultMaxNumOfMessageTypeAccepted represents the maximum number of the same message type accepted in one round to be +// DefaultMaxNumOfMessageTypeAccepted represents the maximum number of the same message type accepted in one round to be // received from the same public key for the default message types -const defaultMaxNumOfMessageTypeAccepted = uint32(1) +const DefaultMaxNumOfMessageTypeAccepted = uint32(1) -// maxNumOfMessageTypeSignatureAccepted represents the maximum number of the signature message type accepted in one round to be +// MaxNumOfMessageTypeSignatureAccepted represents the maximum number of the signature message type accepted in one round to be // received from the same public key -const maxNumOfMessageTypeSignatureAccepted = uint32(2) +const MaxNumOfMessageTypeSignatureAccepted = uint32(2) // worker defines the data needed by spos to communicate between nodes which are in the validators group type worker struct { @@ -52,17 +52,17 @@ func (wrk *worker) InitReceivedMessages() map[consensus.MessageType][]*consensus // GetMaxMessagesInARoundPerPeer returns the maximum number of messages a peer can send per round for BLS func (wrk *worker) GetMaxMessagesInARoundPerPeer() uint32 { - return peerMaxMessagesPerSec + return PeerMaxMessagesPerSec } // GetStringValue gets the name of the messageType func (wrk *worker) GetStringValue(messageType consensus.MessageType) string { - return getStringValue(messageType) + return GetStringValue(messageType) } // GetSubroundName gets the subround name for the subround id provided func (wrk *worker) GetSubroundName(subroundId int) string { - return getSubroundName(subroundId) + return GetSubroundName(subroundId) } // IsMessageWithBlockBodyAndHeader returns if the current messageType is about block body and header @@ -151,10 +151,10 @@ func (wrk *worker) CanProceed(consensusState *spos.ConsensusState, msgType conse // GetMaxNumOfMessageTypeAccepted returns the maximum number of accepted consensus message types per round, per public key func (wrk *worker) GetMaxNumOfMessageTypeAccepted(msgType consensus.MessageType) uint32 { if msgType == MtSignature { - return maxNumOfMessageTypeSignatureAccepted + return MaxNumOfMessageTypeSignatureAccepted } - return defaultMaxNumOfMessageTypeAccepted + return DefaultMaxNumOfMessageTypeAccepted } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/spos/bls/blsWorker_test.go b/consensus/spos/bls/blsWorker_test.go index 6786b96cde8..8d39b02e5f1 100644 --- a/consensus/spos/bls/blsWorker_test.go +++ b/consensus/spos/bls/blsWorker_test.go @@ -4,68 +4,14 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" - "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" ) -func createEligibleList(size int) []string { - eligibleList := make([]string, 0) - for i := 0; i < size; i++ { - eligibleList = append(eligibleList, string([]byte{byte(i + 65)})) - } - return eligibleList -} - -func initConsensusState() *spos.ConsensusState { - return initConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) -} - -func initConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { - consensusGroupSize := 9 - eligibleList := createEligibleList(consensusGroupSize) - - eligibleNodesPubKeys := make(map[string]struct{}) - for _, key := range eligibleList { - eligibleNodesPubKeys[key] = struct{}{} - } - - indexLeader := 1 - rcns, _ := spos.NewRoundConsensus( - eligibleNodesPubKeys, - consensusGroupSize, - eligibleList[indexLeader], - keysHandler, - ) - - rcns.SetConsensusGroup(eligibleList) - rcns.ResetRoundState() - - pBFTThreshold := consensusGroupSize*2/3 + 1 - pBFTFallbackThreshold := consensusGroupSize*1/2 + 1 - - rthr := spos.NewRoundThreshold() - rthr.SetThreshold(1, 1) - rthr.SetThreshold(2, pBFTThreshold) - rthr.SetFallbackThreshold(1, 1) - rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) - - rstatus := spos.NewRoundStatus() - rstatus.ResetRoundStatus() - - cns := spos.NewConsensusState( - rcns, - rthr, - rstatus, - ) - - cns.Data = []byte("X") - cns.RoundIndex = 0 - return cns -} - func TestWorker_NewConsensusServiceShouldWork(t *testing.T) { t.Parallel() @@ -121,7 +67,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockBodyAndHeaderShouldW blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBodyAndHeader) @@ -133,7 +79,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockBodyAndHeaderShou blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBodyAndHeader) @@ -145,7 +91,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockBodyShouldWork(t *te blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBody) @@ -157,7 +103,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockBodyShouldNotWork blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBody) @@ -169,7 +115,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockHeaderShouldWork(t * blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeader) @@ -181,7 +127,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockHeaderShouldNotWo blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeader) @@ -193,7 +139,7 @@ func TestWorker_CanProceedWithSrBlockFinishedForMtBlockHeaderShouldWork(t *testi blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrBlock, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtSignature) @@ -205,7 +151,7 @@ func TestWorker_CanProceedWithSrBlockRoundNotFinishedForMtBlockHeaderShouldNotWo blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrBlock, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtSignature) @@ -217,7 +163,7 @@ func TestWorker_CanProceedWithSrSignatureFinishedForMtBlockHeaderFinalInfoShould blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrSignature, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeaderFinalInfo) @@ -229,7 +175,7 @@ func TestWorker_CanProceedWithSrSignatureRoundNotFinishedForMtBlockHeaderFinalIn blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrSignature, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeaderFinalInfo) @@ -240,7 +186,7 @@ func TestWorker_CanProceedWitUnkownMessageTypeShouldNotWork(t *testing.T) { t.Parallel() blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() canProceed := blsService.CanProceed(consensusState, -1) assert.False(t, canProceed) diff --git a/consensus/spos/bls/constants.go b/consensus/spos/bls/constants.go index 166abe70b65..88667da3003 100644 --- a/consensus/spos/bls/constants.go +++ b/consensus/spos/bls/constants.go @@ -2,11 +2,8 @@ package bls import ( "github.com/multiversx/mx-chain-go/consensus" - logger "github.com/multiversx/mx-chain-logger-go" ) -var log = logger.GetOrCreate("consensus/spos/bls") - const ( // SrStartRound defines ID of Subround "Start round" SrStartRound = iota @@ -36,36 +33,6 @@ const ( MtInvalidSigners ) -// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature -const waitingAllSigsMaxTimeThreshold = 0.5 - -// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round -const processingThresholdPercent = 85 - -// srStartStartTime specifies the start time, from the total time of the round, of Subround Start -const srStartStartTime = 0.0 - -// srEndStartTime specifies the end time, from the total time of the round, of Subround Start -const srStartEndTime = 0.05 - -// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block -const srBlockStartTime = 0.05 - -// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block -const srBlockEndTime = 0.25 - -// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature -const srSignatureStartTime = 0.25 - -// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature -const srSignatureEndTime = 0.85 - -// srEndStartTime specifies the start time, from the total time of the round, of Subround End -const srEndStartTime = 0.85 - -// srEndEndTime specifies the end time, from the total time of the round, of Subround End -const srEndEndTime = 0.95 - const ( // BlockBodyAndHeaderStringValue represents the string to be used to identify a block body and a block header BlockBodyAndHeaderStringValue = "(BLOCK_BODY_AND_HEADER)" @@ -89,7 +56,8 @@ const ( BlockDefaultStringValue = "Undefined message type" ) -func getStringValue(msgType consensus.MessageType) string { +// GetStringValue returns the string value of a given MessageType +func GetStringValue(msgType consensus.MessageType) string { switch msgType { case MtBlockBodyAndHeader: return BlockBodyAndHeaderStringValue @@ -108,8 +76,8 @@ func getStringValue(msgType consensus.MessageType) string { } } -// getSubroundName returns the name of each Subround from a given Subround ID -func getSubroundName(subroundId int) string { +// GetSubroundName returns the name of each Subround from a given Subround ID +func GetSubroundName(subroundId int) string { switch subroundId { case SrStartRound: return "(START_ROUND)" diff --git a/consensus/spos/bls/proxy/errors.go b/consensus/spos/bls/proxy/errors.go new file mode 100644 index 00000000000..4036ecf1c63 --- /dev/null +++ b/consensus/spos/bls/proxy/errors.go @@ -0,0 +1,38 @@ +package proxy + +import ( + "errors" +) + +// ErrNilChronologyHandler is the error returned when the chronology handler is nil +var ErrNilChronologyHandler = errors.New("nil chronology handler") + +// ErrNilConsensusCoreHandler is the error returned when the consensus core handler is nil +var ErrNilConsensusCoreHandler = errors.New("nil consensus core handler") + +// ErrNilConsensusState is the error returned when the consensus state is nil +var ErrNilConsensusState = errors.New("nil consensus state") + +// ErrNilWorker is the error returned when the worker is nil +var ErrNilWorker = errors.New("nil worker") + +// ErrNilSignatureThrottler is the error returned when the signature throttler is nil +var ErrNilSignatureThrottler = errors.New("nil signature throttler") + +// ErrNilAppStatusHandler is the error returned when the app status handler is nil +var ErrNilAppStatusHandler = errors.New("nil app status handler") + +// ErrNilOutportHandler is the error returned when the outport handler is nil +var ErrNilOutportHandler = errors.New("nil outport handler") + +// ErrNilSentSignatureTracker is the error returned when the sent signature tracker is nil +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrNilChainID is the error returned when the chain ID is nil +var ErrNilChainID = errors.New("nil chain ID") + +// ErrNilCurrentPid is the error returned when the current PID is nil +var ErrNilCurrentPid = errors.New("nil current PID") + +// ErrNilEnableEpochsHandler is the error returned when the enable epochs handler is nil +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/consensus/spos/bls/proxy/subroundsHandler.go b/consensus/spos/bls/proxy/subroundsHandler.go new file mode 100644 index 00000000000..2b284db5144 --- /dev/null +++ b/consensus/spos/bls/proxy/subroundsHandler.go @@ -0,0 +1,217 @@ +package proxy + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/factory" + "github.com/multiversx/mx-chain-go/outport" +) + +var log = logger.GetOrCreate("consensus/spos/bls/proxy") + +// SubroundsHandlerArgs struct contains the needed data for the SubroundsHandler +type SubroundsHandlerArgs struct { + Chronology consensus.ChronologyHandler + ConsensusCoreHandler spos.ConsensusCoreHandler + ConsensusState spos.ConsensusStateHandler + Worker factory.ConsensusWorker + SignatureThrottler core.Throttler + AppStatusHandler core.AppStatusHandler + OutportHandler outport.OutportHandler + SentSignatureTracker spos.SentSignaturesTracker + EnableEpochsHandler core.EnableEpochsHandler + ChainID []byte + CurrentPid core.PeerID +} + +// subroundsFactory defines the methods needed to generate the subrounds +type subroundsFactory interface { + GenerateSubrounds() error + SetOutportHandler(driver outport.OutportHandler) + IsInterfaceNil() bool +} + +type consensusStateMachineType int + +// SubroundsHandler struct contains the needed data for the SubroundsHandler +type SubroundsHandler struct { + chronology consensus.ChronologyHandler + consensusCoreHandler spos.ConsensusCoreHandler + consensusState spos.ConsensusStateHandler + worker factory.ConsensusWorker + signatureThrottler core.Throttler + appStatusHandler core.AppStatusHandler + outportHandler outport.OutportHandler + sentSignatureTracker spos.SentSignaturesTracker + enableEpochsHandler core.EnableEpochsHandler + chainID []byte + currentPid core.PeerID + currentConsensusType consensusStateMachineType +} + +const ( + consensusNone consensusStateMachineType = iota + consensusV1 + consensusV2 +) + +// NewSubroundsHandler creates a new SubroundsHandler object +func NewSubroundsHandler(args *SubroundsHandlerArgs) (*SubroundsHandler, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + subroundHandler := &SubroundsHandler{ + chronology: args.Chronology, + consensusCoreHandler: args.ConsensusCoreHandler, + consensusState: args.ConsensusState, + worker: args.Worker, + signatureThrottler: args.SignatureThrottler, + appStatusHandler: args.AppStatusHandler, + outportHandler: args.OutportHandler, + sentSignatureTracker: args.SentSignatureTracker, + enableEpochsHandler: args.EnableEpochsHandler, + chainID: args.ChainID, + currentPid: args.CurrentPid, + currentConsensusType: consensusNone, + } + + subroundHandler.consensusCoreHandler.EpochStartRegistrationHandler().RegisterHandler(subroundHandler) + + return subroundHandler, nil +} + +func checkArgs(args *SubroundsHandlerArgs) error { + if check.IfNil(args.Chronology) { + return ErrNilChronologyHandler + } + if check.IfNil(args.ConsensusCoreHandler) { + return ErrNilConsensusCoreHandler + } + if check.IfNil(args.ConsensusState) { + return ErrNilConsensusState + } + if check.IfNil(args.Worker) { + return ErrNilWorker + } + if check.IfNil(args.SignatureThrottler) { + return ErrNilSignatureThrottler + } + if check.IfNil(args.AppStatusHandler) { + return ErrNilAppStatusHandler + } + if check.IfNil(args.OutportHandler) { + return ErrNilOutportHandler + } + if check.IfNil(args.SentSignatureTracker) { + return ErrNilSentSignatureTracker + } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + if args.ChainID == nil { + return ErrNilChainID + } + if len(args.CurrentPid) == 0 { + return ErrNilCurrentPid + } + // outport handler can be nil if not configured so no need to check it + + return nil +} + +// Start starts the sub-rounds handler +func (s *SubroundsHandler) Start(epoch uint32) error { + return s.initSubroundsForEpoch(epoch) +} + +func (s *SubroundsHandler) initSubroundsForEpoch(epoch uint32) error { + var err error + var fct subroundsFactory + if s.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, epoch) { + if s.currentConsensusType == consensusV2 { + return nil + } + + s.currentConsensusType = consensusV2 + fct, err = v2.NewSubroundsFactory( + s.consensusCoreHandler, + s.consensusState, + s.worker, + s.chainID, + s.currentPid, + s.appStatusHandler, + s.sentSignatureTracker, + s.signatureThrottler, + s.outportHandler, + ) + } else { + if s.currentConsensusType == consensusV1 { + return nil + } + + s.currentConsensusType = consensusV1 + fct, err = v1.NewSubroundsFactory( + s.consensusCoreHandler, + s.consensusState, + s.worker, + s.chainID, + s.currentPid, + s.appStatusHandler, + s.sentSignatureTracker, + s.outportHandler, + ) + } + if err != nil { + return err + } + + err = s.chronology.Close() + if err != nil { + log.Warn("SubroundsHandler.initSubroundsForEpoch: cannot close the chronology", "error", err) + } + + err = fct.GenerateSubrounds() + if err != nil { + return err + } + + s.chronology.StartRounds() + return nil +} + +// EpochStartAction is called when the epoch starts +func (s *SubroundsHandler) EpochStartAction(hdr data.HeaderHandler) { + if check.IfNil(hdr) { + log.Error("SubroundsHandler.EpochStartAction: nil header") + return + } + + err := s.initSubroundsForEpoch(hdr.GetEpoch()) + if err != nil { + log.Error("SubroundsHandler.EpochStartAction: cannot initialize subrounds", "error", err) + } +} + +// EpochStartPrepare prepares the subrounds handler for the epoch start +func (s *SubroundsHandler) EpochStartPrepare(_ data.HeaderHandler, _ data.BodyHandler) { +} + +// NotifyOrder returns the order of the subrounds handler +func (s *SubroundsHandler) NotifyOrder() uint32 { + return common.ConsensusHandlerOrder +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *SubroundsHandler) IsInterfaceNil() bool { + return s == nil +} diff --git a/consensus/spos/bls/proxy/subroundsHandler_test.go b/consensus/spos/bls/proxy/subroundsHandler_test.go new file mode 100644 index 00000000000..403dc2c7826 --- /dev/null +++ b/consensus/spos/bls/proxy/subroundsHandler_test.go @@ -0,0 +1,468 @@ +package proxy + +import ( + "sync/atomic" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + + chainCommon "github.com/multiversx/mx-chain-go/common" + mock2 "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + mock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + outportStub "github.com/multiversx/mx-chain-go/testscommon/outport" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func getDefaultArgumentsSubroundHandler() (*SubroundsHandlerArgs, *consensus.ConsensusCoreMock) { + x := make(chan bool) + chronology := &consensus.ChronologyHandlerMock{} + epochsEnable := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + epochStartNotifier := &mock.EpochStartNotifierStub{} + consensusState := &consensus.ConsensusStateMock{} + worker := &consensus.SposWorkerMock{ + RemoveAllReceivedMessagesCallsCalled: func() {}, + GetConsensusStateChangedChannelsCalled: func() chan bool { + return x + }, + } + antiFloodHandler := &mock2.P2PAntifloodHandlerStub{} + handlerArgs := &SubroundsHandlerArgs{ + Chronology: chronology, + ConsensusState: consensusState, + Worker: worker, + SignatureThrottler: &common.ThrottlerStub{}, + AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + OutportHandler: &outportStub.OutportStub{}, + SentSignatureTracker: &testscommon.SentSignatureTrackerStub{}, + EnableEpochsHandler: epochsEnable, + ChainID: []byte("chainID"), + CurrentPid: "peerID", + } + + consensusCore := &consensus.ConsensusCoreMock{} + consensusCore.SetEpochStartNotifier(epochStartNotifier) + consensusCore.SetBlockchain(&testscommon.ChainHandlerStub{}) + consensusCore.SetBlockProcessor(&testscommon.BlockProcessorStub{}) + consensusCore.SetBootStrapper(&bootstrapperStubs.BootstrapperStub{}) + consensusCore.SetBroadcastMessenger(&consensus.BroadcastMessengerMock{}) + consensusCore.SetChronology(chronology) + consensusCore.SetAntifloodHandler(antiFloodHandler) + consensusCore.SetHasher(&testscommon.HasherStub{}) + consensusCore.SetMarshalizer(&testscommon.MarshallerStub{}) + consensusCore.SetMultiSignerContainer(&cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultisignerMock{}, nil + }, + }) + consensusCore.SetRoundHandler(&consensus.RoundHandlerMock{}) + consensusCore.SetShardCoordinator(&testscommon.ShardsCoordinatorMock{}) + consensusCore.SetSyncTimer(&testscommon.SyncTimerStub{}) + consensusCore.SetValidatorGroupSelector(&shardingMocks.NodesCoordinatorMock{}) + consensusCore.SetPeerHonestyHandler(&testscommon.PeerHonestyHandlerStub{}) + consensusCore.SetHeaderSigVerifier(&consensus.HeaderSigVerifierMock{}) + consensusCore.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{}) + consensusCore.SetNodeRedundancyHandler(&mock2.NodeRedundancyHandlerStub{}) + consensusCore.SetScheduledProcessor(&consensus.ScheduledProcessorStub{}) + consensusCore.SetMessageSigningHandler(&mock2.MessageSigningHandlerStub{}) + consensusCore.SetPeerBlacklistHandler(&mock2.PeerBlacklistHandlerStub{}) + consensusCore.SetSigningHandler(&consensus.SigningHandlerStub{}) + consensusCore.SetEnableEpochsHandler(epochsEnable) + consensusCore.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{}) + handlerArgs.ConsensusCoreHandler = consensusCore + + return handlerArgs, consensusCore +} + +func TestNewSubroundsHandler(t *testing.T) { + t.Parallel() + + t.Run("nil chronology should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.Chronology = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilChronologyHandler, err) + require.Nil(t, sh) + }) + t.Run("nil consensus core should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ConsensusCoreHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilConsensusCoreHandler, err) + require.Nil(t, sh) + }) + t.Run("nil consensus state should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ConsensusState = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilConsensusState, err) + require.Nil(t, sh) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.Worker = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilWorker, err) + require.Nil(t, sh) + }) + t.Run("nil signature throttler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.SignatureThrottler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilSignatureThrottler, err) + require.Nil(t, sh) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.AppStatusHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilAppStatusHandler, err) + require.Nil(t, sh) + }) + t.Run("nil outport handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.OutportHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilOutportHandler, err) + require.Nil(t, sh) + }) + t.Run("nil sent signature tracker should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.SentSignatureTracker = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilSentSignatureTracker, err) + require.Nil(t, sh) + }) + t.Run("nil enable epochs handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.EnableEpochsHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilEnableEpochsHandler, err) + require.Nil(t, sh) + }) + t.Run("nil chain ID should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ChainID = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilChainID, err) + require.Nil(t, sh) + }) + t.Run("empty current PID should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.CurrentPid = "" + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilCurrentPid, err) + require.Nil(t, sh) + }) + t.Run("OK", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + }) +} + +func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { + t.Parallel() + + t.Run("equivalent messages not enabled, with previous consensus type not consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusNone + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) + t.Run("equivalent messages not enabled, with previous consensus type consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusV1 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(0), startCalled.Load()) + }) + t.Run("equivalent messages enabled, with previous consensus type consensusNone", func(t *testing.T) { + t.Parallel() + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusNone + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) + t.Run("equivalent messages enabled, with previous consensus type consensusV1", func(t *testing.T) { + t.Parallel() + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusV1 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) + t.Run("equivalent messages enabled, with previous consensus type consensusV2", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusV2 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(0), startCalled.Load()) + }) +} + +func TestSubroundsHandler_Start(t *testing.T) { + t.Parallel() + + // the Start is tested via initSubroundsForEpoch, adding one of the test cases here as well + t.Run("equivalent messages not enabled, with previous consensus type not consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + sh.currentConsensusType = consensusNone + + err = sh.Start(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) +} + +func TestSubroundsHandler_NotifyOrder(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + + order := sh.NotifyOrder() + require.Equal(t, uint32(chainCommon.ConsensusHandlerOrder), order) +} + +func TestSubroundsHandler_IsInterfaceNil(t *testing.T) { + t.Parallel() + + t.Run("nil handler", func(t *testing.T) { + t.Parallel() + + var sh *SubroundsHandler + require.True(t, sh.IsInterfaceNil()) + }) + t.Run("not nil handler", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + + require.False(t, sh.IsInterfaceNil()) + }) +} + +func TestSubroundsHandler_EpochStartAction(t *testing.T) { + t.Parallel() + + t.Run("nil handler does not panic", func(t *testing.T) { + t.Parallel() + + defer func() { + if r := recover(); r != nil { + t.Errorf("The code panicked") + } + }() + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + sh.EpochStartAction(&testscommon.HeaderHandlerStub{}) + }) + + // tested through initSubroundsForEpoch + t.Run("OK", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + + sh.currentConsensusType = consensusNone + sh.EpochStartAction(&testscommon.HeaderHandlerStub{}) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) +} diff --git a/consensus/spos/bls/blsSubroundsFactory.go b/consensus/spos/bls/v1/blsSubroundsFactory.go similarity index 80% rename from consensus/spos/bls/blsSubroundsFactory.go rename to consensus/spos/bls/v1/blsSubroundsFactory.go index aeb64a5775a..70915c5f30b 100644 --- a/consensus/spos/bls/blsSubroundsFactory.go +++ b/consensus/spos/bls/v1/blsSubroundsFactory.go @@ -1,11 +1,13 @@ -package bls +package v1 import ( "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/outport" ) @@ -13,7 +15,7 @@ import ( // functionality type factory struct { consensusCore spos.ConsensusCoreHandler - consensusState *spos.ConsensusState + consensusState spos.ConsensusStateHandler worker spos.WorkerHandler appStatusHandler core.AppStatusHandler @@ -26,13 +28,15 @@ type factory struct { // NewSubroundsFactory creates a new consensusState object func NewSubroundsFactory( consensusDataContainer spos.ConsensusCoreHandler, - consensusState *spos.ConsensusState, + consensusState spos.ConsensusStateHandler, worker spos.WorkerHandler, chainID []byte, currentPid core.PeerID, appStatusHandler core.AppStatusHandler, sentSignaturesTracker spos.SentSignaturesTracker, + outportHandler outport.OutportHandler, ) (*factory, error) { + // no need to check the outportHandler, it can be nil err := checkNewFactoryParams( consensusDataContainer, consensusState, @@ -53,6 +57,7 @@ func NewSubroundsFactory( chainID: chainID, currentPid: currentPid, sentSignaturesTracker: sentSignaturesTracker, + outportHandler: outportHandler, } return &fct, nil @@ -60,7 +65,7 @@ func NewSubroundsFactory( func checkNewFactoryParams( container spos.ConsensusCoreHandler, - state *spos.ConsensusState, + state spos.ConsensusStateHandler, worker spos.WorkerHandler, chainID []byte, appStatusHandler core.AppStatusHandler, @@ -70,7 +75,7 @@ func checkNewFactoryParams( if err != nil { return err } - if state == nil { + if check.IfNil(state) { return spos.ErrNilConsensusState } if check.IfNil(worker) { @@ -130,11 +135,11 @@ func (fct *factory) getTimeDuration() time.Duration { func (fct *factory) generateStartRoundSubround() error { subround, err := spos.NewSubround( -1, - SrStartRound, - SrBlock, + bls.SrStartRound, + bls.SrBlock, int64(float64(fct.getTimeDuration())*srStartStartTime), int64(float64(fct.getTimeDuration())*srStartEndTime), - getSubroundName(SrStartRound), + bls.GetSubroundName(bls.SrStartRound), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -171,12 +176,12 @@ func (fct *factory) generateStartRoundSubround() error { func (fct *factory) generateBlockSubround() error { subround, err := spos.NewSubround( - SrStartRound, - SrBlock, - SrSignature, + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, int64(float64(fct.getTimeDuration())*srBlockStartTime), int64(float64(fct.getTimeDuration())*srBlockEndTime), - getSubroundName(SrBlock), + bls.GetSubroundName(bls.SrBlock), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -198,9 +203,9 @@ func (fct *factory) generateBlockSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) - fct.worker.AddReceivedMessageCall(MtBlockBody, subroundBlockInstance.receivedBlockBody) - fct.worker.AddReceivedMessageCall(MtBlockHeader, subroundBlockInstance.receivedBlockHeader) + fct.worker.AddReceivedMessageCall(bls.MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) + fct.worker.AddReceivedMessageCall(bls.MtBlockBody, subroundBlockInstance.receivedBlockBody) + fct.worker.AddReceivedMessageCall(bls.MtBlockHeader, subroundBlockInstance.receivedBlockHeader) fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) return nil @@ -208,12 +213,12 @@ func (fct *factory) generateBlockSubround() error { func (fct *factory) generateSignatureSubround() error { subround, err := spos.NewSubround( - SrBlock, - SrSignature, - SrEndRound, + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, int64(float64(fct.getTimeDuration())*srSignatureStartTime), int64(float64(fct.getTimeDuration())*srSignatureEndTime), - getSubroundName(SrSignature), + bls.GetSubroundName(bls.SrSignature), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -236,7 +241,7 @@ func (fct *factory) generateSignatureSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtSignature, subroundSignatureObject.receivedSignature) + fct.worker.AddReceivedMessageCall(bls.MtSignature, subroundSignatureObject.receivedSignature) fct.consensusCore.Chronology().AddSubround(subroundSignatureObject) return nil @@ -244,12 +249,12 @@ func (fct *factory) generateSignatureSubround() error { func (fct *factory) generateEndRoundSubround() error { subround, err := spos.NewSubround( - SrSignature, - SrEndRound, + bls.SrSignature, + bls.SrEndRound, -1, int64(float64(fct.getTimeDuration())*srEndStartTime), int64(float64(fct.getTimeDuration())*srEndEndTime), - getSubroundName(SrEndRound), + bls.GetSubroundName(bls.SrEndRound), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -274,8 +279,8 @@ func (fct *factory) generateEndRoundSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtBlockHeaderFinalInfo, subroundEndRoundObject.receivedBlockHeaderFinalInfo) - fct.worker.AddReceivedMessageCall(MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) + fct.worker.AddReceivedMessageCall(bls.MtBlockHeaderFinalInfo, subroundEndRoundObject.receivedBlockHeaderFinalInfo) + fct.worker.AddReceivedMessageCall(bls.MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) fct.worker.AddReceivedHeaderHandler(subroundEndRoundObject.receivedHeader) fct.consensusCore.Chronology().AddSubround(subroundEndRoundObject) @@ -285,10 +290,10 @@ func (fct *factory) generateEndRoundSubround() error { func (fct *factory) initConsensusThreshold() { pBFTThreshold := core.GetPBFTThreshold(fct.consensusState.ConsensusGroupSize()) pBFTFallbackThreshold := core.GetPBFTFallbackThreshold(fct.consensusState.ConsensusGroupSize()) - fct.consensusState.SetThreshold(SrBlock, 1) - fct.consensusState.SetThreshold(SrSignature, pBFTThreshold) - fct.consensusState.SetFallbackThreshold(SrBlock, 1) - fct.consensusState.SetFallbackThreshold(SrSignature, pBFTFallbackThreshold) + fct.consensusState.SetThreshold(bls.SrBlock, 1) + fct.consensusState.SetThreshold(bls.SrSignature, pBFTThreshold) + fct.consensusState.SetFallbackThreshold(bls.SrBlock, 1) + fct.consensusState.SetFallbackThreshold(bls.SrSignature, pBFTFallbackThreshold) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/spos/bls/blsSubroundsFactory_test.go b/consensus/spos/bls/v1/blsSubroundsFactory_test.go similarity index 74% rename from consensus/spos/bls/blsSubroundsFactory_test.go rename to consensus/spos/bls/v1/blsSubroundsFactory_test.go index af3267a78cc..f057daae16f 100644 --- a/consensus/spos/bls/blsSubroundsFactory_test.go +++ b/consensus/spos/bls/v1/blsSubroundsFactory_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "context" @@ -8,15 +8,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" testscommonOutport "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) var chainID = []byte("chain ID") @@ -40,8 +43,8 @@ func executeStoredMessages() { func resetConsensusMessages() { } -func initRoundHandlerMock() *mock.RoundHandlerMock { - return &mock.RoundHandlerMock{ +func initRoundHandlerMock() *consensusMock.RoundHandlerMock { + return &consensusMock.RoundHandlerMock{ RoundIndex: 0, TimeStampCalled: func() time.Time { return time.Unix(0, 0) @@ -53,7 +56,7 @@ func initRoundHandlerMock() *mock.RoundHandlerMock { } func initWorker() spos.WorkerHandler { - sposWorker := &mock.SposWorkerMock{} + sposWorker := &consensusMock.SposWorkerMock{} sposWorker.GetConsensusStateChangedChannelsCalled = func() chan bool { return make(chan bool) } @@ -66,11 +69,11 @@ func initWorker() spos.WorkerHandler { return sposWorker } -func initFactoryWithContainer(container *mock.ConsensusCoreMock) bls.Factory { +func initFactoryWithContainer(container *consensusMock.ConsensusCoreMock) v1.Factory { worker := initWorker() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() - fct, _ := bls.NewSubroundsFactory( + fct, _ := v1.NewSubroundsFactory( container, consensusState, worker, @@ -78,13 +81,14 @@ func initFactoryWithContainer(container *mock.ConsensusCoreMock) bls.Factory { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) return fct } -func initFactory() bls.Factory { - container := mock.InitConsensusCore() +func initFactory() v1.Factory { + container := consensusMock.InitConsensusCore() return initFactoryWithContainer(container) } @@ -116,10 +120,10 @@ func TestFactory_GetMessageTypeName(t *testing.T) { func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( nil, consensusState, worker, @@ -127,6 +131,7 @@ func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -136,10 +141,10 @@ func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, nil, worker, @@ -147,6 +152,7 @@ func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -156,12 +162,12 @@ func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBlockchain(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -169,6 +175,7 @@ func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -178,12 +185,12 @@ func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBlockProcessor(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -191,6 +198,7 @@ func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -200,12 +208,12 @@ func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBootStrapper(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -213,6 +221,7 @@ func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -222,12 +231,12 @@ func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetChronology(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -235,6 +244,7 @@ func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -244,12 +254,12 @@ func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetHasher(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -257,6 +267,7 @@ func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -266,12 +277,12 @@ func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetMarshalizer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -279,6 +290,7 @@ func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -288,12 +300,12 @@ func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetMultiSignerContainer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -301,6 +313,7 @@ func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -310,12 +323,12 @@ func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetRoundHandler(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -323,6 +336,7 @@ func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -332,12 +346,12 @@ func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetShardCoordinator(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -345,6 +359,7 @@ func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -354,12 +369,12 @@ func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetSyncTimer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -367,6 +382,7 @@ func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -376,12 +392,12 @@ func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetValidatorGroupSelector(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -389,6 +405,7 @@ func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -398,10 +415,10 @@ func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, nil, @@ -409,6 +426,7 @@ func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -418,11 +436,11 @@ func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -430,6 +448,7 @@ func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { currentPid, nil, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -439,11 +458,11 @@ func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -451,10 +470,11 @@ func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, nil, + nil, ) assert.Nil(t, fct) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) } func TestFactory_NewFactoryShouldWork(t *testing.T) { @@ -468,11 +488,11 @@ func TestFactory_NewFactoryShouldWork(t *testing.T) { func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -480,6 +500,7 @@ func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -490,7 +511,7 @@ func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *test t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -502,7 +523,7 @@ func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *test func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundStartRoundFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -515,7 +536,7 @@ func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -527,7 +548,7 @@ func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundBlockFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -540,7 +561,7 @@ func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testi t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -552,7 +573,7 @@ func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testi func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundSignatureFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -565,7 +586,7 @@ func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testin t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -577,7 +598,7 @@ func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testin func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundEndRoundFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -591,11 +612,11 @@ func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { subroundHandlers := 0 - chrm := &mock.ChronologyHandlerMock{} + chrm := &consensusMock.ChronologyHandlerMock{} chrm.AddSubroundCalled = func(subroundHandler consensus.SubroundHandler) { subroundHandlers++ } - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() container.SetChronology(chrm) fct := *initFactoryWithContainer(container) fct.SetOutportHandler(&testscommonOutport.OutportStub{}) @@ -609,7 +630,7 @@ func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { func TestFactory_GenerateSubroundsNilOutportShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) err := fct.GenerateSubrounds() @@ -619,7 +640,7 @@ func TestFactory_GenerateSubroundsNilOutportShouldFail(t *testing.T) { func TestFactory_SetIndexerShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) outportHandler := &testscommonOutport.OutportStub{} diff --git a/consensus/spos/bls/v1/constants.go b/consensus/spos/bls/v1/constants.go new file mode 100644 index 00000000000..fc35333e15a --- /dev/null +++ b/consensus/spos/bls/v1/constants.go @@ -0,0 +1,37 @@ +package v1 + +import ( + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("consensus/spos/bls/v1") + +// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature +const waitingAllSigsMaxTimeThreshold = 0.5 + +// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round +const processingThresholdPercent = 85 + +// srStartStartTime specifies the start time, from the total time of the round, of Subround Start +const srStartStartTime = 0.0 + +// srEndStartTime specifies the end time, from the total time of the round, of Subround Start +const srStartEndTime = 0.05 + +// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block +const srBlockStartTime = 0.05 + +// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block +const srBlockEndTime = 0.25 + +// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature +const srSignatureStartTime = 0.25 + +// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature +const srSignatureEndTime = 0.85 + +// srEndStartTime specifies the start time, from the total time of the round, of Subround End +const srEndStartTime = 0.85 + +// srEndEndTime specifies the end time, from the total time of the round, of Subround End +const srEndEndTime = 0.95 diff --git a/consensus/spos/bls/errors.go b/consensus/spos/bls/v1/errors.go similarity index 93% rename from consensus/spos/bls/errors.go rename to consensus/spos/bls/v1/errors.go index b840f9e2c85..05c55b9592c 100644 --- a/consensus/spos/bls/errors.go +++ b/consensus/spos/bls/v1/errors.go @@ -1,4 +1,4 @@ -package bls +package v1 import "errors" diff --git a/consensus/spos/bls/export_test.go b/consensus/spos/bls/v1/export_test.go similarity index 94% rename from consensus/spos/bls/export_test.go rename to consensus/spos/bls/v1/export_test.go index 71d3cfc8348..4a386a57933 100644 --- a/consensus/spos/bls/export_test.go +++ b/consensus/spos/bls/v1/export_test.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" @@ -18,9 +19,8 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) +// ProcessingThresholdPercent exports the internal processingThresholdPercent const ProcessingThresholdPercent = processingThresholdPercent -const DefaultMaxNumOfMessageTypeAccepted = defaultMaxNumOfMessageTypeAccepted -const MaxNumOfMessageTypeSignatureAccepted = maxNumOfMessageTypeSignatureAccepted // factory @@ -48,7 +48,7 @@ func (fct *factory) ChronologyHandler() consensus.ChronologyHandler { } // ConsensusState gets the consensus state struct pointer -func (fct *factory) ConsensusState() *spos.ConsensusState { +func (fct *factory) ConsensusState() spos.ConsensusStateHandler { return fct.consensusState } @@ -129,8 +129,8 @@ func (fct *factory) Outport() outport.OutportHandler { // subroundStartRound -// SubroundStartRound defines a type for the subroundStartRound structure -type SubroundStartRound *subroundStartRound +// SubroundStartRound defines an alias to the subroundStartRound structure +type SubroundStartRound = *subroundStartRound // DoStartRoundJob method does the job of the subround StartRound func (sr *subroundStartRound) DoStartRoundJob() bool { @@ -160,7 +160,7 @@ func (sr *subroundStartRound) GetSentSignatureTracker() spos.SentSignaturesTrack // subroundBlock // SubroundBlock defines a type for the subroundBlock structure -type SubroundBlock *subroundBlock +type SubroundBlock = *subroundBlock // Blockchain gets the ChainHandler stored in the ConsensusCore func (sr *subroundBlock) BlockChain() data.ChainHandler { @@ -229,8 +229,8 @@ func (sr *subroundBlock) ReceivedBlockBodyAndHeader(cnsDta *consensus.Message) b // subroundSignature -// SubroundSignature defines a type for the subroundSignature structure -type SubroundSignature *subroundSignature +// SubroundSignature defines an alias for the subroundSignature structure +type SubroundSignature = *subroundSignature // DoSignatureJob method does the job of the subround Signature func (sr *subroundSignature) DoSignatureJob() bool { @@ -254,8 +254,8 @@ func (sr *subroundSignature) AreSignaturesCollected(threshold int) (bool, int) { // subroundEndRound -// SubroundEndRound defines a type for the subroundEndRound structure -type SubroundEndRound *subroundEndRound +// SubroundEndRound defines an alias for the subroundEndRound structure +type SubroundEndRound = *subroundEndRound // DoEndRoundJob method does the job of the subround EndRound func (sr *subroundEndRound) DoEndRoundJob() bool { @@ -351,8 +351,3 @@ func (sr *subroundEndRound) GetFullMessagesForInvalidSigners(invalidPubKeys []st func (sr *subroundEndRound) GetSentSignatureTracker() spos.SentSignaturesTracker { return sr.sentSignatureTracker } - -// GetStringValue calls the unexported getStringValue function -func GetStringValue(messageType consensus.MessageType) string { - return getStringValue(messageType) -} diff --git a/consensus/spos/bls/subroundBlock.go b/consensus/spos/bls/v1/subroundBlock.go similarity index 90% rename from consensus/spos/bls/subroundBlock.go rename to consensus/spos/bls/v1/subroundBlock.go index a83969721b8..504cb82a180 100644 --- a/consensus/spos/bls/subroundBlock.go +++ b/consensus/spos/bls/v1/subroundBlock.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -7,9 +7,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" ) // maxAllowedSizeInBytes defines how many bytes are allowed as payload in a message @@ -52,7 +54,7 @@ func checkNewSubroundBlockParams( return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -114,7 +116,7 @@ func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { // placeholder for subroundBlock.doBlockJob script - sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.RoundTimeStamp) + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.GetRoundTimeStamp()) return true } @@ -163,7 +165,7 @@ func (sr *subroundBlock) couldBeSentTogether(marshalizedBody []byte, marshalized } func (sr *subroundBlock) createBlock(header data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := time.Duration(sr.EndTime()) haveTimeInCurrentSubround := func() bool { return sr.RoundHandler().RemainingTime(startTime, maxTime) > 0 @@ -202,7 +204,7 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( marshalizedHeader, []byte(leader), nil, - int(MtBlockBodyAndHeader), + int(bls.MtBlockBodyAndHeader), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -222,9 +224,9 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( "nonce", headerHandler.GetNonce(), "hash", headerHash) - sr.Data = headerHash - sr.Body = bodyHandler - sr.Header = headerHandler + sr.SetData(headerHash) + sr.SetBody(bodyHandler) + sr.SetHeader(headerHandler) return true } @@ -244,7 +246,7 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized nil, []byte(leader), nil, - int(MtBlockBody), + int(bls.MtBlockBody), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -262,7 +264,7 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized log.Debug("step 1: block body has been sent") - sr.Body = bodyHandler + sr.SetBody(bodyHandler) return true } @@ -284,7 +286,7 @@ func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marsh marshalizedHeader, []byte(leader), nil, - int(MtBlockHeader), + int(bls.MtBlockHeader), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -304,8 +306,8 @@ func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marsh "nonce", headerHandler.GetNonce(), "hash", headerHash) - sr.Data = headerHash - sr.Header = headerHandler + sr.SetData(headerHash) + sr.SetHeader(headerHandler) return true } @@ -413,17 +415,22 @@ func (sr *subroundBlock) receivedBlockBodyAndHeader(ctx context.Context, cnsDta return false } - sr.Data = cnsDta.BlockHeaderHash - sr.Body = sr.BlockProcessor().DecodeBlockBody(cnsDta.Body) - sr.Header = sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + header := sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + if headerHasProof(header) { + return false + } + + sr.SetData(cnsDta.BlockHeaderHash) + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) + sr.SetHeader(header) - isInvalidData := check.IfNil(sr.Body) || sr.isInvalidHeaderOrData() + isInvalidData := check.IfNil(sr.GetBody()) || sr.isInvalidHeaderOrData() if isInvalidData { return false } log.Debug("step 1: block body and header have been received", - "nonce", sr.Header.GetNonce(), + "nonce", sr.GetHeader().GetNonce(), "hash", cnsDta.BlockHeaderHash) sw.Start("processReceivedBlock") @@ -440,7 +447,7 @@ func (sr *subroundBlock) receivedBlockBodyAndHeader(ctx context.Context, cnsDta } func (sr *subroundBlock) isInvalidHeaderOrData() bool { - return sr.Data == nil || check.IfNil(sr.Header) || sr.Header.CheckFieldsForNil() != nil + return sr.GetData() == nil || check.IfNil(sr.GetHeader()) || sr.GetHeader().CheckFieldsForNil() != nil } // receivedBlockBody method is called when a block body is received through the block body channel @@ -465,9 +472,9 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return false } - sr.Body = sr.BlockProcessor().DecodeBlockBody(cnsDta.Body) + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) - if check.IfNil(sr.Body) { + if check.IfNil(sr.GetBody()) { return false } @@ -512,15 +519,20 @@ func (sr *subroundBlock) receivedBlockHeader(ctx context.Context, cnsDta *consen return false } - sr.Data = cnsDta.BlockHeaderHash - sr.Header = sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + header := sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + if headerHasProof(header) { + return false + } + + sr.SetData(cnsDta.BlockHeaderHash) + sr.SetHeader(header) if sr.isInvalidHeaderOrData() { return false } log.Debug("step 1: block header has been received", - "nonce", sr.Header.GetNonce(), + "nonce", sr.GetHeader().GetNonce(), "hash", cnsDta.BlockHeaderHash) blockProcessedWithSuccess := sr.processReceivedBlock(ctx, cnsDta) @@ -533,11 +545,18 @@ func (sr *subroundBlock) receivedBlockHeader(ctx context.Context, cnsDta *consen return blockProcessedWithSuccess } +func headerHasProof(headerHandler data.HeaderHandler) bool { + if check.IfNil(headerHandler) { + return false + } + return !check.IfNilReflect(headerHandler.GetPreviousProof()) +} + func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *consensus.Message) bool { - if check.IfNil(sr.Body) { + if check.IfNil(sr.GetBody()) { return false } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false } @@ -547,20 +566,20 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse sr.SetProcessingBlock(true) - shouldNotProcessBlock := sr.ExtendedCalled || cnsDta.RoundIndex < sr.RoundHandler().Index() + shouldNotProcessBlock := sr.GetExtendedCalled() || cnsDta.RoundIndex < sr.RoundHandler().Index() if shouldNotProcessBlock { log.Debug("canceled round, extended has been called or round index has been changed", "round", sr.RoundHandler().Index(), "subround", sr.Name(), "cnsDta round", cnsDta.RoundIndex, - "extended called", sr.ExtendedCalled, + "extended called", sr.GetExtendedCalled(), ) return false } node := string(cnsDta.PubKey) - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 remainingTimeInCurrentRound := func() time.Duration { return sr.RoundHandler().RemainingTime(startTime, maxTime) @@ -570,8 +589,8 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricProcessedProposedBlock) err := sr.BlockProcessor().ProcessBlock( - sr.Header, - sr.Body, + sr.GetHeader(), + sr.GetBody(), remainingTimeInCurrentRound, ) @@ -586,7 +605,7 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse if err != nil { sr.printCancelRoundLogMessage(ctx, err) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -597,7 +616,7 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse return false } - sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.Header, sr.Body, sr.RoundTimeStamp) + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.GetHeader(), sr.GetBody(), sr.GetRoundTimeStamp()) return true } @@ -627,7 +646,7 @@ func (sr *subroundBlock) computeSubroundProcessingMetric(startTime time.Time, me // doBlockConsensusCheck method checks if the consensus in the subround Block is achieved func (sr *subroundBlock) doBlockConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } diff --git a/consensus/spos/bls/subroundBlock_test.go b/consensus/spos/bls/v1/subroundBlock_test.go similarity index 74% rename from consensus/spos/bls/subroundBlock_test.go rename to consensus/spos/bls/v1/subroundBlock_test.go index 2354ab92b11..e0d4690021d 100644 --- a/consensus/spos/bls/subroundBlock_test.go +++ b/consensus/spos/bls/v1/subroundBlock_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "errors" @@ -10,19 +10,23 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func defaultSubroundForSRBlock(consensusState *spos.ConsensusState, ch chan bool, - container *mock.ConsensusCoreMock, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { + container *consensusMock.ConsensusCoreMock, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { return spos.NewSubround( bls.SrStartRound, bls.SrBlock, @@ -55,21 +59,21 @@ func createDefaultHeader() *block.Header { } } -func defaultSubroundBlockFromSubround(sr *spos.Subround) (bls.SubroundBlock, error) { - srBlock, err := bls.NewSubroundBlock( +func defaultSubroundBlockFromSubround(sr *spos.Subround) (v1.SubroundBlock, error) { + srBlock, err := v1.NewSubroundBlock( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) return srBlock, err } -func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) bls.SubroundBlock { - srBlock, _ := bls.NewSubroundBlock( +func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) v1.SubroundBlock { + srBlock, _ := v1.NewSubroundBlock( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) return srBlock @@ -77,9 +81,9 @@ func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) bls.Subroun func initSubroundBlock( blockChain data.ChainHandler, - container *mock.ConsensusCoreMock, + container *consensusMock.ConsensusCoreMock, appStatusHandler core.AppStatusHandler, -) bls.SubroundBlock { +) v1.SubroundBlock { if blockChain == nil { blockChain = &testscommon.ChainHandlerStub{ GetCurrentBlockHeaderCalled: func() data.HeaderHandler { @@ -98,7 +102,7 @@ func initSubroundBlock( } } - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) container.SetBlockchain(blockChain) @@ -108,19 +112,19 @@ func initSubroundBlock( return srBlock } -func createConsensusContainers() []*mock.ConsensusCoreMock { - consensusContainers := make([]*mock.ConsensusCoreMock, 0) - container := mock.InitConsensusCore() +func createConsensusContainers() []*consensusMock.ConsensusCoreMock { + consensusContainers := make([]*consensusMock.ConsensusCoreMock, 0) + container := consensusMock.InitConsensusCore() consensusContainers = append(consensusContainers, container) - container = mock.InitConsensusCoreHeaderV2() + container = consensusMock.InitConsensusCoreHeaderV2() consensusContainers = append(consensusContainers, container) return consensusContainers } func initSubroundBlockWithBlockProcessor( bp *testscommon.BlockProcessorStub, - container *mock.ConsensusCoreMock, -) bls.SubroundBlock { + container *consensusMock.ConsensusCoreMock, +) v1.SubroundBlock { blockChain := &testscommon.ChainHandlerStub{ GetGenesisHeaderCalled: func() data.HeaderHandler { return &block.Header{ @@ -136,7 +140,7 @@ func initSubroundBlockWithBlockProcessor( container.SetBlockchain(blockChain) container.SetBlockProcessor(blockProcessorMock) - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -147,10 +151,10 @@ func initSubroundBlockWithBlockProcessor( func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { t.Parallel() - srBlock, err := bls.NewSubroundBlock( + srBlock, err := v1.NewSubroundBlock( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) assert.Nil(t, srBlock) assert.Equal(t, spos.ErrNilSubround, err) @@ -158,9 +162,9 @@ func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -174,9 +178,9 @@ func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -190,12 +194,12 @@ func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMock.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) - sr.ConsensusState = nil + sr.ConsensusStateHandler = nil srBlock, err := defaultSubroundBlockFromSubround(sr) assert.Nil(t, srBlock) @@ -204,9 +208,9 @@ func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -219,9 +223,9 @@ func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -234,9 +238,9 @@ func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -249,9 +253,9 @@ func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *test func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -264,9 +268,9 @@ func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -279,9 +283,9 @@ func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing. func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -294,9 +298,9 @@ func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) srBlock, err := defaultSubroundBlockFromSubround(sr) @@ -306,12 +310,12 @@ func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { func TestSubroundBlock_DoBlockJob(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) r := sr.DoBlockJob() assert.False(t, r) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetSelfPubKey(sr.Leader()) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) r = sr.DoBlockJob() assert.False(t, r) @@ -331,34 +335,34 @@ func TestSubroundBlock_DoBlockJob(t *testing.T) { r = sr.DoBlockJob() assert.False(t, r) - bpm = mock.InitBlockProcessorMock(container.Marshalizer()) + bpm = consensusMock.InitBlockProcessorMock(container.Marshalizer()) container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMock.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { return nil }, } container.SetBroadcastMessenger(bm) - container.SetRoundHandler(&mock.RoundHandlerMock{ + container.SetRoundHandler(&consensusMock.RoundHandlerMock{ RoundIndex: 1, }) r = sr.DoBlockJob() assert.True(t, r) - assert.Equal(t, uint64(1), sr.Header.GetNonce()) + assert.Equal(t, uint64(1), sr.GetHeader().GetNonce()) } func TestSubroundBlock_ReceivedBlockBodyAndHeaderDataAlreadySet(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = []byte("some data") + sr.SetData([]byte("some data")) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -366,15 +370,15 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderDataAlreadySet(t *testing.T) { func TestSubroundBlock_ReceivedBlockBodyAndHeaderNodeNotLeaderInCurrentRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[1]), bls.MtBlockBodyAndHeader) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -382,16 +386,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderNodeNotLeaderInCurrentRound(t * func TestSubroundBlock_ReceivedBlockBodyAndHeaderCannotProcessJobDone(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrBlock, true) + sr.SetData(nil) + _ = sr.SetJobDone(sr.Leader(), bls.SrBlock, true) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -400,22 +404,22 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderCannotProcessJobDone(t *testing func TestSubroundBlock_ReceivedBlockBodyAndHeaderErrorDecoding(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - blProc := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + blProc := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blProc.DecodeBlockHeaderCalled = func(dta []byte) data.HeaderHandler { // error decoding so return nil return nil } container.SetBlockProcessor(blProc) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -424,16 +428,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderErrorDecoding(t *testing.T) { func TestSubroundBlock_ReceivedBlockBodyAndHeaderBodyAlreadyReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - sr.Body = &block.Body{} + sr.SetData(nil) + sr.SetBody(&block.Body{}) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -442,16 +446,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderBodyAlreadyReceived(t *testing. func TestSubroundBlock_ReceivedBlockBodyAndHeaderHeaderAlreadyReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - sr.Header = &block.Header{Nonce: 1} + sr.SetData(nil) + sr.SetHeader(&block.Header{Nonce: 1}) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -459,14 +463,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderHeaderAlreadyReceived(t *testin func TestSubroundBlock_ReceivedBlockBodyAndHeaderOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) t.Run("block is valid", func(t *testing.T) { hdr := createDefaultHeader() blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) - sr.Data = nil + leader, err := sr.GetLeader() + require.Nil(t, err) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(leader), bls.MtBlockBodyAndHeader) + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.True(t, r) }) @@ -475,15 +481,17 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderOK(t *testing.T) { Nonce: 1, } blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) - sr.Data = nil + leader, err := sr.GetLeader() + require.Nil(t, err) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(leader), bls.MtBlockBodyAndHeader) + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) }) } func createConsensusMessage(header *block.Header, body *block.Body, leader []byte, topic consensus.MessageType) *consensus.Message { - marshaller := &mock.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} hdrStr, _ := marshaller.Marshal(header) @@ -510,17 +518,19 @@ func createConsensusMessage(header *block.Header, body *block.Body, leader []byt func TestSubroundBlock_ReceivedBlock(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - blockProcessorMock := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blockProcessorMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -531,11 +541,11 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { currentPid, nil, ) - sr.Body = &block.Body{} + sr.SetBody(&block.Body{}) r := sr.ReceivedBlockBody(cnsMsg) assert.False(t, r) - sr.Body = nil + sr.SetBody(nil) cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) r = sr.ReceivedBlockBody(cnsMsg) assert.False(t, r) @@ -558,7 +568,7 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { nil, nil, hdrStr, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockHeader), 0, @@ -572,12 +582,12 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) - sr.Data = nil - sr.Header = hdr + sr.SetData(nil) + sr.SetHeader(hdr) r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) - sr.Header = nil + sr.SetHeader(nil) cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) @@ -589,11 +599,11 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { sr.SetStatus(bls.SrBlock, spos.SsNotFinished) container.SetBlockProcessor(blockProcessorMock) - sr.Data = nil - sr.Header = nil + sr.SetData(nil) + sr.SetHeader(nil) hdr = createDefaultHeader() hdr.Nonce = 1 - hdrStr, _ = mock.MarshalizerMock{}.Marshal(hdr) + hdrStr, _ = marshallerMock.MarshalizerMock{}.Marshal(hdr) hdrHash = (&hashingMocks.HasherMock{}).Compute(string(hdrStr)) cnsMsg.BlockHeaderHash = hdrHash cnsMsg.Header = hdrStr @@ -603,14 +613,15 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAreNotSet(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, nil, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBodyAndHeader), 0, @@ -626,9 +637,9 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAre func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - blProcMock := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blProcMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) err := errors.New("error process block") blProcMock.ProcessBlockCalled = func(data.HeaderHandler, data.BodyHandler, func() time.Duration) error { return err @@ -636,13 +647,14 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFail container.SetBlockProcessor(blProcMock) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -653,24 +665,25 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFail currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) } func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockReturnsInNextRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -681,14 +694,14 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockRetu currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody - blockProcessorMock := mock.InitBlockProcessorMock(container.Marshalizer()) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + blockProcessorMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blockProcessorMock.ProcessBlockCalled = func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { return errors.New("error") } container.SetBlockProcessor(blockProcessorMock) - container.SetRoundHandler(&mock.RoundHandlerMock{RoundIndex: 1}) + container.SetRoundHandler(&consensusMock.RoundHandlerMock{RoundIndex: 1}) assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) } @@ -697,17 +710,18 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr, _ := container.BlockProcessor().CreateNewHeader(1, 1) hdr, blkBody, _ := container.BlockProcessor().CreateBlock(hdr, func() bool { return true }) - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -718,19 +732,19 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) assert.True(t, sr.ProcessReceivedBlock(cnsMsg)) } } func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() roundHandlerMock := initRoundHandlerMock() container.SetRoundHandler(roundHandlerMock) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) remainingTimeInThisRound := func() time.Duration { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -739,19 +753,19 @@ func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { return remainingTime } - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 84 / 100) }}) ret := remainingTimeInThisRound() assert.True(t, ret > 0) - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 85 / 100) }}) ret = remainingTimeInThisRound() assert.True(t, ret == 0) - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 86 / 100) }}) ret = remainingTimeInThisRound() @@ -760,24 +774,24 @@ func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) assert.False(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) sr.SetStatus(bls.SrBlock, spos.SsFinished) assert.True(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) for i := 0; i < sr.Threshold(bls.SrBlock); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, true) } @@ -786,15 +800,15 @@ func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedR func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenBlockIsReceivedReturnFalse(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) assert.False(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_IsBlockReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) for i := 0; i < len(sr.ConsensusGroup()); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, false) @@ -815,8 +829,8 @@ func TestSubroundBlock_IsBlockReceived(t *testing.T) { func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) haveTimeInCurrentSubound := func() bool { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -825,14 +839,14 @@ func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { return time.Duration(remainingTime) > 0 } - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMock.RoundHandlerMock{} roundHandlerMock.TimeDurationCalled = func() time.Duration { return 4000 * time.Millisecond } roundHandlerMock.TimeStampCalled = func() time.Time { return time.Unix(0, 0) } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMock.SyncTimerMock{} timeElapsed := sr.EndTime() - 1 syncTimerMock.CurrentTimeCalled = func() time.Time { return time.Unix(0, timeElapsed) @@ -845,8 +859,8 @@ func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) haveTimeInCurrentSubound := func() bool { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -855,14 +869,14 @@ func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { return time.Duration(remainingTime) > 0 } - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMock.RoundHandlerMock{} roundHandlerMock.TimeDurationCalled = func() time.Duration { return 4000 * time.Millisecond } roundHandlerMock.TimeStampCalled = func() time.Time { return time.Unix(0, 0) } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMock.SyncTimerMock{} timeElapsed := sr.EndTime() + 1 syncTimerMock.CurrentTimeCalled = func() time.Time { return time.Unix(0, timeElapsed) @@ -892,7 +906,7 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(nil, nil) header, _ := sr.CreateHeader() header, body, _ := sr.CreateBlock(header) @@ -923,7 +937,7 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ Nonce: 1, }, []byte("root hash")) @@ -967,8 +981,8 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { } }, } - container := mock.InitConsensusCore() - bp := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + bp := consensusMock.InitBlockProcessorMock(container.Marshalizer()) bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { shardHeader, _ := header.(*block.Header) shardHeader.MiniBlockHeaders = mbHeaders @@ -976,7 +990,7 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { return shardHeader, &block.Body{}, nil } - sr := *initSubroundBlockWithBlockProcessor(bp, container) + sr := initSubroundBlockWithBlockProcessor(bp, container) container.SetBlockchain(&blockChainMock) header, _ := sr.CreateHeader() @@ -1002,12 +1016,12 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { func TestSubroundBlock_CreateHeaderNilMiniBlocks(t *testing.T) { expectedErr := errors.New("nil mini blocks") - container := mock.InitConsensusCore() - bp := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + bp := consensusMock.InitBlockProcessorMock(container.Marshalizer()) bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { return nil, nil, expectedErr } - sr := *initSubroundBlockWithBlockProcessor(bp, container) + sr := initSubroundBlockWithBlockProcessor(bp, container) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ Nonce: 1, }, []byte("root hash")) @@ -1059,7 +1073,7 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { srDuration := srEndTime - srStartTime delay := srDuration * 430 / 1000 - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() receivedValue := uint64(0) container.SetBlockProcessor(&testscommon.BlockProcessorStub{ ProcessBlockCalled: func(_ data.HeaderHandler, _ data.BodyHandler, _ func() time.Duration) error { @@ -1067,20 +1081,22 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { return nil }, }) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ SetUInt64ValueHandler: func(key string, value uint64) { receivedValue = value }}) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -1091,8 +1107,8 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) minimumExpectedValue := uint64(delay * 100 / srDuration) _ = sr.ProcessReceivedBlock(cnsMsg) @@ -1113,13 +1129,13 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDurationWithZeroDurationShould } }() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) - srBlock := *defaultSubroundBlockWithoutErrorFromSubround(sr) + srBlock := defaultSubroundBlockWithoutErrorFromSubround(sr) srBlock.ComputeSubroundProcessingMetric(time.Now(), "dummy") } diff --git a/consensus/spos/bls/subroundEndRound.go b/consensus/spos/bls/v1/subroundEndRound.go similarity index 90% rename from consensus/spos/bls/subroundEndRound.go rename to consensus/spos/bls/v1/subroundEndRound.go index 21675715f39..c591c736aca 100644 --- a/consensus/spos/bls/subroundEndRound.go +++ b/consensus/spos/bls/v1/subroundEndRound.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "bytes" @@ -11,9 +11,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/display" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process/headerCheck" ) @@ -73,7 +75,7 @@ func checkNewSubroundEndRoundParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -131,11 +133,11 @@ func (sr *subroundEndRound) receivedBlockHeaderFinalInfo(_ context.Context, cnsD } func (sr *subroundEndRound) isBlockHeaderFinalInfoValid(cnsDta *consensus.Message) bool { - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false } - header := sr.Header.ShallowClone() + header := sr.GetHeader().ShallowClone() err := header.SetPubKeysBitmap(cnsDta.PubKeysBitmap) if err != nil { log.Debug("isBlockHeaderFinalInfoValid.SetPubKeysBitmap", "error", err.Error()) @@ -293,14 +295,15 @@ func (sr *subroundEndRound) doEndRoundJob(_ context.Context) bool { } func (sr *subroundEndRound) doEndRoundJobByLeader() bool { - bitmap := sr.GenerateBitmap(SrSignature) + bitmap := sr.GenerateBitmap(bls.SrSignature) err := sr.checkSignaturesValidity(bitmap) if err != nil { log.Debug("doEndRoundJobByLeader.checkSignaturesValidity", "error", err.Error()) return false } - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { log.Error("doEndRoundJobByLeader.CheckNilHeader", "error", spos.ErrNilHeader) return false } @@ -312,13 +315,13 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { return false } - err = sr.Header.SetPubKeysBitmap(bitmap) + err = header.SetPubKeysBitmap(bitmap) if err != nil { log.Debug("doEndRoundJobByLeader.SetPubKeysBitmap", "error", err.Error()) return false } - err = sr.Header.SetSignature(sig) + err = header.SetSignature(sig) if err != nil { log.Debug("doEndRoundJobByLeader.SetSignature", "error", err.Error()) return false @@ -331,7 +334,7 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { return false } - err = sr.Header.SetLeaderSignature(leaderSignature) + err = header.SetLeaderSignature(leaderSignature) if err != nil { log.Debug("doEndRoundJobByLeader.SetLeaderSignature", "error", err.Error()) return false @@ -362,13 +365,13 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { } // broadcast header - err = sr.BroadcastMessenger().BroadcastHeader(sr.Header, []byte(leader)) + err = sr.BroadcastMessenger().BroadcastHeader(header, []byte(leader)) if err != nil { log.Debug("doEndRoundJobByLeader.BroadcastHeader", "error", err.Error()) } startTime := time.Now() - err = sr.BlockProcessor().CommitBlock(sr.Header, sr.Body) + err = sr.BlockProcessor().CommitBlock(header, sr.GetBody()) elapsedTime := time.Since(startTime) if elapsedTime >= common.CommitMaxTime { log.Warn("doEndRoundJobByLeader.CommitBlock", "elapsed time", elapsedTime) @@ -393,7 +396,7 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { log.Debug("doEndRoundJobByLeader.broadcastBlockDataLeader", "error", err.Error()) } - msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", sr.Header.GetNonce()) + msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", header.GetNonce()) log.Debug(display.Headline(msg, sr.SyncTimer().FormattedCurrentTime(), "+")) sr.updateMetricsForLeader() @@ -402,7 +405,8 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { } func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) ([]byte, []byte, error) { - sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + header := sr.GetHeader() + sig, err := sr.SigningHandler().AggregateSigs(bitmap, header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.AggregateSigs", "error", err.Error()) @@ -415,7 +419,7 @@ func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) return nil, nil, err } - err = sr.SigningHandler().Verify(sr.GetData(), bitmap, sr.Header.GetEpoch()) + err = sr.SigningHandler().Verify(sr.GetData(), bitmap, header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.Verify", "error", err.Error()) @@ -429,12 +433,13 @@ func (sr *subroundEndRound) verifyNodesOnAggSigFail() ([]string, error) { invalidPubKeys := make([]string, 0) pubKeys := sr.ConsensusGroup() - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { return nil, spos.ErrNilHeader } for i, pk := range pubKeys { - isJobDone, err := sr.JobDone(pk, SrSignature) + isJobDone, err := sr.JobDone(pk, bls.SrSignature) if err != nil || !isJobDone { continue } @@ -445,11 +450,11 @@ func (sr *subroundEndRound) verifyNodesOnAggSigFail() ([]string, error) { } isSuccessfull := true - err = sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.Header.GetEpoch()) + err = sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), header.GetEpoch()) if err != nil { isSuccessfull = false - err = sr.SetJobDone(pk, SrSignature, false) + err = sr.SetJobDone(pk, bls.SrSignature, false) if err != nil { return nil, err } @@ -520,9 +525,10 @@ func (sr *subroundEndRound) handleInvalidSignersOnAggSigFail() ([]byte, []byte, func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) { threshold := sr.Threshold(sr.Current()) - numValidSigShares := sr.ComputeSize(SrSignature) + numValidSigShares := sr.ComputeSize(bls.SrSignature) - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { return nil, nil, spos.ErrNilHeader } @@ -531,13 +537,13 @@ func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) spos.ErrInvalidNumSigShares, numValidSigShares, threshold) } - bitmap := sr.GenerateBitmap(SrSignature) + bitmap := sr.GenerateBitmap(bls.SrSignature) err := sr.checkSignaturesValidity(bitmap) if err != nil { return nil, nil, err } - sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + sig, err := sr.SigningHandler().AggregateSigs(bitmap, header.GetEpoch()) if err != nil { return nil, nil, err } @@ -557,6 +563,7 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { return } + header := sr.GetHeader() cnsMsg := consensus.NewConsensusMessage( sr.GetData(), nil, @@ -564,12 +571,12 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { nil, []byte(leader), nil, - int(MtBlockHeaderFinalInfo), + int(bls.MtBlockHeaderFinalInfo), sr.RoundHandler().Index(), sr.ChainID(), - sr.Header.GetPubKeysBitmap(), - sr.Header.GetSignature(), - sr.Header.GetLeaderSignature(), + header.GetPubKeysBitmap(), + header.GetSignature(), + header.GetLeaderSignature(), sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -581,9 +588,9 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { } log.Debug("step 3: block header final info has been sent", - "PubKeysBitmap", sr.Header.GetPubKeysBitmap(), - "AggregateSignature", sr.Header.GetSignature(), - "LeaderSignature", sr.Header.GetLeaderSignature()) + "PubKeysBitmap", header.GetPubKeysBitmap(), + "AggregateSignature", header.GetSignature(), + "LeaderSignature", header.GetLeaderSignature()) } func (sr *subroundEndRound) createAndBroadcastInvalidSigners(invalidSigners []byte) { @@ -605,7 +612,7 @@ func (sr *subroundEndRound) createAndBroadcastInvalidSigners(invalidSigners []by nil, []byte(leader), nil, - int(MtInvalidSigners), + int(bls.MtInvalidSigners), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -628,7 +635,7 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message sr.mutProcessingEndRound.Lock() defer sr.mutProcessingEndRound.Unlock() - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } if !sr.IsConsensusDataSet() { @@ -652,13 +659,13 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message sr.SetProcessingBlock(true) - shouldNotCommitBlock := sr.ExtendedCalled || int64(header.GetRound()) < sr.RoundHandler().Index() + shouldNotCommitBlock := sr.GetExtendedCalled() || int64(header.GetRound()) < sr.RoundHandler().Index() if shouldNotCommitBlock { log.Debug("canceled round, extended has been called or round index has been changed", "round", sr.RoundHandler().Index(), "subround", sr.Name(), "header round", header.GetRound(), - "extended called", sr.ExtendedCalled, + "extended called", sr.GetExtendedCalled(), ) return false } @@ -673,7 +680,7 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message } startTime := time.Now() - err := sr.BlockProcessor().CommitBlock(header, sr.Body) + err := sr.BlockProcessor().CommitBlock(header, sr.GetBody()) elapsedTime := time.Since(startTime) if elapsedTime >= common.CommitMaxTime { log.Warn("doEndRoundJobByParticipant.CommitBlock", "elapsed time", elapsedTime) @@ -715,11 +722,11 @@ func (sr *subroundEndRound) haveConsensusHeaderWithFullInfo(cnsDta *consensus.Me return sr.isConsensusHeaderReceived() } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false, nil } - header := sr.Header.ShallowClone() + header := sr.GetHeader().ShallowClone() err := header.SetPubKeysBitmap(cnsDta.PubKeysBitmap) if err != nil { return false, nil @@ -739,11 +746,11 @@ func (sr *subroundEndRound) haveConsensusHeaderWithFullInfo(cnsDta *consensus.Me } func (sr *subroundEndRound) isConsensusHeaderReceived() (bool, data.HeaderHandler) { - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false, nil } - consensusHeaderHash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), sr.Header) + consensusHeaderHash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), sr.GetHeader()) if err != nil { log.Debug("isConsensusHeaderReceived: calculate consensus header hash", "error", err.Error()) return false, nil @@ -787,7 +794,7 @@ func (sr *subroundEndRound) isConsensusHeaderReceived() (bool, data.HeaderHandle } func (sr *subroundEndRound) signBlockHeader() ([]byte, error) { - headerClone := sr.Header.ShallowClone() + headerClone := sr.GetHeader().ShallowClone() err := headerClone.SetLeaderSignature(nil) if err != nil { return nil, err @@ -813,7 +820,7 @@ func (sr *subroundEndRound) updateMetricsForLeader() { } func (sr *subroundEndRound) broadcastBlockDataLeader() error { - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.GetHeader(), sr.GetBody()) if err != nil { return err } @@ -824,7 +831,7 @@ func (sr *subroundEndRound) broadcastBlockDataLeader() error { return errGetLeader } - return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.Header, miniBlocks, transactions, []byte(leader)) + return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.GetHeader(), miniBlocks, transactions, []byte(leader)) } func (sr *subroundEndRound) setHeaderForValidator(header data.HeaderHandler) error { @@ -844,14 +851,14 @@ func (sr *subroundEndRound) prepareBroadcastBlockDataForValidator() error { return err } - go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.Header, miniBlocks, transactions, idx, pk) + go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.GetHeader(), miniBlocks, transactions, idx, pk) return nil } // doEndRoundConsensusCheck method checks if the consensus is achieved func (sr *subroundEndRound) doEndRoundConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -866,7 +873,7 @@ func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { consensusGroup := sr.ConsensusGroup() signers := headerCheck.ComputeSignersPublicKeys(consensusGroup, bitmap) for _, pubKey := range signers { - isSigJobDone, err := sr.JobDone(pubKey, SrSignature) + isSigJobDone, err := sr.JobDone(pubKey, bls.SrSignature) if err != nil { return err } @@ -880,14 +887,14 @@ func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { } func (sr *subroundEndRound) isOutOfTime() bool { - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { log.Debug("canceled round, time is out", "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), "subround", sr.Name()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return true } @@ -908,7 +915,7 @@ func (sr *subroundEndRound) getIndexPkAndDataToBroadcast() (int, []byte, map[uin return -1, nil, nil, nil, err } - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.GetHeader(), sr.GetBody()) if err != nil { return -1, nil, nil, nil, err } @@ -923,7 +930,7 @@ func (sr *subroundEndRound) getMinConsensusGroupIndexOfManagedKeys() int { minIdx := sr.ConsensusGroupSize() for idx, validator := range sr.ConsensusGroup() { - if !sr.IsKeyManagedByCurrentNode([]byte(validator)) { + if !sr.IsKeyManagedBySelf([]byte(validator)) { continue } diff --git a/consensus/spos/bls/subroundEndRound_test.go b/consensus/spos/bls/v1/subroundEndRound_test.go similarity index 77% rename from consensus/spos/bls/subroundEndRound_test.go rename to consensus/spos/bls/v1/subroundEndRound_test.go index 725513b8cb2..c3388302557 100644 --- a/consensus/spos/bls/subroundEndRound_test.go +++ b/consensus/spos/bls/v1/subroundEndRound_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "bytes" @@ -12,27 +12,30 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/p2p/factory" "github.com/multiversx/mx-chain-go/testscommon" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func initSubroundEndRoundWithContainer( - container *mock.ConsensusCoreMock, + container *consensusMocks.ConsensusCoreMock, appStatusHandler core.AppStatusHandler, -) bls.SubroundEndRound { +) v1.SubroundEndRound { ch := make(chan bool, 1) - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -49,10 +52,10 @@ func initSubroundEndRoundWithContainer( appStatusHandler, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, appStatusHandler, &testscommon.SentSignatureTrackerStub{}, @@ -61,16 +64,16 @@ func initSubroundEndRoundWithContainer( return srEndRound } -func initSubroundEndRound(appStatusHandler core.AppStatusHandler) bls.SubroundEndRound { - container := mock.InitConsensusCore() +func initSubroundEndRound(appStatusHandler core.AppStatusHandler) v1.SubroundEndRound { + container := consensusMocks.InitConsensusCore() return initSubroundEndRoundWithContainer(container, appStatusHandler) } func TestNewSubroundEndRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( bls.SrSignature, @@ -91,10 +94,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -106,10 +109,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, nil, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -121,10 +124,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil app status handler should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, nil, &testscommon.SentSignatureTrackerStub{}, @@ -136,25 +139,25 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, nil, ) assert.Nil(t, srEndRound) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -173,10 +176,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing. &statusHandler.AppStatusHandlerStub{}, ) container.SetBlockchain(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -189,8 +192,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing. func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -209,10 +212,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *test &statusHandler.AppStatusHandlerStub{}, ) container.SetBlockProcessor(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -225,8 +228,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *test func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -245,11 +248,11 @@ func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *test &statusHandler.AppStatusHandlerStub{}, ) - sr.ConsensusState = nil - srEndRound, err := bls.NewSubroundEndRound( + sr.ConsensusStateHandler = nil + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -262,8 +265,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *test func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -282,10 +285,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t &statusHandler.AppStatusHandlerStub{}, ) container.SetMultiSignerContainer(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -298,8 +301,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -318,10 +321,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testin &statusHandler.AppStatusHandlerStub{}, ) container.SetRoundHandler(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -334,8 +337,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testin func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -354,10 +357,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T &statusHandler.AppStatusHandlerStub{}, ) container.SetSyncTimer(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -370,8 +373,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -390,10 +393,10 @@ func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -405,8 +408,8 @@ func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) signingHandler := &consensusMocks.SigningHandlerStub{ AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { @@ -415,9 +418,10 @@ func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") assert.True(t, sr.IsSelfLeaderInCurrentRound()) r := sr.DoEndRoundJob() @@ -427,11 +431,12 @@ func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - blProcMock := mock.InitBlockProcessorMock(container.Marshalizer()) + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) blProcMock.CommitBlockCalled = func( header data.HeaderHandler, body data.BodyHandler, @@ -440,7 +445,7 @@ func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { } container.SetBlockProcessor(blProcMock) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.False(t, r) @@ -449,19 +454,20 @@ func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") remainingTime := time.Millisecond - roundHandlerMock := &mock.RoundHandlerMock{ + roundHandlerMock := &consensusMocks.RoundHandlerMock{ RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { return remainingTime }, } container.SetRoundHandler(roundHandlerMock) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -475,17 +481,18 @@ func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrBroadcastBlockOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - bm := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -495,16 +502,16 @@ func TestSubroundEndRound_DoEndRoundJobErrMarshalizedDataToBroadcastOK(t *testin t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { err = errors.New("error marshalized data to broadcast") return make(map[uint32][]byte), make(map[string][][]byte), err } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -516,10 +523,11 @@ func TestSubroundEndRound_DoEndRoundJobErrMarshalizedDataToBroadcastOK(t *testin }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -530,15 +538,15 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastMiniBlocksOK(t *testing.T) { t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { return make(map[uint32][]byte), make(map[string][][]byte), nil } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -551,10 +559,11 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastMiniBlocksOK(t *testing.T) { }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -566,15 +575,15 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { return make(map[uint32][]byte), make(map[string][][]byte), nil } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -587,10 +596,11 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -601,17 +611,18 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) func TestSubroundEndRound_DoEndRoundJobAllOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - bm := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -621,7 +632,7 @@ func TestSubroundEndRound_CheckIfSignatureIsFilled(t *testing.T) { t.Parallel() expectedSignature := []byte("signature") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() signingHandler := &consensusMocks.SigningHandlerStub{ CreateSignatureForPublicKeyCalled: func(publicKeyBytes []byte, msg []byte) ([]byte, error) { var receivedHdr block.Header @@ -630,27 +641,28 @@ func TestSubroundEndRound_CheckIfSignatureIsFilled(t *testing.T) { }, } container.SetSigningHandler(signingHandler) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{Nonce: 5} + sr.SetHeader(&block.Header{Nonce: 5}) r := sr.DoEndRoundJob() assert.True(t, r) - assert.Equal(t, expectedSignature, sr.Header.GetLeaderSignature()) + assert.Equal(t, expectedSignature, sr.GetHeader().GetLeaderSignature()) } func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) ok := sr.DoEndRoundConsensusCheck() assert.False(t, ok) @@ -659,7 +671,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCa func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.SetStatus(bls.SrEndRound, spos.SsFinished) ok := sr.DoEndRoundConsensusCheck() @@ -669,7 +681,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFin func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNotFinished(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) ok := sr.DoEndRoundConsensusCheck() assert.False(t, ok) @@ -678,7 +690,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNo func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) err := sr.CheckSignaturesValidity([]byte{2}) assert.Equal(t, spos.ErrNilSignature, err) @@ -687,7 +699,7 @@ func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testin func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -698,8 +710,8 @@ func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobByParticipant_RoundCanceledShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -709,8 +721,8 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_RoundCanceledShouldReturnFa func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusDataNotSetShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Data = nil + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetData(nil) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -720,7 +732,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusDataNotSetShouldRe func TestSubroundEndRound_DoEndRoundJobByParticipant_PreviousSubroundNotFinishedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.SetStatus(2, spos.SsNotFinished) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -730,7 +742,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_PreviousSubroundNotFinished func TestSubroundEndRound_DoEndRoundJobByParticipant_CurrentSubroundFinishedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) // set previous as finished sr.SetStatus(2, spos.SsFinished) @@ -746,7 +758,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_CurrentSubroundFinishedShou func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusHeaderNotReceivedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) // set previous as finished sr.SetStatus(2, spos.SsFinished) @@ -763,8 +775,8 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_ShouldReturnTrue(t *testing t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) // set previous as finished @@ -782,8 +794,8 @@ func TestSubroundEndRound_IsConsensusHeaderReceived_NoReceivedHeadersShouldRetur t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) res, retHdr := sr.IsConsensusHeaderReceived() assert.False(t, res) @@ -795,9 +807,9 @@ func TestSubroundEndRound_IsConsensusHeaderReceived_HeaderNotReceivedShouldRetur hdr := &block.Header{Nonce: 37} hdrToSearchFor := &block.Header{Nonce: 38} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.AddReceivedHeader(hdr) - sr.Header = hdrToSearchFor + sr.SetHeader(hdrToSearchFor) res, retHdr := sr.IsConsensusHeaderReceived() assert.False(t, res) @@ -808,8 +820,8 @@ func TestSubroundEndRound_IsConsensusHeaderReceivedShouldReturnTrue(t *testing.T t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) res, retHdr := sr.IsConsensusHeaderReceived() @@ -820,7 +832,7 @@ func TestSubroundEndRound_IsConsensusHeaderReceivedShouldReturnTrue(t *testing.T func TestSubroundEndRound_HaveConsensusHeaderWithFullInfoNilHdrShouldNotWork(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{} @@ -843,8 +855,8 @@ func TestSubroundEndRound_HaveConsensusHeaderWithFullInfoShouldWork(t *testing.T Signature: originalSig, LeaderSignature: originalLeaderSig, } - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = &hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&hdr) cnsData := consensus.Message{ PubKeysBitmap: newPubKeyBitMap, @@ -864,8 +876,8 @@ func TestSubroundEndRound_CreateAndBroadcastHeaderFinalInfoBroadcastShouldBeCall chanRcv := make(chan bool, 1) leaderSigInHdr := []byte("leader sig") - container := mock.InitConsensusCore() - messenger := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { chanRcv <- true assert.Equal(t, message.LeaderSignature, leaderSigInHdr) @@ -873,8 +885,8 @@ func TestSubroundEndRound_CreateAndBroadcastHeaderFinalInfoBroadcastShouldBeCall }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{LeaderSignature: leaderSigInHdr} + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{LeaderSignature: leaderSigInHdr}) sr.CreateAndBroadcastHeaderFinalInfo() @@ -889,8 +901,8 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldWork(t *testing.T) { t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) sr.SetStatus(2, spos.SsFinished) @@ -909,9 +921,9 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldWork(t *testing.T) { func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinalInfoIsNotValid(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return errors.New("error") }, @@ -921,12 +933,12 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinal } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), } - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) res := sr.ReceivedBlockHeaderFinalInfo(&cnsData) assert.False(t, res) } @@ -934,7 +946,7 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinal func TestSubroundEndRound_IsOutOfTimeShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) res := sr.IsOutOfTime() assert.False(t, res) @@ -944,8 +956,8 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { t.Parallel() // update roundHandler's mock, so it will calculate for real the duration - container := mock.InitConsensusCore() - roundHandler := mock.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + container := consensusMocks.InitConsensusCore() + roundHandler := consensusMocks.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { currentTime := time.Now() elapsedTime := currentTime.Sub(startTime) remainingTime := maxTime - elapsedTime @@ -953,9 +965,9 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { return remainingTime }} container.SetRoundHandler(&roundHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.RoundTimeStamp = time.Now().AddDate(0, 0, -1) + sr.SetRoundTimeStamp(time.Now().AddDate(0, 0, -1)) res := sr.IsOutOfTime() assert.True(t, res) @@ -964,9 +976,9 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerifyLeaderSignatureFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return errors.New("error") }, @@ -976,9 +988,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.False(t, isValid) } @@ -986,9 +998,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerifySignatureFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return nil }, @@ -998,9 +1010,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.False(t, isValid) } @@ -1008,9 +1020,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return nil }, @@ -1020,9 +1032,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnTrue(t *testing } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.True(t, isValid) } @@ -1033,8 +1045,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("fail to get signature share", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1045,7 +1057,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, err := sr.VerifyNodesOnAggSigFail() @@ -1055,8 +1067,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("fail to verify signature share, job done will be set to false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1068,7 +1080,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { }, } - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) container.SetSigningHandler(signingHandler) @@ -1083,8 +1095,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) signingHandler := &consensusMocks.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil @@ -1098,7 +1110,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) @@ -1114,9 +1126,9 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("invalid number of valid sig shares", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{} + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) sr.SetThreshold(bls.SrEndRound, 2) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1126,8 +1138,8 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("fail to created aggregated sig", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1137,7 +1149,7 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1147,8 +1159,8 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("fail to set aggregated sig", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1157,7 +1169,7 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { }, } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1167,9 +1179,9 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{} + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) bitmap, sig, err := sr.ComputeAggSigOnValidNodes() @@ -1185,8 +1197,8 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { t.Run("not enough valid signature shares", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) verifySigShareNumCalls := 0 verifyFirstCall := true @@ -1220,7 +1232,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJobByLeader() require.False(t, r) @@ -1232,8 +1244,8 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) verifySigShareNumCalls := 0 verifyFirstCall := true @@ -1268,7 +1280,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[2], bls.SrSignature, true) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJobByLeader() require.True(t, r) @@ -1284,10 +1296,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("consensus data is not set", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.ConsensusState.Data = nil + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.ConsensusStateHandler.SetData(nil) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1301,9 +1313,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message node is not leader in current round", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1317,10 +1329,11 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message from self leader should return false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1334,14 +1347,14 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message from self multikey leader should return false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return string(pkBytes) == "A" }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -1358,10 +1371,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -1381,9 +1394,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received hash does not match the hash from current consensus state", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("Y"), @@ -1397,9 +1410,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("process received message verification failed, different round index", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1414,9 +1427,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("empty invalid signers", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), @@ -1437,10 +1450,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), @@ -1454,9 +1467,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1475,7 +1488,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to deserialize invalidSigners field, should error", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() expectedErr := errors.New("expected err") messageSigningHandler := &mock.MessageSigningHandlerStub{ @@ -1486,7 +1499,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners([]byte{}) require.Equal(t, expectedErr, err) @@ -1495,7 +1508,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to verify low level p2p message, should error", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() invalidSigners := []p2p.MessageP2P{&factory.Message{ FromField: []byte("from"), @@ -1515,7 +1528,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Equal(t, expectedErr, err) @@ -1524,7 +1537,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to verify signature share", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() pubKey := []byte("A") // it's in consensus @@ -1557,7 +1570,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetSigningHandler(signingHandler) container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Nil(t, err) @@ -1567,7 +1580,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() pubKey := []byte("A") // it's in consensus @@ -1585,7 +1598,7 @@ func TestVerifyInvalidSigners(t *testing.T) { messageSigningHandler := &mock.MessageSignerMock{} container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Nil(t, err) @@ -1600,7 +1613,7 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { expectedInvalidSigners := []byte("invalid signers") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() nodeRedundancy := &mock.NodeRedundancyHandlerStub{ IsRedundancyNodeCalled: func() bool { return true @@ -1610,14 +1623,14 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { }, } container.SetNodeRedundancyHandler(nodeRedundancy) - messenger := &mock.BroadcastMessengerMock{ + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { assert.Fail(t, "should have not been called") return nil }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) }) @@ -1630,8 +1643,8 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { expectedInvalidSigners := []byte("invalid signers") wasCalled := false - container := mock.InitConsensusCore() - messenger := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { assert.Equal(t, expectedInvalidSigners, message.InvalidSigners) wasCalled = true @@ -1640,8 +1653,9 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) @@ -1657,7 +1671,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { t.Run("empty p2p messages slice if not in state", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() messageSigningHandler := &mock.MessageSigningHandlerStub{ SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { @@ -1669,7 +1683,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) invalidSigners := []string{"B", "C"} invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) @@ -1680,7 +1694,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() expectedInvalidSigners := []byte("expectedInvalidSigners") @@ -1694,7 +1708,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.AddMessageWithSignature("B", &p2pmocks.P2PMessageMock{}) sr.AddMessageWithSignature("C", &p2pmocks.P2PMessageMock{}) @@ -1709,10 +1723,10 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -1729,10 +1743,10 @@ func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, diff --git a/consensus/spos/bls/subroundSignature.go b/consensus/spos/bls/v1/subroundSignature.go similarity index 94% rename from consensus/spos/bls/subroundSignature.go rename to consensus/spos/bls/v1/subroundSignature.go index ac06cc72fdd..1d71ac59420 100644 --- a/consensus/spos/bls/subroundSignature.go +++ b/consensus/spos/bls/v1/subroundSignature.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -8,9 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" ) type subroundSignature struct { @@ -60,7 +62,7 @@ func checkNewSubroundSignatureParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -74,7 +76,7 @@ func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { if !sr.CanDoSubroundJob(sr.Current()) { return false } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { log.Error("doSignatureJob", "error", spos.ErrNilHeader) return false } @@ -92,7 +94,7 @@ func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( sr.GetData(), uint16(selfIndex), - sr.Header.GetEpoch(), + sr.GetHeader().GetEpoch(), []byte(sr.SelfPubKey()), ) if err != nil { @@ -125,7 +127,7 @@ func (sr *subroundSignature) createAndSendSignatureMessage(signatureShare []byte nil, pkBytes, nil, - int(MtSignature), + int(bls.MtSignature), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -236,7 +238,7 @@ func (sr *subroundSignature) receivedSignature(_ context.Context, cnsDta *consen // doSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved func (sr *subroundSignature) doSignatureConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -250,7 +252,7 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { isSelfInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() threshold := sr.Threshold(sr.Current()) - if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.Header) { + if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) { threshold = sr.FallbackThreshold(sr.Current()) log.Warn("subroundSignature.doSignatureConsensusCheck: fallback validation has been applied", "minimum number of signatures required", threshold, @@ -261,7 +263,7 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { areSignaturesCollected, numSigs := sr.areSignaturesCollected(threshold) areAllSignaturesCollected := numSigs == sr.ConsensusGroupSize() - isJobDoneByLeader := isSelfLeader && (areAllSignaturesCollected || (areSignaturesCollected && sr.WaitingAllSignaturesTimeOut)) + isJobDoneByLeader := isSelfLeader && (areAllSignaturesCollected || (areSignaturesCollected && sr.GetWaitingAllSignaturesTimeOut())) selfJobDone := true if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { @@ -332,7 +334,7 @@ func (sr *subroundSignature) waitAllSignatures() { return } - sr.WaitingAllSignaturesTimeOut = true + sr.SetWaitingAllSignaturesTimeOut(true) select { case sr.ConsensusChannel() <- true: @@ -352,12 +354,12 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { isMultiKeyLeader := sr.IsMultiKeyLeaderInCurrentRound() numMultiKeysSignaturesSent := 0 - for idx, pk := range sr.ConsensusGroup() { + for _, pk := range sr.ConsensusGroup() { pkBytes := []byte(pk) if sr.IsJobDone(pk, sr.Current()) { continue } - if !sr.IsKeyManagedByCurrentNode(pkBytes) { + if !sr.IsKeyManagedBySelf(pkBytes) { continue } @@ -370,7 +372,7 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( sr.GetData(), uint16(selfIndex), - sr.Header.GetEpoch(), + sr.GetHeader().GetEpoch(), pkBytes, ) if err != nil { @@ -387,8 +389,13 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { numMultiKeysSignaturesSent++ } sr.sentSignatureTracker.SignatureSent(pkBytes) + leader, err := sr.GetLeader() + if err != nil { + log.Debug("doSignatureJobForManagedKeys.GetLeader", "error", err.Error()) + return false + } - isLeader := idx == spos.IndexOfLeaderInConsensusGroup + isLeader := pk == leader ok := sr.completeSignatureSubRound(pk, isLeader) if !ok { return false diff --git a/consensus/spos/bls/subroundSignature_test.go b/consensus/spos/bls/v1/subroundSignature_test.go similarity index 78% rename from consensus/spos/bls/subroundSignature_test.go rename to consensus/spos/bls/v1/subroundSignature_test.go index 9ee8a03ba19..73d765cb67b 100644 --- a/consensus/spos/bls/subroundSignature_test.go +++ b/consensus/spos/bls/v1/subroundSignature_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "testing" @@ -6,19 +6,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/testscommon" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) -func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.SubroundSignature { - consensusState := initConsensusState() +func initSubroundSignatureWithContainer(container *consensusMocks.ConsensusCoreMock) v1.SubroundSignature { + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -37,7 +39,7 @@ func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.S &statusHandler.AppStatusHandlerStub{}, ) - srSignature, _ := bls.NewSubroundSignature( + srSignature, _ := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -47,16 +49,16 @@ func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.S return srSignature } -func initSubroundSignature() bls.SubroundSignature { - container := mock.InitConsensusCore() +func initSubroundSignature() v1.SubroundSignature { + container := consensusMocks.InitConsensusCore() return initSubroundSignatureWithContainer(container) } func TestNewSubroundSignature(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -78,7 +80,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( nil, extend, &statusHandler.AppStatusHandlerStub{}, @@ -91,7 +93,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, nil, &statusHandler.AppStatusHandlerStub{}, @@ -104,7 +106,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil app status handler should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, nil, @@ -117,7 +119,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -125,15 +127,15 @@ func TestNewSubroundSignature(t *testing.T) { ) assert.Nil(t, srSignature) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -152,8 +154,8 @@ func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *te &statusHandler.AppStatusHandlerStub{}, ) - sr.ConsensusState = nil - srSignature, err := bls.NewSubroundSignature( + sr.ConsensusStateHandler = nil + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -167,8 +169,8 @@ func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *te func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -187,7 +189,7 @@ func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) &statusHandler.AppStatusHandlerStub{}, ) container.SetHasher(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -201,8 +203,8 @@ func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -221,7 +223,7 @@ func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail &statusHandler.AppStatusHandlerStub{}, ) container.SetMultiSignerContainer(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -235,8 +237,8 @@ func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -256,7 +258,7 @@ func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *test ) container.SetRoundHandler(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -270,8 +272,8 @@ func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *test func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -290,7 +292,7 @@ func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing &statusHandler.AppStatusHandlerStub{}, ) container.SetSyncTimer(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -304,8 +306,8 @@ func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -324,7 +326,7 @@ func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -338,15 +340,15 @@ func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { func TestSubroundSignature_DoSignatureJob(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) - sr.Header = &block.Header{} - sr.Data = nil + sr.SetHeader(&block.Header{}) + sr.SetData(nil) r := sr.DoSignatureJob() assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) err := errors.New("create signature share error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -370,18 +372,21 @@ func TestSubroundSignature_DoSignatureJob(t *testing.T) { assert.True(t, r) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrSignature, false) - sr.RoundCanceled = false - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetRoundCanceled(false) + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) r = sr.DoSignatureJob() assert.True(t, r) - assert.False(t, sr.RoundCanceled) + assert.False(t, sr.GetRoundCanceled()) } func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusStateWithKeysHandler( + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusStateWithKeysHandler( &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return true @@ -407,7 +412,7 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { ) signatureSentForPks := make(map[string]struct{}) - srSignature, _ := bls.NewSubroundSignature( + srSignature, _ := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -418,12 +423,12 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { }, ) - srSignature.Header = &block.Header{} - srSignature.Data = nil + srSignature.SetHeader(&block.Header{}) + srSignature.SetData(nil) r := srSignature.DoSignatureJob() assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) err := errors.New("create signature share error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -447,11 +452,15 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { assert.True(t, r) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrSignature, false) - sr.RoundCanceled = false - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetRoundCanceled(false) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) r = srSignature.DoSignatureJob() assert.True(t, r) - assert.False(t, sr.RoundCanceled) + assert.False(t, sr.GetRoundCanceled()) expectedMap := map[string]struct{}{ "A": {}, "B": {}, @@ -469,10 +478,10 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { func TestSubroundSignature_ReceivedSignature(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() signature := []byte("signature") cnsMsg := consensus.NewConsensusMessage( - sr.Data, + sr.GetData(), signature, nil, nil, @@ -488,20 +497,22 @@ func TestSubroundSignature_ReceivedSignature(t *testing.T) { nil, ) - sr.Header = &block.Header{} - sr.Data = nil + sr.SetHeader(&block.Header{}) + sr.SetData(nil) r := sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("Y") + sr.SetData([]byte("Y")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) + leader, err := sr.GetLeader() + assert.Nil(t, err) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetSelfPubKey(leader) cnsMsg.PubKey = []byte("X") r = sr.ReceivedSignature(cnsMsg) @@ -538,14 +549,14 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetSigningHandler(signingHandler) - sr := *initSubroundSignatureWithContainer(container) - sr.Header = &block.Header{} + sr := initSubroundSignatureWithContainer(container) + sr.SetHeader(&block.Header{}) signature := []byte("signature") cnsMsg := consensus.NewConsensusMessage( - sr.Data, + sr.GetData(), signature, nil, nil, @@ -561,19 +572,21 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { nil, ) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("Y") + sr.SetData([]byte("Y")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) cnsMsg.PubKey = []byte("X") r = sr.ReceivedSignature(cnsMsg) @@ -599,7 +612,7 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { func TestSubroundSignature_SignaturesCollected(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() for i := 0; i < len(sr.ConsensusGroup()); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) @@ -628,15 +641,15 @@ func TestSubroundSignature_SignaturesCollected(t *testing.T) { func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() - sr.RoundCanceled = true + sr := initSubroundSignature() + sr.SetRoundCanceled(true) assert.False(t, sr.DoSignatureConsensusCheck()) } func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() sr.SetStatus(bls.SrSignature, spos.SsFinished) assert.True(t, sr.DoSignatureConsensusCheck()) } @@ -644,7 +657,7 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSubround func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSignaturesCollectedReturnTrue(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -656,18 +669,20 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSignatur func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenSignaturesCollectedReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() assert.False(t, sr.DoSignatureConsensusCheck()) } func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenNotAllSignaturesCollectedAndTimeIsNotOut(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -679,11 +694,13 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenNotAllS func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenAllSignaturesCollected(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.ConsensusGroupSize(); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -695,11 +712,13 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenAllSigna func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenEnoughButNotAllSignaturesCollectedAndTimeIsOut(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = true + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(true) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -711,14 +730,14 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenEnoughBu func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenFallbackThresholdCouldNotBeApplied(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{ ShouldApplyFallbackValidationCalled: func(headerHandler data.HeaderHandler) bool { return false }, }) - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) sr.SetSelfPubKey(sr.ConsensusGroup()[0]) @@ -732,16 +751,18 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenFallbac func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenFallbackThresholdCouldBeApplied(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{ ShouldApplyFallbackValidationCalled: func(headerHandler data.HeaderHandler) bool { return true }, }) - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = true + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(true) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.FallbackThreshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -753,14 +774,16 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenFallback func TestSubroundSignature_ReceivedSignatureReturnFalseWhenConsensusDataIsNotEqual(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( - append(sr.Data, []byte("X")...), + append(sr.GetData(), []byte("X")...), []byte("signature"), nil, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtSignature), 0, diff --git a/consensus/spos/bls/subroundStartRound.go b/consensus/spos/bls/v1/subroundStartRound.go similarity index 93% rename from consensus/spos/bls/subroundStartRound.go rename to consensus/spos/bls/v1/subroundStartRound.go index 0898e039c8b..a47d9235cd2 100644 --- a/consensus/spos/bls/subroundStartRound.go +++ b/consensus/spos/bls/v1/subroundStartRound.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/outport" @@ -80,7 +81,7 @@ func checkNewSubroundStartRoundParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -105,8 +106,8 @@ func (sr *subroundStartRound) SetOutportHandler(outportHandler outport.OutportHa // doStartRoundJob method does the job of the subround StartRound func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { sr.ResetConsensusState() - sr.RoundIndex = sr.RoundHandler().Index() - sr.RoundTimeStamp = sr.RoundHandler().TimeStamp() + sr.SetRoundIndex(sr.RoundHandler().Index()) + sr.SetRoundTimeStamp(sr.RoundHandler().TimeStamp()) topic := spos.GetConsensusTopicID(sr.ShardCoordinator()) sr.GetAntiFloodHandler().ResetForTopic(topic) sr.resetConsensusMessages() @@ -115,7 +116,7 @@ func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { // doStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound func (sr *subroundStartRound) doStartRoundConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -144,7 +145,7 @@ func (sr *subroundStartRound) initCurrentRound() bool { "round index", sr.RoundHandler().Index(), "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -163,13 +164,13 @@ func (sr *subroundStartRound) initCurrentRound() bool { if err != nil { log.Debug("initCurrentRound.GetLeader", "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } msg := "" - if sr.IsKeyManagedByCurrentNode([]byte(leader)) { + if sr.IsKeyManagedBySelf([]byte(leader)) { msg = " (my turn in multi-key)" } if leader == sr.SelfPubKey() && sr.ShouldConsiderSelfKeyInConsensus() { @@ -192,7 +193,7 @@ func (sr *subroundStartRound) initCurrentRound() bool { sr.indexRoundIfNeeded(pubKeys) isSingleKeyLeader := leader == sr.SelfPubKey() && sr.ShouldConsiderSelfKeyInConsensus() - isLeader := isSingleKeyLeader || sr.IsKeyManagedByCurrentNode([]byte(leader)) + isLeader := isSingleKeyLeader || sr.IsKeyManagedBySelf([]byte(leader)) isSelfInConsensus := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || numMultiKeysInConsensusGroup > 0 if !isSelfInConsensus { log.Debug("not in consensus group") @@ -208,19 +209,19 @@ func (sr *subroundStartRound) initCurrentRound() bool { if err != nil { log.Debug("initCurrentRound.Reset", "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { log.Debug("canceled round, time is out", "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), "subround", sr.Name()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -237,7 +238,7 @@ func (sr *subroundStartRound) computeNumManagedKeysInConsensusGroup(pubKeys []st numMultiKeysInConsensusGroup := 0 for _, pk := range pubKeys { pkBytes := []byte(pk) - if sr.IsKeyManagedByCurrentNode(pkBytes) { + if sr.IsKeyManagedBySelf(pkBytes) { numMultiKeysInConsensusGroup++ log.Trace("in consensus group with multi key", "pk", core.GetTrimmedPk(hex.EncodeToString(pkBytes))) @@ -297,7 +298,7 @@ func (sr *subroundStartRound) indexRoundIfNeeded(pubKeys []string) { BlockWasProposed: false, ShardId: shardId, Epoch: epoch, - Timestamp: uint64(sr.RoundTimeStamp.Unix()), + Timestamp: uint64(sr.GetRoundTimeStamp().Unix()), } roundsInfo := &outportcore.RoundsInfo{ ShardID: shardId, @@ -322,9 +323,9 @@ func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error shardId := sr.ShardCoordinator().SelfId() - nextConsensusGroup, err := sr.GetNextConsensusGroup( + leader, nextConsensusGroup, err := sr.GetNextConsensusGroup( randomSeed, - uint64(sr.RoundIndex), + uint64(sr.GetRoundIndex()), shardId, sr.NodesCoordinator(), currentHeader.GetEpoch(), @@ -341,6 +342,7 @@ func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error } sr.SetConsensusGroup(nextConsensusGroup) + sr.SetLeader(leader) consensusGroupSizeForEpoch := sr.NodesCoordinator().ConsensusGroupSizeForShardAndEpoch(shardId, currentHeader.GetEpoch()) sr.SetConsensusGroupSize(consensusGroupSizeForEpoch) @@ -372,5 +374,5 @@ func (sr *subroundStartRound) changeEpoch(currentEpoch uint32) { // NotifyOrder returns the notification order for a start of epoch event func (sr *subroundStartRound) NotifyOrder() uint32 { - return common.ConsensusOrder + return common.ConsensusStartRoundOrder } diff --git a/consensus/spos/bls/subroundStartRound_test.go b/consensus/spos/bls/v1/subroundStartRound_test.go similarity index 74% rename from consensus/spos/bls/subroundStartRound_test.go rename to consensus/spos/bls/v1/subroundStartRound_test.go index 2f5c21d2659..5ab4523bf94 100644 --- a/consensus/spos/bls/subroundStartRound_test.go +++ b/consensus/spos/bls/v1/subroundStartRound_test.go @@ -1,26 +1,31 @@ -package bls_test +package v1_test import ( "errors" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) -func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (bls.SubroundStartRound, error) { - startRound, err := bls.NewSubroundStartRound( +func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (v1.SubroundStartRound, error) { + startRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -29,11 +34,11 @@ func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (bls.SubroundStart return startRound, err } -func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) bls.SubroundStartRound { - startRound, _ := bls.NewSubroundStartRound( +func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) v1.SubroundStartRound { + startRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -65,14 +70,14 @@ func defaultSubround( ) } -func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) bls.SubroundStartRound { - consensusState := initConsensusState() +func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) v1.SubroundStartRound { + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -81,8 +86,8 @@ func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) bl return srStartRound } -func initSubroundStartRound() bls.SubroundStartRound { - container := mock.InitConsensusCore() +func initSubroundStartRound() v1.SubroundStartRound { + container := consensusMocks.InitConsensusCore() return initSubroundStartRoundWithContainer(container) } @@ -90,8 +95,8 @@ func TestNewSubroundStartRound(t *testing.T) { t.Parallel() ch := make(chan bool, 1) - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMocks.InitConsensusCore() sr, _ := spos.NewSubround( -1, bls.SrStartRound, @@ -111,10 +116,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -126,10 +131,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, nil, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -142,10 +147,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil executeStoredMessages function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, nil, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -158,10 +163,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil resetConsensusMessages function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, nil, &testscommon.SentSignatureTrackerStub{}, @@ -174,26 +179,26 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, nil, ) assert.Nil(t, srStartRound) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -207,9 +212,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *test func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -223,13 +228,13 @@ func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *te func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - sr.ConsensusState = nil + sr.ConsensusStateHandler = nil srStartRound, err := defaultSubroundStartRoundFromSubround(sr) assert.Nil(t, srStartRound) @@ -239,9 +244,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t * func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -255,9 +260,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFa func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -271,9 +276,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *te func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -287,9 +292,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testi func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -303,9 +308,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShould func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -319,14 +324,14 @@ func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound := *defaultWithoutErrorSubroundStartRoundFromSubround(sr) + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) r := srStartRound.DoStartRoundJob() assert.True(t, r) @@ -335,9 +340,9 @@ func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundStartRound() + sr := initSubroundStartRound() - sr.RoundCanceled = true + sr.SetRoundCanceled(true) ok := sr.DoStartRoundConsensusCheck() assert.False(t, ok) @@ -346,7 +351,7 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRound func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundStartRound() + sr := initSubroundStartRound() sr.SetStatus(bls.SrStartRound, spos.SsFinished) @@ -357,14 +362,14 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundI func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCurrentRoundReturnTrue(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { return common.NsSynchronized }} - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - sr := *initSubroundStartRoundWithContainer(container) + sr := initSubroundStartRoundWithContainer(container) sentTrackerInterface := sr.GetSentSignatureTracker() sentTracker := sentTrackerInterface.(*testscommon.SentSignatureTrackerStub) startRoundCalled := false @@ -380,15 +385,15 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCu func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitCurrentRoundReturnFalse(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }} - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) container.SetRoundHandler(initRoundHandlerMock()) - sr := *initSubroundStartRoundWithContainer(container) + sr := initSubroundStartRoundWithContainer(container) ok := sr.DoStartRoundConsensusCheck() assert.False(t, ok) @@ -397,15 +402,15 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitC func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetNodeStateNotReturnSynchronized(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} bootstrapperMock.GetNodeStateCalled = func() common.NodeState { return common.NsNotSynchronized } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -416,13 +421,13 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGenerateNextCon validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} err := errors.New("error") - validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { - return nil, err + validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetValidatorGroupSelector(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -436,10 +441,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenMainMachineIsAct return true }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetNodeRedundancyHandler(nodeRedundancyMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -449,19 +454,24 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t t.Parallel() validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + leader := &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { + return []byte("leader") + }} + validatorGroupSelector.ComputeValidatorsGroupCalled = func( bytes []byte, round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return make([]nodesCoordinator.Validator, 0), nil + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + // will cause an error in GetLeader because of empty consensus group + return leader, []nodesCoordinator.Validator{}, nil } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetValidatorGroupSelector(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -470,14 +480,14 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenIsNotInTheConsensusGroup(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() consensusState.SetSelfPubKey(consensusState.SelfPubKey() + "X") ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound := *defaultWithoutErrorSubroundStartRoundFromSubround(sr) + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -492,10 +502,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *te return time.Duration(-1) } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetRoundHandler(roundHandlerMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -504,16 +514,16 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *te func TestSubroundStartRound_InitCurrentRoundShouldReturnTrue(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} bootstrapperMock.GetNodeStateCalled = func() common.NodeState { return common.NsSynchronized } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -526,18 +536,18 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { t.Parallel() wasCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasCalled = true - assert.Equal(t, value, "not in consensus group") + assert.Equal(t, "not in consensus group", value) } }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) consensusState.SetSelfPubKey("not in consensus") sr, _ := spos.NewSubround( -1, @@ -555,10 +565,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -571,7 +581,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasCalled := false wasIncrementCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return string(pkBytes) == "B" @@ -591,7 +601,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) consensusState.SetSelfPubKey("B") sr, _ := spos.NewSubround( -1, @@ -609,10 +619,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -626,13 +636,13 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasCalled := false wasIncrementCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasCalled = true - assert.Equal(t, value, "participant") + assert.Equal(t, "participant", value) } }, IncrementHandler: func(key string) { @@ -642,7 +652,8 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("B") keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { return string(pkBytes) == consensusState.SelfPubKey() } @@ -662,10 +673,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -680,21 +691,21 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasMetricConsensusStateCalled := false wasMetricCountLeaderCalled := false cntMetricConsensusRoundStateCalled := 0 - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasMetricConsensusStateCalled = true - assert.Equal(t, value, "proposer") + assert.Equal(t, "proposer", value) } if key == common.MetricConsensusRoundState { cntMetricConsensusRoundStateCalled++ switch cntMetricConsensusRoundStateCalled { case 1: - assert.Equal(t, value, "") + assert.Equal(t, "", value) case 2: - assert.Equal(t, value, "proposed") + assert.Equal(t, "proposed", value) default: assert.Fail(t, "should have been called only twice") } @@ -707,9 +718,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) leader, _ := consensusState.GetLeader() consensusState.SetSelfPubKey(leader) + sr, _ := spos.NewSubround( -1, bls.SrStartRound, @@ -726,10 +738,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -745,21 +757,21 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasMetricConsensusStateCalled := false wasMetricCountLeaderCalled := false cntMetricConsensusRoundStateCalled := 0 - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasMetricConsensusStateCalled = true - assert.Equal(t, value, "proposer") + assert.Equal(t, "proposer", value) } if key == common.MetricConsensusRoundState { cntMetricConsensusRoundStateCalled++ switch cntMetricConsensusRoundStateCalled { case 1: - assert.Equal(t, value, "") + assert.Equal(t, "", value) case 2: - assert.Equal(t, value, "proposed") + assert.Equal(t, "proposed", value) default: assert.Fail(t, "should have been called only twice") } @@ -772,7 +784,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) leader, _ := consensusState.GetLeader() consensusState.SetSelfPubKey(leader) keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { @@ -794,10 +806,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -820,13 +832,13 @@ func TestSubroundStartRound_GenerateNextConsensusGroupShouldReturnErr(t *testing round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return nil, err + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetValidatorGroupSelector(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) err2 := srStartRound.GenerateNextConsensusGroup(0) diff --git a/consensus/spos/bls/v2/benchmark_test.go b/consensus/spos/bls/v2/benchmark_test.go new file mode 100644 index 00000000000..b7c4b962071 --- /dev/null +++ b/consensus/spos/bls/v2/benchmark_test.go @@ -0,0 +1,138 @@ +package v2_test + +import ( + "context" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" + "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/testscommon" + nodeMock "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func BenchmarkSubroundSignature_doSignatureJobForManagedKeys63(b *testing.B) { + benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b, 63) +} + +func BenchmarkSubroundSignature_doSignatureJobForManagedKeys400(b *testing.B) { + benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b, 400) +} + +func createMultiSignerSetup(grSize uint16, suite crypto.Suite) (crypto.KeyGenerator, map[string]crypto.PrivateKey) { + kg := signing.NewKeyGenerator(suite) + mapKeys := make(map[string]crypto.PrivateKey) + + for i := uint16(0); i < grSize; i++ { + sk, pk := kg.GeneratePair() + + pubKey, _ := pk.ToByteArray() + mapKeys[string(pubKey)] = sk + } + return kg, mapKeys +} + +func benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b *testing.B, numberOfKeys int) { + container := consensus.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + llSigner := &mclMultiSig.BlsMultiSignerKOSK{} + + suite := mcl.NewSuiteBLS12() + kg, mapKeys := createMultiSignerSetup(uint16(numberOfKeys), suite) + + multiSigHandler, _ := multisig.NewBLSMultisig(llSigner, kg) + + keysHandlerMock := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + GetHandledPrivateKeyCalled: func(pkBytes []byte) crypto.PrivateKey { + return mapKeys[string(pkBytes)] + }, + } + + args := cryptoFactory.ArgsSigningHandler{ + PubKeys: initializers.CreateEligibleListFromMap(mapKeys), + MultiSignerContainer: &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return multiSigHandler, nil + }}, + SingleSigner: &cryptoMocks.SingleSignerStub{}, + KeyGenerator: kg, + KeysHandler: keysHandlerMock, + } + signingHandler, err := cryptoFactory.NewSigningHandler(args) + require.Nil(b, err) + + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithArgs(keysHandlerMock, mapKeys) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensus.SposWorkerMock{}, + &nodeMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + sr.SetSelfPubKey("OTHER") + + b.ResetTimer() + b.StopTimer() + + for i := 0; i < b.N; i++ { + b.StartTimer() + r := srSignature.DoSignatureJobForManagedKeys(context.TODO()) + b.StopTimer() + + require.True(b, r) + } +} diff --git a/consensus/spos/bls/v2/benchmark_verify_signatures_test.go b/consensus/spos/bls/v2/benchmark_verify_signatures_test.go new file mode 100644 index 00000000000..09a276dc3a3 --- /dev/null +++ b/consensus/spos/bls/v2/benchmark_verify_signatures_test.go @@ -0,0 +1,123 @@ +package v2_test + +import ( + "context" + "sort" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/stretchr/testify/require" + + crypto "github.com/multiversx/mx-chain-crypto-go" + mclMultisig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" + "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + + "github.com/multiversx/mx-chain-go/common" + factoryCrypto "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +const benchmarkKeyPairsCardinal = 400 + +// createListFromMapKeys make a predictable iteration on keys from a map of keys +func createListFromMapKeys(mapKeys map[string]crypto.PrivateKey) []string { + keys := make([]string, 0, len(mapKeys)) + + for key := range mapKeys { + keys = append(keys, key) + } + + sort.Strings(keys) + + return keys +} + +// generateKeyPairs generates benchmarkKeyPairsCardinal number of pairs(public key & private key) +func generateKeyPairs(kg crypto.KeyGenerator) map[string]crypto.PrivateKey { + mapKeys := make(map[string]crypto.PrivateKey) + + for i := uint16(0); i < benchmarkKeyPairsCardinal; i++ { + sk, pk := kg.GeneratePair() + + pubKey, _ := pk.ToByteArray() + mapKeys[string(pubKey)] = sk + } + return mapKeys +} + +// BenchmarkSubroundEndRound_VerifyNodesOnAggSigFailTime measure time needed to verify signatures +func BenchmarkSubroundEndRound_VerifyNodesOnAggSigFailTime(b *testing.B) { + + b.ResetTimer() + b.StopTimer() + ctx, cancel := context.WithCancel(context.TODO()) + + defer func() { + cancel() + }() + + container := consensus.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + llSigner := &mclMultisig.BlsMultiSignerKOSK{} + suite := mcl.NewSuiteBLS12() + kg := signing.NewKeyGenerator(suite) + + multiSigHandler, _ := multisig.NewBLSMultisig(llSigner, kg) + + mapKeys := generateKeyPairs(kg) + + keysHandlerMock := &testscommon.KeysHandlerStub{ + GetHandledPrivateKeyCalled: func(pkBytes []byte) crypto.PrivateKey { + return mapKeys[string(pkBytes)] + }, + } + keys := createListFromMapKeys(mapKeys) + args := factoryCrypto.ArgsSigningHandler{ + PubKeys: keys, + MultiSignerContainer: &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return multiSigHandler, nil + }, + }, + SingleSigner: &cryptoMocks.SingleSignerStub{}, + KeyGenerator: kg, + KeysHandler: keysHandlerMock, + } + + signingHandler, err := factoryCrypto.NewSigningHandler(args) + require.Nil(b, err) + + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithArgsVerifySignature(keysHandlerMock, keys) + dataToBeSigned := []byte("message") + consensusState.Data = dataToBeSigned + + sr := initSubroundEndRoundWithContainerAndConsensusState(container, &statusHandler.AppStatusHandlerStub{}, consensusState, &dataRetrieverMocks.ThrottlerStub{}) + for i := 0; i < len(sr.ConsensusGroup()); i++ { + _, err := sr.SigningHandler().CreateSignatureShareForPublicKey(dataToBeSigned, uint16(i), sr.EnableEpochsHandler().GetCurrentEpoch(), []byte(keys[i])) + require.Nil(b, err) + _ = sr.SetJobDone(keys[i], bls.SrSignature, true) + } + for i := 0; i < b.N; i++ { + b.StartTimer() + invalidSigners, err := sr.VerifyNodesOnAggSigFail(ctx) + b.StopTimer() + require.Nil(b, err) + require.NotNil(b, invalidSigners) + } +} diff --git a/consensus/spos/bls/v2/blsSubroundsFactory.go b/consensus/spos/bls/v2/blsSubroundsFactory.go new file mode 100644 index 00000000000..2c9ade325a0 --- /dev/null +++ b/consensus/spos/bls/v2/blsSubroundsFactory.go @@ -0,0 +1,307 @@ +package v2 + +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/outport" +) + +// factory defines the data needed by this factory to create all the subrounds and give them their specific +// functionality +type factory struct { + consensusCore spos.ConsensusCoreHandler + consensusState spos.ConsensusStateHandler + worker spos.WorkerHandler + + appStatusHandler core.AppStatusHandler + outportHandler outport.OutportHandler + sentSignaturesTracker spos.SentSignaturesTracker + chainID []byte + currentPid core.PeerID + signatureThrottler core.Throttler +} + +// NewSubroundsFactory creates a new consensusState object +func NewSubroundsFactory( + consensusDataContainer spos.ConsensusCoreHandler, + consensusState spos.ConsensusStateHandler, + worker spos.WorkerHandler, + chainID []byte, + currentPid core.PeerID, + appStatusHandler core.AppStatusHandler, + sentSignaturesTracker spos.SentSignaturesTracker, + signatureThrottler core.Throttler, + outportHandler outport.OutportHandler, +) (*factory, error) { + // no need to check the outport handler, it can be nil + err := checkNewFactoryParams( + consensusDataContainer, + consensusState, + worker, + chainID, + appStatusHandler, + sentSignaturesTracker, + signatureThrottler, + ) + if err != nil { + return nil, err + } + + fct := factory{ + consensusCore: consensusDataContainer, + consensusState: consensusState, + worker: worker, + appStatusHandler: appStatusHandler, + chainID: chainID, + currentPid: currentPid, + sentSignaturesTracker: sentSignaturesTracker, + signatureThrottler: signatureThrottler, + outportHandler: outportHandler, + } + + return &fct, nil +} + +func checkNewFactoryParams( + container spos.ConsensusCoreHandler, + state spos.ConsensusStateHandler, + worker spos.WorkerHandler, + chainID []byte, + appStatusHandler core.AppStatusHandler, + sentSignaturesTracker spos.SentSignaturesTracker, + signatureThrottler core.Throttler, +) error { + err := spos.ValidateConsensusCore(container) + if err != nil { + return err + } + if state == nil { + return spos.ErrNilConsensusState + } + if check.IfNil(worker) { + return spos.ErrNilWorker + } + if check.IfNil(appStatusHandler) { + return spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignaturesTracker) { + return ErrNilSentSignatureTracker + } + if check.IfNil(signatureThrottler) { + return spos.ErrNilThrottler + } + if len(chainID) == 0 { + return spos.ErrInvalidChainID + } + + return nil +} + +// SetOutportHandler method will update the value of the factory's outport +func (fct *factory) SetOutportHandler(driver outport.OutportHandler) { + fct.outportHandler = driver +} + +// GenerateSubrounds will generate the subrounds used in BLS Cns +func (fct *factory) GenerateSubrounds() error { + fct.initConsensusThreshold() + fct.consensusCore.Chronology().RemoveAllSubrounds() + fct.worker.RemoveAllReceivedMessagesCalls() + + err := fct.generateStartRoundSubround() + if err != nil { + return err + } + + err = fct.generateBlockSubround() + if err != nil { + return err + } + + err = fct.generateSignatureSubround() + if err != nil { + return err + } + + err = fct.generateEndRoundSubround() + if err != nil { + return err + } + + return nil +} + +func (fct *factory) getTimeDuration() time.Duration { + return fct.consensusCore.RoundHandler().TimeDuration() +} + +func (fct *factory) generateStartRoundSubround() error { + subround, err := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(float64(fct.getTimeDuration())*srStartStartTime), + int64(float64(fct.getTimeDuration())*srStartEndTime), + bls.GetSubroundName(bls.SrStartRound), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundStartRoundInstance, err := NewSubroundStartRound( + subround, + processingThresholdPercent, + fct.sentSignaturesTracker, + fct.worker, + ) + if err != nil { + return err + } + + err = subroundStartRoundInstance.SetOutportHandler(fct.outportHandler) + if err != nil { + return err + } + + fct.consensusCore.Chronology().AddSubround(subroundStartRoundInstance) + + return nil +} + +func (fct *factory) generateBlockSubround() error { + subround, err := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(float64(fct.getTimeDuration())*srBlockStartTime), + int64(float64(fct.getTimeDuration())*srBlockEndTime), + bls.GetSubroundName(bls.SrBlock), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundBlockInstance, err := NewSubroundBlock( + subround, + processingThresholdPercent, + fct.worker, + ) + if err != nil { + return err + } + + fct.worker.AddReceivedMessageCall(bls.MtBlockBody, subroundBlockInstance.receivedBlockBody) + fct.worker.AddReceivedHeaderHandler(subroundBlockInstance.receivedBlockHeader) + fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) + + return nil +} + +func (fct *factory) generateSignatureSubround() error { + subround, err := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(float64(fct.getTimeDuration())*srSignatureStartTime), + int64(float64(fct.getTimeDuration())*srSignatureEndTime), + bls.GetSubroundName(bls.SrSignature), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundSignatureObject, err := NewSubroundSignature( + subround, + fct.appStatusHandler, + fct.sentSignaturesTracker, + fct.worker, + fct.signatureThrottler, + ) + if err != nil { + return err + } + + fct.consensusCore.Chronology().AddSubround(subroundSignatureObject) + + return nil +} + +func (fct *factory) generateEndRoundSubround() error { + subround, err := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(float64(fct.getTimeDuration())*srEndStartTime), + int64(float64(fct.getTimeDuration())*srEndEndTime), + bls.GetSubroundName(bls.SrEndRound), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundEndRoundObject, err := NewSubroundEndRound( + subround, + spos.MaxThresholdPercent, + fct.appStatusHandler, + fct.sentSignaturesTracker, + fct.worker, + fct.signatureThrottler, + ) + if err != nil { + return err + } + + fct.worker.AddReceivedProofHandler(subroundEndRoundObject.receivedProof) + fct.worker.AddReceivedMessageCall(bls.MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) + fct.worker.AddReceivedMessageCall(bls.MtSignature, subroundEndRoundObject.receivedSignature) + fct.consensusCore.Chronology().AddSubround(subroundEndRoundObject) + + return nil +} + +func (fct *factory) initConsensusThreshold() { + pBFTThreshold := core.GetPBFTThreshold(fct.consensusState.ConsensusGroupSize()) + pBFTFallbackThreshold := core.GetPBFTFallbackThreshold(fct.consensusState.ConsensusGroupSize()) + fct.consensusState.SetThreshold(bls.SrBlock, 1) + fct.consensusState.SetThreshold(bls.SrSignature, pBFTThreshold) + fct.consensusState.SetFallbackThreshold(bls.SrBlock, 1) + fct.consensusState.SetFallbackThreshold(bls.SrSignature, pBFTFallbackThreshold) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (fct *factory) IsInterfaceNil() bool { + return fct == nil +} diff --git a/consensus/spos/bls/v2/blsSubroundsFactory_test.go b/consensus/spos/bls/v2/blsSubroundsFactory_test.go new file mode 100644 index 00000000000..bfafd967169 --- /dev/null +++ b/consensus/spos/bls/v2/blsSubroundsFactory_test.go @@ -0,0 +1,680 @@ +package v2_test + +import ( + "context" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/testscommon" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + testscommonOutport "github.com/multiversx/mx-chain-go/testscommon/outport" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var chainID = []byte("chain ID") + +const currentPid = core.PeerID("pid") + +const roundTimeDuration = 100 * time.Millisecond + +// executeStoredMessages tries to execute all the messages received which are valid for execution +func executeStoredMessages() { +} + +func initRoundHandlerMock() *testscommonConsensus.RoundHandlerMock { + return &testscommonConsensus.RoundHandlerMock{ + RoundIndex: 0, + TimeStampCalled: func() time.Time { + return time.Unix(0, 0) + }, + TimeDurationCalled: func() time.Duration { + return roundTimeDuration + }, + } +} + +func initWorker() spos.WorkerHandler { + sposWorker := &testscommonConsensus.SposWorkerMock{} + sposWorker.GetConsensusStateChangedChannelsCalled = func() chan bool { + return make(chan bool) + } + sposWorker.RemoveAllReceivedMessagesCallsCalled = func() {} + + sposWorker.AddReceivedMessageCallCalled = + func(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { + } + + return sposWorker +} + +func initFactoryWithContainer(container *testscommonConsensus.ConsensusCoreMock) v2.Factory { + worker := initWorker() + consensusState := initializers.InitConsensusState() + + fct, _ := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + return fct +} + +func initFactory() v2.Factory { + container := testscommonConsensus.InitConsensusCore() + return initFactoryWithContainer(container) +} + +func TestFactory_GetMessageTypeName(t *testing.T) { + t.Parallel() + + r := bls.GetStringValue(bls.MtBlockBodyAndHeader) + assert.Equal(t, "(BLOCK_BODY_AND_HEADER)", r) + + r = bls.GetStringValue(bls.MtBlockBody) + assert.Equal(t, "(BLOCK_BODY)", r) + + r = bls.GetStringValue(bls.MtBlockHeader) + assert.Equal(t, "(BLOCK_HEADER)", r) + + r = bls.GetStringValue(bls.MtSignature) + assert.Equal(t, "(SIGNATURE)", r) + + r = bls.GetStringValue(bls.MtBlockHeaderFinalInfo) + assert.Equal(t, "(FINAL_INFO)", r) + + r = bls.GetStringValue(bls.MtUnknown) + assert.Equal(t, "(UNKNOWN)", r) + + r = bls.GetStringValue(consensus.MessageType(-1)) + assert.Equal(t, "Undefined message type", r) +} + +func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + nil, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilConsensusCore, err) +} + +func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + nil, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBlockchain(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBlockProcessor(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBootStrapper(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBootstrapper, err) +} + +func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetChronology(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilChronologyHandler, err) +} + +func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetHasher(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetMarshalizer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilMarshalizer, err) +} + +func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetMultiSignerContainer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetRoundHandler(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetShardCoordinator(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilShardCoordinator, err) +} + +func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetSyncTimer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetValidatorGroupSelector(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilNodesCoordinator, err) +} + +func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + nil, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilWorker, err) +} + +func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + nil, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) +} + +func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + nil, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) +} + +func TestFactory_NewFactoryNilThrottlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilThrottler, err) +} + +func TestFactory_NewFactoryShouldWork(t *testing.T) { + t.Parallel() + + fct := *initFactory() + + assert.False(t, check.IfNil(&fct)) +} + +func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + nil, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrInvalidChainID, err) +} + +func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateStartRoundSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundStartRoundFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateStartRoundSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateBlockSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundBlockFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateBlockSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateSignatureSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundSignatureFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateSignatureSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateEndRoundSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundEndRoundFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateEndRoundSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { + t.Parallel() + + subroundHandlers := 0 + + chrm := &testscommonConsensus.ChronologyHandlerMock{} + chrm.AddSubroundCalled = func(subroundHandler consensus.SubroundHandler) { + subroundHandlers++ + } + container := testscommonConsensus.InitConsensusCore() + container.SetChronology(chrm) + fct := *initFactoryWithContainer(container) + fct.SetOutportHandler(&testscommonOutport.OutportStub{}) + + err := fct.GenerateSubrounds() + assert.Nil(t, err) + + assert.Equal(t, 4, subroundHandlers) +} + +func TestFactory_GenerateSubroundsNilOutportShouldFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + + err := fct.GenerateSubrounds() + assert.Equal(t, outport.ErrNilDriver, err) +} + +func TestFactory_SetIndexerShouldWork(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + + outportHandler := &testscommonOutport.OutportStub{} + fct.SetOutportHandler(outportHandler) + + assert.Equal(t, outportHandler, fct.Outport()) +} diff --git a/consensus/spos/bls/v2/constants.go b/consensus/spos/bls/v2/constants.go new file mode 100644 index 00000000000..93856652b39 --- /dev/null +++ b/consensus/spos/bls/v2/constants.go @@ -0,0 +1,37 @@ +package v2 + +import ( + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("consensus/spos/bls/v2") + +// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature +const waitingAllSigsMaxTimeThreshold = 0.5 + +// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round +const processingThresholdPercent = 85 + +// srStartStartTime specifies the start time, from the total time of the round, of Subround Start +const srStartStartTime = 0.0 + +// srEndStartTime specifies the end time, from the total time of the round, of Subround Start +const srStartEndTime = 0.05 + +// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block +const srBlockStartTime = 0.05 + +// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block +const srBlockEndTime = 0.25 + +// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature +const srSignatureStartTime = 0.25 + +// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature +const srSignatureEndTime = 0.85 + +// srEndStartTime specifies the start time, from the total time of the round, of Subround End +const srEndStartTime = 0.85 + +// srEndEndTime specifies the end time, from the total time of the round, of Subround End +const srEndEndTime = 0.95 diff --git a/consensus/spos/bls/v2/errors.go b/consensus/spos/bls/v2/errors.go new file mode 100644 index 00000000000..6e6e6bf5400 --- /dev/null +++ b/consensus/spos/bls/v2/errors.go @@ -0,0 +1,12 @@ +package v2 + +import "errors" + +// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrWrongSizeBitmap defines the error for wrong size bitmap +var ErrWrongSizeBitmap = errors.New("wrong size bitmap") + +// ErrNotEnoughSignatures defines the error for not enough signatures +var ErrNotEnoughSignatures = errors.New("not enough signatures") diff --git a/consensus/spos/bls/v2/export_test.go b/consensus/spos/bls/v2/export_test.go new file mode 100644 index 00000000000..84ab13e2016 --- /dev/null +++ b/consensus/spos/bls/v2/export_test.go @@ -0,0 +1,342 @@ +package v2 + +import ( + "context" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + + cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/ntp" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +) + +const ProcessingThresholdPercent = processingThresholdPercent + +// factory + +// Factory defines a type for the factory structure +type Factory *factory + +// BlockChain gets the chain handler object +func (fct *factory) BlockChain() data.ChainHandler { + return fct.consensusCore.Blockchain() +} + +// BlockProcessor gets the block processor object +func (fct *factory) BlockProcessor() process.BlockProcessor { + return fct.consensusCore.BlockProcessor() +} + +// Bootstrapper gets the bootstrapper object +func (fct *factory) Bootstrapper() process.Bootstrapper { + return fct.consensusCore.BootStrapper() +} + +// ChronologyHandler gets the chronology handler object +func (fct *factory) ChronologyHandler() consensus.ChronologyHandler { + return fct.consensusCore.Chronology() +} + +// ConsensusState gets the consensus state struct pointer +func (fct *factory) ConsensusState() spos.ConsensusStateHandler { + return fct.consensusState +} + +// Hasher gets the hasher object +func (fct *factory) Hasher() hashing.Hasher { + return fct.consensusCore.Hasher() +} + +// Marshalizer gets the marshalizer object +func (fct *factory) Marshalizer() marshal.Marshalizer { + return fct.consensusCore.Marshalizer() +} + +// MultiSigner gets the multi signer object +func (fct *factory) MultiSignerContainer() cryptoCommon.MultiSignerContainer { + return fct.consensusCore.MultiSignerContainer() +} + +// RoundHandler gets the roundHandler object +func (fct *factory) RoundHandler() consensus.RoundHandler { + return fct.consensusCore.RoundHandler() +} + +// ShardCoordinator gets the shard coordinator object +func (fct *factory) ShardCoordinator() sharding.Coordinator { + return fct.consensusCore.ShardCoordinator() +} + +// SyncTimer gets the sync timer object +func (fct *factory) SyncTimer() ntp.SyncTimer { + return fct.consensusCore.SyncTimer() +} + +// NodesCoordinator gets the nodes coordinator object +func (fct *factory) NodesCoordinator() nodesCoordinator.NodesCoordinator { + return fct.consensusCore.NodesCoordinator() +} + +// Worker gets the worker object +func (fct *factory) Worker() spos.WorkerHandler { + return fct.worker +} + +// SetWorker sets the worker object +func (fct *factory) SetWorker(worker spos.WorkerHandler) { + fct.worker = worker +} + +// GenerateStartRoundSubround generates the instance of subround StartRound and added it to the chronology subrounds list +func (fct *factory) GenerateStartRoundSubround() error { + return fct.generateStartRoundSubround() +} + +// GenerateBlockSubround generates the instance of subround Block and added it to the chronology subrounds list +func (fct *factory) GenerateBlockSubround() error { + return fct.generateBlockSubround() +} + +// GenerateSignatureSubround generates the instance of subround Signature and added it to the chronology subrounds list +func (fct *factory) GenerateSignatureSubround() error { + return fct.generateSignatureSubround() +} + +// GenerateEndRoundSubround generates the instance of subround EndRound and added it to the chronology subrounds list +func (fct *factory) GenerateEndRoundSubround() error { + return fct.generateEndRoundSubround() +} + +// AppStatusHandler gets the app status handler object +func (fct *factory) AppStatusHandler() core.AppStatusHandler { + return fct.appStatusHandler +} + +// Outport gets the outport object +func (fct *factory) Outport() outport.OutportHandler { + return fct.outportHandler +} + +// subroundStartRound + +// SubroundStartRound defines an alias for the subroundStartRound structure +type SubroundStartRound = *subroundStartRound + +// DoStartRoundJob method does the job of the subround StartRound +func (sr *subroundStartRound) DoStartRoundJob() bool { + return sr.doStartRoundJob(context.Background()) +} + +// DoStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound +func (sr *subroundStartRound) DoStartRoundConsensusCheck() bool { + return sr.doStartRoundConsensusCheck() +} + +// GenerateNextConsensusGroup generates the next consensu group based on current (random seed, shard id and round) +func (sr *subroundStartRound) GenerateNextConsensusGroup(roundIndex int64) error { + return sr.generateNextConsensusGroup(roundIndex) +} + +// InitCurrentRound inits all the stuff needed in the current round +func (sr *subroundStartRound) InitCurrentRound() bool { + return sr.initCurrentRound() +} + +// GetSentSignatureTracker returns the subroundStartRound's SentSignaturesTracker instance +func (sr *subroundStartRound) GetSentSignatureTracker() spos.SentSignaturesTracker { + return sr.sentSignatureTracker +} + +// subroundBlock + +// SubroundBlock defines an alias for the subroundBlock structure +type SubroundBlock = *subroundBlock + +// Blockchain gets the ChainHandler stored in the ConsensusCore +func (sr *subroundBlock) BlockChain() data.ChainHandler { + return sr.Blockchain() +} + +// DoBlockJob method does the job of the subround Block +func (sr *subroundBlock) DoBlockJob() bool { + return sr.doBlockJob(context.Background()) +} + +// ProcessReceivedBlock method processes the received proposed block in the subround Block +func (sr *subroundBlock) ProcessReceivedBlock(cnsDta *consensus.Message) bool { + return sr.processReceivedBlock(context.Background(), cnsDta.RoundIndex, cnsDta.PubKey) +} + +// DoBlockConsensusCheck method checks if the consensus in the subround Block is achieved +func (sr *subroundBlock) DoBlockConsensusCheck() bool { + return sr.doBlockConsensusCheck() +} + +// IsBlockReceived method checks if the block was received from the leader in the current round +func (sr *subroundBlock) IsBlockReceived(threshold int) bool { + return sr.isBlockReceived(threshold) +} + +// CreateHeader method creates the proposed block header in the subround Block +func (sr *subroundBlock) CreateHeader() (data.HeaderHandler, error) { + return sr.createHeader() +} + +// CreateBody method creates the proposed block body in the subround Block +func (sr *subroundBlock) CreateBlock(hdr data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { + return sr.createBlock(hdr) +} + +// SendBlockBody method sends the proposed block body in the subround Block +func (sr *subroundBlock) SendBlockBody(body data.BodyHandler, marshalizedBody []byte) bool { + return sr.sendBlockBody(body, marshalizedBody) +} + +// SendBlockHeader method sends the proposed block header in the subround Block +func (sr *subroundBlock) SendBlockHeader(header data.HeaderHandler, marshalizedHeader []byte) bool { + return sr.sendBlockHeader(header, marshalizedHeader) +} + +// ComputeSubroundProcessingMetric computes processing metric related to the subround Block +func (sr *subroundBlock) ComputeSubroundProcessingMetric(startTime time.Time, metric string) { + sr.computeSubroundProcessingMetric(startTime, metric) +} + +// ReceivedBlockBody method is called when a block body is received through the block body channel +func (sr *subroundBlock) ReceivedBlockBody(cnsDta *consensus.Message) bool { + return sr.receivedBlockBody(context.Background(), cnsDta) +} + +// ReceivedBlockHeader method is called when a block header is received through the block header channel +func (sr *subroundBlock) ReceivedBlockHeader(header data.HeaderHandler) { + sr.receivedBlockHeader(header) +} + +// subroundSignature + +// SubroundSignature defines an alias to the subroundSignature structure +type SubroundSignature = *subroundSignature + +// DoSignatureJob method does the job of the subround Signature +func (sr *subroundSignature) DoSignatureJob() bool { + return sr.doSignatureJob(context.Background()) +} + +// DoSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved +func (sr *subroundSignature) DoSignatureConsensusCheck() bool { + return sr.doSignatureConsensusCheck() +} + +// subroundEndRound + +// SubroundEndRound defines a type for the subroundEndRound structure +type SubroundEndRound = *subroundEndRound + +// DoEndRoundJob method does the job of the subround EndRound +func (sr *subroundEndRound) DoEndRoundJob() bool { + return sr.doEndRoundJob(context.Background()) +} + +// DoEndRoundConsensusCheck method checks if the consensus is achieved +func (sr *subroundEndRound) DoEndRoundConsensusCheck() bool { + return sr.doEndRoundConsensusCheck() +} + +// CheckSignaturesValidity method checks the signature validity for the nodes included in bitmap +func (sr *subroundEndRound) CheckSignaturesValidity(bitmap []byte) error { + return sr.checkSignaturesValidity(bitmap) +} + +// DoEndRoundJobByLeader calls the unexported doEndRoundJobByNode function +func (sr *subroundEndRound) DoEndRoundJobByNode() bool { + return sr.doEndRoundJobByNode() +} + +// CreateAndBroadcastProof calls the unexported createAndBroadcastHeaderFinalInfo function +func (sr *subroundEndRound) CreateAndBroadcastProof(signature []byte, bitmap []byte) { + sr.createAndBroadcastProof(signature, bitmap) +} + +// ReceivedProof calls the unexported receivedProof function +func (sr *subroundEndRound) ReceivedProof(proof consensus.ProofHandler) { + sr.receivedProof(proof) +} + +// IsOutOfTime calls the unexported isOutOfTime function +func (sr *subroundEndRound) IsOutOfTime() bool { + return sr.isOutOfTime() +} + +// VerifyNodesOnAggSigFail calls the unexported verifyNodesOnAggSigFail function +func (sr *subroundEndRound) VerifyNodesOnAggSigFail(ctx context.Context) ([]string, error) { + return sr.verifyNodesOnAggSigFail(ctx) +} + +// ComputeAggSigOnValidNodes calls the unexported computeAggSigOnValidNodes function +func (sr *subroundEndRound) ComputeAggSigOnValidNodes() ([]byte, []byte, error) { + return sr.computeAggSigOnValidNodes() +} + +// ReceivedInvalidSignersInfo calls the unexported receivedInvalidSignersInfo function +func (sr *subroundEndRound) ReceivedInvalidSignersInfo(cnsDta *consensus.Message) bool { + return sr.receivedInvalidSignersInfo(context.Background(), cnsDta) +} + +// VerifyInvalidSigners calls the unexported verifyInvalidSigners function +func (sr *subroundEndRound) VerifyInvalidSigners(invalidSigners []byte) error { + return sr.verifyInvalidSigners(invalidSigners) +} + +// GetMinConsensusGroupIndexOfManagedKeys calls the unexported getMinConsensusGroupIndexOfManagedKeys function +func (sr *subroundEndRound) GetMinConsensusGroupIndexOfManagedKeys() int { + return sr.getMinConsensusGroupIndexOfManagedKeys() +} + +// CreateAndBroadcastInvalidSigners calls the unexported createAndBroadcastInvalidSigners function +func (sr *subroundEndRound) CreateAndBroadcastInvalidSigners(invalidSigners []byte) { + sr.createAndBroadcastInvalidSigners(invalidSigners) +} + +// GetFullMessagesForInvalidSigners calls the unexported getFullMessagesForInvalidSigners function +func (sr *subroundEndRound) GetFullMessagesForInvalidSigners(invalidPubKeys []string) ([]byte, error) { + return sr.getFullMessagesForInvalidSigners(invalidPubKeys) +} + +// GetSentSignatureTracker returns the subroundEndRound's SentSignaturesTracker instance +func (sr *subroundEndRound) GetSentSignatureTracker() spos.SentSignaturesTracker { + return sr.sentSignatureTracker +} + +// ChangeEpoch calls the unexported changeEpoch function +func (sr *subroundStartRound) ChangeEpoch(epoch uint32) { + sr.changeEpoch(epoch) +} + +// IndexRoundIfNeeded calls the unexported indexRoundIfNeeded function +func (sr *subroundStartRound) IndexRoundIfNeeded(pubKeys []string) { + sr.indexRoundIfNeeded(pubKeys) +} + +// SendSignatureForManagedKey calls the unexported sendSignatureForManagedKey function +func (sr *subroundSignature) SendSignatureForManagedKey(idx int, pk string) bool { + return sr.sendSignatureForManagedKey(idx, pk) +} + +// DoSignatureJobForManagedKeys calls the unexported doSignatureJobForManagedKeys function +func (sr *subroundSignature) DoSignatureJobForManagedKeys(ctx context.Context) bool { + return sr.doSignatureJobForManagedKeys(ctx) +} + +// ReceivedSignature method is called when a signature is received through the signature channel +func (sr *subroundEndRound) ReceivedSignature(cnsDta *consensus.Message) bool { + return sr.receivedSignature(context.Background(), cnsDta) +} diff --git a/consensus/spos/bls/v2/subroundBlock.go b/consensus/spos/bls/v2/subroundBlock.go new file mode 100644 index 00000000000..2454ad3643e --- /dev/null +++ b/consensus/spos/bls/v2/subroundBlock.go @@ -0,0 +1,738 @@ +package v2 + +import ( + "bytes" + "context" + "encoding/hex" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" +) + +// maxAllowedSizeInBytes defines how many bytes are allowed as payload in a message +const maxAllowedSizeInBytes = uint32(core.MegabyteSize * 95 / 100) + +// subroundBlock defines the data needed by the subround Block +type subroundBlock struct { + *spos.Subround + + processingThresholdPercentage int + worker spos.WorkerHandler +} + +// NewSubroundBlock creates a subroundBlock object +func NewSubroundBlock( + baseSubround *spos.Subround, + processingThresholdPercentage int, + worker spos.WorkerHandler, +) (*subroundBlock, error) { + err := checkNewSubroundBlockParams(baseSubround) + if err != nil { + return nil, err + } + + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + + srBlock := subroundBlock{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + worker: worker, + } + + srBlock.Job = srBlock.doBlockJob + srBlock.Check = srBlock.doBlockConsensusCheck + srBlock.Extend = srBlock.worker.Extend + + return &srBlock, nil +} + +func checkNewSubroundBlockParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// doBlockJob method does the job of the subround Block +func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { + isSelfLeader := sr.IsSelfLeader() && sr.ShouldConsiderSelfKeyInConsensus() + if !isSelfLeader { // is NOT self leader in this round? + return false + } + + if sr.RoundHandler().Index() <= sr.getRoundInLastCommittedBlock() { + return false + } + + if sr.IsLeaderJobDone(sr.Current()) { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return false + } + + metricStatTime := time.Now() + defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricCreatedProposedBlock) + + header, err := sr.createHeader() + if err != nil { + printLogMessage(ctx, "doBlockJob.createHeader", err) + return false + } + + header, body, err := sr.createBlock(header) + if err != nil { + printLogMessage(ctx, "doBlockJob.createBlock", err) + return false + } + + // This must be done after createBlock, in order to have the proper epoch set + wasProofAdded := sr.addProofOnHeader(header) + if !wasProofAdded { + return false + } + + // block proof verification should be done over the header that contains the leader signature + leaderSignature, err := sr.signBlockHeader(header) + if err != nil { + printLogMessage(ctx, "doBlockJob.signBlockHeader", err) + return false + } + + err = header.SetLeaderSignature(leaderSignature) + if err != nil { + printLogMessage(ctx, "doBlockJob.SetLeaderSignature", err) + return false + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("doBlockJob.GetLeader", "error", errGetLeader) + return false + } + + sentWithSuccess := sr.sendBlock(header, body, leader) + if !sentWithSuccess { + return false + } + + err = sr.SetJobDone(leader, sr.Current(), true) + if err != nil { + log.Debug("doBlockJob.SetSelfJobDone", "error", err.Error()) + return false + } + + // placeholder for subroundBlock.doBlockJob script + + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.GetRoundTimeStamp()) + + return true +} + +func (sr *subroundBlock) signBlockHeader(header data.HeaderHandler) ([]byte, error) { + headerClone := header.ShallowClone() + err := headerClone.SetLeaderSignature(nil) + if err != nil { + return nil, err + } + + marshalledHdr, err := sr.Marshalizer().Marshal(headerClone) + if err != nil { + return nil, err + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + return sr.SigningHandler().CreateSignatureForPublicKey(marshalledHdr, []byte(leader)) +} + +func printLogMessage(ctx context.Context, baseMessage string, err error) { + if common.IsContextDone(ctx) { + log.Debug(baseMessage + " context is closing") + return + } + + log.Debug(baseMessage, "error", err.Error()) +} + +func (sr *subroundBlock) sendBlock(header data.HeaderHandler, body data.BodyHandler, _ string) bool { + marshalledBody, err := sr.Marshalizer().Marshal(body) + if err != nil { + log.Debug("sendBlock.Marshal: body", "error", err.Error()) + return false + } + + marshalledHeader, err := sr.Marshalizer().Marshal(header) + if err != nil { + log.Debug("sendBlock.Marshal: header", "error", err.Error()) + return false + } + + sr.logBlockSize(marshalledBody, marshalledHeader) + if !sr.sendBlockBody(body, marshalledBody) || !sr.sendBlockHeader(header, marshalledHeader) { + return false + } + + return true +} + +func (sr *subroundBlock) logBlockSize(marshalledBody []byte, marshalledHeader []byte) { + bodyAndHeaderSize := uint32(len(marshalledBody) + len(marshalledHeader)) + log.Debug("logBlockSize", + "body size", len(marshalledBody), + "header size", len(marshalledHeader), + "body and header size", bodyAndHeaderSize, + "max allowed size in bytes", maxAllowedSizeInBytes) +} + +func (sr *subroundBlock) createBlock(header data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { + startTime := sr.GetRoundTimeStamp() + maxTime := time.Duration(sr.EndTime()) + haveTimeInCurrentSubround := func() bool { + return sr.RoundHandler().RemainingTime(startTime, maxTime) > 0 + } + + finalHeader, blockBody, err := sr.BlockProcessor().CreateBlock( + header, + haveTimeInCurrentSubround, + ) + if err != nil { + return nil, nil, err + } + + return finalHeader, blockBody, nil +} + +// sendBlockBody method sends the proposed block body in the subround Block +func (sr *subroundBlock) sendBlockBody( + bodyHandler data.BodyHandler, + marshalizedBody []byte, +) bool { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockBody.GetLeader", "error", errGetLeader) + return false + } + + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + marshalizedBody, + nil, + []byte(leader), + nil, + int(bls.MtBlockBody), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid([]byte(leader)), + nil, + ) + + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("sendBlockBody.BroadcastConsensusMessage", "error", err.Error()) + return false + } + + log.Debug("step 1: block body has been sent") + + sr.SetBody(bodyHandler) + + return true +} + +// sendBlockHeader method sends the proposed block header in the subround Block +func (sr *subroundBlock) sendBlockHeader( + headerHandler data.HeaderHandler, + marshalledHeader []byte, +) bool { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockHeader.GetLeader", "error", errGetLeader) + return false + } + + err := sr.BroadcastMessenger().BroadcastHeader(headerHandler, []byte(leader)) + if err != nil { + log.Warn("sendBlockHeader.BroadcastHeader", "error", err.Error()) + return false + } + + headerHash := sr.Hasher().Compute(string(marshalledHeader)) + + log.Debug("step 1: block header has been sent", + "nonce", headerHandler.GetNonce(), + "hash", headerHash) + + sr.SetData(headerHash) + sr.SetHeader(headerHandler) + + return true +} + +func (sr *subroundBlock) getPrevHeaderAndHash() (data.HeaderHandler, []byte) { + prevHeader := sr.Blockchain().GetCurrentBlockHeader() + prevHeaderHash := sr.Blockchain().GetCurrentBlockHeaderHash() + if check.IfNil(prevHeader) { + prevHeader = sr.Blockchain().GetGenesisHeader() + prevHeaderHash = sr.Blockchain().GetGenesisHeaderHash() + } + + return prevHeader, prevHeaderHash +} + +func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { + prevHeader, prevHash := sr.getPrevHeaderAndHash() + nonce := prevHeader.GetNonce() + 1 + prevRandSeed := prevHeader.GetRandSeed() + + round := uint64(sr.RoundHandler().Index()) + hdr, err := sr.BlockProcessor().CreateNewHeader(round, nonce) + if err != nil { + return nil, err + } + + err = hdr.SetPrevHash(prevHash) + if err != nil { + return nil, err + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + randSeed, err := sr.SigningHandler().CreateSignatureForPublicKey(prevRandSeed, []byte(leader)) + if err != nil { + return nil, err + } + + err = hdr.SetShardID(sr.ShardCoordinator().SelfId()) + if err != nil { + return nil, err + } + + err = hdr.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + if err != nil { + return nil, err + } + + err = hdr.SetPrevRandSeed(prevRandSeed) + if err != nil { + return nil, err + } + + err = hdr.SetRandSeed(randSeed) + if err != nil { + return nil, err + } + + err = hdr.SetChainID(sr.ChainID()) + if err != nil { + return nil, err + } + + return hdr, nil +} + +func (sr *subroundBlock) addProofOnHeader(header data.HeaderHandler) bool { + prevBlockProof, err := sr.EquivalentProofsPool().GetProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash()) + if err != nil { + // for the first block after activation we won't add the proof + // TODO: fix this on verifications as well + return common.IsEpochChangeBlockForFlagActivation(header, sr.EnableEpochsHandler(), common.EquivalentMessagesFlag) + } + + if !isProofEmpty(prevBlockProof) { + header.SetPreviousProof(prevBlockProof) + return true + } + + hash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), header) + if err != nil { + hash = []byte("") + } + + log.Debug("addProofOnHeader: no proof found", "header hash", hex.EncodeToString(hash)) + + return false +} + +func isProofEmpty(proof data.HeaderProofHandler) bool { + return len(proof.GetAggregatedSignature()) == 0 || + len(proof.GetPubKeysBitmap()) == 0 || + len(proof.GetHeaderHash()) == 0 +} + +func (sr *subroundBlock) saveProofForPreviousHeaderIfNeeded(header data.HeaderHandler) { + hasProof := sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), header.GetPrevHash()) + if hasProof { + log.Debug("saveProofForPreviousHeaderIfNeeded: no need to set proof since it is already saved") + return + } + + proof := header.GetPreviousProof() + err := sr.EquivalentProofsPool().AddProof(proof) + if err != nil { + log.Debug("saveProofForPreviousHeaderIfNeeded: failed to add proof, %w", err) + return + } +} + +// receivedBlockBody method is called when a block body is received through the block body channel +func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensus.Message) bool { + node := string(cnsDta.PubKey) + + if !sr.IsNodeLeaderInCurrentRound(node) { // is NOT this node leader in current round? + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyDecreaseFactor, + ) + + return false + } + + if sr.IsBlockBodyAlreadyReceived() { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) + + if check.IfNil(sr.GetBody()) { + return false + } + + log.Debug("step 1: block body has been received") + + blockProcessedWithSuccess := sr.processReceivedBlock(ctx, cnsDta.RoundIndex, cnsDta.PubKey) + + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) + + return blockProcessedWithSuccess +} + +func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) bool { + if check.IfNil(header) { + return false + } + if header.GetShardID() != sr.ShardCoordinator().SelfId() { + return false + } + if header.GetRound() != uint64(sr.RoundHandler().Index()) { + return false + } + + prevHeader, prevHash := sr.getPrevHeaderAndHash() + if check.IfNil(prevHeader) { + return false + } + if !bytes.Equal(header.GetPrevHash(), prevHash) { + return false + } + if header.GetNonce() != prevHeader.GetNonce()+1 { + return false + } + prevRandSeed := prevHeader.GetRandSeed() + + return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed) +} + +func (sr *subroundBlock) getLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) { + nc := sr.NodesCoordinator() + leader, _, err := nc.ComputeConsensusGroup( + headerHandler.GetPrevRandSeed(), + headerHandler.GetRound(), + headerHandler.GetShardID(), + headerHandler.GetEpoch(), + ) + if err != nil { + return nil, err + } + + return leader.PubKey(), err +} + +func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { + if check.IfNil(headerHandler) { + return + } + + if headerHandler.CheckFieldsForNil() != nil { + return + } + + if !sr.isHeaderForCurrentConsensus(headerHandler) { + return + } + + isLeader := sr.IsSelfLeader() + if sr.ConsensusGroup() == nil || isLeader { + return + } + + if sr.IsConsensusDataSet() { + return + } + + headerLeader, err := sr.getLeaderForHeader(headerHandler) + if err != nil { + return + } + + if !sr.IsNodeLeaderInCurrentRound(string(headerLeader)) { + sr.PeerHonestyHandler().ChangeScore( + string(headerLeader), + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyDecreaseFactor, + ) + + return + } + + if sr.IsHeaderAlreadyReceived() { + return + } + + if !sr.CanProcessReceivedHeader(string(headerLeader)) { + return + } + + marshalledHeader, err := sr.Marshalizer().Marshal(headerHandler) + if err != nil { + return + } + + sr.SetData(sr.Hasher().Compute(string(marshalledHeader))) + sr.SetHeader(headerHandler) + + sr.saveProofForPreviousHeaderIfNeeded(headerHandler) + + log.Debug("step 1: block header has been received", + "nonce", sr.GetHeader().GetNonce(), + "hash", sr.GetData()) + + sr.AddReceivedHeader(headerHandler) + + ctx, cancel := context.WithTimeout(context.Background(), sr.RoundHandler().TimeDuration()) + defer cancel() + + _ = sr.processReceivedBlock(ctx, int64(headerHandler.GetRound()), []byte(sr.Leader())) + sr.PeerHonestyHandler().ChangeScore( + sr.Leader(), + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) +} + +// CanProcessReceivedHeader method returns true if the received header can be processed and false otherwise +func (sr *subroundBlock) CanProcessReceivedHeader(headerLeader string) bool { + if sr.IsNodeSelf(headerLeader) { + return false + } + + if sr.IsJobDone(headerLeader, sr.Current()) { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return false + } + + return true +} + +func (sr *subroundBlock) processReceivedBlock( + ctx context.Context, + round int64, + senderPK []byte, +) bool { + if check.IfNil(sr.GetBody()) { + return false + } + if check.IfNil(sr.GetHeader()) { + return false + } + + defer func() { + sr.SetProcessingBlock(false) + }() + + sr.SetProcessingBlock(true) + + shouldNotProcessBlock := sr.GetExtendedCalled() || round < sr.RoundHandler().Index() + if shouldNotProcessBlock { + log.Debug("canceled round, extended has been called or round index has been changed", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "cnsDta round", round, + "extended called", sr.GetExtendedCalled(), + ) + return false + } + + return sr.processBlock(ctx, round, senderPK) +} + +func (sr *subroundBlock) processBlock( + ctx context.Context, + roundIndex int64, + pubkey []byte, +) bool { + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + remainingTimeInCurrentRound := func() time.Duration { + return sr.RoundHandler().RemainingTime(startTime, maxTime) + } + + metricStatTime := time.Now() + defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricProcessedProposedBlock) + + err := sr.BlockProcessor().ProcessBlock( + sr.GetHeader(), + sr.GetBody(), + remainingTimeInCurrentRound, + ) + + if roundIndex < sr.RoundHandler().Index() { + log.Debug("canceled round, round index has been changed", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "cnsDta round", roundIndex, + ) + return false + } + + if err != nil { + sr.printCancelRoundLogMessage(ctx, err) + sr.SetRoundCanceled(true) + + return false + } + + node := string(pubkey) + err = sr.SetJobDone(node, sr.Current(), true) + if err != nil { + sr.printCancelRoundLogMessage(ctx, err) + return false + } + + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.GetHeader(), sr.GetBody(), sr.GetRoundTimeStamp()) + + return true +} + +func (sr *subroundBlock) printCancelRoundLogMessage(ctx context.Context, err error) { + if common.IsContextDone(ctx) { + log.Debug("canceled round as the context is closing") + return + } + + log.Debug("canceled round", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "error", err.Error()) +} + +func (sr *subroundBlock) computeSubroundProcessingMetric(startTime time.Time, metric string) { + subRoundDuration := sr.EndTime() - sr.StartTime() + if subRoundDuration == 0 { + // can not do division by 0 + return + } + + percent := uint64(time.Since(startTime)) * 100 / uint64(subRoundDuration) + sr.AppStatusHandler().SetUInt64Value(metric, percent) +} + +// doBlockConsensusCheck method checks if the consensus in the subround Block is achieved +func (sr *subroundBlock) doBlockConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + threshold := sr.Threshold(sr.Current()) + if sr.isBlockReceived(threshold) { + log.Debug("step 1: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + return true + } + + return false +} + +// isBlockReceived method checks if the block was received from the leader in the current round +func (sr *subroundBlock) isBlockReceived(threshold int) bool { + n := 0 + + for i := 0; i < len(sr.ConsensusGroup()); i++ { + node := sr.ConsensusGroup()[i] + isJobDone, err := sr.JobDone(node, sr.Current()) + if err != nil { + log.Debug("isBlockReceived.JobDone", + "node", node, + "subround", sr.Name(), + "error", err.Error()) + continue + } + + if isJobDone { + n++ + } + } + + return n >= threshold +} + +func (sr *subroundBlock) getRoundInLastCommittedBlock() int64 { + roundInLastCommittedBlock := int64(0) + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if !check.IfNil(currentHeader) { + roundInLastCommittedBlock = int64(currentHeader.GetRound()) + } + + return roundInLastCommittedBlock +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundBlock) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundBlock_test.go b/consensus/spos/bls/v2/subroundBlock_test.go new file mode 100644 index 00000000000..d22d5e2f1ca --- /dev/null +++ b/consensus/spos/bls/v2/subroundBlock_test.go @@ -0,0 +1,1197 @@ +package v2_test + +import ( + "errors" + "fmt" + "math/big" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var expectedErr = errors.New("expected error") + +func defaultSubroundForSRBlock(consensusState *spos.ConsensusState, ch chan bool, + container *consensusMocks.ConsensusCoreMock, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { + return spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) +} + +func createDefaultHeader() *block.Header { + return &block.Header{ + Nonce: 1, + PrevHash: []byte("prev hash"), + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + RootHash: []byte("roothash"), + TxCount: 0, + ChainID: []byte("chain ID"), + SoftwareVersion: []byte("software version"), + AccumulatedFees: big.NewInt(0), + DeveloperFees: big.NewInt(0), + } +} + +func defaultSubroundBlockFromSubround(sr *spos.Subround) (v2.SubroundBlock, error) { + srBlock, err := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + + return srBlock, err +} + +func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) v2.SubroundBlock { + srBlock, _ := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + + return srBlock +} + +func initSubroundBlock( + blockChain data.ChainHandler, + container *consensusMocks.ConsensusCoreMock, + appStatusHandler core.AppStatusHandler, +) v2.SubroundBlock { + if blockChain == nil { + blockChain = &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + RandSeed: []byte{0}, + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + } + + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + ch := make(chan bool, 1) + + container.SetBlockchain(blockChain) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, appStatusHandler) + srBlock, _ := defaultSubroundBlockFromSubround(sr) + return srBlock +} + +func createConsensusContainers() []*consensusMocks.ConsensusCoreMock { + consensusContainers := make([]*consensusMocks.ConsensusCoreMock, 0) + container := consensusMocks.InitConsensusCore() + consensusContainers = append(consensusContainers, container) + container = consensusMocks.InitConsensusCoreHeaderV2() + consensusContainers = append(consensusContainers, container) + return consensusContainers +} + +func initSubroundBlockWithBlockProcessor( + bp *testscommon.BlockProcessorStub, + container *consensusMocks.ConsensusCoreMock, +) v2.SubroundBlock { + blockChain := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + blockProcessorMock := bp + + container.SetBlockchain(blockChain) + container.SetBlockProcessor(blockProcessorMock) + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock, _ := defaultSubroundBlockFromSubround(sr) + return srBlock +} + +func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { + t.Parallel() + + srBlock, err := v2.NewSubroundBlock( + nil, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilSubround, err) +} + +func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetBlockchain(nil) + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetBlockProcessor(nil) + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + sr.ConsensusStateHandler = nil + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetHasher(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetMarshalizer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilMarshalizer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetMultiSignerContainer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetShardCoordinator(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilShardCoordinator, err) +} + +func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetSyncTimer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilWorkerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + srBlock, err := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + nil, + ) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilWorker, err) +} + +func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.NotNil(t, srBlock) + assert.Nil(t, err) +} + +func TestSubroundBlock_DoBlockJob(t *testing.T) { + t.Parallel() + + t.Run("not leader should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("round index lower than last committed block should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("leader job done should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("subround finished should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, false) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("create header error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + bpm := &testscommon.BlockProcessorStub{} + + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return nil, expectedErr + } + container.SetBlockProcessor(bpm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("create block error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + bpm := &testscommon.BlockProcessorStub{} + bpm.CreateBlockCalled = func(header data.HeaderHandler, remainingTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + return header, nil, expectedErr + } + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &block.Header{}, nil + } + container.SetBlockProcessor(bpm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("send block error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + container.SetBlockProcessor(bpm) + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return expectedErr + }, + } + container.SetBroadcastMessenger(bm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("should work, equivalent messages flag enabled", func(t *testing.T) { + t.Parallel() + + providedSignature := []byte("provided signature") + providedBitmap := []byte("provided bitmap") + providedHash := []byte("provided hash") + providedHeadr := &block.HeaderV2{ + Header: &block.Header{ + Signature: []byte("signature"), + PubKeysBitmap: []byte("bitmap"), + }, + } + + container := consensusMocks.InitConsensusCore() + chainHandler := &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return providedHeadr + }, + GetCurrentBlockHeaderHashCalled: func() []byte { + return providedHash + }, + } + container.SetBlockchain(chainHandler) + + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + ch := make(chan bool, 1) + + baseSr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock, _ := v2.NewSubroundBlock( + baseSr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + sr := *srBlock + + providedLeaderSignature := []byte("leader signature") + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureForPublicKeyCalled: func(message []byte, publicKeyBytes []byte) ([]byte, error) { + return providedLeaderSignature, nil + }, + VerifySignatureShareCalled: func(index uint16, sig []byte, msg []byte, epoch uint32) error { + assert.Fail(t, "should have not been called for leader") + return nil + }, + }) + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + container.SetBlockProcessor(bpm) + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &block.HeaderV2{ + Header: &block.Header{ + Round: round, + Nonce: nonce, + }, + }, nil + } + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return nil + }, + } + container.SetBroadcastMessenger(bm) + container.SetRoundHandler(&consensusMocks.RoundHandlerMock{ + RoundIndex: 1, + }) + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return &block.HeaderProof{ + HeaderHash: headerHash, + AggregatedSignature: providedSignature, + PubKeysBitmap: providedBitmap, + }, nil + }, + }) + + r := sr.DoBlockJob() + assert.True(t, r) + assert.Equal(t, uint64(1), sr.GetHeader().GetNonce()) + + proof := sr.GetHeader().GetPreviousProof() + assert.Equal(t, providedSignature, proof.GetAggregatedSignature()) + assert.Equal(t, providedBitmap, proof.GetPubKeysBitmap()) + }) +} + +func TestSubroundBlock_ReceivedBlock(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetBody(&block.Body{}) + r := sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + sr.SetBody(nil) + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[0]) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAreNotSet(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + nil, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBodyAndHeader), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFails(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + err := errors.New("error process block") + blProcMock.ProcessBlockCalled = func(data.HeaderHandler, data.BodyHandler, func() time.Duration) error { + return err + } + container.SetBlockProcessor(blProcMock) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockReturnsInNextRound(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + blockProcessorMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + blockProcessorMock.ProcessBlockCalled = func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { + return expectedErr + } + container.SetBlockProcessor(blockProcessorMock) + container.SetRoundHandler(&consensusMocks.RoundHandlerMock{RoundIndex: 1}) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { + t.Parallel() + + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + hdr, _ := container.BlockProcessor().CreateNewHeader(1, 1) + hdr, blkBody, _ := container.BlockProcessor().CreateBlock(hdr, func() bool { return true }) + + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + assert.True(t, sr.ProcessReceivedBlock(cnsMsg)) + } +} + +func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + roundHandlerMock := initRoundHandlerMock() + container.SetRoundHandler(roundHandlerMock) + + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + remainingTimeInThisRound := func() time.Duration { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.RoundHandler().TimeDuration()*85/100 - elapsedTime + + return remainingTime + } + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 84 / 100) + }}) + ret := remainingTimeInThisRound() + assert.True(t, ret > 0) + + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 85 / 100) + }}) + ret = remainingTimeInThisRound() + assert.True(t, ret == 0) + + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 86 / 100) + }}) + ret = remainingTimeInThisRound() + assert.True(t, ret < 0) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) + assert.False(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + assert.True(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedReturnTrue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + for i := 0; i < sr.Threshold(bls.SrBlock); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, true) + } + assert.True(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenBlockIsReceivedReturnFalse(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + assert.False(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_IsBlockReceived(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + for i := 0; i < len(sr.ConsensusGroup()); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, false) + } + ok := sr.IsBlockReceived(1) + assert.False(t, ok) + + _ = sr.SetJobDone("A", bls.SrBlock, true) + isJobDone, _ := sr.JobDone("A", bls.SrBlock) + assert.True(t, isJobDone) + + ok = sr.IsBlockReceived(1) + assert.True(t, ok) + + ok = sr.IsBlockReceived(2) + assert.False(t, ok) +} + +func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + haveTimeInCurrentSubound := func() bool { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.EndTime() - int64(elapsedTime) + + return time.Duration(remainingTime) > 0 + } + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + roundHandlerMock.TimeDurationCalled = func() time.Duration { + return 4000 * time.Millisecond + } + roundHandlerMock.TimeStampCalled = func() time.Time { + return time.Unix(0, 0) + } + syncTimerMock := &consensusMocks.SyncTimerMock{} + timeElapsed := sr.EndTime() - 1 + syncTimerMock.CurrentTimeCalled = func() time.Time { + return time.Unix(0, timeElapsed) + } + container.SetRoundHandler(roundHandlerMock) + container.SetSyncTimer(syncTimerMock) + + assert.True(t, haveTimeInCurrentSubound()) +} + +func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + haveTimeInCurrentSubound := func() bool { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.EndTime() - int64(elapsedTime) + + return time.Duration(remainingTime) > 0 + } + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + roundHandlerMock.TimeDurationCalled = func() time.Duration { + return 4000 * time.Millisecond + } + roundHandlerMock.TimeStampCalled = func() time.Time { + return time.Unix(0, 0) + } + syncTimerMock := &consensusMocks.SyncTimerMock{} + timeElapsed := sr.EndTime() + 1 + syncTimerMock.CurrentTimeCalled = func() time.Time { + return time.Unix(0, timeElapsed) + } + container.SetRoundHandler(roundHandlerMock) + container.SetSyncTimer(syncTimerMock) + + assert.False(t, haveTimeInCurrentSubound()) +} + +func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { + blockChain := &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return nil + }, + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + RandSeed: []byte{0}, + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(nil, nil) + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader, _ := container.BlockProcessor().CreateNewHeader(uint64(sr.RoundHandler().Index()), uint64(1)) + err := expectedHeader.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + require.Nil(t, err) + err = expectedHeader.SetRootHash([]byte{}) + require.Nil(t, err) + err = expectedHeader.SetPrevHash(sr.BlockChain().GetGenesisHeaderHash()) + require.Nil(t, err) + err = expectedHeader.SetPrevRandSeed(sr.BlockChain().GetGenesisHeader().GetRandSeed()) + require.Nil(t, err) + err = expectedHeader.SetRandSeed(make([]byte, 0)) + require.Nil(t, err) + err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) + require.Nil(t, err) + err = expectedHeader.SetChainID(chainID) + require.Nil(t, err) + require.Equal(t, expectedHeader, header) + } +} + +func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 1, + }, []byte("root hash")) + + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader, _ := container.BlockProcessor().CreateNewHeader( + uint64(sr.RoundHandler().Index()), + sr.BlockChain().GetCurrentBlockHeader().GetNonce()+1) + err := expectedHeader.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + require.Nil(t, err) + err = expectedHeader.SetRootHash([]byte{}) + require.Nil(t, err) + err = expectedHeader.SetPrevHash(sr.BlockChain().GetCurrentBlockHeaderHash()) + require.Nil(t, err) + err = expectedHeader.SetRandSeed(make([]byte, 0)) + require.Nil(t, err) + err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) + require.Nil(t, err) + err = expectedHeader.SetChainID(chainID) + require.Nil(t, err) + require.Equal(t, expectedHeader, header) + } +} + +func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { + mbHeaders := []block.MiniBlockHeader{ + {Hash: []byte("mb1"), SenderShardID: 1, ReceiverShardID: 1}, + {Hash: []byte("mb2"), SenderShardID: 1, ReceiverShardID: 2}, + {Hash: []byte("mb3"), SenderShardID: 2, ReceiverShardID: 3}, + } + blockChainMock := testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: 1, + } + }, + } + container := consensusMocks.InitConsensusCore() + bp := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + shardHeader, _ := header.(*block.Header) + shardHeader.MiniBlockHeaders = mbHeaders + shardHeader.RootHash = []byte{} + + return shardHeader, &block.Body{}, nil + } + sr := initSubroundBlockWithBlockProcessor(bp, container) + container.SetBlockchain(&blockChainMock) + + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader := &block.Header{ + Round: uint64(sr.RoundHandler().Index()), + TimeStamp: uint64(sr.RoundHandler().TimeStamp().Unix()), + RootHash: []byte{}, + Nonce: sr.BlockChain().GetCurrentBlockHeader().GetNonce() + 1, + PrevHash: sr.BlockChain().GetCurrentBlockHeaderHash(), + RandSeed: make([]byte, 0), + MiniBlockHeaders: mbHeaders, + ChainID: chainID, + } + + assert.Equal(t, expectedHeader, header) +} + +func TestSubroundBlock_CreateHeaderNilMiniBlocks(t *testing.T) { + expectedErr := errors.New("nil mini blocks") + container := consensusMocks.InitConsensusCore() + bp := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + return nil, nil, expectedErr + } + sr := initSubroundBlockWithBlockProcessor(bp, container) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 1, + }, []byte("root hash")) + header, _ := sr.CreateHeader() + _, _, err := sr.CreateBlock(header) + assert.Equal(t, expectedErr, err) +} + +func TestSubroundBlock_CallFuncRemainingTimeWithStructShouldWork(t *testing.T) { + roundStartTime := time.Now() + maxTime := 100 * time.Millisecond + newRoundStartTime := roundStartTime + remainingTimeInCurrentRound := func() time.Duration { + return RemainingTimeWithStruct(newRoundStartTime, maxTime) + } + assert.True(t, remainingTimeInCurrentRound() > 0) + + time.Sleep(200 * time.Millisecond) + assert.True(t, remainingTimeInCurrentRound() < 0) +} + +func TestSubroundBlock_CallFuncRemainingTimeWithStructShouldNotWork(t *testing.T) { + roundStartTime := time.Now() + maxTime := 100 * time.Millisecond + remainingTimeInCurrentRound := func() time.Duration { + return RemainingTimeWithStruct(roundStartTime, maxTime) + } + assert.True(t, remainingTimeInCurrentRound() > 0) + + time.Sleep(200 * time.Millisecond) + assert.True(t, remainingTimeInCurrentRound() < 0) + + roundStartTime = roundStartTime.Add(500 * time.Millisecond) + assert.False(t, remainingTimeInCurrentRound() < 0) +} + +func RemainingTimeWithStruct(startTime time.Time, maxTime time.Duration) time.Duration { + currentTime := time.Now() + elapsedTime := currentTime.Sub(startTime) + remainingTime := maxTime - elapsedTime + return remainingTime +} + +func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { + t.Parallel() + + srStartTime := int64(5 * roundTimeDuration / 100) + srEndTime := int64(25 * roundTimeDuration / 100) + srDuration := srEndTime - srStartTime + delay := srDuration * 430 / 1000 + + container := consensusMocks.InitConsensusCore() + receivedValue := uint64(0) + container.SetBlockProcessor(&testscommon.BlockProcessorStub{ + ProcessBlockCalled: func(_ data.HeaderHandler, _ data.BodyHandler, _ func() time.Duration) error { + time.Sleep(time.Duration(delay)) + return nil + }, + }) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ + SetUInt64ValueHandler: func(key string, value uint64) { + receivedValue = value + }}) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + + minimumExpectedValue := uint64(delay * 100 / srDuration) + _ = sr.ProcessReceivedBlock(cnsMsg) + + assert.True(t, + receivedValue >= minimumExpectedValue, + fmt.Sprintf("minimum expected was %d, got %d", minimumExpectedValue, receivedValue), + ) +} + +func TestSubroundBlock_ReceivedBlockComputeProcessDurationWithZeroDurationShouldNotPanic(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, "should not have paniced", r) + } + }() + + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock := defaultSubroundBlockWithoutErrorFromSubround(sr) + + srBlock.ComputeSubroundProcessingMetric(time.Now(), "dummy") +} + +func TestSubroundBlock_ReceivedBlockHeader(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + // nil header + sr.ReceivedBlockHeader(nil) + + // flag not active + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + + container.SetEnableEpochsHandler(&enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }) + + // nil fields on header + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{ + CheckFieldsForNilCalled: func() error { + return expectedErr + }, + }) + + // leader + defaultLeader := sr.Leader() + sr.SetLeader(sr.SelfPubKey()) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + sr.SetLeader(defaultLeader) + + // consensus data already set + sr.SetData([]byte("some data")) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + sr.SetData(nil) + + // header already received + sr.SetHeader(&testscommon.HeaderHandlerStub{}) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + sr.SetHeader(nil) + + // self job already done + _ = sr.SetJobDone(sr.SelfPubKey(), sr.Current(), true) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + _ = sr.SetJobDone(sr.SelfPubKey(), sr.Current(), false) + + // subround already finished + sr.SetStatus(sr.Current(), spos.SsFinished) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + sr.SetStatus(sr.Current(), spos.SsNotFinished) + + // marshal error + container.SetMarshalizer(&testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + }) + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + container.SetMarshalizer(&testscommon.MarshallerStub{}) + + // should work + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) +} diff --git a/consensus/spos/bls/v2/subroundEndRound.go b/consensus/spos/bls/v2/subroundEndRound.go new file mode 100644 index 00000000000..b5e6440685f --- /dev/null +++ b/consensus/spos/bls/v2/subroundEndRound.go @@ -0,0 +1,883 @@ +package v2 + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "math/bits" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/display" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/process/headerCheck" +) + +const timeBetweenSignaturesChecks = time.Millisecond * 5 + +type subroundEndRound struct { + *spos.Subround + processingThresholdPercentage int + appStatusHandler core.AppStatusHandler + mutProcessingEndRound sync.Mutex + sentSignatureTracker spos.SentSignaturesTracker + worker spos.WorkerHandler + signatureThrottler core.Throttler +} + +// NewSubroundEndRound creates a subroundEndRound object +func NewSubroundEndRound( + baseSubround *spos.Subround, + processingThresholdPercentage int, + appStatusHandler core.AppStatusHandler, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, + signatureThrottler core.Throttler, +) (*subroundEndRound, error) { + err := checkNewSubroundEndRoundParams(baseSubround) + if err != nil { + return nil, err + } + if check.IfNil(appStatusHandler) { + return nil, spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + if check.IfNil(signatureThrottler) { + return nil, spos.ErrNilThrottler + } + + srEndRound := subroundEndRound{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + appStatusHandler: appStatusHandler, + mutProcessingEndRound: sync.Mutex{}, + sentSignatureTracker: sentSignatureTracker, + worker: worker, + signatureThrottler: signatureThrottler, + } + srEndRound.Job = srEndRound.doEndRoundJob + srEndRound.Check = srEndRound.doEndRoundConsensusCheck + srEndRound.Extend = worker.Extend + + return &srEndRound, nil +} + +func checkNewSubroundEndRoundParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +func (sr *subroundEndRound) isProofForCurrentConsensus(proof consensus.ProofHandler) bool { + return bytes.Equal(sr.GetData(), proof.GetHeaderHash()) +} + +// receivedProof method is called when a block header final info is received +func (sr *subroundEndRound) receivedProof(proof consensus.ProofHandler) { + sr.mutProcessingEndRound.Lock() + defer sr.mutProcessingEndRound.Unlock() + + if sr.IsJobDone(sr.SelfPubKey(), sr.Current()) { + return + } + if !sr.IsConsensusDataSet() { + return + } + if check.IfNil(sr.GetHeader()) { + return + } + if !sr.isProofForCurrentConsensus(proof) { + return + } + + // no need to re-verify the proof since it was already verified when it was added to the proofs pool + log.Debug("step 3: block header final info has been received", + "PubKeysBitmap", proof.GetPubKeysBitmap(), + "AggregateSignature", proof.GetAggregatedSignature(), + "HederHash", proof.GetHeaderHash()) + + sr.doEndRoundJobByNode() +} + +// receivedInvalidSignersInfo method is called when a message with invalid signers has been received +func (sr *subroundEndRound) receivedInvalidSignersInfo(_ context.Context, cnsDta *consensus.Message) bool { + messageSender := string(cnsDta.PubKey) + + if !sr.IsConsensusDataSet() { + return false + } + if check.IfNil(sr.GetHeader()) { + return false + } + + isSelfSender := sr.IsNodeSelf(messageSender) || sr.IsKeyManagedBySelf([]byte(messageSender)) + if isSelfSender { + return false + } + + if !sr.IsConsensusDataEqual(cnsDta.BlockHeaderHash) { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + if len(cnsDta.InvalidSigners) == 0 { + return false + } + + err := sr.verifyInvalidSigners(cnsDta.InvalidSigners) + if err != nil { + log.Trace("receivedInvalidSignersInfo.verifyInvalidSigners", "error", err.Error()) + return false + } + + log.Debug("step 3: invalid signers info has been evaluated") + + sr.PeerHonestyHandler().ChangeScore( + messageSender, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) + + return true +} + +func (sr *subroundEndRound) verifyInvalidSigners(invalidSigners []byte) error { + messages, err := sr.MessageSigningHandler().Deserialize(invalidSigners) + if err != nil { + return err + } + + for _, msg := range messages { + err = sr.verifyInvalidSigner(msg) + if err != nil { + return err + } + } + + return nil +} + +func (sr *subroundEndRound) verifyInvalidSigner(msg p2p.MessageP2P) error { + err := sr.MessageSigningHandler().Verify(msg) + if err != nil { + return err + } + + cnsMsg := &consensus.Message{} + err = sr.Marshalizer().Unmarshal(cnsMsg, msg.Data()) + if err != nil { + return err + } + + err = sr.SigningHandler().VerifySingleSignature(cnsMsg.PubKey, cnsMsg.BlockHeaderHash, cnsMsg.SignatureShare) + if err != nil { + log.Debug("verifyInvalidSigner: confirmed that node provided invalid signature", + "pubKey", cnsMsg.PubKey, + "blockHeaderHash", cnsMsg.BlockHeaderHash, + "error", err.Error(), + ) + sr.applyBlacklistOnNode(msg.Peer()) + } + + return nil +} + +func (sr *subroundEndRound) applyBlacklistOnNode(peer core.PeerID) { + sr.PeerBlacklistHandler().BlacklistPeer(peer, common.InvalidSigningBlacklistDuration) +} + +// doEndRoundJob method does the job of the subround EndRound +func (sr *subroundEndRound) doEndRoundJob(_ context.Context) bool { + if check.IfNil(sr.GetHeader()) { + return false + } + + sr.mutProcessingEndRound.Lock() + defer sr.mutProcessingEndRound.Unlock() + + return sr.doEndRoundJobByNode() +} + +func (sr *subroundEndRound) commitBlock() error { + startTime := time.Now() + err := sr.BlockProcessor().CommitBlock(sr.GetHeader(), sr.GetBody()) + elapsedTime := time.Since(startTime) + if elapsedTime >= common.CommitMaxTime { + log.Warn("doEndRoundJobByNode.CommitBlock", "elapsed time", elapsedTime) + } else { + log.Debug("elapsed time to commit block", "time [s]", elapsedTime) + } + if err != nil { + log.Debug("doEndRoundJobByNode.CommitBlock", "error", err) + return err + } + + return nil +} + +func (sr *subroundEndRound) doEndRoundJobByNode() bool { + if sr.shouldSendProof() { + if !sr.waitForSignalSync() { + return false + } + } + + proof, ok := sr.sendProof() + if !ok { + return false + } + + err := sr.commitBlock() + if err != nil { + return false + } + + // if proof not nil, it was created and broadcasted so it has to be added to the pool + if proof != nil { + err = sr.EquivalentProofsPool().AddProof(proof) + if err != nil { + log.Debug("doEndRoundJobByNode.AddProof", "error", err) + return false + } + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + + sr.worker.DisplayStatistics() + + log.Debug("step 3: Body and Header have been committed") + + msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", sr.GetHeader().GetNonce()) + log.Debug(display.Headline(msg, sr.SyncTimer().FormattedCurrentTime(), "+")) + + sr.updateMetricsForLeader() + + return true +} + +func (sr *subroundEndRound) sendProof() (data.HeaderProofHandler, bool) { + if !sr.shouldSendProof() { + return nil, true + } + + bitmap := sr.GenerateBitmap(bls.SrSignature) + err := sr.checkSignaturesValidity(bitmap) + if err != nil { + log.Debug("sendProof.checkSignaturesValidity", "error", err.Error()) + return nil, false + } + + // Aggregate signatures, handle invalid signers and send final info if needed + bitmap, sig, err := sr.aggregateSigsAndHandleInvalidSigners(bitmap) + if err != nil { + log.Debug("sendProof.aggregateSigsAndHandleInvalidSigners", "error", err.Error()) + return nil, false + } + + ok := sr.ScheduledProcessor().IsProcessedOKWithTimeout() + // placeholder for subroundEndRound.doEndRoundJobByLeader script + if !ok { + return nil, false + } + + roundHandler := sr.RoundHandler() + if roundHandler.RemainingTime(roundHandler.TimeStamp(), roundHandler.TimeDuration()) < 0 { + log.Debug("sendProof: time is out -> cancel broadcasting final info and header", + "round time stamp", roundHandler.TimeStamp(), + "current time", time.Now()) + return nil, false + } + + // broadcast header proof + proof, err := sr.createAndBroadcastProof(sig, bitmap) + return proof, err == nil +} + +func (sr *subroundEndRound) shouldSendProof() bool { + if sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) { + log.Debug("shouldSendProof: equivalent message already processed") + return false + } + + return true +} + +func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) ([]byte, []byte, error) { + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + log.Debug("doEndRoundJobByNode.AggregateSigs", "error", err.Error()) + + return sr.handleInvalidSignersOnAggSigFail() + } + + err = sr.SigningHandler().SetAggregatedSig(sig) + if err != nil { + log.Debug("doEndRoundJobByNode.SetAggregatedSig", "error", err.Error()) + return nil, nil, err + } + + // the header (hash) verified here is with leader signature on it + err = sr.SigningHandler().Verify(sr.GetData(), bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + log.Debug("doEndRoundJobByNode.Verify", "error", err.Error()) + + return sr.handleInvalidSignersOnAggSigFail() + } + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) checkGoRoutinesThrottler(ctx context.Context) error { + for { + if sr.signatureThrottler.CanProcess() { + break + } + + select { + case <-time.After(time.Millisecond): + continue + case <-ctx.Done(): + return spos.ErrTimeIsOut + } + } + return nil +} + +// verifySignature implements parallel signature verification +func (sr *subroundEndRound) verifySignature(i int, pk string, sigShare []byte) error { + err := sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.GetHeader().GetEpoch()) + if err != nil { + log.Trace("VerifySignatureShare returned an error: ", err) + errSetJob := sr.SetJobDone(pk, bls.SrSignature, false) + if errSetJob != nil { + return errSetJob + } + + decreaseFactor := -spos.ValidatorPeerHonestyIncreaseFactor + spos.ValidatorPeerHonestyDecreaseFactor + + sr.PeerHonestyHandler().ChangeScore( + pk, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + decreaseFactor, + ) + return err + } + + log.Trace("verifyNodesOnAggSigVerificationFail: verifying signature share", "public key", pk) + + return nil +} + +func (sr *subroundEndRound) verifyNodesOnAggSigFail(ctx context.Context) ([]string, error) { + wg := &sync.WaitGroup{} + mutex := &sync.Mutex{} + invalidPubKeys := make([]string, 0) + pubKeys := sr.ConsensusGroup() + + if check.IfNil(sr.GetHeader()) { + return nil, spos.ErrNilHeader + } + + for i, pk := range pubKeys { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + if err != nil || !isJobDone { + continue + } + + sigShare, err := sr.SigningHandler().SignatureShare(uint16(i)) + if err != nil { + return nil, err + } + + err = sr.checkGoRoutinesThrottler(ctx) + if err != nil { + return nil, err + } + + sr.signatureThrottler.StartProcessing() + + wg.Add(1) + + go func(i int, pk string, wg *sync.WaitGroup, sigShare []byte) { + defer func() { + sr.signatureThrottler.EndProcessing() + wg.Done() + }() + errSigVerification := sr.verifySignature(i, pk, sigShare) + if errSigVerification != nil { + mutex.Lock() + invalidPubKeys = append(invalidPubKeys, pk) + mutex.Unlock() + } + }(i, pk, wg, sigShare) + } + wg.Wait() + + return invalidPubKeys, nil +} + +func (sr *subroundEndRound) getFullMessagesForInvalidSigners(invalidPubKeys []string) ([]byte, error) { + p2pMessages := make([]p2p.MessageP2P, 0) + + for _, pk := range invalidPubKeys { + p2pMsg, ok := sr.GetMessageWithSignature(pk) + if !ok { + log.Trace("message not found in state for invalid signer", "pubkey", pk) + continue + } + + p2pMessages = append(p2pMessages, p2pMsg) + } + + invalidSigners, err := sr.MessageSigningHandler().Serialize(p2pMessages) + if err != nil { + return nil, err + } + + return invalidSigners, nil +} + +func (sr *subroundEndRound) handleInvalidSignersOnAggSigFail() ([]byte, []byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), sr.RoundHandler().TimeDuration()) + invalidPubKeys, err := sr.verifyNodesOnAggSigFail(ctx) + cancel() + if err != nil { + log.Debug("doEndRoundJobByNode.verifyNodesOnAggSigFail", "error", err.Error()) + return nil, nil, err + } + + _, err = sr.getFullMessagesForInvalidSigners(invalidPubKeys) + if err != nil { + log.Debug("doEndRoundJobByNode.getFullMessagesForInvalidSigners", "error", err.Error()) + return nil, nil, err + } + + // TODO: handle invalid signers broadcast without flooding the network + // if len(invalidSigners) > 0 { + // sr.createAndBroadcastInvalidSigners(invalidSigners) + // } + + bitmap, sig, err := sr.computeAggSigOnValidNodes() + if err != nil { + log.Debug("doEndRoundJobByNode.computeAggSigOnValidNodes", "error", err.Error()) + return nil, nil, err + } + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) { + threshold := sr.Threshold(bls.SrSignature) + numValidSigShares := sr.ComputeSize(bls.SrSignature) + + if check.IfNil(sr.GetHeader()) { + return nil, nil, spos.ErrNilHeader + } + + if numValidSigShares < threshold { + return nil, nil, fmt.Errorf("%w: number of valid sig shares lower than threshold, numSigShares: %d, threshold: %d", + spos.ErrInvalidNumSigShares, numValidSigShares, threshold) + } + + bitmap := sr.GenerateBitmap(bls.SrSignature) + err := sr.checkSignaturesValidity(bitmap) + if err != nil { + return nil, nil, err + } + + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + return nil, nil, err + } + + err = sr.SigningHandler().SetAggregatedSig(sig) + if err != nil { + return nil, nil, err + } + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) createAndBroadcastProof(signature []byte, bitmap []byte) (*block.HeaderProof, error) { + headerProof := &block.HeaderProof{ + PubKeysBitmap: bitmap, + AggregatedSignature: signature, + HeaderHash: sr.GetData(), + HeaderEpoch: sr.GetHeader().GetEpoch(), + HeaderNonce: sr.GetHeader().GetNonce(), + HeaderShardId: sr.GetHeader().GetShardID(), + } + + err := sr.BroadcastMessenger().BroadcastEquivalentProof(headerProof, []byte(sr.SelfPubKey())) + if err != nil { + return nil, err + } + + log.Debug("step 3: block header proof has been sent", + "PubKeysBitmap", bitmap, + "AggregateSignature", signature) + + return headerProof, nil +} + +func (sr *subroundEndRound) createAndBroadcastInvalidSigners(invalidSigners []byte) { + if !sr.ShouldConsiderSelfKeyInConsensus() { + return + } + + sender, err := sr.GetLeader() + if err != nil { + log.Debug("createAndBroadcastInvalidSigners.getSender", "error", err) + return + } + + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + nil, + nil, + nil, + []byte(sender), + nil, + int(bls.MtInvalidSigners), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid([]byte(sender)), + invalidSigners, + ) + + err = sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("doEndRoundJob.BroadcastConsensusMessage", "error", err.Error()) + return + } + + log.Debug("step 3: invalid signers info has been sent") +} + +func (sr *subroundEndRound) updateMetricsForLeader() { + // TODO: decide if we keep these metrics the same way + sr.appStatusHandler.Increment(common.MetricCountAcceptedBlocks) + sr.appStatusHandler.SetStringValue(common.MetricConsensusRoundState, + fmt.Sprintf("valid block produced in %f sec", time.Since(sr.RoundHandler().TimeStamp()).Seconds())) +} + +// doEndRoundConsensusCheck method checks if the consensus is achieved +func (sr *subroundEndRound) doEndRoundConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + return sr.IsSubroundFinished(sr.Current()) +} + +// IsBitmapInvalid checks if the bitmap is valid +// TODO: remove duplicated code and use the header sig verifier instead +func (sr *subroundEndRound) IsBitmapInvalid(bitmap []byte, consensusPubKeys []string) error { + consensusSize := len(consensusPubKeys) + + expectedBitmapSize := consensusSize / 8 + if consensusSize%8 != 0 { + expectedBitmapSize++ + } + if len(bitmap) != expectedBitmapSize { + log.Debug("wrong size bitmap", + "expected number of bytes", expectedBitmapSize, + "actual", len(bitmap)) + return ErrWrongSizeBitmap + } + + numOfOnesInBitmap := 0 + for index := range bitmap { + numOfOnesInBitmap += bits.OnesCount8(bitmap[index]) + } + + minNumRequiredSignatures := core.GetPBFTThreshold(consensusSize) + if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) { + minNumRequiredSignatures = core.GetPBFTFallbackThreshold(consensusSize) + log.Warn("HeaderSigVerifier.verifyConsensusSize: fallback validation has been applied", + "minimum number of signatures required", minNumRequiredSignatures, + "actual number of signatures in bitmap", numOfOnesInBitmap, + ) + } + + if numOfOnesInBitmap >= minNumRequiredSignatures { + return nil + } + + log.Debug("not enough signatures", + "minimum expected", minNumRequiredSignatures, + "actual", numOfOnesInBitmap) + + return ErrNotEnoughSignatures +} + +func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { + consensusGroup := sr.ConsensusGroup() + err := sr.IsBitmapInvalid(bitmap, consensusGroup) + if err != nil { + return err + } + + signers := headerCheck.ComputeSignersPublicKeys(consensusGroup, bitmap) + for _, pubKey := range signers { + isSigJobDone, err := sr.JobDone(pubKey, bls.SrSignature) + if err != nil { + return err + } + + if !isSigJobDone { + return spos.ErrNilSignature + } + } + + return nil +} + +func (sr *subroundEndRound) isOutOfTime() bool { + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { + log.Debug("canceled round, time is out", + "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), + "subround", sr.Name()) + + sr.SetRoundCanceled(true) + return true + } + + return false +} + +func (sr *subroundEndRound) getMinConsensusGroupIndexOfManagedKeys() int { + minIdx := sr.ConsensusGroupSize() + + for idx, validator := range sr.ConsensusGroup() { + if !sr.IsKeyManagedBySelf([]byte(validator)) { + continue + } + + if idx < minIdx { + minIdx = idx + } + } + + return minIdx +} + +func (sr *subroundEndRound) waitForSignalSync() bool { + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if sr.checkReceivedSignatures() { + return true + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go sr.waitSignatures(ctx) + timerBetweenStatusChecks := time.NewTimer(timeBetweenSignaturesChecks) + + remainingSRTime := sr.remainingTime() + timeout := time.NewTimer(remainingSRTime) + for { + select { + case <-timerBetweenStatusChecks.C: + if sr.IsSubroundFinished(sr.Current()) { + log.Trace("subround already finished", "subround", sr.Name()) + return true + } + + if sr.checkReceivedSignatures() { + return true + } + timerBetweenStatusChecks.Reset(timeBetweenSignaturesChecks) + case <-timeout.C: + log.Debug("timeout while waiting for signatures or final info", "subround", sr.Name()) + return false + } + } +} + +func (sr *subroundEndRound) waitSignatures(ctx context.Context) { + remainingTime := sr.remainingTime() + if sr.IsSubroundFinished(sr.Current()) { + return + } + sr.SetWaitingAllSignaturesTimeOut(true) + + select { + case <-time.After(remainingTime): + case <-ctx.Done(): + } + sr.ConsensusChannel() <- true +} + +// maximum time to wait for signatures +func (sr *subroundEndRound) remainingTime() time.Duration { + startTime := sr.RoundHandler().TimeStamp() + maxTime := time.Duration(float64(sr.StartTime()) + float64(sr.EndTime()-sr.StartTime())*waitingAllSigsMaxTimeThreshold) + remainingTime := sr.RoundHandler().RemainingTime(startTime, maxTime) + + return remainingTime +} + +// receivedSignature method is called when a signature is received through the signature channel. +// If the signature is valid, then the jobDone map corresponding to the node which sent it, +// is set on true for the subround Signature +func (sr *subroundEndRound) receivedSignature(_ context.Context, cnsDta *consensus.Message) bool { + node := string(cnsDta.PubKey) + pkForLogs := core.GetTrimmedPk(hex.EncodeToString(cnsDta.PubKey)) + + if !sr.IsConsensusDataSet() { + return false + } + + if !sr.IsNodeInConsensusGroup(node) { + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.ValidatorPeerHonestyDecreaseFactor, + ) + + return false + } + + if !sr.IsConsensusDataEqual(cnsDta.BlockHeaderHash) { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + index, err := sr.ConsensusGroupIndex(node) + if err != nil { + log.Debug("receivedSignature.ConsensusGroupIndex", + "node", pkForLogs, + "error", err.Error()) + return false + } + + err = sr.SigningHandler().StoreSignatureShare(uint16(index), cnsDta.SignatureShare) + if err != nil { + log.Debug("receivedSignature.StoreSignatureShare", + "node", pkForLogs, + "index", index, + "error", err.Error()) + return false + } + + err = sr.SetJobDone(node, bls.SrSignature, true) + if err != nil { + log.Debug("receivedSignature.SetJobDone", + "node", pkForLogs, + "subround", sr.Name(), + "error", err.Error()) + return false + } + + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.ValidatorPeerHonestyIncreaseFactor, + ) + + return true +} + +func (sr *subroundEndRound) checkReceivedSignatures() bool { + threshold := sr.Threshold(bls.SrSignature) + if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) { + threshold = sr.FallbackThreshold(bls.SrSignature) + log.Warn("subroundEndRound.checkReceivedSignatures: fallback validation has been applied", + "minimum number of signatures required", threshold, + "actual number of signatures received", sr.getNumOfSignaturesCollected(), + ) + } + + areSignaturesCollected, numSigs := sr.areSignaturesCollected(threshold) + areAllSignaturesCollected := numSigs == sr.ConsensusGroupSize() + + isSignatureCollectionDone := areAllSignaturesCollected || (areSignaturesCollected && sr.GetWaitingAllSignaturesTimeOut()) + + isSelfJobDone := sr.IsSelfJobDone(bls.SrSignature) + + shouldStopWaitingSignatures := isSelfJobDone && isSignatureCollectionDone + if shouldStopWaitingSignatures { + log.Debug("step 2: signatures collection done", + "subround", sr.Name(), + "signatures received", numSigs, + "total signatures", len(sr.ConsensusGroup())) + + return true + } + + return false +} + +func (sr *subroundEndRound) getNumOfSignaturesCollected() int { + n := 0 + + for i := 0; i < len(sr.ConsensusGroup()); i++ { + node := sr.ConsensusGroup()[i] + + isSignJobDone, err := sr.JobDone(node, bls.SrSignature) + if err != nil { + log.Debug("getNumOfSignaturesCollected.JobDone", + "node", node, + "subround", sr.Name(), + "error", err.Error()) + continue + } + + if isSignJobDone { + n++ + } + } + + return n +} + +// areSignaturesCollected method checks if the signatures received from the nodes, belonging to the current +// jobDone group, are more than the necessary given threshold +func (sr *subroundEndRound) areSignaturesCollected(threshold int) (bool, int) { + n := sr.getNumOfSignaturesCollected() + return n >= threshold, n +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundEndRound) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundEndRound_test.go b/consensus/spos/bls/v2/subroundEndRound_test.go new file mode 100644 index 00000000000..1db112cfff5 --- /dev/null +++ b/consensus/spos/bls/v2/subroundEndRound_test.go @@ -0,0 +1,1936 @@ +package v2_test + +import ( + "bytes" + "context" + "errors" + "math/big" + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/p2p/factory" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func initSubroundEndRoundWithContainer( + container *consensusMocks.ConsensusCoreMock, + appStatusHandler core.AppStatusHandler, +) v2.SubroundEndRound { + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + appStatusHandler, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + return srEndRound +} + +func initSubroundEndRoundWithContainerAndConsensusState( + container *consensusMocks.ConsensusCoreMock, + appStatusHandler core.AppStatusHandler, + consensusState *spos.ConsensusState, + signatureThrottler core.Throttler, +) v2.SubroundEndRound { + ch := make(chan bool, 1) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + appStatusHandler, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + signatureThrottler, + ) + + return srEndRound +} + +func initSubroundEndRound(appStatusHandler core.AppStatusHandler) v2.SubroundEndRound { + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, appStatusHandler) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + return sr +} + +func TestNewSubroundEndRound(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + nil, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + nil, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilWorker, err) + }) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetBlockchain(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetBlockProcessor(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + sr.ConsensusStateHandler = nil + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetMultiSignerContainer(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetRoundHandler(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetSyncTimer(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilThrottlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + nil, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, err, spos.ErrNilThrottler) +} + +func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.False(t, check.IfNil(srEndRound)) + assert.Nil(t, err) +} + +func TestSubroundEndRound_DoEndRoundJobNilHeaderShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + return nil, crypto.ErrNilHasher + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + + sr.SetSelfPubKey("A") + + assert.True(t, sr.IsSelfLeader()) + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + blProcMock.CommitBlockCalled = func( + header data.HeaderHandler, + body data.BodyHandler, + ) error { + return blockchain.ErrHeaderUnitNil + } + + container.SetBlockProcessor(blProcMock) + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + remainingTime := -time.Millisecond + roundHandlerMock := &consensusMocks.RoundHandlerMock{ + RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + return remainingTime + }, + } + + container.SetRoundHandler(roundHandlerMock) + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobAllOK(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + sr.SetHeader(&block.Header{}) + + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + r := sr.DoEndRoundJob() + assert.True(t, r) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) + + ok := sr.DoEndRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetStatus(bls.SrEndRound, spos.SsFinished) + + ok := sr.DoEndRoundConsensusCheck() + assert.True(t, ok) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNotFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + ok := sr.DoEndRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + bitmap := make([]byte, len(sr.ConsensusGroup())/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + err := sr.CheckSignaturesValidity(bitmap) + + assert.Equal(t, spos.ErrNilSignature, err) +} + +func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + bitmap := make([]byte, len(sr.ConsensusGroup())/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + + err := sr.CheckSignaturesValidity(bitmap) + require.Nil(t, err) +} + +func TestSubroundEndRound_CreateAndBroadcastProofShouldBeCalled(t *testing.T) { + t.Parallel() + + chanRcv := make(chan bool, 1) + leaderSigInHdr := []byte("leader sig") + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + chanRcv <- true + return nil + }, + } + container.SetBroadcastMessenger(messenger) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{LeaderSignature: leaderSigInHdr}) + sr.CreateAndBroadcastProof([]byte("sig"), []byte("bitmap")) + + select { + case <-chanRcv: + case <-time.After(100 * time.Millisecond): + assert.Fail(t, "broadcast not called") + } +} + +func TestSubroundEndRound_ReceivedProof(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + hdr := &block.Header{Nonce: 37} + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) + sr.AddReceivedHeader(hdr) + + sr.SetStatus(2, spos.SsFinished) + sr.SetStatus(3, spos.SsNotFinished) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should work with equivalent messages flag on", func(t *testing.T) { + t.Parallel() + + providedPrevSig := []byte("prev sig") + providedPrevBitmap := []byte{1, 1, 1, 1} + hdr := &block.HeaderV2{ + Header: createDefaultHeader(), + ScheduledRootHash: []byte("sch root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + PreviousHeaderProof: nil, + } + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + container.SetBlockchain(&testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.HeaderV2{} + }, + }) + + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + assert.Equal(t, hdr.GetPrevHash(), headerHash) + return &block.HeaderProof{ + HeaderHash: headerHash, + AggregatedSignature: providedPrevSig, + PubKeysBitmap: providedPrevBitmap, + }, nil + }, + }) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetHeader(hdr) + srEndRound.AddReceivedHeader(hdr) + + srEndRound.SetStatus(2, spos.SsFinished) + srEndRound.SetStatus(3, spos.SsNotFinished) + + proof := &block.HeaderProof{} + srEndRound.ReceivedProof(proof) + }) + t.Run("should return false when header is nil", func(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + proof := &block.HeaderProof{} + + sr.ReceivedProof(proof) + }) + t.Run("should return false when final info is not valid", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ + VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { + return errors.New("error") + }, + VerifySignatureCalled: func(header data.HeaderHandler) error { + return errors.New("error") + }, + } + + container.SetHeaderSigVerifier(headerSigVerifier) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should return false when consensus data is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetData(nil) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should return false when sender is not in consensus group", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should return false when sender is self", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should return false when different data is received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetData([]byte("Y")) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should return true when final info already received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + proof := &block.HeaderProof{} + srEndRound.ReceivedProof(proof) + }) +} + +func TestSubroundEndRound_IsOutOfTimeShouldReturnFalse(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + res := sr.IsOutOfTime() + assert.False(t, res) +} + +func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { + t.Parallel() + + // update roundHandler's mock, so it will calculate for real the duration + container := consensusMocks.InitConsensusCore() + roundHandler := consensusMocks.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + currentTime := time.Now() + elapsedTime := currentTime.Sub(startTime) + remainingTime := maxTime - elapsedTime + + return remainingTime + }} + container.SetRoundHandler(&roundHandler) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + sr.SetRoundTimeStamp(time.Now().AddDate(0, 0, -1)) + + res := sr.IsOutOfTime() + assert.True(t, res) +} + +func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { + t.Parallel() + + t.Run("fail to get signature share", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, expectedErr + }, + } + + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + require.Nil(t, err) + _ = sr.SetJobDone(leader, bls.SrSignature, true) + + _, err = sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Equal(t, expectedErr, err) + }) + + t.Run("fail to verify signature share, job done will be set to false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return expectedErr + }, + } + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + require.Nil(t, err) + _ = sr.SetJobDone(leader, bls.SrSignature, true) + container.SetSigningHandler(signingHandler) + _, err = sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Nil(t, err) + + isJobDone, err := sr.JobDone(leader, bls.SrSignature) + require.Nil(t, err) + require.False(t, isJobDone) + }) + + t.Run("fail to verify signature share, an element will return an error on SignatureShare, should not panic", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + if index < 8 { + return nil, nil + } + return nil, expectedErr + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + time.Sleep(100 * time.Millisecond) + return expectedErr + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + return nil + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[2], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[3], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[4], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[5], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[6], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[7], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[8], bls.SrSignature, true) + go func() { + defer func() { + if r := recover(); r != nil { + t.Error("Should not panic") + } + }() + invalidSigners, err := sr.VerifyNodesOnAggSigFail(context.TODO()) + time.Sleep(200 * time.Millisecond) + require.Equal(t, err, expectedErr) + require.Nil(t, invalidSigners) + }() + time.Sleep(time.Second) + + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return nil + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + return nil + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) + invalidSigners, err := sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Nil(t, err) + require.NotNil(t, invalidSigners) + }) +} + +func TestComputeAddSigOnValidNodes(t *testing.T) { + t.Parallel() + + t.Run("invalid number of valid sig shares", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + sr.SetThreshold(bls.SrEndRound, 2) + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.True(t, errors.Is(err, spos.ErrInvalidNumSigShares)) + }) + + t.Run("fail to created aggregated sig", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + return nil, expectedErr + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.Equal(t, expectedErr, err) + }) + + t.Run("fail to set aggregated sig", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SetAggregatedSigCalled: func(_ []byte) error { + return expectedErr + }, + } + container.SetSigningHandler(signingHandler) + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.Equal(t, expectedErr, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + bitmap, sig, err := sr.ComputeAggSigOnValidNodes() + require.NotNil(t, bitmap) + require.NotNil(t, sig) + require.Nil(t, err) + }) +} + +func TestSubroundEndRound_DoEndRoundJobByNode(t *testing.T) { + t.Parallel() + + t.Run("equivalent messages flag enabled and message already received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + wasHasEquivalentProofCalled := false + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + wasHasEquivalentProofCalled = true + return true + }, + }) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetThreshold(bls.SrSignature, 2) + + for _, participant := range srEndRound.ConsensusGroup() { + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + } + + r := srEndRound.DoEndRoundJobByNode() + require.True(t, r) + require.True(t, wasHasEquivalentProofCalled) + }) + + t.Run("should work without equivalent messages flag active", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + verifySigShareNumCalls := 0 + mutex := &sync.Mutex{} + verifyFirstCall := true + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + mutex.Lock() + defer mutex.Unlock() + if verifySigShareNumCalls == 0 { + verifySigShareNumCalls++ + return expectedErr + } + + verifySigShareNumCalls++ + return nil + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + mutex.Lock() + defer mutex.Unlock() + if verifyFirstCall { + verifyFirstCall = false + return expectedErr + } + + return nil + }, + } + + container.SetSigningHandler(signingHandler) + + sr.SetThreshold(bls.SrEndRound, 2) + + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJobByNode() + require.True(t, r) + + assert.False(t, verifyFirstCall) + assert.Equal(t, 9, verifySigShareNumCalls) + }) + t.Run("should work with equivalent messages flag active", func(t *testing.T) { + t.Parallel() + + providedPrevSig := []byte("prev sig") + providedPrevBitmap := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1} + container := consensusMocks.InitConsensusCore() + container.SetBlockchain(&testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.HeaderV2{} + }, + }) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + wasSetCurrentHeaderProofCalled := false + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + AddProofCalled: func(headerProof data.HeaderProofHandler) error { + wasSetCurrentHeaderProofCalled = true + require.NotEqual(t, providedPrevSig, headerProof.GetAggregatedSignature()) + require.NotEqual(t, providedPrevBitmap, headerProof.GetPubKeysBitmap()) + return nil + }, + }) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetThreshold(bls.SrEndRound, 2) + + for _, participant := range srEndRound.ConsensusGroup() { + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + } + + srEndRound.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + ScheduledRootHash: []byte("sch root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + PreviousHeaderProof: nil, + }) + + r := srEndRound.DoEndRoundJobByNode() + require.True(t, r) + require.True(t, wasSetCurrentHeaderProofCalled) + }) +} + +func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { + t.Parallel() + + t.Run("consensus data is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.ConsensusStateHandler.SetData(nil) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("consensus header is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("received message node is not leader in current round", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("other node"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received message from self leader should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received message from self multikey leader should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return string(pkBytes) == "A" + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetSelfPubKey("A") + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := srEndRound.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received hash does not match the hash from current consensus state", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("Y"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("process received message verification failed, different round index", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + RoundIndex: 1, + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("empty invalid signers", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte{}, + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("invalid signers data", func(t *testing.T) { + t.Parallel() + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + return nil, expectedErr + }, + } + + container := consensusMocks.InitConsensusCore() + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte("invalid data"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte("invalidSignersData"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.True(t, res) + }) +} + +func TestVerifyInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("failed to deserialize invalidSigners field, should error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + expectedErr := errors.New("expected err") + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + return nil, expectedErr + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + err := sr.VerifyInvalidSigners([]byte{}) + require.Equal(t, expectedErr, err) + }) + + t.Run("failed to verify low level p2p message, should error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + expectedErr := errors.New("expected err") + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + require.Equal(t, invalidSignersBytes, messagesBytes) + return invalidSigners, nil + }, + VerifyCalled: func(message p2p.MessageP2P) error { + return expectedErr + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Equal(t, expectedErr, err) + }) + + t.Run("failed to verify signature share", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + pubKey := []byte("A") // it's in consensus + + consensusMsg := &consensus.Message{ + PubKey: pubKey, + } + consensusMsgBytes, _ := container.Marshalizer().Marshal(consensusMsg) + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + DataField: consensusMsgBytes, + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + require.Equal(t, invalidSignersBytes, messagesBytes) + return invalidSigners, nil + }, + } + + wasCalled := false + signingHandler := &consensusMocks.SigningHandlerStub{ + VerifySingleSignatureCalled: func(publicKeyBytes []byte, message []byte, signature []byte) error { + wasCalled = true + return errors.New("expected err") + }, + } + + container.SetSigningHandler(signingHandler) + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Nil(t, err) + require.True(t, wasCalled) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + pubKey := []byte("A") // it's in consensus + + consensusMsg := &consensus.Message{ + PubKey: pubKey, + } + consensusMsgBytes, _ := container.Marshalizer().Marshal(consensusMsg) + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + DataField: consensusMsgBytes, + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + messageSigningHandler := &mock.MessageSignerMock{} + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Nil(t, err) + }) +} + +func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("redundancy node should not send while main is active", func(t *testing.T) { + t.Parallel() + + expectedInvalidSigners := []byte("invalid signers") + + container := consensusMocks.InitConsensusCore() + nodeRedundancy := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(nodeRedundancy) + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + assert.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(messenger) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + wg := &sync.WaitGroup{} + wg.Add(1) + + expectedInvalidSigners := []byte("invalid signers") + + wasCalled := false + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + assert.Equal(t, expectedInvalidSigners, message.InvalidSigners) + wasCalled = true + wg.Done() + return nil + }, + } + container.SetBroadcastMessenger(messenger) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) + + wg.Wait() + + require.True(t, wasCalled) + }) +} + +func TestGetFullMessagesForInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("empty p2p messages slice if not in state", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { + require.Equal(t, 0, len(messages)) + + return []byte{}, nil + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + invalidSigners := []string{"B", "C"} + + invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) + require.Nil(t, err) + require.Equal(t, []byte{}, invalidSignersBytes) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + expectedInvalidSigners := []byte("expectedInvalidSigners") + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { + require.Equal(t, 2, len(messages)) + + return expectedInvalidSigners, nil + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.AddMessageWithSignature("B", &p2pmocks.P2PMessageMock{}) + sr.AddMessageWithSignature("C", &p2pmocks.P2PMessageMock{}) + + invalidSigners := []string{"B", "C"} + + invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) + require.Nil(t, err) + require.Equal(t, expectedInvalidSigners, invalidSignersBytes) + }) +} + +func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + t.Run("no managed keys from consensus group", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + + assert.Equal(t, 9, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("first managed key in consensus group should return 0", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("A"), pkBytes) + } + + assert.Equal(t, 0, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("third managed key in consensus group should return 2", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("C"), pkBytes) + } + + assert.Equal(t, 2, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("last managed key in consensus group should return 8", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("I"), pkBytes) + } + + assert.Equal(t, 8, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) +} + +func TestSubroundSignature_ReceivedSignature(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + signature := []byte("signature") + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signature, + nil, + nil, + []byte(sr.ConsensusGroup()[1]), + []byte("sig"), + int(bls.MtSignature), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + + sr.SetHeader(&block.Header{}) + sr.SetData(nil) + r := sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("Y")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("X")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) + + cnsMsg.PubKey = []byte("X") + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + maxCount := len(sr.ConsensusGroup()) * 2 / 3 + count := 0 + for i := 0; i < len(sr.ConsensusGroup()); i++ { + if sr.ConsensusGroup()[i] != string(cnsMsg.PubKey) { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + count++ + if count == maxCount { + break + } + } + } + r = sr.ReceivedSignature(cnsMsg) + assert.True(t, r) +} + +func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { + t.Parallel() + + errStore := errors.New("signature share store failed") + storeSigShareCalled := false + signingHandler := &consensusMocks.SigningHandlerStub{ + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return nil + }, + StoreSignatureShareCalled: func(index uint16, sig []byte) error { + storeSigShareCalled = true + return errStore + }, + } + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(signingHandler) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + + signature := []byte("signature") + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signature, + nil, + nil, + []byte(sr.ConsensusGroup()[1]), + []byte("sig"), + int(bls.MtSignature), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + + sr.SetData(nil) + r := sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("Y")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("X")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + + cnsMsg.PubKey = []byte("X") + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + maxCount := len(sr.ConsensusGroup()) * 2 / 3 + count := 0 + for i := 0; i < len(sr.ConsensusGroup()); i++ { + if sr.ConsensusGroup()[i] != string(cnsMsg.PubKey) { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + count++ + if count == maxCount { + break + } + } + } + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + assert.True(t, storeSigShareCalled) +} diff --git a/consensus/spos/bls/v2/subroundSignature.go b/consensus/spos/bls/v2/subroundSignature.go new file mode 100644 index 00000000000..3c273437e41 --- /dev/null +++ b/consensus/spos/bls/v2/subroundSignature.go @@ -0,0 +1,306 @@ +package v2 + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + atomicCore "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/multiversx/mx-chain-core-go/core/check" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" +) + +const timeSpentBetweenChecks = time.Millisecond + +type subroundSignature struct { + *spos.Subround + appStatusHandler core.AppStatusHandler + sentSignatureTracker spos.SentSignaturesTracker + signatureThrottler core.Throttler +} + +// NewSubroundSignature creates a subroundSignature object +func NewSubroundSignature( + baseSubround *spos.Subround, + appStatusHandler core.AppStatusHandler, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, + signatureThrottler core.Throttler, +) (*subroundSignature, error) { + err := checkNewSubroundSignatureParams( + baseSubround, + ) + if err != nil { + return nil, err + } + if check.IfNil(appStatusHandler) { + return nil, spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + if check.IfNil(signatureThrottler) { + return nil, spos.ErrNilThrottler + } + + srSignature := subroundSignature{ + Subround: baseSubround, + appStatusHandler: appStatusHandler, + sentSignatureTracker: sentSignatureTracker, + signatureThrottler: signatureThrottler, + } + srSignature.Job = srSignature.doSignatureJob + srSignature.Check = srSignature.doSignatureConsensusCheck + srSignature.Extend = worker.Extend + + return &srSignature, nil +} + +func checkNewSubroundSignatureParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// doSignatureJob method does the job of the subround Signature +func (sr *subroundSignature) doSignatureJob(ctx context.Context) bool { + if !sr.CanDoSubroundJob(sr.Current()) { + return false + } + if check.IfNil(sr.GetHeader()) { + log.Error("doSignatureJob", "error", spos.ErrNilHeader) + return false + } + + isSelfSingleKeyInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) && sr.ShouldConsiderSelfKeyInConsensus() + if isSelfSingleKeyInConsensusGroup { + if !sr.doSignatureJobForSingleKey() { + return false + } + } + + if !sr.doSignatureJobForManagedKeys(ctx) { + return false + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + + return true +} + +func (sr *subroundSignature) createAndSendSignatureMessage(signatureShare []byte, pkBytes []byte) bool { + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signatureShare, + nil, + nil, + pkBytes, + nil, + int(bls.MtSignature), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid(pkBytes), + nil, + ) + + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("createAndSendSignatureMessage.BroadcastConsensusMessage", + "error", err.Error(), "pk", pkBytes) + return false + } + + log.Debug("step 2: signature has been sent", "pk", pkBytes) + + return true +} + +func (sr *subroundSignature) completeSignatureSubRound(pk string) bool { + err := sr.SetJobDone(pk, sr.Current(), true) + if err != nil { + log.Debug("doSignatureJob.SetSelfJobDone", + "subround", sr.Name(), + "error", err.Error(), + "pk", []byte(pk), + ) + return false + } + + return true +} + +// doSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved +func (sr *subroundSignature) doSignatureConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if check.IfNil(sr.GetHeader()) { + return false + } + + isSelfInConsensusGroup := sr.IsSelfInConsensusGroup() + if !isSelfInConsensusGroup { + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + + return true + } + + if sr.IsSelfJobDone(sr.Current()) { + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + sr.appStatusHandler.SetStringValue(common.MetricConsensusRoundState, "signed") + + return true + } + + return false +} + +func (sr *subroundSignature) doSignatureJobForManagedKeys(ctx context.Context) bool { + numMultiKeysSignaturesSent := int32(0) + sentSigForAllKeys := atomicCore.Flag{} + sentSigForAllKeys.SetValue(true) + + wg := sync.WaitGroup{} + + for idx, pk := range sr.ConsensusGroup() { + pkBytes := []byte(pk) + if !sr.IsKeyManagedBySelf(pkBytes) { + continue + } + + if sr.IsJobDone(pk, sr.Current()) { + continue + } + + err := sr.checkGoRoutinesThrottler(ctx) + if err != nil { + return false + } + sr.signatureThrottler.StartProcessing() + wg.Add(1) + + go func(idx int, pk string) { + defer sr.signatureThrottler.EndProcessing() + + signatureSent := sr.sendSignatureForManagedKey(idx, pk) + if signatureSent { + atomic.AddInt32(&numMultiKeysSignaturesSent, 1) + } else { + sentSigForAllKeys.SetValue(false) + } + wg.Done() + }(idx, pk) + } + + wg.Wait() + + if numMultiKeysSignaturesSent > 0 { + log.Debug("step 2: multi keys signatures have been sent", "num", numMultiKeysSignaturesSent) + } + + return sentSigForAllKeys.IsSet() +} + +func (sr *subroundSignature) sendSignatureForManagedKey(idx int, pk string) bool { + pkBytes := []byte(pk) + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( + sr.GetData(), + uint16(idx), + sr.GetHeader().GetEpoch(), + pkBytes, + ) + if err != nil { + log.Debug("sendSignatureForManagedKey.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + // with the equivalent messages feature on, signatures from all managed keys must be broadcast, as the aggregation is done by any participant + ok := sr.createAndSendSignatureMessage(signatureShare, pkBytes) + if !ok { + return false + } + sr.sentSignatureTracker.SignatureSent(pkBytes) + + return sr.completeSignatureSubRound(pk) +} + +func (sr *subroundSignature) checkGoRoutinesThrottler(ctx context.Context) error { + for { + if sr.signatureThrottler.CanProcess() { + break + } + select { + case <-time.After(timeSpentBetweenChecks): + continue + case <-ctx.Done(): + return fmt.Errorf("%w while checking the throttler", spos.ErrTimeIsOut) + } + } + return nil +} + +func (sr *subroundSignature) doSignatureJobForSingleKey() bool { + selfIndex, err := sr.SelfConsensusGroupIndex() + if err != nil { + log.Debug("doSignatureJobForSingleKey.SelfConsensusGroupIndex: not in consensus group") + return false + } + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( + sr.GetData(), + uint16(selfIndex), + sr.GetHeader().GetEpoch(), + []byte(sr.SelfPubKey()), + ) + if err != nil { + log.Debug("doSignatureJobForSingleKey.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + // leader also sends his signature here + ok := sr.createAndSendSignatureMessage(signatureShare, []byte(sr.SelfPubKey())) + if !ok { + return false + } + + return sr.completeSignatureSubRound(sr.SelfPubKey()) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundSignature) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundSignature_test.go b/consensus/spos/bls/v2/subroundSignature_test.go new file mode 100644 index 00000000000..0a7a2ce7ffd --- /dev/null +++ b/consensus/spos/bls/v2/subroundSignature_test.go @@ -0,0 +1,982 @@ +package v2_test + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + dataRetrieverMock "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +const setThresholdJobsDone = "threshold" + +func initSubroundSignatureWithContainer(container *consensusMocks.ConsensusCoreMock) v2.SubroundSignature { + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + return srSignature +} + +func initSubroundSignature() v2.SubroundSignature { + container := consensusMocks.InitConsensusCore() + return initSubroundSignatureWithContainer(container) +} + +func TestNewSubroundSignature(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + nil, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilWorker, err) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + nil, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + + t.Run("nil signatureThrottler should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + nil, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilThrottler, err) + }) +} + +func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + sr.ConsensusStateHandler = nil + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetHasher(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetMultiSignerContainer(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetRoundHandler(nil) + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetSyncTimer(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilAppStatusHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, err := v2.NewSubroundSignature( + sr, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) +} + +func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.False(t, check.IfNil(srSignature)) + assert.Nil(t, err) +} + +func TestSubroundSignature_DoSignatureJob(t *testing.T) { + t.Parallel() + t.Run("with equivalent messages flag active should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + sr := initSubroundSignatureWithContainer(container) + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + if string(message.PubKey) != leader || message.MsgType != int64(bls.MtSignature) { + assert.Fail(t, "should have not been called") + } + return nil + }, + }) + r := sr.DoSignatureJob() + assert.True(t, r) + + assert.False(t, sr.GetRoundCanceled()) + assert.Nil(t, err) + leaderJobDone, err := sr.JobDone(leader, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, leaderJobDone) + assert.True(t, sr.IsSubroundFinished(bls.SrSignature)) + }) +} + +func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { + t.Run("with equivalent messages flag active should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + signingHandler := &consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + } + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + signaturesBroadcast := make(map[string]int) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + mutex.Lock() + signaturesBroadcast[string(message.PubKey)]++ + mutex.Unlock() + return nil + }, + }) + + sr.SetSelfPubKey("OTHER") + + r := srSignature.DoSignatureJob() + assert.True(t, r) + + assert.False(t, sr.GetRoundCanceled()) + assert.True(t, sr.IsSubroundFinished(bls.SrSignature)) + + for _, pk := range sr.ConsensusGroup() { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, isJobDone) + } + + expectedMap := map[string]struct{}{"A": {}, "B": {}, "C": {}, "D": {}, "E": {}, "F": {}, "G": {}, "H": {}, "I": {}} + assert.Equal(t, expectedMap, signatureSentForPks) + + // leader also sends his signature + expectedBroadcastMap := map[string]int{"A": 1, "B": 1, "C": 1, "D": 1, "E": 1, "F": 1, "G": 1, "H": 1, "I": 1} + assert.Equal(t, expectedBroadcastMap, signaturesBroadcast) + }) +} + +func TestSubroundSignature_SendSignature(t *testing.T) { + t.Parallel() + + t.Run("sendSignatureForManagedKey will return false because of error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return make([]byte, 0), expErr + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + r := srSignature.SendSignatureForManagedKey(0, "a") + + assert.False(t, r) + }) + + t.Run("sendSignatureForManagedKey should be false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + }) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return fmt.Errorf("error") + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + r := srSignature.SendSignatureForManagedKey(1, "a") + + assert.False(t, r) + }) + + t.Run("SentSignature should be called", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + }) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return nil + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + varCalled := false + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + varCalled = true + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + _ = srSignature.SendSignatureForManagedKey(1, "a") + + assert.True(t, varCalled) + }) +} + +func TestSubroundSignature_DoSignatureJobForManagedKeys(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + signingHandler := &consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + } + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + signaturesBroadcast := make(map[string]int) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + mutex.Lock() + signaturesBroadcast[string(message.PubKey)]++ + mutex.Unlock() + return nil + }, + }) + + sr.SetSelfPubKey("OTHER") + + r := srSignature.DoSignatureJobForManagedKeys(context.TODO()) + assert.True(t, r) + + for _, pk := range sr.ConsensusGroup() { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, isJobDone) + } + + expectedMap := map[string]struct{}{"A": {}, "B": {}, "C": {}, "D": {}, "E": {}, "F": {}, "G": {}, "H": {}, "I": {}} + assert.Equal(t, expectedMap, signatureSentForPks) + + expectedBroadcastMap := map[string]int{"A": 1, "B": 1, "C": 1, "D": 1, "E": 1, "F": 1, "G": 1, "H": 1, "I": 1} + assert.Equal(t, expectedBroadcastMap, signaturesBroadcast) + }) + + t.Run("should fail", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{ + CanProcessCalled: func() bool { + return false + }, + }, + ) + + sr.SetHeader(&block.Header{}) + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + r := srSignature.DoSignatureJobForManagedKeys(ctx) + assert.False(t, r) + }) +} + +func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetRoundCanceled(true) + assert.False(t, sr.DoSignatureConsensusCheck()) +} + +func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetStatus(bls.SrSignature, spos.SsFinished) + assert.True(t, sr.DoSignatureConsensusCheck()) +} + +func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSignaturesCollectedReturnTrue(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + + for i := 0; i < sr.Threshold(bls.SrSignature); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + } + + sr.SetHeader(&block.HeaderV2{}) + assert.True(t, sr.DoSignatureConsensusCheck()) +} + +func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenSignaturesCollectedReturnFalse(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetHeader(&block.HeaderV2{Header: createDefaultHeader()}) + assert.False(t, sr.DoSignatureConsensusCheck()) +} + +func TestSubroundSignature_DoSignatureConsensusCheckNotAllSignaturesCollectedAndTimeIsNotOut(t *testing.T) { + t.Parallel() + + t.Run("with flag active, should return true", testSubroundSignatureDoSignatureConsensusCheck(argTestSubroundSignatureDoSignatureConsensusCheck{ + flagActive: true, + jobsDone: setThresholdJobsDone, + expectedResult: true, + })) +} + +func TestSubroundSignature_DoSignatureConsensusCheckAllSignaturesCollected(t *testing.T) { + t.Parallel() + t.Run("with flag active, should return true", testSubroundSignatureDoSignatureConsensusCheck(argTestSubroundSignatureDoSignatureConsensusCheck{ + flagActive: true, + jobsDone: "all", + expectedResult: true, + })) +} + +func TestSubroundSignature_DoSignatureConsensusCheckEnoughButNotAllSignaturesCollectedAndTimeIsOut(t *testing.T) { + t.Parallel() + + t.Run("with flag active, should return true", testSubroundSignatureDoSignatureConsensusCheck(argTestSubroundSignatureDoSignatureConsensusCheck{ + flagActive: true, + jobsDone: setThresholdJobsDone, + expectedResult: true, + })) +} + +type argTestSubroundSignatureDoSignatureConsensusCheck struct { + flagActive bool + jobsDone string + expectedResult bool +} + +func testSubroundSignatureDoSignatureConsensusCheck(args argTestSubroundSignatureDoSignatureConsensusCheck) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetEnableEpochsHandler(&enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + if flag == common.EquivalentMessagesFlag { + return args.flagActive + } + return false + }, + }) + sr := initSubroundSignatureWithContainer(container) + + if !args.flagActive { + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + } + + numberOfJobsDone := sr.ConsensusGroupSize() + if args.jobsDone == setThresholdJobsDone { + numberOfJobsDone = sr.Threshold(bls.SrSignature) + } + for i := 0; i < numberOfJobsDone; i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + } + + sr.SetHeader(&block.HeaderV2{}) + assert.Equal(t, args.expectedResult, sr.DoSignatureConsensusCheck()) + } +} diff --git a/consensus/spos/bls/v2/subroundStartRound.go b/consensus/spos/bls/v2/subroundStartRound.go new file mode 100644 index 00000000000..17c4a890ecf --- /dev/null +++ b/consensus/spos/bls/v2/subroundStartRound.go @@ -0,0 +1,359 @@ +package v2 + +import ( + "context" + "encoding/hex" + "fmt" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/outport/disabled" +) + +// subroundStartRound defines the data needed by the subround StartRound +type subroundStartRound struct { + *spos.Subround + processingThresholdPercentage int + + sentSignatureTracker spos.SentSignaturesTracker + worker spos.WorkerHandler + outportHandler outport.OutportHandler + outportMutex sync.RWMutex +} + +// NewSubroundStartRound creates a subroundStartRound object +func NewSubroundStartRound( + baseSubround *spos.Subround, + processingThresholdPercentage int, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, +) (*subroundStartRound, error) { + err := checkNewSubroundStartRoundParams( + baseSubround, + ) + if err != nil { + return nil, err + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + + srStartRound := subroundStartRound{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + sentSignatureTracker: sentSignatureTracker, + worker: worker, + outportHandler: disabled.NewDisabledOutport(), + outportMutex: sync.RWMutex{}, + } + srStartRound.Job = srStartRound.doStartRoundJob + srStartRound.Check = srStartRound.doStartRoundConsensusCheck + srStartRound.Extend = worker.Extend + baseSubround.EpochStartRegistrationHandler().RegisterHandler(&srStartRound) + + return &srStartRound, nil +} + +func checkNewSubroundStartRoundParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// SetOutportHandler method sets outport handler +func (sr *subroundStartRound) SetOutportHandler(outportHandler outport.OutportHandler) error { + if check.IfNil(outportHandler) { + return outport.ErrNilDriver + } + + sr.outportMutex.Lock() + sr.outportHandler = outportHandler + sr.outportMutex.Unlock() + + return nil +} + +// doStartRoundJob method does the job of the subround StartRound +func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { + sr.ResetConsensusState() + sr.SetRoundIndex(sr.RoundHandler().Index()) + sr.SetRoundTimeStamp(sr.RoundHandler().TimeStamp()) + topic := spos.GetConsensusTopicID(sr.ShardCoordinator()) + sr.GetAntiFloodHandler().ResetForTopic(topic) + sr.worker.ResetConsensusMessages() + + return true +} + +// doStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound +func (sr *subroundStartRound) doStartRoundConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if sr.initCurrentRound() { + return true + } + + return false +} + +func (sr *subroundStartRound) initCurrentRound() bool { + nodeState := sr.BootStrapper().GetNodeState() + if nodeState != common.NsSynchronized { // if node is not synchronized yet, it has to continue the bootstrapping mechanism + return false + } + + sr.AppStatusHandler().SetStringValue(common.MetricConsensusRoundState, "") + + err := sr.generateNextConsensusGroup(sr.RoundHandler().Index()) + if err != nil { + log.Debug("initCurrentRound.generateNextConsensusGroup", + "round index", sr.RoundHandler().Index(), + "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + if sr.NodeRedundancyHandler().IsRedundancyNode() { + sr.NodeRedundancyHandler().AdjustInactivityIfNeeded( + sr.SelfPubKey(), + sr.ConsensusGroup(), + sr.RoundHandler().Index(), + ) + // we should not return here, the multikey redundancy system relies on it + // the NodeRedundancyHandler "thinks" it is in redundancy mode even if we use the multikey redundancy system + } + + leader, err := sr.GetLeader() + if err != nil { + log.Debug("initCurrentRound.GetLeader", "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + msg := sr.GetLeaderStartRoundMessage() + if len(msg) != 0 { + sr.AppStatusHandler().Increment(common.MetricCountLeader) + sr.AppStatusHandler().SetStringValue(common.MetricConsensusRoundState, "proposed") + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "proposer") + } + + log.Debug("step 0: preparing the round", + "leader", core.GetTrimmedPk(hex.EncodeToString([]byte(leader))), + "messsage", msg) + sr.sentSignatureTracker.StartRound() + + pubKeys := sr.ConsensusGroup() + numMultiKeysInConsensusGroup := sr.computeNumManagedKeysInConsensusGroup(pubKeys) + if numMultiKeysInConsensusGroup > 0 { + log.Debug("in consensus group with multi keys identities", "num", numMultiKeysInConsensusGroup) + } + + sr.indexRoundIfNeeded(pubKeys) + + if !sr.IsSelfInConsensusGroup() { + log.Debug("not in consensus group") + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "not in consensus group") + } else { + if !sr.IsSelfLeader() { + sr.AppStatusHandler().Increment(common.MetricCountConsensus) + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "participant") + } + } + + err = sr.SigningHandler().Reset(pubKeys) + if err != nil { + log.Debug("initCurrentRound.Reset", "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { + log.Debug("canceled round, time is out", + "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), + "subround", sr.Name()) + + sr.SetRoundCanceled(true) + + return false + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + + // execute stored messages which were received in this new round but before this initialisation + go sr.worker.ExecuteStoredMessages() + + return true +} + +func (sr *subroundStartRound) computeNumManagedKeysInConsensusGroup(pubKeys []string) int { + numMultiKeysInConsensusGroup := 0 + for _, pk := range pubKeys { + pkBytes := []byte(pk) + if sr.IsKeyManagedBySelf(pkBytes) { + numMultiKeysInConsensusGroup++ + log.Trace("in consensus group with multi key", + "pk", core.GetTrimmedPk(hex.EncodeToString(pkBytes))) + } + sr.IncrementRoundsWithoutReceivedMessages(pkBytes) + } + + return numMultiKeysInConsensusGroup +} + +func (sr *subroundStartRound) indexRoundIfNeeded(pubKeys []string) { + sr.outportMutex.RLock() + defer sr.outportMutex.RUnlock() + + if !sr.outportHandler.HasDrivers() { + return + } + + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if check.IfNil(currentHeader) { + currentHeader = sr.Blockchain().GetGenesisHeader() + } + + epoch := currentHeader.GetEpoch() + shardId := sr.ShardCoordinator().SelfId() + nodesCoordinatorShardID, err := sr.NodesCoordinator().ShardIdForEpoch(epoch) + if err != nil { + log.Debug("initCurrentRound.ShardIdForEpoch", + "epoch", epoch, + "error", err.Error()) + return + } + + if shardId != nodesCoordinatorShardID { + log.Debug("initCurrentRound.ShardIdForEpoch", + "epoch", epoch, + "shardCoordinator.ShardID", shardId, + "nodesCoordinator.ShardID", nodesCoordinatorShardID) + return + } + + signersIndexes, err := sr.NodesCoordinator().GetValidatorsIndexes(pubKeys, epoch) + if err != nil { + log.Error(err.Error()) + return + } + + round := sr.RoundHandler().Index() + + roundInfo := &outportcore.RoundInfo{ + Round: uint64(round), + SignersIndexes: signersIndexes, + BlockWasProposed: false, + ShardId: shardId, + Epoch: epoch, + Timestamp: uint64(sr.GetRoundTimeStamp().Unix()), + } + roundsInfo := &outportcore.RoundsInfo{ + ShardID: shardId, + RoundsInfo: []*outportcore.RoundInfo{roundInfo}, + } + sr.outportHandler.SaveRoundsInfo(roundsInfo) +} + +func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error { + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if check.IfNil(currentHeader) { + currentHeader = sr.Blockchain().GetGenesisHeader() + if check.IfNil(currentHeader) { + return spos.ErrNilHeader + } + } + + randomSeed := currentHeader.GetRandSeed() + + log.Debug("random source for the next consensus group", + "rand", randomSeed) + + shardId := sr.ShardCoordinator().SelfId() + + leader, nextConsensusGroup, err := sr.GetNextConsensusGroup( + randomSeed, + uint64(sr.GetRoundIndex()), + shardId, + sr.NodesCoordinator(), + currentHeader.GetEpoch(), + ) + if err != nil { + return err + } + + log.Trace("consensus group is formed by next validators:", + "round", roundIndex) + + for i := 0; i < len(nextConsensusGroup); i++ { + log.Trace(core.GetTrimmedPk(hex.EncodeToString([]byte(nextConsensusGroup[i])))) + } + + sr.SetConsensusGroup(nextConsensusGroup) + sr.SetLeader(leader) + + consensusGroupSizeForEpoch := sr.NodesCoordinator().ConsensusGroupSizeForShardAndEpoch(shardId, currentHeader.GetEpoch()) + sr.SetConsensusGroupSize(consensusGroupSizeForEpoch) + + return nil +} + +// EpochStartPrepare wis called when an epoch start event is observed, but not yet confirmed/committed. +// Some components may need to do initialisation on this event +func (sr *subroundStartRound) EpochStartPrepare(metaHdr data.HeaderHandler, _ data.BodyHandler) { + log.Trace(fmt.Sprintf("epoch %d start prepare in consensus", metaHdr.GetEpoch())) +} + +// EpochStartAction is called upon a start of epoch event. +func (sr *subroundStartRound) EpochStartAction(hdr data.HeaderHandler) { + log.Trace(fmt.Sprintf("epoch %d start action in consensus", hdr.GetEpoch())) + + sr.changeEpoch(hdr.GetEpoch()) +} + +func (sr *subroundStartRound) changeEpoch(currentEpoch uint32) { + epochNodes, err := sr.NodesCoordinator().GetConsensusWhitelistedNodes(currentEpoch) + if err != nil { + panic(fmt.Sprintf("consensus changing epoch failed with error %s", err.Error())) + } + + sr.SetEligibleList(epochNodes) +} + +// NotifyOrder returns the notification order for a start of epoch event +func (sr *subroundStartRound) NotifyOrder() uint32 { + return common.ConsensusStartRoundOrder +} diff --git a/consensus/spos/bls/v2/subroundStartRound_test.go b/consensus/spos/bls/v2/subroundStartRound_test.go new file mode 100644 index 00000000000..28f063277c0 --- /dev/null +++ b/consensus/spos/bls/v2/subroundStartRound_test.go @@ -0,0 +1,1115 @@ +package v2_test + +import ( + "fmt" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/data" + outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + "github.com/stretchr/testify/require" + + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/outport" + + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var expErr = fmt.Errorf("expected error") + +func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (v2.SubroundStartRound, error) { + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return startRound, err +} + +func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) v2.SubroundStartRound { + startRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return startRound +} + +func defaultSubround( + consensusState *spos.ConsensusState, + ch chan bool, + container spos.ConsensusCoreHandler, +) (*spos.Subround, error) { + + return spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(0*roundTimeDuration/100), + int64(5*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) +} + +func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) v2.SubroundStartRound { + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubround(consensusState, ch, container) + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return srStartRound +} + +func initSubroundStartRound() v2.SubroundStartRound { + container := consensus.InitConsensusCore() + return initSubroundStartRoundWithContainer(container) +} + +func TestNewSubroundStartRound(t *testing.T) { + t.Parallel() + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + container := consensus.InitConsensusCore() + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + nil, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + nil, + &consensus.SposWorkerMock{}, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + nil, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilWorker, err) + }) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetBlockchain(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetBootStrapper(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilBootstrapper, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + sr.ConsensusStateHandler = nil + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetMultiSignerContainer(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetRoundHandler(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetSyncTimer(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetValidatorGroupSelector(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilNodesCoordinator, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.NotNil(t, srStartRound) + assert.Nil(t, err) +} + +func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) + + r := srStartRound.DoStartRoundJob() + assert.True(t, r) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + + sr := initSubroundStartRound() + + sr.SetRoundCanceled(true) + + ok := sr.DoStartRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundStartRound() + + sr.SetStatus(bls.SrStartRound, spos.SsFinished) + + ok := sr.DoStartRoundConsensusCheck() + assert.True(t, ok) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCurrentRoundReturnTrue(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + return common.NsSynchronized + }} + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + sr := initSubroundStartRoundWithContainer(container) + sentTrackerInterface := sr.GetSentSignatureTracker() + sentTracker := sentTrackerInterface.(*testscommon.SentSignatureTrackerStub) + startRoundCalled := false + sentTracker.StartRoundCalled = func() { + startRoundCalled = true + } + + ok := sr.DoStartRoundConsensusCheck() + assert.True(t, ok) + assert.True(t, startRoundCalled) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitCurrentRoundReturnFalse(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + return common.NsNotSynchronized + }} + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + container.SetRoundHandler(initRoundHandlerMock()) + + sr := initSubroundStartRoundWithContainer(container) + + ok := sr.DoStartRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetNodeStateNotReturnSynchronized(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + + bootstrapperMock.GetNodeStateCalled = func() common.NodeState { + return common.NsNotSynchronized + } + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGenerateNextConsensusGroupErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, expErr + } + container := consensus.InitConsensusCore() + + container.SetValidatorGroupSelector(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenMainMachineIsActive(t *testing.T) { + t.Parallel() + + nodeRedundancyMock := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + } + container := consensus.InitConsensusCore() + container.SetNodeRedundancyHandler(nodeRedundancyMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + leader := &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { + return []byte("leader") + }} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func( + bytes []byte, + round uint64, + shardId uint32, + epoch uint32, + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + // will cause an error in GetLeader because of empty consensus group + return leader, []nodesCoordinator.Validator{}, nil + } + + container := consensus.InitConsensusCore() + container.SetValidatorGroupSelector(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenIsNotInTheConsensusGroup(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + consensusState := initializers.InitConsensusState() + consensusState.SetSelfPubKey(consensusState.SelfPubKey() + "X") + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *testing.T) { + t.Parallel() + + roundHandlerMock := initRoundHandlerMock() + + roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { + return time.Duration(-1) + } + + container := consensus.InitConsensusCore() + container.SetRoundHandler(roundHandlerMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrue(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + + bootstrapperMock.GetNodeStateCalled = func() common.NodeState { + return common.NsSynchronized + } + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { + t.Parallel() + + t.Run("not in consensus node", func(t *testing.T) { + t.Parallel() + + wasCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "not in consensus group", value) + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("not in consensus") + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + }) + t.Run("main key participant", func(t *testing.T) { + t.Parallel() + + wasCalled := false + wasIncrementCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return string(pkBytes) == "B" + }, + } + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "participant", value) + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountConsensus { + wasIncrementCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("B") + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + assert.True(t, wasIncrementCalled) + }) + t.Run("multi key participant", func(t *testing.T) { + t.Parallel() + + wasCalled := false + wasIncrementCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "participant", value) + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountConsensus { + wasIncrementCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return string(pkBytes) == consensusState.SelfPubKey() + } + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + assert.True(t, wasIncrementCalled) + }) + t.Run("main key leader", func(t *testing.T) { + t.Parallel() + + wasMetricConsensusStateCalled := false + wasMetricCountLeaderCalled := false + cntMetricConsensusRoundStateCalled := 0 + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasMetricConsensusStateCalled = true + assert.Equal(t, "proposer", value) + } + if key == common.MetricConsensusRoundState { + cntMetricConsensusRoundStateCalled++ + switch cntMetricConsensusRoundStateCalled { + case 1: + assert.Equal(t, "", value) + case 2: + assert.Equal(t, "proposed", value) + default: + assert.Fail(t, "should have been called only twice") + } + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountLeader { + wasMetricCountLeaderCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + leader, _ := consensusState.GetLeader() + consensusState.SetSelfPubKey(leader) + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasMetricConsensusStateCalled) + assert.True(t, wasMetricCountLeaderCalled) + assert.Equal(t, 2, cntMetricConsensusRoundStateCalled) + }) + t.Run("managed key leader", func(t *testing.T) { + t.Parallel() + + wasMetricConsensusStateCalled := false + wasMetricCountLeaderCalled := false + cntMetricConsensusRoundStateCalled := 0 + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasMetricConsensusStateCalled = true + assert.Equal(t, "proposer", value) + } + if key == common.MetricConsensusRoundState { + cntMetricConsensusRoundStateCalled++ + switch cntMetricConsensusRoundStateCalled { + case 1: + assert.Equal(t, "", value) + case 2: + assert.Equal(t, "proposed", value) + default: + assert.Fail(t, "should have been called only twice") + } + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountLeader { + wasMetricCountLeaderCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + leader, _ := consensusState.GetLeader() + consensusState.SetSelfPubKey(leader) + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return string(pkBytes) == leader + } + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasMetricConsensusStateCalled) + assert.True(t, wasMetricCountLeaderCalled) + assert.Equal(t, 2, cntMetricConsensusRoundStateCalled) + }) +} + +func buildDefaultSubround(container spos.ConsensusCoreHandler) *spos.Subround { + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + return sr +} + +func TestSubroundStartRound_GenerateNextConsensusGroupShouldErrNilHeader(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + chainHandlerMock := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return nil + }, + } + + container.SetBlockchain(chainHandlerMock) + + sr := buildDefaultSubround(container) + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + err = startRound.GenerateNextConsensusGroup(0) + + assert.Equal(t, spos.ErrNilHeader, err) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenResetErr(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + signingHandlerMock := &consensus.SigningHandlerStub{ + ResetCalled: func(pubKeys []string) error { + return expErr + }, + } + + container.SetSigningHandler(signingHandlerMock) + + sr := buildDefaultSubround(container) + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + r := startRound.InitCurrentRound() + + assert.False(t, r) +} + +func TestSubroundStartRound_IndexRoundIfNeededFailShardIdForEpoch(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + container.SetValidatorGroupSelector( + &shardingMocks.NodesCoordinatorStub{ + ShardIdForEpochCalled: func(epoch uint32) (uint32, error) { + return 0, expErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + require.Fail(t, "SaveRoundsInfo should not be called") + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + +} + +func TestSubroundStartRound_IndexRoundIfNeededFailGetValidatorsIndexes(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + container.SetValidatorGroupSelector( + &shardingMocks.NodesCoordinatorStub{ + GetValidatorsIndexesCalled: func(pubKeys []string, epoch uint32) ([]uint64, error) { + return nil, expErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + require.Fail(t, "SaveRoundsInfo should not be called") + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + +} + +func TestSubroundStartRound_IndexRoundIfNeededShouldFullyWork(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + saveRoundInfoCalled := false + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + saveRoundInfoCalled = true + }}) + + startRound.IndexRoundIfNeeded(pubKeys) + + assert.True(t, saveRoundInfoCalled) + +} + +func TestSubroundStartRound_IndexRoundIfNeededDifferentShardIdFail(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + shardID := 1 + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(shardID) + }, + }) + + container.SetValidatorGroupSelector(&shardingMocks.NodesCoordinatorStub{ + ShardIdForEpochCalled: func(epoch uint32) (uint32, error) { + return 0, nil + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + require.Fail(t, "SaveRoundsInfo should not be called") + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + +} + +func TestSubroundStartRound_changeEpoch(t *testing.T) { + t.Parallel() + + expectPanic := func() { + if recover() == nil { + require.Fail(t, "expected panic") + } + } + + expectNoPanic := func() { + if recover() != nil { + require.Fail(t, "expected no panic") + } + } + + t.Run("error returned by nodes coordinator should error", func(t *testing.T) { + t.Parallel() + + defer expectPanic() + + container := consensus.InitConsensusCore() + exErr := fmt.Errorf("expected error") + container.SetValidatorGroupSelector( + &shardingMocks.NodesCoordinatorStub{ + GetConsensusWhitelistedNodesCalled: func(epoch uint32) (map[string]struct{}, error) { + return nil, exErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + startRound.ChangeEpoch(1) + }) + t.Run("success - no panic", func(t *testing.T) { + t.Parallel() + + defer expectNoPanic() + + container := consensus.InitConsensusCore() + expectedKeys := map[string]struct{}{ + "aaa": {}, + "bbb": {}, + } + + container.SetValidatorGroupSelector( + &shardingMocks.NodesCoordinatorStub{ + GetConsensusWhitelistedNodesCalled: func(epoch uint32) (map[string]struct{}, error) { + return expectedKeys, nil + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + startRound.ChangeEpoch(1) + }) +} + +func TestSubroundStartRound_GenerateNextConsensusGroupShouldReturnErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func( + bytes []byte, + round uint64, + shardId uint32, + epoch uint32, + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, expErr + } + container := consensus.InitConsensusCore() + container.SetValidatorGroupSelector(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + err2 := srStartRound.GenerateNextConsensusGroup(0) + + assert.Equal(t, expErr, err2) +} diff --git a/consensus/spos/consensusCore.go b/consensus/spos/consensusCore.go index 2cf7ca369d6..1f263a0af9d 100644 --- a/consensus/spos/consensusCore.go +++ b/consensus/spos/consensusCore.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -38,6 +39,8 @@ type ConsensusCore struct { messageSigningHandler consensus.P2PSigningHandler peerBlacklistHandler consensus.PeerBlacklistHandler signingHandler consensus.SigningHandler + enableEpochsHandler common.EnableEpochsHandler + equivalentProofsPool consensus.EquivalentProofsPool } // ConsensusCoreArgs store all arguments that are needed to create a ConsensusCore object @@ -64,6 +67,8 @@ type ConsensusCoreArgs struct { MessageSigningHandler consensus.P2PSigningHandler PeerBlacklistHandler consensus.PeerBlacklistHandler SigningHandler consensus.SigningHandler + EnableEpochsHandler common.EnableEpochsHandler + EquivalentProofsPool consensus.EquivalentProofsPool } // NewConsensusCore creates a new ConsensusCore instance @@ -93,6 +98,8 @@ func NewConsensusCore( messageSigningHandler: args.MessageSigningHandler, peerBlacklistHandler: args.PeerBlacklistHandler, signingHandler: args.SigningHandler, + enableEpochsHandler: args.EnableEpochsHandler, + equivalentProofsPool: args.EquivalentProofsPool, } err := ValidateConsensusCore(consensusCore) @@ -213,6 +220,16 @@ func (cc *ConsensusCore) SigningHandler() consensus.SigningHandler { return cc.signingHandler } +// EnableEpochsHandler returns the enable epochs handler component +func (cc *ConsensusCore) EnableEpochsHandler() common.EnableEpochsHandler { + return cc.enableEpochsHandler +} + +// EquivalentProofsPool returns the equivalent proofs component +func (cc *ConsensusCore) EquivalentProofsPool() consensus.EquivalentProofsPool { + return cc.equivalentProofsPool +} + // IsInterfaceNil returns true if there is no value under the interface func (cc *ConsensusCore) IsInterfaceNil() bool { return cc == nil diff --git a/consensus/spos/consensusCoreValidator.go b/consensus/spos/consensusCoreValidator.go index 239c762f6d3..0eee3039007 100644 --- a/consensus/spos/consensusCoreValidator.go +++ b/consensus/spos/consensusCoreValidator.go @@ -74,6 +74,12 @@ func ValidateConsensusCore(container ConsensusCoreHandler) error { if check.IfNil(container.SigningHandler()) { return ErrNilSigningHandler } + if check.IfNil(container.EnableEpochsHandler()) { + return ErrNilEnableEpochsHandler + } + if check.IfNil(container.EquivalentProofsPool()) { + return ErrNilEquivalentProofPool + } return nil } diff --git a/consensus/spos/consensusCoreValidator_test.go b/consensus/spos/consensusCoreValidator_test.go index acdc008cbe8..5594b831311 100644 --- a/consensus/spos/consensusCoreValidator_test.go +++ b/consensus/spos/consensusCoreValidator_test.go @@ -3,31 +3,35 @@ package spos import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) func initConsensusDataContainer() *ConsensusCore { marshalizerMock := mock.MarshalizerMock{} blockChain := &testscommon.ChainHandlerStub{} - blockProcessorMock := mock.InitBlockProcessorMock(marshalizerMock) - bootstrapperMock := &mock.BootstrapperStub{} - broadcastMessengerMock := &mock.BroadcastMessengerMock{} - chronologyHandlerMock := mock.InitChronologyHandlerMock() + blockProcessorMock := consensusMocks.InitBlockProcessorMock(marshalizerMock) + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + broadcastMessengerMock := &consensusMocks.BroadcastMessengerMock{} + chronologyHandlerMock := consensusMocks.InitChronologyHandlerMock() multiSignerMock := cryptoMocks.NewMultiSigner() hasherMock := &hashingMocks.HasherMock{} - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} shardCoordinatorMock := mock.ShardCoordinatorMock{} - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} antifloodHandler := &mock.P2PAntifloodHandlerStub{} peerHonestyHandler := &testscommon.PeerHonestyHandlerStub{} - headerSigVerifier := &mock.HeaderSigVerifierStub{} + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{} fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{} nodeRedundancyHandler := &mock.NodeRedundancyHandlerStub{} scheduledProcessor := &consensusMocks.ScheduledProcessorStub{} @@ -35,6 +39,8 @@ func initConsensusDataContainer() *ConsensusCore { peerBlacklistHandler := &mock.PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSignerMock) signingHandler := &consensusMocks.SigningHandlerStub{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + proofsPool := &dataRetriever.ProofsPoolMock{} return &ConsensusCore{ blockChain: blockChain, @@ -58,6 +64,8 @@ func initConsensusDataContainer() *ConsensusCore { messageSigningHandler: messageSigningHandler, peerBlacklistHandler: peerBlacklistHandler, signingHandler: signingHandler, + enableEpochsHandler: enableEpochsHandler, + equivalentProofsPool: proofsPool, } } @@ -259,6 +267,17 @@ func TestConsensusContainerValidator_ValidateNilSignatureHandlerShouldFail(t *te assert.Equal(t, ErrNilSigningHandler, err) } +func TestConsensusContainerValidator_ValidateNilEnableEpochsHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.enableEpochsHandler = nil + + err := ValidateConsensusCore(container) + + assert.Equal(t, ErrNilEnableEpochsHandler, err) +} + func TestConsensusContainerValidator_ShouldWork(t *testing.T) { t.Parallel() diff --git a/consensus/spos/consensusCore_test.go b/consensus/spos/consensusCore_test.go index 2fd67a2cb63..ef860956152 100644 --- a/consensus/spos/consensusCore_test.go +++ b/consensus/spos/consensusCore_test.go @@ -3,15 +3,15 @@ package spos_test import ( "testing" - "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" - "github.com/stretchr/testify/assert" ) func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { - consensusCoreMock := mock.InitConsensusCore() + consensusCoreMock := consensus.InitConsensusCore() scheduledProcessor := &consensus.ScheduledProcessorStub{} @@ -38,6 +38,8 @@ func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { MessageSigningHandler: consensusCoreMock.MessageSigningHandler(), PeerBlacklistHandler: consensusCoreMock.PeerBlacklistHandler(), SigningHandler: consensusCoreMock.SigningHandler(), + EnableEpochsHandler: consensusCoreMock.EnableEpochsHandler(), + EquivalentProofsPool: consensusCoreMock.EquivalentProofsPool(), } return args } @@ -334,6 +336,20 @@ func TestConsensusCore_WithNilPeerBlacklistHandlerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilPeerBlacklistHandler, err) } +func TestConsensusCore_WithNilEnableEpochsHandlerShouldFail(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusCoreArgs() + args.EnableEpochsHandler = nil + + consensusCore, err := spos.NewConsensusCore( + args, + ) + + assert.Nil(t, consensusCore) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) +} + func TestConsensusCore_CreateConsensusCoreShouldWork(t *testing.T) { t.Parallel() diff --git a/consensus/spos/consensusMessageValidator.go b/consensus/spos/consensusMessageValidator.go index 67fa9616e07..cdcf507cbbf 100644 --- a/consensus/spos/consensusMessageValidator.go +++ b/consensus/spos/consensusMessageValidator.go @@ -7,9 +7,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -17,6 +21,9 @@ type consensusMessageValidator struct { consensusState *ConsensusState consensusService ConsensusService peerSignatureHandler crypto.PeerSignatureHandler + enableEpochsHandler common.EnableEpochsHandler + marshaller marshal.Marshalizer + shardCoordinator sharding.Coordinator signatureSize int publicKeySize int @@ -33,6 +40,9 @@ type ArgsConsensusMessageValidator struct { ConsensusState *ConsensusState ConsensusService ConsensusService PeerSignatureHandler crypto.PeerSignatureHandler + EnableEpochsHandler common.EnableEpochsHandler + Marshaller marshal.Marshalizer + ShardCoordinator sharding.Coordinator SignatureSize int PublicKeySize int HeaderHashSize int @@ -50,6 +60,9 @@ func NewConsensusMessageValidator(args ArgsConsensusMessageValidator) (*consensu consensusState: args.ConsensusState, consensusService: args.ConsensusService, peerSignatureHandler: args.PeerSignatureHandler, + enableEpochsHandler: args.EnableEpochsHandler, + marshaller: args.Marshaller, + shardCoordinator: args.ShardCoordinator, signatureSize: args.SignatureSize, publicKeySize: args.PublicKeySize, chainID: args.ChainID, @@ -69,6 +82,15 @@ func checkArgsConsensusMessageValidator(args ArgsConsensusMessageValidator) erro if check.IfNil(args.PeerSignatureHandler) { return ErrNilPeerSignatureHandler } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + if check.IfNil(args.Marshaller) { + return ErrNilMarshalizer + } + if check.IfNil(args.ShardCoordinator) { + return ErrNilShardCoordinator + } if args.ConsensusState == nil { return ErrNilConsensusState } @@ -137,13 +159,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons msgType := consensus.MessageType(cnsMsg.MsgType) - if cmv.consensusState.RoundIndex+1 < cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex()+1 < cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a future round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a future round: %d", @@ -151,13 +173,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons cnsMsg.RoundIndex) } - if cmv.consensusState.RoundIndex > cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex() > cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a past round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a past round: %d", @@ -239,7 +261,19 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidityForMessageTyp } func (cmv *consensusMessageValidator) checkMessageWithBlockBodyAndHeaderValidity(cnsMsg *consensus.Message) error { - isMessageInvalid := cnsMsg.SignatureShare != nil || + // TODO[cleanup cns finality]: remove this + isInvalidSigShare := cnsMsg.SignatureShare != nil + + header, err := process.UnmarshalHeader(cmv.shardCoordinator.SelfId(), cmv.marshaller, cnsMsg.Header) + if err != nil { + return err + } + + if cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + isInvalidSigShare = cnsMsg.SignatureShare == nil + } + + isMessageInvalid := isInvalidSigShare || cnsMsg.PubKeysBitmap != nil || cnsMsg.AggregateSignature != nil || cnsMsg.LeaderSignature != nil || @@ -306,8 +340,19 @@ func (cmv *consensusMessageValidator) checkMessageWithBlockBodyValidity(cnsMsg * } func (cmv *consensusMessageValidator) checkMessageWithBlockHeaderValidity(cnsMsg *consensus.Message) error { + // TODO[cleanup cns finality]: remove this + isInvalidSigShare := cnsMsg.SignatureShare != nil + + header, err := process.UnmarshalHeader(cmv.shardCoordinator.SelfId(), cmv.marshaller, cnsMsg.Header) + if err != nil { + return err + } + + if cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + isInvalidSigShare = cnsMsg.SignatureShare == nil + } isMessageInvalid := cnsMsg.Body != nil || - cnsMsg.SignatureShare != nil || + isInvalidSigShare || cnsMsg.PubKeysBitmap != nil || cnsMsg.AggregateSignature != nil || cnsMsg.LeaderSignature != nil || @@ -398,6 +443,11 @@ func (cmv *consensusMessageValidator) checkMessageWithFinalInfoValidity(cnsMsg * len(cnsMsg.AggregateSignature)) } + // TODO[cleanup cns finality]: remove this + if cmv.shouldNotVerifyLeaderSignature() { + return nil + } + if len(cnsMsg.LeaderSignature) != cmv.signatureSize { return fmt.Errorf("%w : received leader signature from consensus topic has an invalid size: %d", ErrInvalidSignatureSize, @@ -407,6 +457,16 @@ func (cmv *consensusMessageValidator) checkMessageWithFinalInfoValidity(cnsMsg * return nil } +func (cmv *consensusMessageValidator) shouldNotVerifyLeaderSignature() bool { + // TODO: this check needs to be removed when equivalent messages are sent separately from the final info + if check.IfNil(cmv.consensusState.Header) { + return true + } + + return cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, cmv.consensusState.Header.GetEpoch()) + +} + func (cmv *consensusMessageValidator) checkMessageWithInvalidSingersValidity(cnsMsg *consensus.Message) error { isMessageInvalid := cnsMsg.SignatureShare != nil || cnsMsg.Body != nil || diff --git a/consensus/spos/consensusMessageValidator_test.go b/consensus/spos/consensusMessageValidator_test.go index 33c37ea4e70..ef46fc9b75e 100644 --- a/consensus/spos/consensusMessageValidator_test.go +++ b/consensus/spos/consensusMessageValidator_test.go @@ -6,13 +6,20 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/testscommon" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" ) func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValidator { @@ -26,7 +33,7 @@ func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValid return nil }, } - keyGeneratorMock, _, _ := mock.InitKeys() + keyGeneratorMock, _, _ := testscommonConsensus.InitKeys() peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock, KeyGen: keyGeneratorMock} hasher := &hashingMocks.HasherMock{} @@ -34,6 +41,9 @@ func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValid ConsensusState: consensusState, ConsensusService: blsService, PeerSignatureHandler: peerSigHandler, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + Marshaller: &marshallerMock.MarshalizerStub{}, + ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, SignatureSize: SignatureSize, PublicKeySize: PublicKeySize, HeaderHashSize: hasher.Size(), @@ -64,6 +74,36 @@ func TestNewConsensusMessageValidator(t *testing.T) { assert.Nil(t, validator) assert.Equal(t, spos.ErrNilPeerSignatureHandler, err) }) + t.Run("nil EnableEpochsHandler", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.EnableEpochsHandler = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) + }) + t.Run("nil Marshaller", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.Marshaller = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilMarshalizer, err) + }) + t.Run("nil ShardCoordinator", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.ShardCoordinator = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilShardCoordinator, err) + }) t.Run("nil ConsensusState", func(t *testing.T) { t.Parallel() @@ -179,17 +219,55 @@ func TestCheckMessageWithFinalInfoValidity_InvalidAggregateSignatureSize(t *test assert.True(t, errors.Is(err, spos.ErrInvalidSignatureSize)) } -func TestCheckMessageWithFinalInfoValidity_InvalidLeaderSignatureSize(t *testing.T) { +func TestCheckMessageWithFinalInfo_LeaderSignatureCheck(t *testing.T) { t.Parallel() - consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + t.Run("should fail", func(t *testing.T) { + t.Parallel() - sig := make([]byte, SignatureSize) - _, _ = rand.Read(sig) - cnsMsg := &consensus.Message{PubKeysBitmap: []byte("01"), AggregateSignature: sig, LeaderSignature: []byte("0")} - err := cmv.CheckMessageWithFinalInfoValidity(cnsMsg) - assert.True(t, errors.Is(err, spos.ErrInvalidSignatureSize)) + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.ConsensusState.Header = &block.Header{Epoch: 2} + + sigSize := SignatureSize + consensusMessageValidatorArgs.SignatureSize = sigSize // different signature size + + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{ + MsgType: int64(bls.MtBlockHeaderFinalInfo), + AggregateSignature: make([]byte, SignatureSize), + LeaderSignature: make([]byte, SignatureSize-1), + PubKeysBitmap: []byte("11"), + } + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.NotNil(t, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + consensusMessageValidatorArgs.ConsensusState.Header = &block.Header{Epoch: 2} + + sigSize := SignatureSize + consensusMessageValidatorArgs.SignatureSize = sigSize // different signature size + + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{ + MsgType: int64(bls.MtBlockHeaderFinalInfo), + AggregateSignature: make([]byte, SignatureSize), + LeaderSignature: make([]byte, SignatureSize-1), + PubKeysBitmap: []byte("11"), + } + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.Nil(t, err) + }) } func TestCheckMessageWithFinalInfoValidity_ShouldWork(t *testing.T) { @@ -337,6 +415,22 @@ func TestCheckMessageWithBlockBodyValidity_ShouldWork(t *testing.T) { assert.Nil(t, err) } +func TestCheckMessageWithBlockBodyAndHeaderValidity_NilSigShareAfterActivation(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{SignatureShare: nil} + err := cmv.CheckMessageWithBlockBodyAndHeaderValidity(cnsMsg) + assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) +} + func TestCheckMessageWithBlockBodyAndHeaderValidity_InvalidMessage(t *testing.T) { t.Parallel() @@ -420,6 +514,22 @@ func TestCheckConsensusMessageValidityForMessageType_MessageWithBlockHeaderInval assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) } +func TestCheckConsensusMessageValidityForMessageType_MessageWithBlockHeaderInvalidAfterFlag(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{MsgType: int64(bls.MtBlockHeader), SignatureShare: nil} + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) +} + func TestCheckConsensusMessageValidityForMessageType_MessageWithSignatureInvalid(t *testing.T) { t.Parallel() @@ -655,7 +765,7 @@ func TestCheckConsensusMessageValidity_ErrMessageForPastRound(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 100 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(100) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -678,7 +788,7 @@ func TestCheckConsensusMessageValidity_ErrMessageTypeLimitReached(t *testing.T) t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) pubKey := []byte(consensusMessageValidatorArgs.ConsensusState.ConsensusGroup()[0]) @@ -724,7 +834,7 @@ func createMockConsensusMessage(args spos.ArgsConsensusMessageValidator, pubKey MsgType: int64(msgType), PubKey: pubKey, Signature: createDummyByteSlice(SignatureSize), - RoundIndex: args.ConsensusState.RoundIndex, + RoundIndex: args.ConsensusState.GetRoundIndex(), BlockHeaderHash: createDummyByteSlice(args.HeaderHashSize), } } @@ -743,7 +853,7 @@ func TestCheckConsensusMessageValidity_InvalidSignature(t *testing.T) { consensusMessageValidatorArgs.PeerSignatureHandler = &mock.PeerSignatureHandler{ Signer: signer, } - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -766,7 +876,7 @@ func TestCheckConsensusMessageValidity_Ok(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) diff --git a/consensus/spos/consensusState.go b/consensus/spos/consensusState.go index 564b3def852..8904717b7ea 100644 --- a/consensus/spos/consensusState.go +++ b/consensus/spos/consensusState.go @@ -7,15 +7,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) -// IndexOfLeaderInConsensusGroup represents the index of the leader in the consensus group -const IndexOfLeaderInConsensusGroup = 0 - var log = logger.GetOrCreate("consensus/spos") // ConsensusState defines the data needed by spos to do the consensus in each round @@ -44,6 +42,8 @@ type ConsensusState struct { *roundConsensus *roundThreshold *roundStatus + + mutState sync.RWMutex } // NewConsensusState creates a new ConsensusState object @@ -136,11 +136,6 @@ func (cns *ConsensusState) IsNodeLeaderInCurrentRound(node string) bool { return leader == node } -// IsSelfLeaderInCurrentRound method checks if the current node is leader in the current round -func (cns *ConsensusState) IsSelfLeaderInCurrentRound() bool { - return cns.IsNodeLeaderInCurrentRound(cns.selfPubKey) -} - // GetLeader method gets the leader of the current round func (cns *ConsensusState) GetLeader() (string, error) { if cns.consensusGroup == nil { @@ -151,7 +146,7 @@ func (cns *ConsensusState) GetLeader() (string, error) { return "", ErrEmptyConsensusGroup } - return cns.consensusGroup[IndexOfLeaderInConsensusGroup], nil + return cns.Leader(), nil } // GetNextConsensusGroup gets the new consensus group for the current round based on current eligible list and a random @@ -162,8 +157,8 @@ func (cns *ConsensusState) GetNextConsensusGroup( shardId uint32, nodesCoordinator nodesCoordinator.NodesCoordinator, epoch uint32, -) ([]string, error) { - validatorsGroup, err := nodesCoordinator.ComputeConsensusGroup(randomSource, round, shardId, epoch) +) (string, []string, error) { + leader, validatorsGroup, err := nodesCoordinator.ComputeConsensusGroup(randomSource, round, shardId, epoch) if err != nil { log.Debug( "compute consensus group", @@ -173,7 +168,7 @@ func (cns *ConsensusState) GetNextConsensusGroup( "shardId", shardId, "epoch", epoch, ) - return nil, err + return "", nil, err } consensusSize := len(validatorsGroup) @@ -183,7 +178,7 @@ func (cns *ConsensusState) GetNextConsensusGroup( newConsensusGroup[i] = string(validatorsGroup[i].PubKey()) } - return newConsensusGroup, nil + return string(leader.PubKey()), newConsensusGroup, nil } // IsConsensusDataSet method returns true if the consensus data for the current round is set and false otherwise @@ -212,11 +207,6 @@ func (cns *ConsensusState) IsJobDone(node string, currentSubroundId int) bool { return jobDone } -// IsSelfJobDone method returns true if self job for the current subround is done and false otherwise -func (cns *ConsensusState) IsSelfJobDone(currentSubroundId int) bool { - return cns.IsJobDone(cns.selfPubKey, currentSubroundId) -} - // IsSubroundFinished method returns true if the current subround is finished and false otherwise func (cns *ConsensusState) IsSubroundFinished(subroundID int) bool { isSubroundFinished := cns.Status(subroundID) == SsFinished @@ -251,16 +241,7 @@ func (cns *ConsensusState) CanDoSubroundJob(currentSubroundId int) bool { return false } - selfJobDone := true - if cns.IsNodeInConsensusGroup(cns.SelfPubKey()) { - selfJobDone = cns.IsSelfJobDone(currentSubroundId) - } - multiKeyJobDone := true - if cns.IsMultiKeyInConsensusGroup() { - multiKeyJobDone = cns.IsMultiKeyJobDone(currentSubroundId) - } - - if selfJobDone && multiKeyJobDone { + if cns.IsSelfJobDone(currentSubroundId) { return false } @@ -341,6 +322,11 @@ func (cns *ConsensusState) GetData() []byte { return cns.Data } +// SetData sets the Data of the consensusState +func (cns *ConsensusState) SetData(data []byte) { + cns.Data = data +} + // IsMultiKeyLeaderInCurrentRound method checks if one of the nodes which are controlled by this instance // is leader in the current round func (cns *ConsensusState) IsMultiKeyLeaderInCurrentRound() bool { @@ -350,7 +336,7 @@ func (cns *ConsensusState) IsMultiKeyLeaderInCurrentRound() bool { return false } - return cns.IsKeyManagedByCurrentNode([]byte(leader)) + return cns.IsKeyManagedBySelf([]byte(leader)) } // IsLeaderJobDone method returns true if the leader job for the current subround is done and false otherwise @@ -380,6 +366,21 @@ func (cns *ConsensusState) IsMultiKeyJobDone(currentSubroundId int) bool { return true } +// IsSelfJobDone method returns true if self job for the current subround is done and false otherwise +func (cns *ConsensusState) IsSelfJobDone(currentSubroundID int) bool { + selfJobDone := true + if cns.IsNodeInConsensusGroup(cns.SelfPubKey()) { + selfJobDone = cns.IsJobDone(cns.SelfPubKey(), currentSubroundID) + } + + multiKeyJobDone := true + if cns.IsMultiKeyInConsensusGroup() { + multiKeyJobDone = cns.IsMultiKeyJobDone(currentSubroundID) + } + + return selfJobDone && multiKeyJobDone +} + // GetMultikeyRedundancyStepInReason returns the reason if the current node stepped in as a multikey redundancy node func (cns *ConsensusState) GetMultikeyRedundancyStepInReason() string { return cns.keysHandler.GetRedundancyStepInReason() @@ -390,3 +391,96 @@ func (cns *ConsensusState) GetMultikeyRedundancyStepInReason() string { func (cns *ConsensusState) ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) { cns.keysHandler.ResetRoundsWithoutReceivedMessages(pkBytes, pid) } + +// GetRoundCanceled returns the state of the current round +func (cns *ConsensusState) GetRoundCanceled() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.RoundCanceled +} + +// SetRoundCanceled sets the state of the current round +func (cns *ConsensusState) SetRoundCanceled(roundCanceled bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.RoundCanceled = roundCanceled +} + +// GetRoundIndex returns the index of the current round +func (cns *ConsensusState) GetRoundIndex() int64 { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.RoundIndex +} + +// SetRoundIndex sets the index of the current round +func (cns *ConsensusState) SetRoundIndex(roundIndex int64) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.RoundIndex = roundIndex +} + +// GetRoundTimeStamp returns the time stamp of the current round +func (cns *ConsensusState) GetRoundTimeStamp() time.Time { + return cns.RoundTimeStamp +} + +// SetRoundTimeStamp sets the time stamp of the current round +func (cns *ConsensusState) SetRoundTimeStamp(roundTimeStamp time.Time) { + cns.RoundTimeStamp = roundTimeStamp +} + +// GetExtendedCalled returns the state of the extended called +func (cns *ConsensusState) GetExtendedCalled() bool { + return cns.ExtendedCalled +} + +// SetExtendedCalled sets the state of the extended called +func (cns *ConsensusState) SetExtendedCalled(extendedCalled bool) { + cns.ExtendedCalled = extendedCalled +} + +// GetBody returns the body of the current round +func (cns *ConsensusState) GetBody() data.BodyHandler { + return cns.Body +} + +// SetBody sets the body of the current round +func (cns *ConsensusState) SetBody(body data.BodyHandler) { + cns.Body = body +} + +// GetHeader returns the header of the current round +func (cns *ConsensusState) GetHeader() data.HeaderHandler { + return cns.Header +} + +// GetWaitingAllSignaturesTimeOut returns the state of the waiting all signatures time out +func (cns *ConsensusState) GetWaitingAllSignaturesTimeOut() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.WaitingAllSignaturesTimeOut +} + +// SetWaitingAllSignaturesTimeOut sets the state of the waiting all signatures time out +func (cns *ConsensusState) SetWaitingAllSignaturesTimeOut(waitingAllSignaturesTimeOut bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.WaitingAllSignaturesTimeOut = waitingAllSignaturesTimeOut +} + +// SetHeader sets the header of the current round +func (cns *ConsensusState) SetHeader(header data.HeaderHandler) { + cns.Header = header +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cns *ConsensusState) IsInterfaceNil() bool { + return cns == nil +} diff --git a/consensus/spos/consensusState_test.go b/consensus/spos/consensusState_test.go index 554c9c0c755..6125c4091c4 100644 --- a/consensus/spos/consensusState_test.go +++ b/consensus/spos/consensusState_test.go @@ -7,13 +7,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) func internalInitConsensusState() *spos.ConsensusState { @@ -36,6 +37,7 @@ func internalInitConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler ) rcns.SetConsensusGroup(eligibleList) + rcns.SetLeader(eligibleList[0]) rcns.ResetRoundState() rthr := spos.NewRoundThreshold() @@ -68,12 +70,12 @@ func TestConsensusState_ResetConsensusStateShouldWork(t *testing.T) { t.Parallel() cns := internalInitConsensusState() - cns.RoundCanceled = true - cns.ExtendedCalled = true - cns.WaitingAllSignaturesTimeOut = true + cns.SetRoundCanceled(true) + cns.SetExtendedCalled(true) + cns.SetWaitingAllSignaturesTimeOut(true) cns.ResetConsensusState() assert.False(t, cns.RoundCanceled) - assert.False(t, cns.ExtendedCalled) + assert.False(t, cns.GetExtendedCalled()) assert.False(t, cns.WaitingAllSignaturesTimeOut) } @@ -102,22 +104,6 @@ func TestConsensusState_IsNodeLeaderInCurrentRoundShouldReturnTrue(t *testing.T) assert.Equal(t, true, cns.IsNodeLeaderInCurrentRound("1")) } -func TestConsensusState_IsSelfLeaderInCurrentRoundShouldReturnFalse(t *testing.T) { - t.Parallel() - - cns := internalInitConsensusState() - - assert.False(t, cns.IsSelfLeaderInCurrentRound()) -} - -func TestConsensusState_IsSelfLeaderInCurrentRoundShouldReturnTrue(t *testing.T) { - t.Parallel() - - cns := internalInitConsensusState() - - assert.False(t, cns.IsSelfLeaderInCurrentRound()) -} - func TestConsensusState_GetLeaderShoudErrNilConsensusGroup(t *testing.T) { t.Parallel() @@ -162,11 +148,11 @@ func TestConsensusState_GetNextConsensusGroupShouldFailWhenComputeValidatorsGrou round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return nil, err + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - _, err2 := cns.GetNextConsensusGroup([]byte(""), 0, 0, nodesCoord, 0) + _, _, err2 := cns.GetNextConsensusGroup([]byte(""), 0, 0, nodesCoord, 0) assert.Equal(t, err, err2) } @@ -176,10 +162,11 @@ func TestConsensusState_GetNextConsensusGroupShouldWork(t *testing.T) { cns := internalInitConsensusState() nodesCoord := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { defaultSelectionChances := uint32(1) - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances), + leader := shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances) + return leader, []nodesCoordinator.Validator{ + leader, shardingMocks.NewValidatorMock([]byte("B"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("C"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("D"), 1, defaultSelectionChances), @@ -192,9 +179,10 @@ func TestConsensusState_GetNextConsensusGroupShouldWork(t *testing.T) { }, } - nextConsensusGroup, err := cns.GetNextConsensusGroup(nil, 0, 0, nodesCoord, 0) + leader, nextConsensusGroup, err := cns.GetNextConsensusGroup(nil, 0, 0, nodesCoord, 0) assert.Nil(t, err) assert.NotNil(t, nextConsensusGroup) + assert.NotEmpty(t, leader) } func TestConsensusState_IsConsensusDataSetShouldReturnTrue(t *testing.T) { diff --git a/consensus/spos/errors.go b/consensus/spos/errors.go index 3aeac029da3..62f9c23ad17 100644 --- a/consensus/spos/errors.go +++ b/consensus/spos/errors.go @@ -243,3 +243,30 @@ var ErrNilFunctionHandler = errors.New("nil function handler") // ErrWrongHashForHeader signals that the hash of the header is not the expected one var ErrWrongHashForHeader = errors.New("wrong hash for header") + +// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrEquivalentMessageAlreadyReceived signals that an equivalent message has been already received +var ErrEquivalentMessageAlreadyReceived = errors.New("equivalent message already received") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") + +// ErrNilThrottler signals that a nil throttler has been provided +var ErrNilThrottler = errors.New("nil throttler") + +// ErrTimeIsOut signals that time is out +var ErrTimeIsOut = errors.New("time is out") + +// ErrNilEquivalentProofPool signals that a nil proof pool has been provided +var ErrNilEquivalentProofPool = errors.New("nil equivalent proof pool") + +// ErrNilHeaderProof signals that a nil header proof has been provided +var ErrNilHeaderProof = errors.New("nil header proof") + +// ErrHeaderProofNotExpected signals that a header proof was not expected +var ErrHeaderProofNotExpected = errors.New("header proof not expected") + +// ErrConsensusMessageNotExpected signals that a consensus message was not expected +var ErrConsensusMessageNotExpected = errors.New("consensus message not expected") diff --git a/consensus/spos/export_test.go b/consensus/spos/export_test.go index 39d19de6e30..1ad0bbc67d5 100644 --- a/consensus/spos/export_test.go +++ b/consensus/spos/export_test.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -13,6 +14,12 @@ import ( // RedundancySingleKeySteppedIn exposes the redundancySingleKeySteppedIn constant const RedundancySingleKeySteppedIn = redundancySingleKeySteppedIn +// LeaderSingleKeyStartMsg - +const LeaderSingleKeyStartMsg = singleKeyStartMsg + +// LeaderMultiKeyStartMsg - +const LeaderMultiKeyStartMsg = multiKeyStartMsg + type RoundConsensus struct { *roundConsensus } @@ -142,17 +149,17 @@ func (wrk *Worker) NilReceivedMessages() { } // ReceivedMessagesCalls - -func (wrk *Worker) ReceivedMessagesCalls() map[consensus.MessageType]func(context.Context, *consensus.Message) bool { +func (wrk *Worker) ReceivedMessagesCalls() map[consensus.MessageType][]func(context.Context, *consensus.Message) bool { wrk.mutReceivedMessagesCalls.RLock() defer wrk.mutReceivedMessagesCalls.RUnlock() return wrk.receivedMessagesCalls } -// SetReceivedMessagesCalls - -func (wrk *Worker) SetReceivedMessagesCalls(messageType consensus.MessageType, f func(context.Context, *consensus.Message) bool) { +// AppendReceivedMessagesCalls - +func (wrk *Worker) AppendReceivedMessagesCalls(messageType consensus.MessageType, f func(context.Context, *consensus.Message) bool) { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls[messageType] = f + wrk.receivedMessagesCalls[messageType] = append(wrk.receivedMessagesCalls[messageType], f) wrk.mutReceivedMessagesCalls.Unlock() } diff --git a/consensus/spos/interface.go b/consensus/spos/interface.go index 0ca771d30e5..d85c94f2b7a 100644 --- a/consensus/spos/interface.go +++ b/consensus/spos/interface.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -21,51 +23,30 @@ import ( // ConsensusCoreHandler encapsulates all needed data for the Consensus type ConsensusCoreHandler interface { - // Blockchain gets the ChainHandler stored in the ConsensusCore Blockchain() data.ChainHandler - // BlockProcessor gets the BlockProcessor stored in the ConsensusCore BlockProcessor() process.BlockProcessor - // BootStrapper gets the Bootstrapper stored in the ConsensusCore BootStrapper() process.Bootstrapper - // BroadcastMessenger gets the BroadcastMessenger stored in ConsensusCore BroadcastMessenger() consensus.BroadcastMessenger - // Chronology gets the ChronologyHandler stored in the ConsensusCore Chronology() consensus.ChronologyHandler - // GetAntiFloodHandler returns the antiflood handler which will be used in subrounds GetAntiFloodHandler() consensus.P2PAntifloodHandler - // Hasher gets the Hasher stored in the ConsensusCore Hasher() hashing.Hasher - // Marshalizer gets the Marshalizer stored in the ConsensusCore Marshalizer() marshal.Marshalizer - // MultiSignerContainer gets the MultiSigner container from the ConsensusCore MultiSignerContainer() cryptoCommon.MultiSignerContainer - // RoundHandler gets the RoundHandler stored in the ConsensusCore RoundHandler() consensus.RoundHandler - // ShardCoordinator gets the ShardCoordinator stored in the ConsensusCore ShardCoordinator() sharding.Coordinator - // SyncTimer gets the SyncTimer stored in the ConsensusCore SyncTimer() ntp.SyncTimer - // NodesCoordinator gets the NodesCoordinator stored in the ConsensusCore NodesCoordinator() nodesCoordinator.NodesCoordinator - // EpochStartRegistrationHandler gets the RegistrationHandler stored in the ConsensusCore EpochStartRegistrationHandler() epochStart.RegistrationHandler - // PeerHonestyHandler returns the peer honesty handler which will be used in subrounds PeerHonestyHandler() consensus.PeerHonestyHandler - // HeaderSigVerifier returns the sig verifier handler which will be used in subrounds HeaderSigVerifier() consensus.HeaderSigVerifier - // FallbackHeaderValidator returns the fallback header validator handler which will be used in subrounds FallbackHeaderValidator() consensus.FallbackHeaderValidator - // NodeRedundancyHandler returns the node redundancy handler which will be used in subrounds NodeRedundancyHandler() consensus.NodeRedundancyHandler - // ScheduledProcessor returns the scheduled txs processor ScheduledProcessor() consensus.ScheduledProcessor - // MessageSigningHandler returns the p2p signing handler MessageSigningHandler() consensus.P2PSigningHandler - // PeerBlacklistHandler return the peer blacklist handler PeerBlacklistHandler() consensus.PeerBlacklistHandler - // SigningHandler returns the signing handler component SigningHandler() consensus.SigningHandler - // IsInterfaceNil returns true if there is no value under the interface + EnableEpochsHandler() common.EnableEpochsHandler + EquivalentProofsPool() consensus.EquivalentProofsPool IsInterfaceNil() bool } @@ -123,6 +104,8 @@ type WorkerHandler interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // AddReceivedProofHandler adds a new handler function for a received proof + AddReceivedProofHandler(handler func(consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers RemoveAllReceivedMessagesCalls() // ProcessReceivedMessage method redirects the received message to the channel which should handle it @@ -137,7 +120,7 @@ type WorkerHandler interface { DisplayStatistics() // ReceivedHeader method is a wired method through which worker will receive headers from network ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) - // ResetConsensusMessages resets at the start of each round all the previous consensus messages received + // ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs ResetConsensusMessages() // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool @@ -154,6 +137,9 @@ type HeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderWithProof(header data.HeaderHandler) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } @@ -177,3 +163,99 @@ type SentSignaturesTracker interface { SignatureSent(pkBytes []byte) IsInterfaceNil() bool } + +// ConsensusStateHandler encapsulates all needed data for the Consensus +type ConsensusStateHandler interface { + ResetConsensusState() + AddReceivedHeader(headerHandler data.HeaderHandler) + GetReceivedHeaders() []data.HeaderHandler + AddMessageWithSignature(key string, message p2p.MessageP2P) + GetMessageWithSignature(key string) (p2p.MessageP2P, bool) + IsNodeLeaderInCurrentRound(node string) bool + GetLeader() (string, error) + GetNextConsensusGroup( + randomSource []byte, + round uint64, + shardId uint32, + nodesCoordinator nodesCoordinator.NodesCoordinator, + epoch uint32, + ) (string, []string, error) + IsConsensusDataSet() bool + IsConsensusDataEqual(data []byte) bool + IsJobDone(node string, currentSubroundId int) bool + IsSubroundFinished(subroundID int) bool + IsNodeSelf(node string) bool + IsBlockBodyAlreadyReceived() bool + IsHeaderAlreadyReceived() bool + CanDoSubroundJob(currentSubroundId int) bool + CanProcessReceivedMessage(cnsDta *consensus.Message, currentRoundIndex int64, currentSubroundId int) bool + GenerateBitmap(subroundId int) []byte + ProcessingBlock() bool + SetProcessingBlock(processingBlock bool) + GetData() []byte + SetData(data []byte) + IsMultiKeyLeaderInCurrentRound() bool + IsLeaderJobDone(currentSubroundId int) bool + IsMultiKeyJobDone(currentSubroundId int) bool + IsSelfJobDone(currentSubroundID int) bool + GetMultikeyRedundancyStepInReason() string + ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) + GetRoundCanceled() bool + SetRoundCanceled(state bool) + GetRoundIndex() int64 + SetRoundIndex(roundIndex int64) + GetRoundTimeStamp() time.Time + SetRoundTimeStamp(roundTimeStamp time.Time) + GetExtendedCalled() bool + GetBody() data.BodyHandler + SetBody(body data.BodyHandler) + GetHeader() data.HeaderHandler + SetHeader(header data.HeaderHandler) + GetWaitingAllSignaturesTimeOut() bool + SetWaitingAllSignaturesTimeOut(bool) + RoundConsensusHandler + RoundStatusHandler + RoundThresholdHandler + IsInterfaceNil() bool +} + +// RoundConsensusHandler encapsulates the methods needed for a consensus round +type RoundConsensusHandler interface { + ConsensusGroupIndex(pubKey string) (int, error) + SelfConsensusGroupIndex() (int, error) + SetEligibleList(eligibleList map[string]struct{}) + ConsensusGroup() []string + SetConsensusGroup(consensusGroup []string) + SetLeader(leader string) + ConsensusGroupSize() int + SetConsensusGroupSize(consensusGroupSize int) + SelfPubKey() string + SetSelfPubKey(selfPubKey string) + JobDone(key string, subroundId int) (bool, error) + SetJobDone(key string, subroundId int, value bool) error + SelfJobDone(subroundId int) (bool, error) + IsNodeInConsensusGroup(node string) bool + IsNodeInEligibleList(node string) bool + ComputeSize(subroundId int) int + ResetRoundState() + IsMultiKeyInConsensusGroup() bool + IsKeyManagedBySelf(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessages(pkBytes []byte) + GetKeysHandler() consensus.KeysHandler + Leader() string +} + +// RoundStatusHandler encapsulates the methods needed for the status of a subround +type RoundStatusHandler interface { + Status(subroundId int) SubroundStatus + SetStatus(subroundId int, subroundStatus SubroundStatus) + ResetRoundStatus() +} + +// RoundThresholdHandler encapsulates the methods needed for the round consensus threshold +type RoundThresholdHandler interface { + Threshold(subroundId int) int + SetThreshold(subroundId int, threshold int) + FallbackThreshold(subroundId int) int + SetFallbackThreshold(subroundId int, threshold int) +} diff --git a/consensus/spos/roundConsensus.go b/consensus/spos/roundConsensus.go index 73e87242b63..dfe6eb88d29 100644 --- a/consensus/spos/roundConsensus.go +++ b/consensus/spos/roundConsensus.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -12,6 +13,7 @@ type roundConsensus struct { eligibleNodes map[string]struct{} mutEligible sync.RWMutex consensusGroup []string + leader string consensusGroupSize int selfPubKey string validatorRoundStates map[string]*roundState @@ -64,15 +66,18 @@ func (rcns *roundConsensus) SetEligibleList(eligibleList map[string]struct{}) { // ConsensusGroup returns the consensus group ID's func (rcns *roundConsensus) ConsensusGroup() []string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + return rcns.consensusGroup } // SetConsensusGroup sets the consensus group ID's func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { - rcns.consensusGroup = consensusGroup - rcns.mut.Lock() + rcns.consensusGroup = consensusGroup + rcns.validatorRoundStates = make(map[string]*roundState) for i := 0; i < len(consensusGroup); i++ { @@ -82,6 +87,22 @@ func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { rcns.mut.Unlock() } +// Leader returns the leader for the current consensus +func (rcns *roundConsensus) Leader() string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + + return rcns.leader +} + +// SetLeader sets the leader for the current consensus +func (rcns *roundConsensus) SetLeader(leader string) { + rcns.mut.Lock() + defer rcns.mut.Unlock() + + rcns.leader = leader +} + // ConsensusGroupSize returns the consensus group size func (rcns *roundConsensus) ConsensusGroupSize() int { return rcns.consensusGroupSize @@ -144,6 +165,9 @@ func (rcns *roundConsensus) SelfJobDone(subroundId int) (bool, error) { // IsNodeInConsensusGroup method checks if the node is part of consensus group of the current round func (rcns *roundConsensus) IsNodeInConsensusGroup(node string) bool { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + for i := 0; i < len(rcns.consensusGroup); i++ { if rcns.consensusGroup[i] == node { return true @@ -165,6 +189,9 @@ func (rcns *roundConsensus) IsNodeInEligibleList(node string) bool { // ComputeSize method returns the number of messages received from the nodes belonging to the current jobDone group // related to this subround func (rcns *roundConsensus) ComputeSize(subroundId int) int { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + n := 0 for i := 0; i < len(rcns.consensusGroup); i++ { @@ -205,7 +232,7 @@ func (rcns *roundConsensus) ResetRoundState() { // is in consensus group in the current round func (rcns *roundConsensus) IsMultiKeyInConsensusGroup() bool { for i := 0; i < len(rcns.consensusGroup); i++ { - if rcns.IsKeyManagedByCurrentNode([]byte(rcns.consensusGroup[i])) { + if rcns.IsKeyManagedBySelf([]byte(rcns.consensusGroup[i])) { return true } } @@ -213,8 +240,8 @@ func (rcns *roundConsensus) IsMultiKeyInConsensusGroup() bool { return false } -// IsKeyManagedByCurrentNode returns true if the key is managed by the current node -func (rcns *roundConsensus) IsKeyManagedByCurrentNode(pkBytes []byte) bool { +// IsKeyManagedBySelf returns true if the key is managed by the current node +func (rcns *roundConsensus) IsKeyManagedBySelf(pkBytes []byte) bool { return rcns.keysHandler.IsKeyManagedByCurrentNode(pkBytes) } @@ -222,3 +249,8 @@ func (rcns *roundConsensus) IsKeyManagedByCurrentNode(pkBytes []byte) bool { func (rcns *roundConsensus) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { rcns.keysHandler.IncrementRoundsWithoutReceivedMessages(pkBytes) } + +// GetKeysHandler returns the keysHandler instance +func (rcns *roundConsensus) GetKeysHandler() consensus.KeysHandler { + return rcns.keysHandler +} diff --git a/consensus/spos/roundConsensus_test.go b/consensus/spos/roundConsensus_test.go index 4ba8f7e47fe..36c8e5ad8ab 100644 --- a/consensus/spos/roundConsensus_test.go +++ b/consensus/spos/roundConsensus_test.go @@ -296,23 +296,6 @@ func TestRoundConsensus_IsMultiKeyInConsensusGroup(t *testing.T) { }) } -func TestRoundConsensus_IsKeyManagedByCurrentNode(t *testing.T) { - t.Parallel() - - managedPkBytes := []byte("managed pk bytes") - wasCalled := false - keysHandler := &testscommon.KeysHandlerStub{ - IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { - assert.Equal(t, managedPkBytes, pkBytes) - wasCalled = true - return true - }, - } - roundConsensus := initRoundConsensusWithKeysHandler(keysHandler) - assert.True(t, roundConsensus.IsKeyManagedByCurrentNode(managedPkBytes)) - assert.True(t, wasCalled) -} - func TestRoundConsensus_IncrementRoundsWithoutReceivedMessages(t *testing.T) { t.Parallel() diff --git a/consensus/spos/roundStatus.go b/consensus/spos/roundStatus.go index 8517396904a..7d3b67fdc15 100644 --- a/consensus/spos/roundStatus.go +++ b/consensus/spos/roundStatus.go @@ -5,7 +5,7 @@ import ( ) // SubroundStatus defines the type used to refer the state of the current subround -type SubroundStatus int +type SubroundStatus = int const ( // SsNotFinished defines the un-finished state of the subround diff --git a/consensus/spos/scheduledProcessor_test.go b/consensus/spos/scheduledProcessor_test.go index 7316209921b..ed1f95287a2 100644 --- a/consensus/spos/scheduledProcessor_test.go +++ b/consensus/spos/scheduledProcessor_test.go @@ -8,9 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/require" ) @@ -30,7 +32,7 @@ func TestNewScheduledProcessorWrapper_NilSyncTimerShouldErr(t *testing.T) { args := ScheduledProcessorWrapperArgs{ SyncTimer: nil, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, err := NewScheduledProcessorWrapper(args) @@ -42,9 +44,9 @@ func TestNewScheduledProcessorWrapper_NilBlockProcessorShouldErr(t *testing.T) { t.Parallel() args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: nil, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, err := NewScheduledProcessorWrapper(args) @@ -56,7 +58,7 @@ func TestNewScheduledProcessorWrapper_NilRoundTimeDurationHandlerShouldErr(t *te t.Parallel() args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, RoundTimeDurationHandler: nil, } @@ -70,9 +72,9 @@ func TestNewScheduledProcessorWrapper_NilBlockProcessorOK(t *testing.T) { t.Parallel() args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, err := NewScheduledProcessorWrapper(args) @@ -85,14 +87,14 @@ func TestScheduledProcessorWrapper_IsProcessedOKEarlyExit(t *testing.T) { called := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { called.SetValue(true) return time.Now() }, }, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, err := NewScheduledProcessorWrapper(args) @@ -112,13 +114,13 @@ func TestScheduledProcessorWrapper_IsProcessedOKEarlyExit(t *testing.T) { func defaultScheduledProcessorWrapperArgs() ScheduledProcessorWrapperArgs { return ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, }, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } } @@ -227,9 +229,9 @@ func TestScheduledProcessorWrapper_StatusGetterAndSetter(t *testing.T) { t.Parallel() args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, _ := NewScheduledProcessorWrapper(args) @@ -250,14 +252,14 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV1ProcessingOK( processScheduledCalled := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, _ := NewScheduledProcessorWrapper(args) @@ -276,14 +278,14 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ProcessingWit processScheduledCalled := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return errors.New("processing error") }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, _ := NewScheduledProcessorWrapper(args) @@ -304,14 +306,14 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ProcessingOK( processScheduledCalled := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } sp, _ := NewScheduledProcessorWrapper(args) @@ -333,7 +335,7 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped( processScheduledCalled := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, @@ -350,7 +352,7 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped( } }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } spw, err := NewScheduledProcessorWrapper(args) @@ -374,7 +376,7 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopAfte processScheduledCalled := atomic.Flag{} args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, @@ -386,7 +388,7 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopAfte return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } spw, err := NewScheduledProcessorWrapper(args) diff --git a/consensus/spos/sposFactory/sposFactory.go b/consensus/spos/sposFactory/sposFactory.go index 84faafe53e6..bb2d409a97f 100644 --- a/consensus/spos/sposFactory/sposFactory.go +++ b/consensus/spos/sposFactory/sposFactory.go @@ -6,50 +6,16 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-crypto-go" + + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" - "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) -// GetSubroundsFactory returns a subrounds factory depending on the given parameter -func GetSubroundsFactory( - consensusDataContainer spos.ConsensusCoreHandler, - consensusState *spos.ConsensusState, - worker spos.WorkerHandler, - consensusType string, - appStatusHandler core.AppStatusHandler, - outportHandler outport.OutportHandler, - sentSignatureTracker spos.SentSignaturesTracker, - chainID []byte, - currentPid core.PeerID, -) (spos.SubroundsFactory, error) { - switch consensusType { - case blsConsensusType: - subRoundFactoryBls, err := bls.NewSubroundsFactory( - consensusDataContainer, - consensusState, - worker, - chainID, - currentPid, - appStatusHandler, - sentSignatureTracker, - ) - if err != nil { - return nil, err - } - - subRoundFactoryBls.SetOutportHandler(outportHandler) - - return subRoundFactoryBls, nil - default: - return nil, ErrInvalidConsensusType - } -} - // GetConsensusCoreFactory returns a consensus service depending on the given parameter func GetConsensusCoreFactory(consensusType string) (spos.ConsensusService, error) { switch consensusType { @@ -71,12 +37,28 @@ func GetBroadcastMessenger( interceptorsContainer process.InterceptorsContainer, alarmScheduler core.TimersScheduler, keysHandler consensus.KeysHandler, + config config.ConsensusGradualBroadcastConfig, ) (consensus.BroadcastMessenger, error) { if check.IfNil(shardCoordinator) { return nil, spos.ErrNilShardCoordinator } + dbbArgs := &broadcast.ArgsDelayedBlockBroadcaster{ + InterceptorsContainer: interceptorsContainer, + HeadersSubscriber: headersSubscriber, + ShardCoordinator: shardCoordinator, + LeaderCacheSize: maxDelayCacheSize, + ValidatorCacheSize: maxDelayCacheSize, + AlarmScheduler: alarmScheduler, + Config: config, + } + + delayedBroadcaster, err := broadcast.NewDelayedBlockBroadcaster(dbbArgs) + if err != nil { + return nil, err + } + commonMessengerArgs := broadcast.CommonMessengerArgs{ Marshalizer: marshalizer, Hasher: hasher, @@ -89,6 +71,7 @@ func GetBroadcastMessenger( InterceptorsContainer: interceptorsContainer, AlarmScheduler: alarmScheduler, KeysHandler: keysHandler, + DelayedBroadcaster: delayedBroadcaster, } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { diff --git a/consensus/spos/sposFactory/sposFactory_test.go b/consensus/spos/sposFactory/sposFactory_test.go index 4a672a3343f..3a39dc943aa 100644 --- a/consensus/spos/sposFactory/sposFactory_test.go +++ b/consensus/spos/sposFactory/sposFactory_test.go @@ -5,20 +5,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) -var currentPid = core.PeerID("pid") - func TestGetConsensusCoreFactory_InvalidTypeShouldErr(t *testing.T) { t.Parallel() @@ -37,98 +36,6 @@ func TestGetConsensusCoreFactory_BlsShouldWork(t *testing.T) { assert.False(t, check.IfNil(csf)) } -func TestGetSubroundsFactory_BlsNilConsensusCoreShouldErr(t *testing.T) { - t.Parallel() - - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - statusHandler := statusHandlerMock.NewAppStatusHandlerMock() - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - nil, - &spos.ConsensusState{}, - worker, - consensusType, - statusHandler, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, spos.ErrNilConsensusCore, err) -} - -func TestGetSubroundsFactory_BlsNilStatusHandlerShouldErr(t *testing.T) { - t.Parallel() - - consensusCore := mock.InitConsensusCore() - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - consensusCore, - &spos.ConsensusState{}, - worker, - consensusType, - nil, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, spos.ErrNilAppStatusHandler, err) -} - -func TestGetSubroundsFactory_BlsShouldWork(t *testing.T) { - t.Parallel() - - consensusCore := mock.InitConsensusCore() - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - statusHandler := statusHandlerMock.NewAppStatusHandlerMock() - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - consensusCore, - &spos.ConsensusState{}, - worker, - consensusType, - statusHandler, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - assert.Nil(t, err) - assert.False(t, check.IfNil(sf)) -} - -func TestGetSubroundsFactory_InvalidConsensusTypeShouldErr(t *testing.T) { - t.Parallel() - - consensusType := "invalid" - sf, err := sposFactory.GetSubroundsFactory( - nil, - nil, - nil, - consensusType, - nil, - nil, - nil, - nil, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, sposFactory.ErrInvalidConsensusType, err) -} - func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { t.Parallel() @@ -140,9 +47,9 @@ func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { return 0 } peerSigHandler := &mock.PeerSignatureHandler{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( marshalizer, @@ -154,6 +61,9 @@ func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { interceptosContainer, alarmSchedulerStub, &testscommon.KeysHandlerStub{}, + config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}, + }, ) assert.Nil(t, err) @@ -171,9 +81,9 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { return core.MetachainShardId } peerSigHandler := &mock.PeerSignatureHandler{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( marshalizer, @@ -185,6 +95,9 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { interceptosContainer, alarmSchedulerStub, &testscommon.KeysHandlerStub{}, + config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}, + }, ) assert.Nil(t, err) @@ -194,9 +107,9 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { func TestGetBroadcastMessenger_NilShardCoordinatorShouldErr(t *testing.T) { t.Parallel() - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( nil, @@ -208,6 +121,9 @@ func TestGetBroadcastMessenger_NilShardCoordinatorShouldErr(t *testing.T) { interceptosContainer, alarmSchedulerStub, &testscommon.KeysHandlerStub{}, + config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}, + }, ) assert.Nil(t, bm) @@ -221,9 +137,9 @@ func TestGetBroadcastMessenger_InvalidShardIdShouldErr(t *testing.T) { shardCoord.SelfIDCalled = func() uint32 { return 37 } - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( nil, @@ -235,6 +151,9 @@ func TestGetBroadcastMessenger_InvalidShardIdShouldErr(t *testing.T) { interceptosContainer, alarmSchedulerStub, &testscommon.KeysHandlerStub{}, + config.ConsensusGradualBroadcastConfig{ + GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}, + }, ) assert.Nil(t, bm) diff --git a/consensus/spos/subround.go b/consensus/spos/subround.go index 1d1b07589a6..00b2c55fe6c 100644 --- a/consensus/spos/subround.go +++ b/consensus/spos/subround.go @@ -6,18 +6,24 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" ) var _ consensus.SubroundHandler = (*Subround)(nil) +const ( + singleKeyStartMsg = " (my turn)" + multiKeyStartMsg = " (my turn in multi-key)" +) + // Subround struct contains the needed data for one Subround and the Subround properties. It defines a Subround // with its properties (its ID, next Subround ID, its duration, its name) and also it has some handler functions // which should be set. Job function will be the main function of this Subround, Extend function will handle the overtime // situation of the Subround and Check function will decide if in this Subround the consensus is achieved type Subround struct { ConsensusCoreHandler - *ConsensusState + ConsensusStateHandler previous int current int @@ -45,7 +51,7 @@ func NewSubround( startTime int64, endTime int64, name string, - consensusState *ConsensusState, + consensusState ConsensusStateHandler, consensusStateChangedChannel chan bool, executeStoredMessages func(), container ConsensusCoreHandler, @@ -67,7 +73,7 @@ func NewSubround( sr := Subround{ ConsensusCoreHandler: container, - ConsensusState: consensusState, + ConsensusStateHandler: consensusState, previous: previous, current: current, next: next, @@ -88,7 +94,7 @@ func NewSubround( } func checkNewSubroundParams( - state *ConsensusState, + state ConsensusStateHandler, consensusStateChangedChannel chan bool, executeStoredMessages func(), container ConsensusCoreHandler, @@ -145,7 +151,7 @@ func (sr *Subround) DoWork(ctx context.Context, roundHandler consensus.RoundHand } case <-time.After(roundHandler.RemainingTime(startTime, maxTime)): if sr.Extend != nil { - sr.RoundCanceled = true + sr.SetRoundCanceled(true) sr.Extend(sr.current) } @@ -206,7 +212,7 @@ func (sr *Subround) ConsensusChannel() chan bool { // GetAssociatedPid returns the associated PeerID to the provided public key bytes func (sr *Subround) GetAssociatedPid(pkBytes []byte) core.PeerID { - return sr.keysHandler.GetAssociatedPid(pkBytes) + return sr.GetKeysHandler().GetAssociatedPid(pkBytes) } // ShouldConsiderSelfKeyInConsensus returns true if current machine is the main one, or it is a backup machine but the main @@ -221,6 +227,36 @@ func (sr *Subround) ShouldConsiderSelfKeyInConsensus() bool { return isMainMachineInactive } +// IsSelfInConsensusGroup returns true is the current node is in consensus group in single +// key or in multi-key mode +func (sr *Subround) IsSelfInConsensusGroup() bool { + return sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() +} + +// IsSelfLeader returns true is the current node is leader is single key or in +// multi-key mode +func (sr *Subround) IsSelfLeader() bool { + return sr.IsSelfLeaderInCurrentRound() || sr.IsMultiKeyLeaderInCurrentRound() +} + +// IsSelfLeaderInCurrentRound method checks if the current node is leader in the current round +func (sr *Subround) IsSelfLeaderInCurrentRound() bool { + return sr.IsNodeLeaderInCurrentRound(sr.SelfPubKey()) && sr.ShouldConsiderSelfKeyInConsensus() +} + +// GetLeaderStartRoundMessage returns the leader start round message based on single key +// or multi-key node type +func (sr *Subround) GetLeaderStartRoundMessage() string { + if sr.IsMultiKeyLeaderInCurrentRound() { + return multiKeyStartMsg + } + if sr.IsSelfLeaderInCurrentRound() { + return singleKeyStartMsg + } + + return "" +} + // IsInterfaceNil returns true if there is no value under the interface func (sr *Subround) IsInterfaceNil() bool { return sr == nil diff --git a/consensus/spos/subround_test.go b/consensus/spos/subround_test.go index 202899e1a24..8eb3e8e568d 100644 --- a/consensus/spos/subround_test.go +++ b/consensus/spos/subround_test.go @@ -1,19 +1,24 @@ package spos_test import ( + "bytes" "context" "sync" "testing" "time" "github.com/multiversx/mx-chain-core-go/core" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) var chainID = []byte("chain ID") @@ -57,6 +62,7 @@ func initConsensusState() *spos.ConsensusState { ) rcns.SetConsensusGroup(eligibleList) + rcns.SetLeader(eligibleList[indexLeader]) rcns.ResetRoundState() pBFTThreshold := consensusGroupSize*2/3 + 1 @@ -84,14 +90,14 @@ func initConsensusState() *spos.ConsensusState { ) cns.Data = []byte("X") - cns.RoundIndex = 0 + cns.SetRoundIndex(0) return cns } func TestSubround_NewSubroundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() ch := make(chan bool, 1) sr, err := spos.NewSubround( @@ -118,7 +124,7 @@ func TestSubround_NewSubroundNilChannelShouldFail(t *testing.T) { t.Parallel() consensusState := initConsensusState() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, @@ -144,7 +150,7 @@ func TestSubround_NewSubroundNilExecuteStoredMessagesShouldFail(t *testing.T) { t.Parallel() consensusState := initConsensusState() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() ch := make(chan bool, 1) sr, err := spos.NewSubround( @@ -198,7 +204,7 @@ func TestSubround_NilContainerBlockchainShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBlockchain(nil) sr, err := spos.NewSubround( @@ -226,7 +232,7 @@ func TestSubround_NilContainerBlockprocessorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBlockProcessor(nil) sr, err := spos.NewSubround( @@ -254,7 +260,7 @@ func TestSubround_NilContainerBootstrapperShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBootStrapper(nil) sr, err := spos.NewSubround( @@ -282,7 +288,7 @@ func TestSubround_NilContainerChronologyShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetChronology(nil) sr, err := spos.NewSubround( @@ -310,7 +316,7 @@ func TestSubround_NilContainerHasherShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetHasher(nil) sr, err := spos.NewSubround( @@ -338,7 +344,7 @@ func TestSubround_NilContainerMarshalizerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetMarshalizer(nil) sr, err := spos.NewSubround( @@ -366,7 +372,7 @@ func TestSubround_NilContainerMultiSignerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetMultiSignerContainer(cryptoMocks.NewMultiSignerContainerMock(nil)) sr, err := spos.NewSubround( @@ -394,7 +400,7 @@ func TestSubround_NilContainerRoundHandlerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(nil) sr, err := spos.NewSubround( @@ -422,7 +428,7 @@ func TestSubround_NilContainerShardCoordinatorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetShardCoordinator(nil) sr, err := spos.NewSubround( @@ -450,7 +456,7 @@ func TestSubround_NilContainerSyncTimerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetSyncTimer(nil) sr, err := spos.NewSubround( @@ -478,7 +484,7 @@ func TestSubround_NilContainerValidatorGroupSelectorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetValidatorGroupSelector(nil) sr, err := spos.NewSubround( @@ -506,7 +512,7 @@ func TestSubround_EmptyChainIDShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, bls.SrStartRound, @@ -532,7 +538,7 @@ func TestSubround_NewSubroundShouldWork(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, bls.SrStartRound, @@ -566,7 +572,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenJobFunctionIsNotSet(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -589,7 +595,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenJobFunctionIsNotSet(t *testing.T) { } maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -604,7 +610,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenCheckFunctionIsNotSet(t *testing.T) consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -627,7 +633,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenCheckFunctionIsNotSet(t *testing.T) sr.Check = nil maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -651,7 +657,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobAndConsensusAreDone(t *testing.T) func testDoWork(t *testing.T, checkDone bool, shouldWork bool) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -676,7 +682,7 @@ func testDoWork(t *testing.T, checkDone bool, shouldWork bool) { } maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -690,7 +696,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobIsDoneAndConsensusIsDoneAfterAWhi consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -723,7 +729,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobIsDoneAndConsensusIsDoneAfterAWhi } maxTime := time.Now().Add(2000 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -748,7 +754,7 @@ func TestSubround_Previous(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -780,7 +786,7 @@ func TestSubround_Current(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -812,7 +818,7 @@ func TestSubround_Next(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -844,7 +850,7 @@ func TestSubround_StartTime(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(initRoundHandlerMock()) sr, _ := spos.NewSubround( bls.SrBlock, @@ -876,7 +882,7 @@ func TestSubround_EndTime(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(initRoundHandlerMock()) sr, _ := spos.NewSubround( bls.SrStartRound, @@ -908,7 +914,7 @@ func TestSubround_Name(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -941,7 +947,7 @@ func TestSubround_GetAssociatedPid(t *testing.T) { keysHandler := &testscommon.KeysHandlerStub{} consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() subround, _ := spos.NewSubround( bls.SrStartRound, @@ -971,3 +977,183 @@ func TestSubround_GetAssociatedPid(t *testing.T) { assert.Equal(t, pid, subround.GetAssociatedPid(providedPkBytes)) assert.True(t, wasCalled) } + +func TestSubround_ShouldConsiderSelfKeyInConsensus(t *testing.T) { + t.Parallel() + + t.Run("is main machine active, should return true", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return false + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.True(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) + + t.Run("is redundancy node machine active, should return true", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return false + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.True(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) + + t.Run("is redundancy node machine but inactive, should return false", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.False(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) +} + +func TestSubround_GetLeaderStartRoundMessage(t *testing.T) { + t.Parallel() + + t.Run("should work with multi key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("1"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + + require.Equal(t, spos.LeaderMultiKeyStartMsg, sr.GetLeaderStartRoundMessage()) + }) + + t.Run("should work with single key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("2"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + + require.Equal(t, spos.LeaderSingleKeyStartMsg, sr.GetLeaderStartRoundMessage()) + }) +} diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index f11e40d3089..e539071331d 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -54,11 +54,12 @@ type Worker struct { headerSigVerifier HeaderSigVerifier headerIntegrityVerifier process.HeaderIntegrityVerifier appStatusHandler core.AppStatusHandler + enableEpochsHandler common.EnableEpochsHandler networkShardingCollector consensus.NetworkShardingCollector receivedMessages map[consensus.MessageType][]*consensus.Message - receivedMessagesCalls map[consensus.MessageType]func(ctx context.Context, msg *consensus.Message) bool + receivedMessagesCalls map[consensus.MessageType][]func(ctx context.Context, msg *consensus.Message) bool executeMessageChannel chan *consensus.Message consensusStateChangedChannel chan bool @@ -72,6 +73,9 @@ type Worker struct { receivedHeadersHandlers []func(headerHandler data.HeaderHandler) mutReceivedHeadersHandler sync.RWMutex + receivedProofHandlers []func(proofHandler consensus.ProofHandler) + mutReceivedProofHandler sync.RWMutex + antifloodHandler consensus.P2PAntifloodHandler poolAdder PoolAdder @@ -109,6 +113,7 @@ type WorkerArgs struct { AppStatusHandler core.AppStatusHandler NodeRedundancyHandler consensus.NodeRedundancyHandler PeerBlacklistHandler consensus.PeerBlacklistHandler + EnableEpochsHandler common.EnableEpochsHandler } // NewWorker creates a new Worker object @@ -122,6 +127,9 @@ func NewWorker(args *WorkerArgs) (*Worker, error) { ConsensusState: args.ConsensusState, ConsensusService: args.ConsensusService, PeerSignatureHandler: args.PeerSignatureHandler, + EnableEpochsHandler: args.EnableEpochsHandler, + Marshaller: args.Marshalizer, + ShardCoordinator: args.ShardCoordinator, SignatureSize: args.SignatureSize, PublicKeySize: args.PublicKeySize, HeaderHashSize: args.Hasher.Size(), @@ -157,11 +165,12 @@ func NewWorker(args *WorkerArgs) (*Worker, error) { nodeRedundancyHandler: args.NodeRedundancyHandler, peerBlacklistHandler: args.PeerBlacklistHandler, closer: closing.NewSafeChanCloser(), + enableEpochsHandler: args.EnableEpochsHandler, } wrk.consensusMessageValidator = consensusMessageValidatorObj wrk.executeMessageChannel = make(chan *consensus.Message) - wrk.receivedMessagesCalls = make(map[consensus.MessageType]func(context.Context, *consensus.Message) bool) + wrk.receivedMessagesCalls = make(map[consensus.MessageType][]func(context.Context, *consensus.Message) bool) wrk.receivedHeadersHandlers = make([]func(data.HeaderHandler), 0) wrk.consensusStateChangedChannel = make(chan bool, 1) wrk.bootstrapper.AddSyncStateListener(wrk.receivedSyncState) @@ -257,6 +266,9 @@ func checkNewWorkerParams(args *WorkerArgs) error { if check.IfNil(args.PeerBlacklistHandler) { return ErrNilPeerBlacklistHandler } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } return nil } @@ -298,23 +310,46 @@ func (wrk *Worker) AddReceivedHeaderHandler(handler func(data.HeaderHandler)) { wrk.mutReceivedHeadersHandler.Unlock() } +// ReceivedProof process the received proof, calling each received proof handler registered in worker instance +func (wrk *Worker) ReceivedProof(proofHandler consensus.ProofHandler) { + if check.IfNilReflect(proofHandler) { + log.Trace("ReceivedProof: nil proof handler") + return + } + + log.Trace("ReceivedProof:", "proof header", proofHandler.GetHeaderHash()) + + wrk.mutReceivedProofHandler.RLock() + for _, handler := range wrk.receivedProofHandlers { + handler(proofHandler) + } + wrk.mutReceivedProofHandler.RUnlock() +} + +// AddReceivedProofHandler adds a new handler function for a received proof +func (wrk *Worker) AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) { + wrk.mutReceivedProofHandler.Lock() + wrk.receivedProofHandlers = append(wrk.receivedProofHandlers, handler) + wrk.mutReceivedProofHandler.Unlock() +} + func (wrk *Worker) initReceivedMessages() { wrk.mutReceivedMessages.Lock() wrk.receivedMessages = wrk.consensusService.InitReceivedMessages() wrk.mutReceivedMessages.Unlock() } -// AddReceivedMessageCall adds a new handler function for a received messege type +// AddReceivedMessageCall adds a new handler function for a received message type func (wrk *Worker) AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls[messageType] = receivedMessageCall + wrk.receivedMessagesCalls[messageType] = append(wrk.receivedMessagesCalls[messageType], receivedMessageCall) wrk.mutReceivedMessagesCalls.Unlock() } // RemoveAllReceivedMessagesCalls removes all the functions handlers func (wrk *Worker) RemoveAllReceivedMessagesCalls() { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls = make(map[consensus.MessageType]func(context.Context, *consensus.Message) bool) + wrk.receivedMessagesCalls = make(map[consensus.MessageType][]func(context.Context, *consensus.Message) bool) wrk.mutReceivedMessagesCalls.Unlock() } @@ -389,23 +424,14 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP ) } - msgType := consensus.MessageType(cnsMsg.MsgType) - - log.Trace("received message from consensus topic", - "msg type", wrk.consensusService.GetStringValue(msgType), - "from", cnsMsg.PubKey, - "header hash", cnsMsg.BlockHeaderHash, - "round", cnsMsg.RoundIndex, - "size", len(message.Data()), - ) - - err = wrk.consensusMessageValidator.checkConsensusMessageValidity(cnsMsg, message.Peer()) + err = wrk.checkValidityAndProcessFinalInfo(cnsMsg, message) if err != nil { return err } wrk.networkShardingCollector.UpdatePeerIDInfo(message.Peer(), cnsMsg.PubKey, wrk.shardCoordinator.SelfId()) + msgType := consensus.MessageType(cnsMsg.MsgType) isMessageWithBlockBody := wrk.consensusService.IsMessageWithBlockBody(msgType) isMessageWithBlockHeader := wrk.consensusService.IsMessageWithBlockHeader(msgType) isMessageWithBlockBodyAndHeader := wrk.consensusService.IsMessageWithBlockBodyAndHeader(msgType) @@ -446,7 +472,8 @@ func (wrk *Worker) shouldBlacklistPeer(err error) bool { errors.Is(err, errorsErd.ErrPIDMismatch) || errors.Is(err, errorsErd.ErrSignatureMismatch) || errors.Is(err, nodesCoordinator.ErrEpochNodesConfigDoesNotExist) || - errors.Is(err, ErrMessageTypeLimitReached) { + errors.Is(err, ErrMessageTypeLimitReached) || + errors.Is(err, ErrEquivalentMessageAlreadyReceived) { return false } @@ -503,6 +530,11 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { err) } + err = wrk.checkHeaderPreviousProof(header) + if err != nil { + return err + } + wrk.processReceivedHeaderMetric(cnsMsg) errNotCritical := wrk.forkDetector.AddHeader(header, headerHash, process.BHProposed, nil, nil) @@ -516,6 +548,18 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { return nil } +func (wrk *Worker) checkHeaderPreviousProof(header data.HeaderHandler) error { + if wrk.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + return fmt.Errorf("%w : received header on consensus topic after equivalent messages activation", ErrConsensusMessageNotExpected) + } + + if !check.IfNilReflect(header.GetPreviousProof()) { + return fmt.Errorf("%w : received header from consensus topic has previous proof", ErrHeaderProofNotExpected) + } + + return nil +} + func (wrk *Worker) verifyHeaderHash(hash []byte, marshalledHeader []byte) bool { computedHash := wrk.hasher.Compute(string(marshalledHeader)) return bytes.Equal(hash, computedHash) @@ -580,7 +624,7 @@ func (wrk *Worker) checkSelfState(cnsDta *consensus.Message) error { return ErrMessageFromItself } - if wrk.consensusState.RoundCanceled && wrk.consensusState.RoundIndex == cnsDta.RoundIndex { + if wrk.consensusState.GetRoundCanceled() && wrk.consensusState.GetRoundIndex() == cnsDta.RoundIndex { return ErrRoundCanceled } @@ -616,7 +660,7 @@ func (wrk *Worker) executeMessage(cnsDtaList []*consensus.Message) { if cnsDta == nil { continue } - if wrk.consensusState.RoundIndex != cnsDta.RoundIndex { + if wrk.consensusState.GetRoundIndex() != cnsDta.RoundIndex { continue } @@ -652,11 +696,13 @@ func (wrk *Worker) checkChannels(ctx context.Context) { msgType := consensus.MessageType(rcvDta.MsgType) - if callReceivedMessage, exist := wrk.receivedMessagesCalls[msgType]; exist { - if callReceivedMessage(ctx, rcvDta) { - select { - case wrk.consensusStateChangedChannel <- true: - default: + if receivedMessageCallbacks, exist := wrk.receivedMessagesCalls[msgType]; exist { + for _, callReceivedMessage := range receivedMessageCallbacks { + if callReceivedMessage(ctx, rcvDta) { + select { + case wrk.consensusStateChangedChannel <- true: + default: + } } } } @@ -665,7 +711,7 @@ func (wrk *Worker) checkChannels(ctx context.Context) { // Extend does an extension for the subround with subroundId func (wrk *Worker) Extend(subroundId int) { - wrk.consensusState.ExtendedCalled = true + wrk.consensusState.SetExtendedCalled(true) log.Debug("extend function is called", "subround", wrk.consensusService.GetSubroundName(subroundId)) @@ -732,11 +778,25 @@ func (wrk *Worker) Close() error { return nil } -// ResetConsensusMessages resets at the start of each round all the previous consensus messages received +// ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs func (wrk *Worker) ResetConsensusMessages() { wrk.consensusMessageValidator.resetConsensusMessages() } +func (wrk *Worker) checkValidityAndProcessFinalInfo(cnsMsg *consensus.Message, p2pMessage p2p.MessageP2P) error { + msgType := consensus.MessageType(cnsMsg.MsgType) + + log.Trace("received message from consensus topic", + "msg type", wrk.consensusService.GetStringValue(msgType), + "from", cnsMsg.PubKey, + "header hash", cnsMsg.BlockHeaderHash, + "round", cnsMsg.RoundIndex, + "size", len(p2pMessage.Data()), + ) + + return wrk.consensusMessageValidator.checkConsensusMessageValidity(cnsMsg, p2pMessage.Peer()) +} + // IsInterfaceNil returns true if there is no value under the interface func (wrk *Worker) IsInterfaceNil() bool { return wrk == nil diff --git a/consensus/spos/worker_test.go b/consensus/spos/worker_test.go index b179fdf0db8..ef00af26c2e 100644 --- a/consensus/spos/worker_test.go +++ b/consensus/spos/worker_test.go @@ -27,8 +27,13 @@ import ( "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/cache" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/processMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" ) @@ -58,14 +63,14 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke return nil }, } - bootstrapperMock := &mock.BootstrapperStub{} - broadcastMessengerMock := &mock.BroadcastMessengerMock{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + broadcastMessengerMock := &consensusMocks.BroadcastMessengerMock{} consensusState := initConsensusState() - forkDetectorMock := &mock.ForkDetectorMock{} + forkDetectorMock := &processMocks.ForkDetectorStub{} forkDetectorMock.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { return nil } - keyGeneratorMock, _, _ := mock.InitKeys() + keyGeneratorMock, _, _ := consensusMocks.InitKeys() marshalizerMock := mock.MarshalizerMock{} roundHandlerMock := initRoundHandlerMock() shardCoordinatorMock := mock.ShardCoordinatorMock{} @@ -77,10 +82,10 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke return nil }, } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} hasher := &hashingMocks.HasherMock{} blsService, _ := bls.NewConsensusService() - poolAdder := testscommon.NewCacherMock() + poolAdder := cache.NewCacherMock() scheduledProcessorArgs := spos.ScheduledProcessorWrapperArgs{ SyncTimer: syncTimerMock, @@ -90,6 +95,7 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke scheduledProcessor, _ := spos.NewScheduledProcessorWrapper(scheduledProcessorArgs) peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock, KeyGen: keyGeneratorMock} + workerArgs := &spos.WorkerArgs{ ConsensusService: blsService, BlockChain: blockchainMock, @@ -105,8 +111,8 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke ShardCoordinator: shardCoordinatorMock, PeerSignatureHandler: peerSigHandler, SyncTimer: syncTimerMock, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + HeaderSigVerifier: &consensusMocks.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &testscommon.HeaderVersionHandlerStub{}, ChainID: chainID, NetworkShardingCollector: &p2pmocks.NetworkShardingCollectorStub{}, AntifloodHandler: createMockP2PAntifloodHandler(), @@ -116,6 +122,7 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke AppStatusHandler: appStatusHandler, NodeRedundancyHandler: &mock.NodeRedundancyHandlerStub{}, PeerBlacklistHandler: &mock.PeerBlacklistHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return workerArgs @@ -136,11 +143,13 @@ func initWorker(appStatusHandler core.AppStatusHandler) *spos.Worker { workerArgs := createDefaultWorkerArgs(appStatusHandler) sposWorker, _ := spos.NewWorker(workerArgs) + sposWorker.ConsensusState().Header = &block.HeaderV2{} + return sposWorker } -func initRoundHandlerMock() *mock.RoundHandlerMock { - return &mock.RoundHandlerMock{ +func initRoundHandlerMock() *consensusMocks.RoundHandlerMock { + return &consensusMocks.RoundHandlerMock{ RoundIndex: 0, TimeStampCalled: func() time.Time { return time.Unix(0, 0) @@ -370,6 +379,17 @@ func TestWorker_NewWorkerNodeRedundancyHandlerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilNodeRedundancyHandler, err) } +func TestWorker_NewWorkerPoolEnableEpochsHandlerNilShouldFail(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + workerArgs.EnableEpochsHandler = nil + wrk, err := spos.NewWorker(workerArgs) + + assert.Nil(t, wrk) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) +} + func TestWorker_NewWorkerShouldWork(t *testing.T) { t.Parallel() @@ -466,7 +486,7 @@ func TestWorker_AddReceivedMessageCallShouldWork(t *testing.T) { assert.Equal(t, 1, len(receivedMessageCalls)) assert.NotNil(t, receivedMessageCalls[bls.MtBlockBody]) - assert.True(t, receivedMessageCalls[bls.MtBlockBody](context.Background(), nil)) + assert.True(t, receivedMessageCalls[bls.MtBlockBody][0](context.Background(), nil)) } func TestWorker_RemoveAllReceivedMessageCallsShouldWork(t *testing.T) { @@ -480,7 +500,7 @@ func TestWorker_RemoveAllReceivedMessageCallsShouldWork(t *testing.T) { assert.Equal(t, 1, len(receivedMessageCalls)) assert.NotNil(t, receivedMessageCalls[bls.MtBlockBody]) - assert.True(t, receivedMessageCalls[bls.MtBlockBody](context.Background(), nil)) + assert.True(t, receivedMessageCalls[bls.MtBlockBody][0](context.Background(), nil)) wrk.RemoveAllReceivedMessagesCalls() receivedMessageCalls = wrk.ReceivedMessagesCalls() @@ -765,7 +785,7 @@ func testWorkerProcessReceivedMessageComputeReceivedProposedBlockMetric( }, }) - wrk.SetRoundHandler(&mock.RoundHandlerMock{ + wrk.SetRoundHandler(&consensusMocks.RoundHandlerMock{ RoundIndex: 0, TimeDurationCalled: func() time.Duration { return roundDuration @@ -793,7 +813,7 @@ func testWorkerProcessReceivedMessageComputeReceivedProposedBlockMetric( nil, nil, hdrStr, - []byte(wrk.ConsensusState().ConsensusGroup()[0]), + []byte(wrk.ConsensusState().Leader()), signature, int(bls.MtBlockHeader), 0, @@ -1258,6 +1278,7 @@ func TestWorker_ProcessReceivedMessageWithHeaderAndWrongHash(t *testing.T) { workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().Header = &block.HeaderV2{} wrk.SetBlockProcessor( &testscommon.BlockProcessorStub{ @@ -1327,6 +1348,7 @@ func TestWorker_ProcessReceivedMessageOkValsShouldWork(t *testing.T) { }, } wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().Header = &block.HeaderV2{} wrk.SetBlockProcessor( &testscommon.BlockProcessorStub{ @@ -1671,7 +1693,7 @@ func TestWorker_CheckChannelsShouldWork(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) wrk.StartWorking() - wrk.SetReceivedMessagesCalls(bls.MtBlockHeader, func(ctx context.Context, cnsMsg *consensus.Message) bool { + wrk.AppendReceivedMessagesCalls(bls.MtBlockHeader, func(ctx context.Context, cnsMsg *consensus.Message) bool { _ = wrk.ConsensusState().SetJobDone(wrk.ConsensusState().ConsensusGroup()[0], bls.SrBlock, true) return true }) @@ -1713,7 +1735,7 @@ func TestWorker_ExtendShouldReturnWhenRoundIsCanceled(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }, @@ -1733,7 +1755,7 @@ func TestWorker_ExtendShouldReturnWhenGetNodeStateNotReturnSynchronized(t *testi t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }, @@ -1752,14 +1774,14 @@ func TestWorker_ExtendShouldReturnWhenCreateEmptyBlockFail(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bmm := &mock.BroadcastMessengerMock{ + bmm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { executed = true return nil }, } wrk.SetBroadcastMessenger(bmm) - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ CreateAndCommitEmptyBlockCalled: func(shardForCurrentNode uint32) (data.BodyHandler, data.HeaderHandler, error) { return nil, nil, errors.New("error") }} @@ -1863,13 +1885,14 @@ func TestWorker_ProcessReceivedMessageWrongHeaderShouldErr(t *testing.T) { t.Parallel() workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) - headerSigVerifier := &mock.HeaderSigVerifierStub{} + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{} headerSigVerifier.VerifyRandSeedCalled = func(header data.HeaderHandler) error { return process.ErrRandSeedDoesNotMatch } workerArgs.HeaderSigVerifier = headerSigVerifier wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().Header = &block.HeaderV2{} hdr := &block.Header{} hdr.Nonce = 1 @@ -1911,6 +1934,7 @@ func TestWorker_ProcessReceivedMessageWithSignature(t *testing.T) { workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().Header = &block.HeaderV2{} hdr := &block.Header{} hdr.Nonce = 1 diff --git a/dataRetriever/blockchain/baseBlockchain_test.go b/dataRetriever/blockchain/baseBlockchain_test.go index 3f6121b6a07..69a49304db0 100644 --- a/dataRetriever/blockchain/baseBlockchain_test.go +++ b/dataRetriever/blockchain/baseBlockchain_test.go @@ -8,6 +8,8 @@ import ( ) func TestBaseBlockchain_SetAndGetSetFinalBlockInfo(t *testing.T) { + t.Parallel() + base := &baseBlockChain{ appStatusHandler: &mock.AppStatusHandlerStub{}, finalBlockInfo: &blockInfo{}, @@ -26,6 +28,8 @@ func TestBaseBlockchain_SetAndGetSetFinalBlockInfo(t *testing.T) { } func TestBaseBlockchain_SetAndGetSetFinalBlockInfoWorksWithNilValues(t *testing.T) { + t.Parallel() + base := &baseBlockChain{ appStatusHandler: &mock.AppStatusHandlerStub{}, finalBlockInfo: &blockInfo{}, diff --git a/dataRetriever/dataPool/dataPool.go b/dataRetriever/dataPool/dataPool.go index 67b55cbfaee..be759b15b43 100644 --- a/dataRetriever/dataPool/dataPool.go +++ b/dataRetriever/dataPool/dataPool.go @@ -26,6 +26,7 @@ type dataPool struct { peerAuthentications storage.Cacher heartbeats storage.Cacher validatorsInfo dataRetriever.ShardedDataCacherNotifier + proofs dataRetriever.ProofsPool } // DataPoolArgs represents the data pool's constructor structure @@ -44,6 +45,7 @@ type DataPoolArgs struct { PeerAuthentications storage.Cacher Heartbeats storage.Cacher ValidatorsInfo dataRetriever.ShardedDataCacherNotifier + Proofs dataRetriever.ProofsPool } // NewDataPool creates a data pools holder object @@ -90,6 +92,9 @@ func NewDataPool(args DataPoolArgs) (*dataPool, error) { if check.IfNil(args.ValidatorsInfo) { return nil, dataRetriever.ErrNilValidatorInfoPool } + if check.IfNil(args.Proofs) { + return nil, dataRetriever.ErrNilProofsPool + } return &dataPool{ transactions: args.Transactions, @@ -106,6 +111,7 @@ func NewDataPool(args DataPoolArgs) (*dataPool, error) { peerAuthentications: args.PeerAuthentications, heartbeats: args.Heartbeats, validatorsInfo: args.ValidatorsInfo, + proofs: args.Proofs, }, nil } @@ -179,6 +185,11 @@ func (dp *dataPool) ValidatorsInfo() dataRetriever.ShardedDataCacherNotifier { return dp.validatorsInfo } +// Proofs returns the holder for equivalent proofs +func (dp *dataPool) Proofs() dataRetriever.ProofsPool { + return dp.proofs +} + // Close closes all the components func (dp *dataPool) Close() error { var lastError error diff --git a/dataRetriever/dataPool/dataPool_test.go b/dataRetriever/dataPool/dataPool_test.go index b948b7f2d44..9a8f17181e3 100644 --- a/dataRetriever/dataPool/dataPool_test.go +++ b/dataRetriever/dataPool/dataPool_test.go @@ -8,11 +8,14 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -//------- NewDataPool +// ------- NewDataPool func createMockDataPoolArgs() dataPool.DataPoolArgs { return dataPool.DataPoolArgs{ @@ -20,16 +23,17 @@ func createMockDataPoolArgs() dataPool.DataPoolArgs { UnsignedTransactions: testscommon.NewShardedDataStub(), RewardTransactions: testscommon.NewShardedDataStub(), Headers: &mock.HeadersCacherStub{}, - MiniBlocks: testscommon.NewCacherStub(), - PeerChangesBlocks: testscommon.NewCacherStub(), - TrieNodes: testscommon.NewCacherStub(), - TrieNodesChunks: testscommon.NewCacherStub(), + MiniBlocks: cache.NewCacherStub(), + PeerChangesBlocks: cache.NewCacherStub(), + TrieNodes: cache.NewCacherStub(), + TrieNodesChunks: cache.NewCacherStub(), CurrentBlockTransactions: &mock.TxForCurrentBlockStub{}, CurrentEpochValidatorInfo: &mock.ValidatorInfoForCurrentEpochStub{}, - SmartContracts: testscommon.NewCacherStub(), - PeerAuthentications: testscommon.NewCacherStub(), - Heartbeats: testscommon.NewCacherStub(), + SmartContracts: cache.NewCacherStub(), + PeerAuthentications: cache.NewCacherStub(), + Heartbeats: cache.NewCacherStub(), ValidatorsInfo: testscommon.NewShardedDataStub(), + Proofs: &dataRetrieverMocks.ProofsPoolMock{}, } } @@ -195,7 +199,7 @@ func TestNewDataPool_OkValsShouldWork(t *testing.T) { assert.Nil(t, err) require.False(t, tdp.IsInterfaceNil()) - //pointer checking + // pointer checking assert.True(t, args.Transactions == tdp.Transactions()) assert.True(t, args.UnsignedTransactions == tdp.UnsignedTransactions()) assert.True(t, args.RewardTransactions == tdp.RewardTransactions()) @@ -220,7 +224,7 @@ func TestNewDataPool_Close(t *testing.T) { t.Parallel() args := createMockDataPoolArgs() - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { return expectedErr }, @@ -234,7 +238,7 @@ func TestNewDataPool_Close(t *testing.T) { t.Parallel() args := createMockDataPoolArgs() - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { return expectedErr }, @@ -251,13 +255,13 @@ func TestNewDataPool_Close(t *testing.T) { paExpectedErr := errors.New("pa expected error") args := createMockDataPoolArgs() tnCalled, paCalled := false, false - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { tnCalled = true return tnExpectedErr }, } - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { paCalled = true return paExpectedErr @@ -275,13 +279,13 @@ func TestNewDataPool_Close(t *testing.T) { args := createMockDataPoolArgs() tnCalled, paCalled := false, false - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { tnCalled = true return nil }, } - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { paCalled = true return nil diff --git a/dataRetriever/dataPool/proofsCache/errors.go b/dataRetriever/dataPool/proofsCache/errors.go new file mode 100644 index 00000000000..630dd8cc394 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/errors.go @@ -0,0 +1,12 @@ +package proofscache + +import "errors" + +// ErrMissingProof signals that the proof is missing +var ErrMissingProof = errors.New("missing proof") + +// ErrNilProof signals that a nil proof has been provided +var ErrNilProof = errors.New("nil proof provided") + +// ErrAlreadyExistingEquivalentProof signals that the provided proof was already exiting in the pool +var ErrAlreadyExistingEquivalentProof = errors.New("already existing equivalent proof") diff --git a/dataRetriever/dataPool/proofsCache/proofsCache.go b/dataRetriever/dataPool/proofsCache/proofsCache.go new file mode 100644 index 00000000000..2bce293b034 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsCache.go @@ -0,0 +1,81 @@ +package proofscache + +import ( + "sort" + "sync" + + "github.com/multiversx/mx-chain-core-go/data" +) + +type proofNonceMapping struct { + headerHash string + nonce uint64 +} + +type proofsCache struct { + mutProofsCache sync.RWMutex + proofsByNonce []*proofNonceMapping + proofsByHash map[string]data.HeaderProofHandler +} + +func newProofsCache() *proofsCache { + return &proofsCache{ + mutProofsCache: sync.RWMutex{}, + proofsByNonce: make([]*proofNonceMapping, 0), + proofsByHash: make(map[string]data.HeaderProofHandler), + } +} + +func (pc *proofsCache) getProofByHash(headerHash []byte) (data.HeaderProofHandler, error) { + pc.mutProofsCache.RLock() + defer pc.mutProofsCache.RUnlock() + + proof, ok := pc.proofsByHash[string(headerHash)] + if !ok { + return nil, ErrMissingProof + } + + return proof, nil +} + +func (pc *proofsCache) addProof(proof data.HeaderProofHandler) { + if proof == nil { + return + } + + pc.mutProofsCache.Lock() + defer pc.mutProofsCache.Unlock() + + pc.proofsByNonce = append(pc.proofsByNonce, &proofNonceMapping{ + headerHash: string(proof.GetHeaderHash()), + nonce: proof.GetHeaderNonce(), + }) + + sort.Slice(pc.proofsByNonce, func(i, j int) bool { + return pc.proofsByNonce[i].nonce < pc.proofsByNonce[j].nonce + }) + + pc.proofsByHash[string(proof.GetHeaderHash())] = proof +} + +func (pc *proofsCache) cleanupProofsBehindNonce(nonce uint64) { + if nonce == 0 { + return + } + + pc.mutProofsCache.Lock() + defer pc.mutProofsCache.Unlock() + + proofsByNonce := make([]*proofNonceMapping, 0) + + for _, proofInfo := range pc.proofsByNonce { + if proofInfo.nonce < nonce { + delete(pc.proofsByHash, proofInfo.headerHash) + continue + } + + proofsByNonce = append(proofsByNonce, proofInfo) + } + + pc.proofsByNonce = proofsByNonce +} diff --git a/dataRetriever/dataPool/proofsCache/proofsPool.go b/dataRetriever/dataPool/proofsCache/proofsPool.go new file mode 100644 index 00000000000..a412794a6db --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsPool.go @@ -0,0 +1,153 @@ +package proofscache + +import ( + "encoding/hex" + "fmt" + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("dataRetriever/proofscache") + +type proofsPool struct { + mutCache sync.RWMutex + cache map[uint32]*proofsCache + + mutAddedProofSubscribers sync.RWMutex + addedProofSubscribers []func(headerProof data.HeaderProofHandler) +} + +// NewProofsPool creates a new proofs pool component +func NewProofsPool() *proofsPool { + return &proofsPool{ + cache: make(map[uint32]*proofsCache), + addedProofSubscribers: make([]func(headerProof data.HeaderProofHandler), 0), + } +} + +// AddProof will add the provided proof to the pool +func (pp *proofsPool) AddProof( + headerProof data.HeaderProofHandler, +) error { + if check.IfNilReflect(headerProof) { + return ErrNilProof + } + + shardID := headerProof.GetHeaderShardId() + headerHash := headerProof.GetHeaderHash() + + hasProof := pp.HasProof(shardID, headerHash) + if hasProof { + return fmt.Errorf("%w, headerHash: %s", ErrAlreadyExistingEquivalentProof, hex.EncodeToString(headerHash)) + } + + pp.mutCache.Lock() + defer pp.mutCache.Unlock() + + proofsPerShard, ok := pp.cache[shardID] + if !ok { + proofsPerShard = newProofsCache() + pp.cache[shardID] = proofsPerShard + } + + log.Trace("added proof to pool", + "header hash", headerProof.GetHeaderHash(), + "epoch", headerProof.GetHeaderEpoch(), + "nonce", headerProof.GetHeaderNonce(), + "shardID", headerProof.GetHeaderShardId(), + "pubKeys bitmap", headerProof.GetPubKeysBitmap(), + ) + + proofsPerShard.addProof(headerProof) + + pp.callAddedProofSubscribers(headerProof) + + return nil +} + +func (pp *proofsPool) callAddedProofSubscribers(headerProof data.HeaderProofHandler) { + pp.mutAddedProofSubscribers.RLock() + defer pp.mutAddedProofSubscribers.RUnlock() + + for _, handler := range pp.addedProofSubscribers { + go handler(headerProof) + } +} + +// CleanupProofsBehindNonce will cleanup proofs from pool based on nonce +func (pp *proofsPool) CleanupProofsBehindNonce(shardID uint32, nonce uint64) error { + if nonce == 0 { + return nil + } + + pp.mutCache.RLock() + defer pp.mutCache.RUnlock() + + proofsPerShard, ok := pp.cache[shardID] + if !ok { + return fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) + } + + log.Trace("cleanup proofs behind nonce", + "nonce", nonce, + "shardID", shardID, + ) + + proofsPerShard.cleanupProofsBehindNonce(nonce) + + return nil +} + +// GetProof will get the proof from pool +func (pp *proofsPool) GetProof( + shardID uint32, + headerHash []byte, +) (data.HeaderProofHandler, error) { + if headerHash == nil { + return nil, fmt.Errorf("nil header hash") + } + + pp.mutCache.RLock() + defer pp.mutCache.RUnlock() + + log.Trace("trying to get proof", + "headerHash", headerHash, + "shardID", shardID, + ) + + proofsPerShard, ok := pp.cache[shardID] + if !ok { + return nil, fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) + } + + return proofsPerShard.getProofByHash(headerHash) +} + +// HasProof will check if there is a proof for the provided hash +func (pp *proofsPool) HasProof( + shardID uint32, + headerHash []byte, +) bool { + _, err := pp.GetProof(shardID, headerHash) + return err == nil +} + +// RegisterHandler registers a new handler to be called when a new data is added +func (pp *proofsPool) RegisterHandler(handler func(headerProof data.HeaderProofHandler)) { + if handler == nil { + log.Error("attempt to register a nil handler to proofs pool") + return + } + + pp.mutAddedProofSubscribers.Lock() + pp.addedProofSubscribers = append(pp.addedProofSubscribers, handler) + pp.mutAddedProofSubscribers.Unlock() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (pp *proofsPool) IsInterfaceNil() bool { + return pp == nil +} diff --git a/dataRetriever/dataPool/proofsCache/proofsPool_test.go b/dataRetriever/dataPool/proofsCache/proofsPool_test.go new file mode 100644 index 00000000000..c4e373eeba7 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsPool_test.go @@ -0,0 +1,177 @@ +package proofscache_test + +import ( + "crypto/rand" + "errors" + "math/big" + "sync" + "sync/atomic" + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewProofsPool(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool() + require.False(t, pp.IsInterfaceNil()) +} + +func TestProofsPool_ShouldWork(t *testing.T) { + t.Parallel() + + shardID := uint32(1) + + pp := proofscache.NewProofsPool() + + proof1 := &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap1"), + AggregatedSignature: []byte("aggSig1"), + HeaderHash: []byte("hash1"), + HeaderEpoch: 1, + HeaderNonce: 1, + HeaderShardId: shardID, + } + proof2 := &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap2"), + AggregatedSignature: []byte("aggSig2"), + HeaderHash: []byte("hash2"), + HeaderEpoch: 1, + HeaderNonce: 2, + HeaderShardId: shardID, + } + proof3 := &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap3"), + AggregatedSignature: []byte("aggSig3"), + HeaderHash: []byte("hash3"), + HeaderEpoch: 1, + HeaderNonce: 3, + HeaderShardId: shardID, + } + proof4 := &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap4"), + AggregatedSignature: []byte("aggSig4"), + HeaderHash: []byte("hash4"), + HeaderEpoch: 1, + HeaderNonce: 4, + HeaderShardId: shardID, + } + _ = pp.AddProof(proof1) + _ = pp.AddProof(proof2) + _ = pp.AddProof(proof3) + _ = pp.AddProof(proof4) + + err := pp.AddProof(proof4) + require.True(t, errors.Is(err, proofscache.ErrAlreadyExistingEquivalentProof)) + + proof, err := pp.GetProof(shardID, []byte("hash3")) + require.Nil(t, err) + require.Equal(t, proof3, proof) + + err = pp.CleanupProofsBehindNonce(shardID, 4) + require.Nil(t, err) + + proof, err = pp.GetProof(shardID, []byte("hash3")) + require.Equal(t, proofscache.ErrMissingProof, err) + require.Nil(t, proof) + + proof, err = pp.GetProof(shardID, []byte("hash4")) + require.Nil(t, err) + require.Equal(t, proof4, proof) +} + +func TestProofsPool_RegisterHandler(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool() + + wasCalled := false + wg := sync.WaitGroup{} + wg.Add(1) + handler := func(proof data.HeaderProofHandler) { + wasCalled = true + wg.Done() + } + pp.RegisterHandler(nil) + pp.RegisterHandler(handler) + + _ = pp.AddProof(generateProof()) + + wg.Wait() + + assert.True(t, wasCalled) +} + +func TestProofsPool_Concurrency(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool() + + numOperations := 1000 + + wg := sync.WaitGroup{} + wg.Add(numOperations) + + cnt := uint32(0) + + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx % 6 { + case 0, 1, 2: + _ = pp.AddProof(generateProof()) + case 3: + _, err := pp.GetProof(generateRandomShardID(), generateRandomHash()) + if errors.Is(err, proofscache.ErrMissingProof) { + atomic.AddUint32(&cnt, 1) + } + case 4: + _ = pp.CleanupProofsBehindNonce(generateRandomShardID(), generateRandomNonce()) + case 5: + handler := func(proof data.HeaderProofHandler) { + } + pp.RegisterHandler(handler) + default: + assert.Fail(t, "should have not beed called") + } + + wg.Done() + }(i) + } + + require.GreaterOrEqual(t, uint32(numOperations/3), atomic.LoadUint32(&cnt)) +} + +func generateProof() *block.HeaderProof { + return &block.HeaderProof{ + HeaderHash: generateRandomHash(), + HeaderEpoch: 1, + HeaderNonce: generateRandomNonce(), + HeaderShardId: generateRandomShardID(), + } +} + +func generateRandomHash() []byte { + hashSuffix := generateRandomInt(100) + hash := []byte("hash_" + hashSuffix.String()) + return hash +} + +func generateRandomNonce() uint64 { + val := generateRandomInt(3) + return val.Uint64() +} + +func generateRandomShardID() uint32 { + val := generateRandomInt(3) + return uint32(val.Uint64()) +} + +func generateRandomInt(max int64) *big.Int { + rantInt, _ := rand.Int(rand.Reader, big.NewInt(max)) + return rantInt +} diff --git a/dataRetriever/errors.go b/dataRetriever/errors.go index a015e6e10ed..21465bf26c7 100644 --- a/dataRetriever/errors.go +++ b/dataRetriever/errors.go @@ -265,3 +265,6 @@ var ErrNilValidatorInfoStorage = errors.New("nil validator info storage") // ErrValidatorInfoNotFound signals that no validator info was found var ErrValidatorInfoNotFound = errors.New("validator info not found") + +// ErrNilProofsPool signals that a nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") diff --git a/dataRetriever/factory/dataPoolFactory.go b/dataRetriever/factory/dataPoolFactory.go index 6e1415ddfd8..b9651bf3d6a 100644 --- a/dataRetriever/factory/dataPoolFactory.go +++ b/dataRetriever/factory/dataPoolFactory.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/process" @@ -146,8 +147,10 @@ func NewDataPoolFromConfig(args ArgsDataPool) (dataRetriever.PoolsHolder, error) return nil, fmt.Errorf("%w while creating the cache for the validator info results", err) } + proofsPool := proofscache.NewProofsPool() currBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() + dataPoolArgs := dataPool.DataPoolArgs{ Transactions: txPool, UnsignedTransactions: uTxPool, @@ -163,6 +166,7 @@ func NewDataPoolFromConfig(args ArgsDataPool) (dataRetriever.PoolsHolder, error) PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } return dataPool.NewDataPool(dataPoolArgs) } diff --git a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go index 755672384cd..2891e3f8888 100644 --- a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -15,11 +17,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" ) func createStubMessengerForMeta(matchStrToErrOnCreate string, matchStrToErrOnRegister string) p2p.Messenger { @@ -56,7 +58,7 @@ func createDataPoolsForMeta() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() diff --git a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go index ca97015f3ae..4c144ebb034 100644 --- a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -15,11 +17,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" ) var errExpected = errors.New("expected error") @@ -63,10 +65,10 @@ func createDataPoolsForShard() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() diff --git a/dataRetriever/interface.go b/dataRetriever/interface.go index 930b6aca124..ade580bd985 100644 --- a/dataRetriever/interface.go +++ b/dataRetriever/interface.go @@ -240,6 +240,7 @@ type PoolsHolder interface { PeerAuthentications() storage.Cacher Heartbeats() storage.Cacher ValidatorsInfo() ShardedDataCacherNotifier + Proofs() ProofsPool Close() error IsInterfaceNil() bool } @@ -357,3 +358,12 @@ type PeerAuthenticationPayloadValidator interface { ValidateTimestamp(payloadTimestamp int64) error IsInterfaceNil() bool } + +// ProofsPool defines the behaviour of a proofs pool components +type ProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) error + CleanupProofsBehindNonce(shardID uint32, nonce uint64) error + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} diff --git a/dataRetriever/provider/miniBlocks_test.go b/dataRetriever/provider/miniBlocks_test.go index dc0e4f206e8..271d8ef55e6 100644 --- a/dataRetriever/provider/miniBlocks_test.go +++ b/dataRetriever/provider/miniBlocks_test.go @@ -8,14 +8,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/provider" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockMiniblockProviderArgs( @@ -37,7 +38,7 @@ func createMockMiniblockProviderArgs( return nil, fmt.Errorf("not found") }, }, - MiniBlockPool: &testscommon.CacherStub{ + MiniBlockPool: &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if isByteSliceInSlice(key, dataPoolExistingHashes) { return &dataBlock.MiniBlock{}, true @@ -105,7 +106,7 @@ func TestNewMiniBlockProvider_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- GetMiniBlocksFromPool +// ------- GetMiniBlocksFromPool func TestMiniBlockProvider_GetMiniBlocksFromPoolFoundInPoolShouldReturn(t *testing.T) { t.Parallel() @@ -140,7 +141,7 @@ func TestMiniBlockProvider_GetMiniBlocksFromPoolWrongTypeInPoolShouldNotReturn(t hashes := [][]byte{[]byte("hash1"), []byte("hash2")} arg := createMockMiniblockProviderArgs(hashes, nil) - arg.MiniBlockPool = &testscommon.CacherStub{ + arg.MiniBlockPool = &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return "not a miniblock", true }, @@ -153,7 +154,7 @@ func TestMiniBlockProvider_GetMiniBlocksFromPoolWrongTypeInPoolShouldNotReturn(t assert.Equal(t, hashes, missingHashes) } -//------- GetMiniBlocks +// ------- GetMiniBlocks func TestMiniBlockProvider_GetMiniBlocksFoundInPoolShouldReturn(t *testing.T) { t.Parallel() diff --git a/dataRetriever/resolvers/miniblockResolver_test.go b/dataRetriever/resolvers/miniblockResolver_test.go index 35588e9d6a9..6bacadd6861 100644 --- a/dataRetriever/resolvers/miniblockResolver_test.go +++ b/dataRetriever/resolvers/miniblockResolver_test.go @@ -9,14 +9,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/p2p" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) var fromConnectedPeerId = core.PeerID("from connected peer Id") @@ -24,7 +25,7 @@ var fromConnectedPeerId = core.PeerID("from connected peer Id") func createMockArgMiniblockResolver() resolvers.ArgMiniblockResolver { return resolvers.ArgMiniblockResolver{ ArgBaseResolver: createMockArgBaseResolver(), - MiniBlockPool: testscommon.NewCacherStub(), + MiniBlockPool: cache.NewCacherStub(), MiniBlockStorage: &storageStubs.StorerStub{}, DataPacker: &mock.DataPackerStub{}, } @@ -173,7 +174,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend( wasResolved := false wasSent := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, mbHash) { wasResolved = true @@ -232,7 +233,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShoul assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, mbHash) { return &block.MiniBlock{}, true @@ -286,7 +287,7 @@ func TestMiniblockResolver_ProcessReceivedMessageUnmarshalFails(t *testing.T) { assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -331,7 +332,7 @@ func TestMiniblockResolver_ProcessReceivedMessagePackDataInChunksFails(t *testin assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -375,7 +376,7 @@ func TestMiniblockResolver_ProcessReceivedMessageSendFails(t *testing.T) { assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -420,7 +421,7 @@ func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStor wasResolved := false wasSend := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -467,7 +468,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMarshalFails(t *testing.T) { wasResolved := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -519,7 +520,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *tes wasSent := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } diff --git a/dataRetriever/resolvers/peerAuthenticationResolver_test.go b/dataRetriever/resolvers/peerAuthenticationResolver_test.go index 188c29d7e3f..8d6df446772 100644 --- a/dataRetriever/resolvers/peerAuthenticationResolver_test.go +++ b/dataRetriever/resolvers/peerAuthenticationResolver_test.go @@ -13,15 +13,17 @@ import ( "github.com/multiversx/mx-chain-core-go/core/partitioning" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var expectedErr = errors.New("expected error") @@ -57,7 +59,7 @@ func createMockPeerAuthenticationObject() interface{} { func createMockArgPeerAuthenticationResolver() resolvers.ArgPeerAuthenticationResolver { return resolvers.ArgPeerAuthenticationResolver{ ArgBaseResolver: createMockArgBaseResolver(), - PeerAuthenticationPool: testscommon.NewCacherStub(), + PeerAuthenticationPool: cache.NewCacherStub(), DataPacker: &mock.DataPackerStub{}, PayloadValidator: &testscommon.PeerAuthenticationPayloadValidatorStub{}, } @@ -233,7 +235,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { t.Run("resolveMultipleHashesRequest: all hashes missing from cache should error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -262,7 +264,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { t.Run("resolveMultipleHashesRequest: all hashes will return wrong objects should error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return "wrong object", true } @@ -292,7 +294,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { t.Parallel() arg := createMockArgPeerAuthenticationResolver() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -349,7 +351,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { val, ok := providedKeys[string(key)] return val, ok @@ -394,7 +396,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { t.Run("resolveMultipleHashesRequest: PackDataInChunks returns error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -419,7 +421,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { t.Run("resolveMultipleHashesRequest: Send returns error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -446,7 +448,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { providedKeys := getKeysSlice() expectedLen := len(providedKeys) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { for _, pk := range providedKeys { if bytes.Equal(pk, key) { diff --git a/epochStart/bootstrap/common.go b/epochStart/bootstrap/common.go index da6e99fda1b..a6621f86ed8 100644 --- a/epochStart/bootstrap/common.go +++ b/epochStart/bootstrap/common.go @@ -123,6 +123,9 @@ func checkArguments(args ArgsEpochStartBootstrap) error { if check.IfNil(args.NodesCoordinatorRegistryFactory) { return fmt.Errorf("%s: %w", baseErrorMessage, nodesCoordinator.ErrNilNodesCoordinatorRegistryFactory) } + if check.IfNil(args.EnableEpochsHandler) { + return fmt.Errorf("%s: %w", baseErrorMessage, epochStart.ErrNilEnableEpochsHandler) + } return nil } diff --git a/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go b/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go index d5de2e34380..e4c4bb14a25 100644 --- a/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go +++ b/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go @@ -2,6 +2,7 @@ package disabled import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -15,27 +16,42 @@ func NewHeaderSigVerifier() *headerSigVerifier { return &headerSigVerifier{} } -// VerifyRandSeed - +// VerifyRandSeed returns nil as it is disabled func (h *headerSigVerifier) VerifyRandSeed(_ data.HeaderHandler) error { return nil } -// VerifyLeaderSignature - +// VerifyLeaderSignature returns nil as it is disabled func (h *headerSigVerifier) VerifyLeaderSignature(_ data.HeaderHandler) error { return nil } -// VerifyRandSeedAndLeaderSignature - +// VerifyRandSeedAndLeaderSignature returns nil as it is disabled func (h *headerSigVerifier) VerifyRandSeedAndLeaderSignature(_ data.HeaderHandler) error { return nil } -// VerifySignature - +// VerifySignature returns nil as it is disabled func (h *headerSigVerifier) VerifySignature(_ data.HeaderHandler) error { return nil } -// IsInterfaceNil - +// VerifySignatureForHash returns nil as it is disabled +func (h *headerSigVerifier) VerifySignatureForHash(_ data.HeaderHandler, _ []byte, _ []byte, _ []byte) error { + return nil +} + +// VerifyHeaderWithProof returns nil as it is disabled +func (h *headerSigVerifier) VerifyHeaderWithProof(_ data.HeaderHandler) error { + return nil +} + +// VerifyHeaderProof returns nil as it is disabled +func (h *headerSigVerifier) VerifyHeaderProof(_ data.HeaderProofHandler) error { + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface func (h *headerSigVerifier) IsInterfaceNil() bool { return h == nil } diff --git a/epochStart/bootstrap/disabled/disabledNodesCoordinator.go b/epochStart/bootstrap/disabled/disabledNodesCoordinator.go index e204aec7cc8..16c2dd104be 100644 --- a/epochStart/bootstrap/disabled/disabledNodesCoordinator.go +++ b/epochStart/bootstrap/disabled/disabledNodesCoordinator.go @@ -44,6 +44,11 @@ func (n *nodesCoordinator) GetAllEligibleValidatorsPublicKeys(_ uint32) (map[uin return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (n *nodesCoordinator) GetAllEligibleValidatorsPublicKeysForShard(_ uint32, _ uint32) ([]string, error) { + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (n *nodesCoordinator) GetAllWaitingValidatorsPublicKeys(_ uint32) (map[uint32][][]byte, error) { return nil, nil @@ -60,8 +65,8 @@ func (n *nodesCoordinator) GetShuffledOutToAuctionValidatorsPublicKeys(_ uint32) } // GetConsensusValidatorsPublicKeys - -func (n *nodesCoordinator) GetConsensusValidatorsPublicKeys(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return nil, nil +func (n *nodesCoordinator) GetConsensusValidatorsPublicKeys(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + return "", nil, nil } // GetOwnPublicKey - @@ -70,8 +75,8 @@ func (n *nodesCoordinator) GetOwnPublicKey() []byte { } // ComputeConsensusGroup - -func (n *nodesCoordinator) ComputeConsensusGroup(_ []byte, _ uint64, _ uint32, _ uint32) (validatorsGroup []nodesCoord.Validator, err error) { - return nil, nil +func (n *nodesCoordinator) ComputeConsensusGroup(_ []byte, _ uint64, _ uint32, _ uint32) (leader nodesCoord.Validator, validatorsGroup []nodesCoord.Validator, err error) { + return nil, nil, nil } // GetValidatorWithPublicKey - diff --git a/epochStart/bootstrap/epochStartMetaBlockProcessor.go b/epochStart/bootstrap/epochStartMetaBlockProcessor.go index ff1a4370ad7..8ee40232287 100644 --- a/epochStart/bootstrap/epochStartMetaBlockProcessor.go +++ b/epochStart/bootstrap/epochStartMetaBlockProcessor.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" @@ -26,14 +27,18 @@ const minNumConnectedPeers = 1 var _ process.InterceptorProcessor = (*epochStartMetaBlockProcessor)(nil) type epochStartMetaBlockProcessor struct { - messenger Messenger - requestHandler RequestHandler - marshalizer marshal.Marshalizer - hasher hashing.Hasher - mutReceivedMetaBlocks sync.RWMutex - mapReceivedMetaBlocks map[string]data.MetaHeaderHandler - mapMetaBlocksFromPeers map[string][]core.PeerID - chanConsensusReached chan bool + messenger Messenger + requestHandler RequestHandler + marshalizer marshal.Marshalizer + hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler + + mutReceivedMetaBlocks sync.RWMutex + mapReceivedMetaBlocks map[string]data.MetaHeaderHandler + mapMetaBlocksFromPeers map[string][]core.PeerID + + chanConfMetaBlockReached chan bool + chanMetaBlockReached chan bool metaBlock data.MetaHeaderHandler peerCountTarget int minNumConnectedPeers int @@ -49,6 +54,7 @@ func NewEpochStartMetaBlockProcessor( consensusPercentage uint8, minNumConnectedPeersConfig int, minNumOfPeersToConsiderBlockValidConfig int, + enableEpochsHandler common.EnableEpochsHandler, ) (*epochStartMetaBlockProcessor, error) { if check.IfNil(messenger) { return nil, epochStart.ErrNilMessenger @@ -71,6 +77,9 @@ func NewEpochStartMetaBlockProcessor( if minNumOfPeersToConsiderBlockValidConfig < minNumPeersToConsiderMetaBlockValid { return nil, epochStart.ErrNotEnoughNumOfPeersToConsiderBlockValid } + if check.IfNil(enableEpochsHandler) { + return nil, epochStart.ErrNilEnableEpochsHandler + } processor := &epochStartMetaBlockProcessor{ messenger: messenger, @@ -79,10 +88,12 @@ func NewEpochStartMetaBlockProcessor( hasher: hasher, minNumConnectedPeers: minNumConnectedPeersConfig, minNumOfPeersToConsiderBlockValid: minNumOfPeersToConsiderBlockValidConfig, + enableEpochsHandler: enableEpochsHandler, mutReceivedMetaBlocks: sync.RWMutex{}, mapReceivedMetaBlocks: make(map[string]data.MetaHeaderHandler), mapMetaBlocksFromPeers: make(map[string][]core.PeerID), - chanConsensusReached: make(chan bool, 1), + chanConfMetaBlockReached: make(chan bool, 1), + chanMetaBlockReached: make(chan bool, 1), } processor.waitForEnoughNumConnectedPeers(messenger) @@ -136,22 +147,47 @@ func (e *epochStartMetaBlockProcessor) Save(data process.InterceptedData, fromCo return nil } - if !metaBlock.IsStartOfEpochBlock() { - log.Debug("received metablock is not of type epoch start", "error", epochStart.ErrNotEpochStartBlock) + mbHash := interceptedHdr.Hash() + + if metaBlock.IsStartOfEpochBlock() { + log.Debug("received epoch start meta", "epoch", metaBlock.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) + e.mutReceivedMetaBlocks.Lock() + e.mapReceivedMetaBlocks[string(mbHash)] = metaBlock + e.addToPeerList(string(mbHash), fromConnectedPeer) + e.mutReceivedMetaBlocks.Unlock() + return nil } - mbHash := interceptedHdr.Hash() + if e.isEpochStartConfirmationBlockWithEquivalentMessages(metaBlock) { + log.Debug("received epoch start confirmation meta", "epoch", metaBlock.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) + e.chanConfMetaBlockReached <- true - log.Debug("received epoch start meta", "epoch", metaBlock.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) - e.mutReceivedMetaBlocks.Lock() - e.mapReceivedMetaBlocks[string(mbHash)] = metaBlock - e.addToPeerList(string(mbHash), fromConnectedPeer) - e.mutReceivedMetaBlocks.Unlock() + return nil + } + + log.Debug("received metablock is not of type epoch start", "error", epochStart.ErrNotEpochStartBlock) return nil } +func (e *epochStartMetaBlockProcessor) isEpochStartConfirmationBlockWithEquivalentMessages(metaBlock data.HeaderHandler) bool { + if !e.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, metaBlock.GetEpoch()) { + return false + } + + startOfEpochMetaBlock, err := e.getMostReceivedMetaBlock() + if err != nil { + return false + } + + if startOfEpochMetaBlock.GetNonce() != metaBlock.GetNonce()-1 { + return false + } + + return true +} + // this func should be called under mutex protection func (e *epochStartMetaBlockProcessor) addToPeerList(hash string, peer core.PeerID) { peersListForHash := e.mapMetaBlocksFromPeers[hash] @@ -180,16 +216,33 @@ func (e *epochStartMetaBlockProcessor) GetEpochStartMetaBlock(ctx context.Contex } }() - err = e.requestMetaBlock() + metaBlock, err := e.waitForMetaBlock(ctx) + if err != nil { + return nil, err + } + + if e.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, metaBlock.GetEpoch()) { + err = e.waitForConfMetaBlock(ctx, metaBlock) + if err != nil { + return nil, err + } + } + + return metaBlock, nil +} + +func (e *epochStartMetaBlockProcessor) waitForMetaBlock(ctx context.Context) (data.MetaHeaderHandler, error) { + err := e.requestMetaBlock() if err != nil { return nil, err } chanRequests := time.After(durationBetweenReRequests) chanCheckMaps := time.After(durationBetweenChecks) + for { select { - case <-e.chanConsensusReached: + case <-e.chanMetaBlockReached: return e.metaBlock, nil case <-ctx.Done(): return e.getMostReceivedMetaBlock() @@ -200,12 +253,40 @@ func (e *epochStartMetaBlockProcessor) GetEpochStartMetaBlock(ctx context.Contex } chanRequests = time.After(durationBetweenReRequests) case <-chanCheckMaps: - e.checkMaps() + e.checkMetaBlockMaps() chanCheckMaps = time.After(durationBetweenChecks) } } } +func (e *epochStartMetaBlockProcessor) waitForConfMetaBlock(ctx context.Context, metaBlock data.MetaHeaderHandler) error { + if check.IfNil(metaBlock) { + return epochStart.ErrNilMetaBlock + } + + err := e.requestConfirmationMetaBlock(metaBlock.GetNonce()) + if err != nil { + return err + } + + chanRequests := time.After(durationBetweenReRequests) + + for { + select { + case <-e.chanConfMetaBlockReached: + return nil + case <-ctx.Done(): + return epochStart.ErrTimeoutWaitingForMetaBlock + case <-chanRequests: + err = e.requestConfirmationMetaBlock(metaBlock.GetNonce()) + if err != nil { + return err + } + chanRequests = time.After(durationBetweenReRequests) + } + } +} + func (e *epochStartMetaBlockProcessor) getMostReceivedMetaBlock() (data.MetaHeaderHandler, error) { e.mutReceivedMetaBlocks.RLock() defer e.mutReceivedMetaBlocks.RUnlock() @@ -238,27 +319,48 @@ func (e *epochStartMetaBlockProcessor) requestMetaBlock() error { return nil } -func (e *epochStartMetaBlockProcessor) checkMaps() { +func (e *epochStartMetaBlockProcessor) requestConfirmationMetaBlock(nonce uint64) error { + numConnectedPeers := len(e.messenger.ConnectedPeers()) + err := e.requestHandler.SetNumPeersToQuery(factory.MetachainBlocksTopic, numConnectedPeers, numConnectedPeers) + if err != nil { + return err + } + + e.requestHandler.RequestMetaHeaderByNonce(nonce) + + return nil +} + +func (e *epochStartMetaBlockProcessor) checkMetaBlockMaps() { e.mutReceivedMetaBlocks.RLock() defer e.mutReceivedMetaBlocks.RUnlock() - for hash, peersList := range e.mapMetaBlocksFromPeers { + hash, metaBlockFound := e.checkReceivedMetaBlock(e.mapMetaBlocksFromPeers) + if metaBlockFound { + e.metaBlock = e.mapReceivedMetaBlocks[hash] + e.chanMetaBlockReached <- true + } +} + +func (e *epochStartMetaBlockProcessor) checkReceivedMetaBlock(blocksFromPeers map[string][]core.PeerID) (string, bool) { + for hash, peersList := range blocksFromPeers { log.Debug("metablock from peers", "num peers", len(peersList), "target", e.peerCountTarget, "hash", []byte(hash)) - found := e.processEntry(peersList, hash) - if found { - break + + metaBlockFound := e.processMetaBlockEntry(peersList, hash) + if metaBlockFound { + return hash, true } } + + return "", false } -func (e *epochStartMetaBlockProcessor) processEntry( +func (e *epochStartMetaBlockProcessor) processMetaBlockEntry( peersList []core.PeerID, hash string, ) bool { if len(peersList) >= e.peerCountTarget { log.Info("got consensus for epoch start metablock", "len", len(peersList)) - e.metaBlock = e.mapReceivedMetaBlocks[hash] - e.chanConsensusReached <- true return true } diff --git a/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go b/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go index 1741c63a25c..200c3f408a3 100644 --- a/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go +++ b/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go @@ -9,9 +9,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" @@ -28,6 +30,7 @@ func TestNewEpochStartMetaBlockProcessor_NilMessengerShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Equal(t, epochStart.ErrNilMessenger, err) @@ -45,6 +48,7 @@ func TestNewEpochStartMetaBlockProcessor_NilRequestHandlerShouldErr(t *testing.T 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Equal(t, epochStart.ErrNilRequestHandler, err) @@ -62,6 +66,7 @@ func TestNewEpochStartMetaBlockProcessor_NilMarshalizerShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Equal(t, epochStart.ErrNilMarshalizer, err) @@ -79,6 +84,7 @@ func TestNewEpochStartMetaBlockProcessor_NilHasherShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Equal(t, epochStart.ErrNilHasher, err) @@ -96,6 +102,7 @@ func TestNewEpochStartMetaBlockProcessor_InvalidConsensusPercentageShouldErr(t * 101, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Equal(t, epochStart.ErrInvalidConsensusThreshold, err) @@ -116,6 +123,7 @@ func TestNewEpochStartMetaBlockProcessorOkValsShouldWork(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.NoError(t, err) @@ -152,6 +160,7 @@ func TestNewEpochStartMetaBlockProcessorOkValsShouldWorkAfterMoreTriesWaitingFor 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.NoError(t, err) @@ -172,6 +181,7 @@ func TestEpochStartMetaBlockProcessor_Validate(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Nil(t, esmbp.Validate(nil, "")) @@ -191,6 +201,7 @@ func TestEpochStartMetaBlockProcessor_SaveNilInterceptedDataShouldNotReturnError 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) err := esmbp.Save(nil, "peer0", "") @@ -212,6 +223,7 @@ func TestEpochStartMetaBlockProcessor_SaveOkInterceptedDataShouldWork(t *testing 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) assert.Zero(t, len(esmbp.GetMapMetaBlock())) @@ -241,6 +253,7 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldTimeOut(t *tes 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) @@ -264,21 +277,30 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldReturnMostRece &hashingMocks.HasherMock{}, 99, 3, - 3, + 5, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) for i := 0; i < esmbp.minNumOfPeersToConsiderBlockValid; i++ { _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", i)), "") } + for i := 0; i < esmbp.minNumOfPeersToConsiderBlockValid; i++ { + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", i)), "") + } + // we need a slightly more time than 1 second in order to also properly test the select branches - timeout := time.Second + time.Millisecond*500 + timeout := 2*time.Second + time.Millisecond*500 ctx, cancel := context.WithTimeout(context.Background(), timeout) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() @@ -301,18 +323,27 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkFromFirstT 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) for i := 0; i < 6; i++ { _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", i)), "") } + for i := 0; i < 6; i++ { + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", i)), "") + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() @@ -320,19 +351,53 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkFromFirstT assert.Equal(t, expectedMetaBlock, mb) } -func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkAfterMultipleTries(t *testing.T) { +func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlock_BeforeEquivalentMessages(t *testing.T) { t.Parallel() - testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t, durationBetweenChecks-10*time.Millisecond) + tts := durationBetweenChecks - 10*time.Millisecond + + esmbp, _ := NewEpochStartMetaBlockProcessor( + &p2pmocks.MessengerStub{ + ConnectedPeersCalled: func() []core.PeerID { + return []core.PeerID{"peer_0", "peer_1", "peer_2", "peer_3", "peer_4", "peer_5"} + }, + }, + &testscommon.RequestHandlerStub{}, + &mock.MarshalizerMock{}, + &hashingMocks.HasherMock{}, + 64, + 3, + 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ) + expectedMetaBlock := &block.MetaBlock{ + Nonce: 10, + EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, + } + intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + + go func() { + index := 0 + for { + time.Sleep(tts) + _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", index)), "") + _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", index+1)), "") + index += 2 + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + mb, err := esmbp.GetEpochStartMetaBlock(ctx) + cancel() + assert.NoError(t, err) + assert.Equal(t, expectedMetaBlock, mb) } -func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkAfterMultipleRequests(t *testing.T) { +func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlock_AfterEquivalentMessages(t *testing.T) { t.Parallel() - testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t, durationBetweenChecks-10*time.Millisecond) -} + tts := durationBetweenChecks - 10*time.Millisecond -func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tts time.Duration) { esmbp, _ := NewEpochStartMetaBlockProcessor( &p2pmocks.MessengerStub{ ConnectedPeersCalled: func() []core.PeerID { @@ -345,12 +410,23 @@ func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tt 64, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + }, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) + go func() { index := 0 for { @@ -360,6 +436,17 @@ func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tt index += 2 } }() + + go func() { + index := 0 + for { + time.Sleep(tts) + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", index)), "") + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", index+1)), "") + index += 2 + } + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() diff --git a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go index d659989896b..8700b1daa24 100644 --- a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go +++ b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -25,23 +26,24 @@ const timeSpanForBadHeaders = time.Minute // ArgsEpochStartInterceptorContainer holds the arguments needed for creating a new epoch start interceptors // container factory type ArgsEpochStartInterceptorContainer struct { - CoreComponents process.CoreComponentsHolder - CryptoComponents process.CryptoComponentsHolder - Config config.Config - ShardCoordinator sharding.Coordinator - MainMessenger process.TopicHandler - FullArchiveMessenger process.TopicHandler - DataPool dataRetriever.PoolsHolder - WhiteListHandler update.WhiteListHandler - WhiteListerVerifiedTxs update.WhiteListHandler - AddressPubkeyConv core.PubkeyConverter - NonceConverter typeConverters.Uint64ByteSliceConverter - ChainID []byte - ArgumentsParser process.ArgumentsParser - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - RequestHandler process.RequestHandler - SignaturesHandler process.SignaturesHandler - NodeOperationMode common.NodeOperation + CoreComponents process.CoreComponentsHolder + CryptoComponents process.CryptoComponentsHolder + Config config.Config + ShardCoordinator sharding.Coordinator + MainMessenger process.TopicHandler + FullArchiveMessenger process.TopicHandler + DataPool dataRetriever.PoolsHolder + WhiteListHandler update.WhiteListHandler + WhiteListerVerifiedTxs update.WhiteListHandler + AddressPubkeyConv core.PubkeyConverter + NonceConverter typeConverters.Uint64ByteSliceConverter + ChainID []byte + ArgumentsParser process.ArgumentsParser + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + RequestHandler process.RequestHandler + SignaturesHandler process.SignaturesHandler + NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewEpochStartInterceptorsContainer will return a real interceptors container factory, but with many disabled components @@ -78,36 +80,37 @@ func NewEpochStartInterceptorsContainer(args ArgsEpochStartInterceptorContainer) hardforkTrigger := disabledFactory.HardforkTrigger() containerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: args.CoreComponents, - CryptoComponents: cryptoComponents, - Accounts: accountsAdapter, - ShardCoordinator: args.ShardCoordinator, - NodesCoordinator: nodesCoordinator, - MainMessenger: args.MainMessenger, - FullArchiveMessenger: args.FullArchiveMessenger, - Store: storer, - DataPool: args.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: feeHandler, - BlockBlackList: blackListHandler, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: args.HeaderIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: args.WhiteListHandler, - WhiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - AntifloodHandler: antiFloodHandler, - ArgumentsParser: args.ArgumentsParser, - PreferredPeersHolder: disabled.NewPreferredPeersHolder(), - SizeCheckDelta: uint32(sizeCheckDelta), - RequestHandler: args.RequestHandler, - PeerSignatureHandler: cryptoComponents.PeerSignatureHandler(), - SignaturesHandler: args.SignaturesHandler, - HeartbeatExpiryTimespanInSec: args.Config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: peerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: args.NodeOperationMode, + CoreComponents: args.CoreComponents, + CryptoComponents: cryptoComponents, + Accounts: accountsAdapter, + ShardCoordinator: args.ShardCoordinator, + NodesCoordinator: nodesCoordinator, + MainMessenger: args.MainMessenger, + FullArchiveMessenger: args.FullArchiveMessenger, + Store: storer, + DataPool: args.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: feeHandler, + BlockBlackList: blackListHandler, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: args.HeaderIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: args.WhiteListHandler, + WhiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + AntifloodHandler: antiFloodHandler, + ArgumentsParser: args.ArgumentsParser, + PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + SizeCheckDelta: uint32(sizeCheckDelta), + RequestHandler: args.RequestHandler, + PeerSignatureHandler: cryptoComponents.PeerSignatureHandler(), + SignaturesHandler: args.SignaturesHandler, + HeartbeatExpiryTimespanInSec: args.Config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: peerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: args.NodeOperationMode, + InterceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } interceptorsContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(containerFactoryArgs) diff --git a/epochStart/bootstrap/interface.go b/epochStart/bootstrap/interface.go index bfc293032ee..7a128098059 100644 --- a/epochStart/bootstrap/interface.go +++ b/epochStart/bootstrap/interface.go @@ -49,6 +49,7 @@ type Messenger interface { // RequestHandler defines which methods a request handler should implement type RequestHandler interface { RequestStartOfEpochMetaBlock(epoch uint32) + RequestMetaHeaderByNonce(nonce uint64) SetNumPeersToQuery(topic string, intra int, cross int) error GetNumPeersToQuery(topic string) (int, int, error) IsInterfaceNil() bool diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 27fc5011cb5..91d40db1a8d 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -14,6 +14,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" disabledCommon "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/common/ordering" @@ -52,7 +54,6 @@ import ( "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/multiversx/mx-chain-go/update" updateSync "github.com/multiversx/mx-chain-go/update/sync" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("epochStart/bootstrap") @@ -121,6 +122,8 @@ type epochStartBootstrap struct { nodeProcessingMode common.NodeProcessingMode nodeOperationMode common.NodeOperation stateStatsHandler common.StateStatisticsHandler + enableEpochsHandler common.EnableEpochsHandler + // created components requestHandler process.RequestHandler mainInterceptorContainer process.InterceptorsContainer @@ -152,6 +155,8 @@ type epochStartBootstrap struct { nodeType core.NodeType startEpoch uint32 shuffledOut bool + + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type baseDataInStorage struct { @@ -190,6 +195,8 @@ type ArgsEpochStartBootstrap struct { NodeProcessingMode common.NodeProcessingMode StateStatsHandler common.StateStatisticsHandler NodesCoordinatorRegistryFactory nodesCoordinator.NodesCoordinatorRegistryFactory + EnableEpochsHandler common.EnableEpochsHandler + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type dataToSync struct { @@ -242,6 +249,8 @@ func NewEpochStartBootstrap(args ArgsEpochStartBootstrap) (*epochStartBootstrap, stateStatsHandler: args.StateStatsHandler, startEpoch: args.GeneralConfig.EpochStartConfig.GenesisEpoch, nodesCoordinatorRegistryFactory: args.NodesCoordinatorRegistryFactory, + enableEpochsHandler: args.EnableEpochsHandler, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } if epochStartProvider.prefsConfig.FullArchive { @@ -547,22 +556,24 @@ func (e *epochStartBootstrap) prepareComponentsToSyncFromNetwork() error { thresholdForConsideringMetaBlockCorrect, epochStartConfig.MinNumConnectedPeersToStart, epochStartConfig.MinNumOfPeersToConsiderBlockValid, + e.enableEpochsHandler, ) if err != nil { return err } argsEpochStartSyncer := ArgsNewEpochStartMetaSyncer{ - CoreComponentsHolder: e.coreComponentsHolder, - CryptoComponentsHolder: e.cryptoComponentsHolder, - RequestHandler: e.requestHandler, - Messenger: e.mainMessenger, - ShardCoordinator: e.shardCoordinator, - EconomicsData: e.economicsData, - WhitelistHandler: e.whiteListHandler, - StartInEpochConfig: epochStartConfig, - HeaderIntegrityVerifier: e.headerIntegrityVerifier, - MetaBlockProcessor: metaBlockProcessor, + CoreComponentsHolder: e.coreComponentsHolder, + CryptoComponentsHolder: e.cryptoComponentsHolder, + RequestHandler: e.requestHandler, + Messenger: e.mainMessenger, + ShardCoordinator: e.shardCoordinator, + EconomicsData: e.economicsData, + WhitelistHandler: e.whiteListHandler, + StartInEpochConfig: epochStartConfig, + HeaderIntegrityVerifier: e.headerIntegrityVerifier, + MetaBlockProcessor: metaBlockProcessor, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, } e.epochStartMetaBlockSyncer, err = NewEpochStartMetaSyncer(argsEpochStartSyncer) if err != nil { @@ -575,20 +586,21 @@ func (e *epochStartBootstrap) prepareComponentsToSyncFromNetwork() error { func (e *epochStartBootstrap) createSyncers() error { var err error args := factoryInterceptors.ArgsEpochStartInterceptorContainer{ - CoreComponents: e.coreComponentsHolder, - CryptoComponents: e.cryptoComponentsHolder, - Config: e.generalConfig, - ShardCoordinator: e.shardCoordinator, - MainMessenger: e.mainMessenger, - FullArchiveMessenger: e.fullArchiveMessenger, - DataPool: e.dataPool, - WhiteListHandler: e.whiteListHandler, - WhiteListerVerifiedTxs: e.whiteListerVerifiedTxs, - ArgumentsParser: e.argumentsParser, - HeaderIntegrityVerifier: e.headerIntegrityVerifier, - RequestHandler: e.requestHandler, - SignaturesHandler: e.mainMessenger, - NodeOperationMode: e.nodeOperationMode, + CoreComponents: e.coreComponentsHolder, + CryptoComponents: e.cryptoComponentsHolder, + Config: e.generalConfig, + ShardCoordinator: e.shardCoordinator, + MainMessenger: e.mainMessenger, + FullArchiveMessenger: e.fullArchiveMessenger, + DataPool: e.dataPool, + WhiteListHandler: e.whiteListHandler, + WhiteListerVerifiedTxs: e.whiteListerVerifiedTxs, + ArgumentsParser: e.argumentsParser, + HeaderIntegrityVerifier: e.headerIntegrityVerifier, + RequestHandler: e.requestHandler, + SignaturesHandler: e.mainMessenger, + NodeOperationMode: e.nodeOperationMode, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, } e.mainInterceptorContainer, e.fullArchiveInterceptorContainer, err = factoryInterceptors.NewEpochStartInterceptorsContainer(args) @@ -666,7 +678,7 @@ func (e *epochStartBootstrap) syncHeadersFrom(meta data.MetaHeaderHandler) (map[ return syncedHeaders, nil } -// Bootstrap will handle requesting and receiving the needed information the node will bootstrap from +// requestAndProcessing will handle requesting and receiving the needed information the node will bootstrap from func (e *epochStartBootstrap) requestAndProcessing() (Parameters, error) { var err error e.baseData.numberOfShards = uint32(len(e.epochStartMeta.GetEpochStartHandler().GetLastFinalizedHeaderHandlers())) @@ -759,20 +771,20 @@ func (e *epochStartBootstrap) processNodesConfig(pubKey []byte) ([]*block.MiniBl shardId = e.genesisShardCoordinator.SelfId() } argsNewValidatorStatusSyncers := ArgsNewSyncValidatorStatus{ - DataPool: e.dataPool, - Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), - RequestHandler: e.requestHandler, - ChanceComputer: e.rater, - GenesisNodesConfig: e.genesisNodesConfig, - ChainParametersHandler: e.coreComponentsHolder.ChainParametersHandler(), - NodeShuffler: e.nodeShuffler, - Hasher: e.coreComponentsHolder.Hasher(), - PubKey: pubKey, - ShardIdAsObserver: shardId, - ChanNodeStop: e.coreComponentsHolder.ChanStopNodeProcess(), - NodeTypeProvider: e.coreComponentsHolder.NodeTypeProvider(), - IsFullArchive: e.prefsConfig.FullArchive, - EnableEpochsHandler: e.coreComponentsHolder.EnableEpochsHandler(), + DataPool: e.dataPool, + Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), + RequestHandler: e.requestHandler, + ChanceComputer: e.rater, + GenesisNodesConfig: e.genesisNodesConfig, + ChainParametersHandler: e.coreComponentsHolder.ChainParametersHandler(), + NodeShuffler: e.nodeShuffler, + Hasher: e.coreComponentsHolder.Hasher(), + PubKey: pubKey, + ShardIdAsObserver: shardId, + ChanNodeStop: e.coreComponentsHolder.ChanStopNodeProcess(), + NodeTypeProvider: e.coreComponentsHolder.NodeTypeProvider(), + IsFullArchive: e.prefsConfig.FullArchive, + EnableEpochsHandler: e.coreComponentsHolder.EnableEpochsHandler(), NodesCoordinatorRegistryFactory: e.nodesCoordinatorRegistryFactory, } diff --git a/epochStart/bootstrap/process_test.go b/epochStart/bootstrap/process_test.go index edcf0a0a495..e38737a7a3e 100644 --- a/epochStart/bootstrap/process_test.go +++ b/epochStart/bootstrap/process_test.go @@ -19,6 +19,9 @@ import ( dataBatch "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/statistics" disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" @@ -29,12 +32,14 @@ import ( "github.com/multiversx/mx-chain-go/epochStart/bootstrap/types" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -54,8 +59,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/syncer" validatorInfoCacherStub "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" "github.com/multiversx/mx-chain-go/trie/factory" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createPkBytes(numShards uint32) map[uint32][]byte { @@ -249,8 +252,10 @@ func createMockEpochStartBootstrapArgs( FlagsConfig: config.ContextFlagsConfig{ ForceStartFromNetwork: false, }, - TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, - StateStatsHandler: disabledStatistics.NewStateStatistics(), + TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + StateStatsHandler: disabledStatistics.NewStateStatistics(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + InterceptedDataVerifierFactory: &processMock.InterceptedDataVerifierFactoryMock{}, } } @@ -976,22 +981,26 @@ func TestCreateSyncers(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, PeerAuthenticationsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, HeartbeatsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} }, } epochStartProvider.whiteListHandler = &testscommon.WhiteListHandlerStub{} epochStartProvider.whiteListerVerifiedTxs = &testscommon.WhiteListHandlerStub{} epochStartProvider.requestHandler = &testscommon.RequestHandlerStub{} epochStartProvider.storageService = &storageMocks.ChainStorerStub{} + epochStartProvider.interceptedDataVerifierFactory = &processMock.InterceptedDataVerifierFactoryMock{} err := epochStartProvider.createSyncers() assert.Nil(t, err) @@ -1042,7 +1051,7 @@ func TestSyncValidatorAccountsState_NilRequestHandlerErr(t *testing.T) { epochStartProvider, _ := NewEpochStartBootstrap(args) epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1088,7 +1097,7 @@ func TestSyncUserAccountsState(t *testing.T) { epochStartProvider.shardCoordinator = mock.NewMultipleShardsCoordinatorMock() epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1341,7 +1350,7 @@ func TestRequestAndProcessForShard_ShouldFail(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1455,7 +1464,7 @@ func TestRequestAndProcessForMeta_ShouldFail(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1820,10 +1829,10 @@ func TestRequestAndProcessing(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1890,10 +1899,10 @@ func TestRequestAndProcessing(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -2055,10 +2064,10 @@ func TestEpochStartBootstrap_WithDisabledShardIDAsObserver(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &validatorInfoCacherStub.ValidatorInfoCacherStub{} @@ -2391,16 +2400,19 @@ func TestSyncSetGuardianTransaction(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, PeerAuthenticationsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, HeartbeatsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} }, } epochStartProvider.whiteListHandler = &testscommon.WhiteListHandlerStub{ diff --git a/epochStart/bootstrap/storageProcess.go b/epochStart/bootstrap/storageProcess.go index a7fff35f193..0ec16f6548d 100644 --- a/epochStart/bootstrap/storageProcess.go +++ b/epochStart/bootstrap/storageProcess.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -177,16 +178,17 @@ func (sesb *storageEpochStartBootstrap) prepareComponentsToSync() error { } argsEpochStartSyncer := ArgsNewEpochStartMetaSyncer{ - CoreComponentsHolder: sesb.coreComponentsHolder, - CryptoComponentsHolder: sesb.cryptoComponentsHolder, - RequestHandler: sesb.requestHandler, - Messenger: sesb.mainMessenger, - ShardCoordinator: sesb.shardCoordinator, - EconomicsData: sesb.economicsData, - WhitelistHandler: sesb.whiteListHandler, - StartInEpochConfig: sesb.generalConfig.EpochStartConfig, - HeaderIntegrityVerifier: sesb.headerIntegrityVerifier, - MetaBlockProcessor: metablockProcessor, + CoreComponentsHolder: sesb.coreComponentsHolder, + CryptoComponentsHolder: sesb.cryptoComponentsHolder, + RequestHandler: sesb.requestHandler, + Messenger: sesb.mainMessenger, + ShardCoordinator: sesb.shardCoordinator, + EconomicsData: sesb.economicsData, + WhitelistHandler: sesb.whiteListHandler, + StartInEpochConfig: sesb.generalConfig.EpochStartConfig, + HeaderIntegrityVerifier: sesb.headerIntegrityVerifier, + MetaBlockProcessor: metablockProcessor, + InterceptedDataVerifierFactory: sesb.interceptedDataVerifierFactory, } sesb.epochStartMetaBlockSyncer, err = NewEpochStartMetaSyncer(argsEpochStartSyncer) @@ -404,20 +406,20 @@ func (sesb *storageEpochStartBootstrap) processNodesConfig(pubKey []byte) error shardId = sesb.genesisShardCoordinator.SelfId() } argsNewValidatorStatusSyncers := ArgsNewSyncValidatorStatus{ - DataPool: sesb.dataPool, - Marshalizer: sesb.coreComponentsHolder.InternalMarshalizer(), - RequestHandler: sesb.requestHandler, - ChanceComputer: sesb.rater, - GenesisNodesConfig: sesb.genesisNodesConfig, - ChainParametersHandler: sesb.coreComponentsHolder.ChainParametersHandler(), - NodeShuffler: sesb.nodeShuffler, - Hasher: sesb.coreComponentsHolder.Hasher(), - PubKey: pubKey, - ShardIdAsObserver: shardId, - ChanNodeStop: sesb.coreComponentsHolder.ChanStopNodeProcess(), - NodeTypeProvider: sesb.coreComponentsHolder.NodeTypeProvider(), - IsFullArchive: sesb.prefsConfig.FullArchive, - EnableEpochsHandler: sesb.coreComponentsHolder.EnableEpochsHandler(), + DataPool: sesb.dataPool, + Marshalizer: sesb.coreComponentsHolder.InternalMarshalizer(), + RequestHandler: sesb.requestHandler, + ChanceComputer: sesb.rater, + GenesisNodesConfig: sesb.genesisNodesConfig, + ChainParametersHandler: sesb.coreComponentsHolder.ChainParametersHandler(), + NodeShuffler: sesb.nodeShuffler, + Hasher: sesb.coreComponentsHolder.Hasher(), + PubKey: pubKey, + ShardIdAsObserver: shardId, + ChanNodeStop: sesb.coreComponentsHolder.ChanStopNodeProcess(), + NodeTypeProvider: sesb.coreComponentsHolder.NodeTypeProvider(), + IsFullArchive: sesb.prefsConfig.FullArchive, + EnableEpochsHandler: sesb.coreComponentsHolder.EnableEpochsHandler(), NodesCoordinatorRegistryFactory: sesb.nodesCoordinatorRegistryFactory, } sesb.nodesConfigHandler, err = NewSyncValidatorStatus(argsNewValidatorStatusSyncers) diff --git a/epochStart/bootstrap/storageProcess_test.go b/epochStart/bootstrap/storageProcess_test.go index a59b0d125f2..64708040acd 100644 --- a/epochStart/bootstrap/storageProcess_test.go +++ b/epochStart/bootstrap/storageProcess_test.go @@ -11,11 +11,14 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" @@ -23,7 +26,6 @@ import ( dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" - "github.com/stretchr/testify/assert" ) func createMockStorageEpochStartBootstrapArgs( @@ -127,6 +129,7 @@ func TestStorageEpochStartBootstrap_BootstrapMetablockNotFound(t *testing.T) { } args.GeneralConfig = testscommon.GetGeneralConfig() args.GeneralConfig.EpochStartConfig.RoundsPerEpoch = roundsPerEpoch + args.InterceptedDataVerifierFactory = &processMock.InterceptedDataVerifierFactoryMock{} sesb, _ := NewStorageEpochStartBootstrap(args) params, err := sesb.Bootstrap() diff --git a/epochStart/bootstrap/syncEpochStartMeta.go b/epochStart/bootstrap/syncEpochStartMeta.go index fa764a04c4a..b550a25911a 100644 --- a/epochStart/bootstrap/syncEpochStartMeta.go +++ b/epochStart/bootstrap/syncEpochStartMeta.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" @@ -22,27 +23,29 @@ import ( var _ epochStart.StartOfEpochMetaSyncer = (*epochStartMetaSyncer)(nil) type epochStartMetaSyncer struct { - requestHandler RequestHandler - messenger Messenger - marshalizer marshal.Marshalizer - hasher hashing.Hasher - singleDataInterceptor process.Interceptor - metaBlockProcessor EpochStartMetaBlockInterceptorProcessor + requestHandler RequestHandler + messenger Messenger + marshalizer marshal.Marshalizer + hasher hashing.Hasher + singleDataInterceptor process.Interceptor + metaBlockProcessor EpochStartMetaBlockInterceptorProcessor + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // ArgsNewEpochStartMetaSyncer - type ArgsNewEpochStartMetaSyncer struct { - CoreComponentsHolder process.CoreComponentsHolder - CryptoComponentsHolder process.CryptoComponentsHolder - RequestHandler RequestHandler - Messenger Messenger - ShardCoordinator sharding.Coordinator - EconomicsData process.EconomicsDataHandler - WhitelistHandler process.WhiteListHandler - StartInEpochConfig config.EpochStartConfig - ArgsParser process.ArgumentsParser - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - MetaBlockProcessor EpochStartMetaBlockInterceptorProcessor + CoreComponentsHolder process.CoreComponentsHolder + CryptoComponentsHolder process.CryptoComponentsHolder + RequestHandler RequestHandler + Messenger Messenger + ShardCoordinator sharding.Coordinator + EconomicsData process.EconomicsDataHandler + WhitelistHandler process.WhiteListHandler + StartInEpochConfig config.EpochStartConfig + ArgsParser process.ArgumentsParser + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + MetaBlockProcessor EpochStartMetaBlockInterceptorProcessor + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewEpochStartMetaSyncer will return a new instance of epochStartMetaSyncer @@ -62,13 +65,17 @@ func NewEpochStartMetaSyncer(args ArgsNewEpochStartMetaSyncer) (*epochStartMetaS if check.IfNil(args.MetaBlockProcessor) { return nil, epochStart.ErrNilMetablockProcessor } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, epochStart.ErrNilInterceptedDataVerifierFactory + } e := &epochStartMetaSyncer{ - requestHandler: args.RequestHandler, - messenger: args.Messenger, - marshalizer: args.CoreComponentsHolder.InternalMarshalizer(), - hasher: args.CoreComponentsHolder.Hasher(), - metaBlockProcessor: args.MetaBlockProcessor, + requestHandler: args.RequestHandler, + messenger: args.Messenger, + marshalizer: args.CoreComponentsHolder.InternalMarshalizer(), + hasher: args.CoreComponentsHolder.Hasher(), + metaBlockProcessor: args.MetaBlockProcessor, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } argsInterceptedDataFactory := interceptorsFactory.ArgInterceptedDataFactory{ @@ -89,16 +96,22 @@ func NewEpochStartMetaSyncer(args ArgsNewEpochStartMetaSyncer) (*epochStartMetaS return nil, err } + interceptedDataVerifier, err := e.interceptedDataVerifierFactory.Create(factory.MetachainBlocksTopic) + if err != nil { + return nil, err + } + e.singleDataInterceptor, err = interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: factory.MetachainBlocksTopic, - DataFactory: interceptedMetaHdrDataFactory, - Processor: args.MetaBlockProcessor, - Throttler: disabled.NewThrottler(), - AntifloodHandler: disabled.NewAntiFloodHandler(), - WhiteListRequest: args.WhitelistHandler, - CurrentPeerId: args.Messenger.ID(), - PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + Topic: factory.MetachainBlocksTopic, + DataFactory: interceptedMetaHdrDataFactory, + Processor: args.MetaBlockProcessor, + Throttler: disabled.NewThrottler(), + AntifloodHandler: disabled.NewAntiFloodHandler(), + WhiteListRequest: args.WhitelistHandler, + CurrentPeerId: args.Messenger.ID(), + PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { diff --git a/epochStart/bootstrap/syncEpochStartMeta_test.go b/epochStart/bootstrap/syncEpochStartMeta_test.go index 169b20a656e..ac05d2ba977 100644 --- a/epochStart/bootstrap/syncEpochStartMeta_test.go +++ b/epochStart/bootstrap/syncEpochStartMeta_test.go @@ -9,17 +9,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/p2p" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewEpochStartMetaSyncer_NilsShouldError(t *testing.T) { @@ -48,6 +50,12 @@ func TestNewEpochStartMetaSyncer_NilsShouldError(t *testing.T) { ess, err = NewEpochStartMetaSyncer(args) assert.True(t, check.IfNil(ess)) assert.Equal(t, epochStart.ErrNilMetablockProcessor, err) + + args = getEpochStartSyncerArgs() + args.InterceptedDataVerifierFactory = nil + ess, err = NewEpochStartMetaSyncer(args) + assert.True(t, check.IfNil(ess)) + assert.Equal(t, epochStart.ErrNilInterceptedDataVerifierFactory, err) } func TestNewEpochStartMetaSyncer_ShouldWork(t *testing.T) { @@ -71,7 +79,8 @@ func TestEpochStartMetaSyncer_SyncEpochStartMetaRegisterMessengerProcessorFailsS }, } args.Messenger = messenger - ess, _ := NewEpochStartMetaSyncer(args) + ess, err := NewEpochStartMetaSyncer(args) + require.NoError(t, err) mb, err := ess.SyncEpochStartMeta(time.Second) require.Equal(t, expectedErr, err) @@ -159,7 +168,8 @@ func getEpochStartSyncerArgs() ArgsNewEpochStartMetaSyncer { MinNumConnectedPeersToStart: 2, MinNumOfPeersToConsiderBlockValid: 2, }, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - MetaBlockProcessor: &mock.EpochStartMetaBlockProcessorStub{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + MetaBlockProcessor: &mock.EpochStartMetaBlockProcessorStub{}, + InterceptedDataVerifierFactory: &processMock.InterceptedDataVerifierFactoryMock{}, } } diff --git a/epochStart/bootstrap/syncValidatorStatus_test.go b/epochStart/bootstrap/syncValidatorStatus_test.go index 2579596ed51..ee8b7c02dae 100644 --- a/epochStart/bootstrap/syncValidatorStatus_test.go +++ b/epochStart/bootstrap/syncValidatorStatus_test.go @@ -9,12 +9,16 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -23,8 +27,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const initRating = uint32(50) @@ -256,7 +258,7 @@ func getSyncValidatorStatusArgs() ArgsNewSyncValidatorStatus { return ArgsNewSyncValidatorStatus{ DataPool: &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} diff --git a/epochStart/errors.go b/epochStart/errors.go index ca115e939f4..e022064c472 100644 --- a/epochStart/errors.go +++ b/epochStart/errors.go @@ -239,6 +239,9 @@ var ErrNilEpochNotifier = errors.New("nil EpochNotifier") // ErrNilMetablockProcessor signals that a nil metablock processor was provided var ErrNilMetablockProcessor = errors.New("nil metablock processor") +// ErrNilInterceptedDataVerifierFactory signals that a nil intercepted data verifier factory was provided +var ErrNilInterceptedDataVerifierFactory = errors.New("nil intercepted data verifier factory") + // ErrCouldNotInitDelegationSystemSC signals that delegation system sc init failed var ErrCouldNotInitDelegationSystemSC = errors.New("could not init delegation system sc") diff --git a/epochStart/metachain/validators_test.go b/epochStart/metachain/validators_test.go index 662b0192044..2ece21d91d7 100644 --- a/epochStart/metachain/validators_test.go +++ b/epochStart/metachain/validators_test.go @@ -15,6 +15,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" @@ -22,12 +25,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" vics "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockValidatorInfo() state.ShardValidatorsInfoMapHandler { @@ -128,7 +130,7 @@ func createMockEpochValidatorInfoCreatorsArguments() ArgsNewValidatorInfoCreator Marshalizer: &mock.MarshalizerMock{}, DataPool: &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ RemoveCalled: func(key []byte) {}, } }, diff --git a/epochStart/shardchain/peerMiniBlocksSyncer_test.go b/epochStart/shardchain/peerMiniBlocksSyncer_test.go index f58ef588a0d..3e131fa7074 100644 --- a/epochStart/shardchain/peerMiniBlocksSyncer_test.go +++ b/epochStart/shardchain/peerMiniBlocksSyncer_test.go @@ -9,18 +9,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func createDefaultArguments() ArgPeerMiniBlockSyncer { defaultArgs := ArgPeerMiniBlockSyncer{ - MiniBlocksPool: testscommon.NewCacherStub(), + MiniBlocksPool: cache.NewCacherStub(), ValidatorsInfoPool: testscommon.NewShardedDataStub(), RequestHandler: &testscommon.RequestHandlerStub{}, } @@ -63,7 +66,7 @@ func TestNewValidatorInfoProcessor_NilRequestHandlerShouldErr(t *testing.T) { func TestValidatorInfoProcessor_IsInterfaceNil(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -76,7 +79,7 @@ func TestValidatorInfoProcessor_IsInterfaceNil(t *testing.T) { func TestValidatorInfoProcessor_ShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -89,7 +92,7 @@ func TestValidatorInfoProcessor_ShouldWork(t *testing.T) { func TestValidatorInfoProcessor_ProcessMetaBlockThatIsNoStartOfEpochShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -104,7 +107,7 @@ func TestValidatorInfoProcessor_ProcessMetaBlockThatIsNoStartOfEpochShouldWork(t func TestValidatorInfoProcessor_ProcesStartOfEpochWithNoPeerMiniblocksShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -120,7 +123,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithNoPeerMiniblocksShouldWork epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} peekCalled := false - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, @@ -182,7 +185,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithPeerMiniblocksInPoolShould epochStartHeader.EpochStart.LastFinalizedHeaders = []block.EpochStartShardData{{ShardID: 0, RootHash: hash, HeaderHash: hash}} epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, @@ -245,7 +248,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithMissinPeerMiniblocksShould epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} var receivedMiniblock func(key []byte, value interface{}) - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { receivedMiniblock = f }, @@ -309,7 +312,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithMissinPeerMiniblocksTimeou epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} var receivedMiniblock func(key []byte, value interface{}) - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { receivedMiniblock = f }, diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index 8a08dffc5c2..fcb7edc0ad2 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -12,20 +12,22 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { @@ -43,7 +45,7 @@ func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} @@ -207,7 +209,7 @@ func TestNewEpochStartTrigger_NilHeadersPoolShouldErr(t *testing.T) { return nil }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } epochStartTrigger, err := NewEpochStartTrigger(args) @@ -376,7 +378,7 @@ func TestTrigger_ReceivedHeaderIsEpochStartTrueWithPeerMiniblocks(t *testing.T) } }, MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, peerMiniBlockHash) { return peerMiniblock, true @@ -679,7 +681,7 @@ func TestTrigger_UpdateMissingValidatorsInfo(t *testing.T) { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} diff --git a/factory/block/headerVersionHandler_test.go b/factory/block/headerVersionHandler_test.go index 9de5238810b..4a17cb291a2 100644 --- a/factory/block/headerVersionHandler_test.go +++ b/factory/block/headerVersionHandler_test.go @@ -10,10 +10,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -50,7 +52,7 @@ func TestNewHeaderIntegrityVerifierr_InvalidVersionElementOnEpochValuesEqualShou }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionOnEpochValues)) @@ -67,7 +69,7 @@ func TestNewHeaderIntegrityVerifier_InvalidVersionElementOnStringTooLongShouldEr }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionStringTooLong)) @@ -79,7 +81,7 @@ func TestNewHeaderIntegrityVerifierr_InvalidDefaultVersionShouldErr(t *testing.T hdrIntVer, err := NewHeaderVersionHandler( versionsCorrectlyConstructed, "", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidSoftwareVersion)) @@ -103,7 +105,7 @@ func TestNewHeaderIntegrityVerifier_EmptyListShouldErr(t *testing.T) { hdrIntVer, err := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrEmptyVersionsByEpochsList)) @@ -120,7 +122,7 @@ func TestNewHeaderIntegrityVerifier_ZerothElementIsNotOnEpochZeroShouldErr(t *te }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionOnEpochValues)) @@ -132,7 +134,7 @@ func TestNewHeaderIntegrityVerifier_ShouldWork(t *testing.T) { hdrIntVer, err := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.False(t, check.IfNil(hdrIntVer)) require.NoError(t, err) @@ -147,7 +149,7 @@ func TestHeaderIntegrityVerifier_PopulatedReservedShouldErr(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify(hdr) require.Equal(t, process.ErrReservedFieldInvalid, err) @@ -159,7 +161,7 @@ func TestHeaderIntegrityVerifier_VerifySoftwareVersionEmptyVersionInHeaderShould hdrIntVer, _ := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify(&block.MetaBlock{}) require.True(t, errors.Is(err, ErrInvalidSoftwareVersion)) @@ -180,7 +182,7 @@ func TestHeaderIntegrityVerifierr_VerifySoftwareVersionWrongVersionShouldErr(t * }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify( &block.MetaBlock{ @@ -207,7 +209,7 @@ func TestHeaderIntegrityVerifier_VerifySoftwareVersionWildcardShouldWork(t *test }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify( &block.MetaBlock{ @@ -227,7 +229,7 @@ func TestHeaderIntegrityVerifier_VerifyShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, "software", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) mb := &block.MetaBlock{ SoftwareVersion: []byte("software"), @@ -243,7 +245,7 @@ func TestHeaderIntegrityVerifier_VerifyNotWildcardShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, "software", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) mb := &block.MetaBlock{ SoftwareVersion: []byte("v1"), @@ -260,7 +262,7 @@ func TestHeaderIntegrityVerifier_GetVersionShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{ + &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) bool { atomic.AddUint32(&numPutCalls, 1) epoch := binary.BigEndian.Uint32(key) @@ -311,7 +313,7 @@ func TestHeaderIntegrityVerifier_ExistsInInternalCacheShouldReturn(t *testing.T) hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return cachedVersion, true }, diff --git a/factory/bootstrap/bootstrapComponents.go b/factory/bootstrap/bootstrapComponents.go index a9ef7851ccb..af6dc1aafa3 100644 --- a/factory/bootstrap/bootstrapComponents.go +++ b/factory/bootstrap/bootstrapComponents.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" + nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" @@ -24,23 +26,23 @@ import ( storageFactory "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/latestData" "github.com/multiversx/mx-chain-go/storage/storageunit" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("factory") // BootstrapComponentsFactoryArgs holds the arguments needed to create a bootstrap components factory type BootstrapComponentsFactoryArgs struct { - Config config.Config - RoundConfig config.RoundConfig - PrefConfig config.Preferences - ImportDbConfig config.ImportDbConfig - FlagsConfig config.ContextFlagsConfig - WorkingDir string - CoreComponents factory.CoreComponentsHolder - CryptoComponents factory.CryptoComponentsHolder - NetworkComponents factory.NetworkComponentsHolder - StatusCoreComponents factory.StatusCoreComponentsHolder + Config config.Config + RoundConfig config.RoundConfig + PrefConfig config.Preferences + ImportDbConfig config.ImportDbConfig + FlagsConfig config.ContextFlagsConfig + WorkingDir string + CoreComponents factory.CoreComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + StatusCoreComponents factory.StatusCoreComponentsHolder + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type bootstrapComponentsFactory struct { @@ -224,6 +226,7 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { NodeProcessingMode: common.GetNodeProcessingMode(&bcf.importDbConfig), StateStatsHandler: bcf.statusCoreComponents.StateStatsHandler(), NodesCoordinatorRegistryFactory: nodesCoordinatorRegistryFactory, + EnableEpochsHandler: bcf.coreComponents.EnableEpochsHandler(), } var epochStartBootstrapper factory.EpochStartBootstrapper diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index c031744f12d..170638a7268 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/core/watchdog" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-storage-go/timecache" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/config" @@ -16,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/blacklist" "github.com/multiversx/mx-chain-go/consensus/chronology" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls/proxy" "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/errors" @@ -29,14 +33,14 @@ import ( "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/update" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/multiversx/mx-chain-storage-go/timecache" ) var log = logger.GetOrCreate("factory") const defaultSpan = 300 * time.Second +const numSignatureGoRoutinesThrottler = 30 + // ConsensusComponentsFactoryArgs holds the arguments needed to create a consensus components factory type ConsensusComponentsFactoryArgs struct { Config config.Config @@ -160,6 +164,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { ccf.processComponents.InterceptorsContainer(), ccf.coreComponents.AlarmScheduler(), ccf.cryptoComponents.KeysHandler(), + ccf.config.ConsensusGradualBroadcast, ) if err != nil { return nil, err @@ -202,6 +207,7 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), NodeRedundancyHandler: ccf.processComponents.NodeRedundancyHandler(), PeerBlacklistHandler: cc.peerBlacklistHandler, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } cc.worker, err = spos.NewWorker(workerArgs) @@ -249,6 +255,8 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { MessageSigningHandler: p2pSigningHandler, PeerBlacklistHandler: cc.peerBlacklistHandler, SigningHandler: ccf.cryptoComponents.ConsensusSigningHandler(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + EquivalentProofsPool: ccf.dataComponents.Datapool().Proofs(), } consensusDataContainer, err := spos.NewConsensusCore( @@ -257,28 +265,34 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { if err != nil { return nil, err } - - fct, err := sposFactory.GetSubroundsFactory( - consensusDataContainer, - consensusState, - cc.worker, - ccf.config.Consensus.Type, - ccf.statusCoreComponents.AppStatusHandler(), - ccf.statusComponents.OutportHandler(), - ccf.processComponents.SentSignaturesTracker(), - []byte(ccf.coreComponents.ChainID()), - ccf.networkComponents.NetworkMessenger().ID(), - ) + signatureThrottler, err := throttler.NewNumGoRoutinesThrottler(numSignatureGoRoutinesThrottler) if err != nil { return nil, err } - err = fct.GenerateSubrounds() + subroundsHandlerArgs := &proxy.SubroundsHandlerArgs{ + Chronology: cc.chronology, + ConsensusCoreHandler: consensusDataContainer, + ConsensusState: consensusState, + Worker: cc.worker, + SignatureThrottler: signatureThrottler, + AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), + OutportHandler: ccf.statusComponents.OutportHandler(), + SentSignatureTracker: ccf.processComponents.SentSignaturesTracker(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + ChainID: []byte(ccf.coreComponents.ChainID()), + CurrentPid: ccf.networkComponents.NetworkMessenger().ID(), + } + + subroundsHandler, err := proxy.NewSubroundsHandler(subroundsHandlerArgs) if err != nil { return nil, err } - cc.chronology.StartRounds() + err = subroundsHandler.Start(epoch) + if err != nil { + return nil, err + } err = ccf.addCloserInstances(cc.chronology, cc.bootstrapper, cc.worker, ccf.coreComponents.SyncTimer()) if err != nil { @@ -485,6 +499,7 @@ func (ccf *consensusComponentsFactory) createShardBootstrapper() (process.Bootst ScheduledTxsExecutionHandler: ccf.processComponents.ScheduledTxsExecutionHandler(), ProcessWaitTime: time.Duration(ccf.config.GeneralSettings.SyncProcessTimeInMillis) * time.Millisecond, RepopulateTokensSupplies: ccf.flagsConfig.RepopulateTokensSupplies, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -615,6 +630,7 @@ func (ccf *consensusComponentsFactory) createMetaChainBootstrapper() (process.Bo ScheduledTxsExecutionHandler: ccf.processComponents.ScheduledTxsExecutionHandler(), ProcessWaitTime: time.Duration(ccf.config.GeneralSettings.SyncProcessTimeInMillis) * time.Millisecond, RepopulateTokensSupplies: ccf.flagsConfig.RepopulateTokensSupplies, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ diff --git a/factory/consensus/consensusComponents_test.go b/factory/consensus/consensusComponents_test.go index a8f175c0b52..d13318ba2b5 100644 --- a/factory/consensus/consensusComponents_test.go +++ b/factory/consensus/consensusComponents_test.go @@ -8,7 +8,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-crypto-go" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" retriever "github.com/multiversx/mx-chain-go/dataRetriever" @@ -21,9 +23,11 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" @@ -38,7 +42,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/update" - "github.com/stretchr/testify/require" ) func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponentsFactoryArgs { @@ -91,14 +94,17 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent DataComponents: &testsMocks.DataComponentsStub{ DataPool: &dataRetriever.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, HeadersCalled: func() retriever.HeadersPool { return &testsMocks.HeadersCacherStub{} }, + ProofsCalled: func() retriever.ProofsPool { + return &dataRetrieverMocks.ProofsPoolMock{} + }, }, BlockChain: &testscommon.ChainHandlerStub{ GetGenesisHeaderHashCalled: func() []byte { @@ -137,7 +143,7 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent CurrentEpochProviderInternal: &testsMocks.CurrentNetworkEpochProviderStub{}, HistoryRepositoryInternal: &dblookupext.HistoryRepositoryStub{}, IntContainer: &testscommon.InterceptorsContainerStub{}, - HeaderSigVerif: &testsMocks.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensusMocks.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, FallbackHdrValidator: &testscommon.FallBackHeaderValidatorStub{}, SentSignaturesTrackerInternal: &testscommon.SentSignatureTrackerStub{}, @@ -743,7 +749,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { cnt := 0 processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { cnt++ - if cnt > 9 { + if cnt >= 10 { return nil // createConsensusTopic fails } return testscommon.NewMultiShardsCoordinatorMock(2) @@ -834,28 +840,6 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { require.True(t, strings.Contains(err.Error(), "signing handler")) require.Nil(t, cc) }) - t.Run("GetSubroundsFactory failure should error", func(t *testing.T) { - t.Parallel() - - args := createMockConsensusComponentsFactoryArgs() - statusCoreCompStub, ok := args.StatusCoreComponents.(*factoryMocks.StatusCoreComponentsStub) - require.True(t, ok) - cnt := 0 - statusCoreCompStub.AppStatusHandlerCalled = func() core.AppStatusHandler { - cnt++ - if cnt > 4 { - return nil - } - return &statusHandler.AppStatusHandlerStub{} - } - ccf, _ := consensusComp.NewConsensusComponentsFactory(args) - require.NotNil(t, ccf) - - cc, err := ccf.Create() - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "AppStatusHandler")) - require.Nil(t, cc) - }) t.Run("addCloserInstances failure should error", func(t *testing.T) { t.Parallel() diff --git a/factory/heartbeat/heartbeatV2Components_test.go b/factory/heartbeat/heartbeatV2Components_test.go index 9a0eb3b14e3..f605bc67b9c 100644 --- a/factory/heartbeat/heartbeatV2Components_test.go +++ b/factory/heartbeat/heartbeatV2Components_test.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" errorsMx "github.com/multiversx/mx-chain-go/errors" @@ -14,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/cache" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -23,7 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2ComponentsFactory { @@ -54,10 +56,10 @@ func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2Co DataComponents: &testsMocks.DataComponentsStub{ DataPool: &dataRetriever.PoolsHolderStub{ PeerAuthenticationsCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, HeartbeatsCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, }, BlockChain: &testscommon.ChainHandlerStub{}, diff --git a/factory/interface.go b/factory/interface.go index 4c66676fe48..85331045ecc 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -14,6 +14,8 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" @@ -37,7 +39,6 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // EpochStartNotifier defines which actions should be done for handling new epoch's events @@ -385,6 +386,8 @@ type ConsensusWorker interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // AddReceivedProofHandler adds a new handler function for a received proof + AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers RemoveAllReceivedMessagesCalls() // ProcessReceivedMessage method redirects the received message to the channel which should handle it @@ -397,10 +400,12 @@ type ConsensusWorker interface { ExecuteStoredMessages() // DisplayStatistics method displays statistics of worker at the end of the round DisplayStatistics() - // ResetConsensusMessages resets at the start of each round all the previous consensus messages received + // ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs ResetConsensusMessages() // ReceivedHeader method is a wired method through which worker will receive headers from network ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) + // ReceivedProof will handle a received proof in consensus worker + ReceivedProof(proofHandler consensus.ProofHandler) // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool } diff --git a/factory/mock/headerSigVerifierStub.go b/factory/mock/headerSigVerifierStub.go deleted file mode 100644 index 03a7e9b2658..00000000000 --- a/factory/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,49 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/factory/peerSignatureHandler/peerSignatureHandler_test.go b/factory/peerSignatureHandler/peerSignatureHandler_test.go index 15395f65379..9f01857b73d 100644 --- a/factory/peerSignatureHandler/peerSignatureHandler_test.go +++ b/factory/peerSignatureHandler/peerSignatureHandler_test.go @@ -7,11 +7,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + errorsErd "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" - "github.com/stretchr/testify/assert" ) func TestNewPeerSignatureHandler_NilCacherShouldErr(t *testing.T) { @@ -31,7 +32,7 @@ func TestNewPeerSignatureHandler_NilSingleSignerShouldErr(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), nil, &cryptoMocks.KeyGenStub{}, ) @@ -44,7 +45,7 @@ func TestNewPeerSignatureHandler_NilKeyGenShouldErr(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, nil, ) @@ -57,7 +58,7 @@ func TestNewPeerSignatureHandler_OkParamsShouldWork(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -70,7 +71,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidPk(t *testing.T) { t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -83,7 +84,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidPID(t *testing.T) { t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -96,7 +97,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidSignature(t *testing.T) t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -116,7 +117,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureCantGetPubKeyBytes(t *testing.T } peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, keyGen, ) @@ -133,7 +134,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureSigNotFoundInCache(t *testing.T pid := "dummy peer" sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -179,7 +180,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureWrongEntryInCache(t *testing.T) pid := "dummy peer" sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() cache.Put(pk, wrongType, len(wrongType)) keyGen := &cryptoMocks.KeyGenStub{ @@ -228,7 +229,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureNewPidAndSig(t *testing.T) { newPid := core.PeerID("new dummy peer") newSig := []byte("new sig") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -277,7 +278,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentPid(t *testing.T) { sig := []byte("signature") newPid := core.PeerID("new dummy peer") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -317,7 +318,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentSig(t *testing.T) { sig := []byte("signature") newSig := []byte("new signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -356,7 +357,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureGetFromCache(t *testing.T) { pid := core.PeerID("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -399,7 +400,7 @@ func TestPeerSignatureHandler_GetPeerSignatureErrInConvertingPrivateKeyToByteArr pid := []byte("dummy peer") peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -422,7 +423,7 @@ func TestPeerSignatureHandler_GetPeerSignatureNotPresentInCache(t *testing.T) { pid := []byte("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -465,7 +466,7 @@ func TestPeerSignatureHandler_GetPeerSignatureWrongEntryInCache(t *testing.T) { sig := []byte("signature") wrongEntry := []byte("wrong entry") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -511,7 +512,7 @@ func TestPeerSignatureHandler_GetPeerSignatureDifferentPidInCache(t *testing.T) sig := []byte("signature") newSig := []byte("new signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -555,7 +556,7 @@ func TestPeerSignatureHandler_GetPeerSignatureGetFromCache(t *testing.T) { pid := []byte("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { return nil, nil diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index 099fec4a82d..8b01c44c8f8 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -8,6 +8,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" dataComp "github.com/multiversx/mx-chain-go/factory/data" @@ -26,8 +29,6 @@ import ( storageManager "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func Test_newBlockProcessorCreatorForShard(t *testing.T) { diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index 482343bbadf..dd5075d5dfd 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -57,6 +57,7 @@ import ( "github.com/multiversx/mx-chain-go/process/factory/interceptorscontainer" "github.com/multiversx/mx-chain-go/process/headerCheck" "github.com/multiversx/mx-chain-go/process/heartbeat/validator" + interceptorFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/process/peer" "github.com/multiversx/mx-chain-go/process/receipts" "github.com/multiversx/mx-chain-go/process/smartContract" @@ -133,6 +134,7 @@ type processComponents struct { receiptsRepository mainFactory.ReceiptsRepository sentSignaturesTracker process.SentSignaturesTracker epochSystemSCProcessor process.EpochStartSystemSCProcessor + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // ProcessComponentsFactoryArgs holds the arguments needed to create a process components factory @@ -208,6 +210,8 @@ type processComponentsFactory struct { genesisNonce uint64 genesisRound uint64 + + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewProcessComponentsFactory will return a new instance of processComponentsFactory @@ -217,37 +221,43 @@ func NewProcessComponentsFactory(args ProcessComponentsFactoryArgs) (*processCom return nil, err } + interceptedDataVerifierFactory := interceptorFactory.NewInterceptedDataVerifierFactory(interceptorFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Duration(args.Config.InterceptedDataVerifier.CacheSpanInSec) * time.Second, + CacheExpiry: time.Duration(args.Config.InterceptedDataVerifier.CacheExpiryInSec) * time.Second, + }) + return &processComponentsFactory{ - config: args.Config, - epochConfig: args.EpochConfig, - prefConfigs: args.PrefConfigs, - importDBConfig: args.ImportDBConfig, - economicsConfig: args.EconomicsConfig, - accountsParser: args.AccountsParser, - smartContractParser: args.SmartContractParser, - gasSchedule: args.GasSchedule, - nodesCoordinator: args.NodesCoordinator, - data: args.Data, - coreData: args.CoreData, - crypto: args.Crypto, - state: args.State, - network: args.Network, - bootstrapComponents: args.BootstrapComponents, - statusComponents: args.StatusComponents, - requestedItemsHandler: args.RequestedItemsHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - maxRating: args.MaxRating, - systemSCConfig: args.SystemSCConfig, - importStartHandler: args.ImportStartHandler, - historyRepo: args.HistoryRepo, - epochNotifier: args.CoreData.EpochNotifier(), - statusCoreComponents: args.StatusCoreComponents, - flagsConfig: args.FlagsConfig, - txExecutionOrderHandler: args.TxExecutionOrderHandler, - genesisNonce: args.GenesisNonce, - genesisRound: args.GenesisRound, - roundConfig: args.RoundConfig, + config: args.Config, + epochConfig: args.EpochConfig, + prefConfigs: args.PrefConfigs, + importDBConfig: args.ImportDBConfig, + economicsConfig: args.EconomicsConfig, + accountsParser: args.AccountsParser, + smartContractParser: args.SmartContractParser, + gasSchedule: args.GasSchedule, + nodesCoordinator: args.NodesCoordinator, + data: args.Data, + coreData: args.CoreData, + crypto: args.Crypto, + state: args.State, + network: args.Network, + bootstrapComponents: args.BootstrapComponents, + statusComponents: args.StatusComponents, + requestedItemsHandler: args.RequestedItemsHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + maxRating: args.MaxRating, + systemSCConfig: args.SystemSCConfig, + importStartHandler: args.ImportStartHandler, + historyRepo: args.HistoryRepo, + epochNotifier: args.CoreData.EpochNotifier(), + statusCoreComponents: args.StatusCoreComponents, + flagsConfig: args.FlagsConfig, + txExecutionOrderHandler: args.TxExecutionOrderHandler, + genesisNonce: args.GenesisNonce, + genesisRound: args.GenesisRound, + roundConfig: args.RoundConfig, + interceptedDataVerifierFactory: interceptedDataVerifierFactory, }, nil } @@ -284,6 +294,8 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { SingleSigVerifier: pcf.crypto.BlockSigner(), KeyGen: pcf.crypto.BlockSignKeyGen(), FallbackHeaderValidator: fallbackHeaderValidator, + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), + HeadersPool: pcf.data.Datapool().Headers(), } headerSigVerifier, err := headerCheck.NewHeaderSigVerifier(argsHeaderSig) if err != nil { @@ -762,6 +774,7 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { accountsParser: pcf.accountsParser, receiptsRepository: receiptsRepository, sentSignaturesTracker: sentSignaturesTracker, + interceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, }, nil } @@ -1324,17 +1337,19 @@ func (pcf *processComponentsFactory) newBlockTracker( ) (process.BlockTracker, error) { shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() argBaseTracker := track.ArgBaseTracker{ - Hasher: pcf.coreData.Hasher(), - HeaderValidator: headerValidator, - Marshalizer: pcf.coreData.InternalMarshalizer(), - RequestHandler: requestHandler, - RoundHandler: pcf.coreData.RoundHandler(), - ShardCoordinator: shardCoordinator, - Store: pcf.data.StorageService(), - StartHeaders: genesisBlocks, - PoolsHolder: pcf.data.Datapool(), - WhitelistHandler: pcf.whiteListHandler, - FeeHandler: pcf.coreData.EconomicsData(), + Hasher: pcf.coreData.Hasher(), + HeaderValidator: headerValidator, + Marshalizer: pcf.coreData.InternalMarshalizer(), + RequestHandler: requestHandler, + RoundHandler: pcf.coreData.RoundHandler(), + ShardCoordinator: shardCoordinator, + Store: pcf.data.StorageService(), + StartHeaders: genesisBlocks, + PoolsHolder: pcf.data.Datapool(), + WhitelistHandler: pcf.whiteListHandler, + FeeHandler: pcf.coreData.EconomicsData(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), + ProofsPool: pcf.data.Datapool().Proofs(), } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { @@ -1666,36 +1681,37 @@ func (pcf *processComponentsFactory) newShardInterceptorContainerFactory( ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) shardInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: pcf.coreData, - CryptoComponents: pcf.crypto, - Accounts: pcf.state.AccountsAdapter(), - ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), - NodesCoordinator: pcf.nodesCoordinator, - MainMessenger: pcf.network.NetworkMessenger(), - FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), - Store: pcf.data.StorageService(), - DataPool: pcf.data.Datapool(), - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: pcf.coreData.EconomicsData(), - BlockBlackList: headerBlackList, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: headerIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: pcf.whiteListHandler, - WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, - AntifloodHandler: pcf.network.InputAntiFloodHandler(), - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), - SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, - RequestHandler: requestHandler, - PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), - SignaturesHandler: pcf.network.NetworkMessenger(), - HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: mainPeerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: nodeOperationMode, + CoreComponents: pcf.coreData, + CryptoComponents: pcf.crypto, + Accounts: pcf.state.AccountsAdapter(), + ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), + NodesCoordinator: pcf.nodesCoordinator, + MainMessenger: pcf.network.NetworkMessenger(), + FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), + Store: pcf.data.StorageService(), + DataPool: pcf.data.Datapool(), + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: pcf.coreData.EconomicsData(), + BlockBlackList: headerBlackList, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: headerIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: pcf.whiteListHandler, + WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, + AntifloodHandler: pcf.network.InputAntiFloodHandler(), + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), + SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, + RequestHandler: requestHandler, + PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), + SignaturesHandler: pcf.network.NetworkMessenger(), + HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: mainPeerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } interceptorContainerFactory, err := interceptorscontainer.NewShardInterceptorsContainerFactory(shardInterceptorsContainerFactoryArgs) @@ -1719,36 +1735,37 @@ func (pcf *processComponentsFactory) newMetaInterceptorContainerFactory( ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) metaInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: pcf.coreData, - CryptoComponents: pcf.crypto, - ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), - NodesCoordinator: pcf.nodesCoordinator, - MainMessenger: pcf.network.NetworkMessenger(), - FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), - Store: pcf.data.StorageService(), - DataPool: pcf.data.Datapool(), - Accounts: pcf.state.AccountsAdapter(), - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: pcf.coreData.EconomicsData(), - BlockBlackList: headerBlackList, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: headerIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: pcf.whiteListHandler, - WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, - AntifloodHandler: pcf.network.InputAntiFloodHandler(), - ArgumentsParser: smartContract.NewArgumentParser(), - SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, - PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), - RequestHandler: requestHandler, - PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), - SignaturesHandler: pcf.network.NetworkMessenger(), - HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: mainPeerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: nodeOperationMode, + CoreComponents: pcf.coreData, + CryptoComponents: pcf.crypto, + ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), + NodesCoordinator: pcf.nodesCoordinator, + MainMessenger: pcf.network.NetworkMessenger(), + FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), + Store: pcf.data.StorageService(), + DataPool: pcf.data.Datapool(), + Accounts: pcf.state.AccountsAdapter(), + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: pcf.coreData.EconomicsData(), + BlockBlackList: headerBlackList, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: headerIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: pcf.whiteListHandler, + WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, + AntifloodHandler: pcf.network.InputAntiFloodHandler(), + ArgumentsParser: smartContract.NewArgumentParser(), + SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, + PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), + RequestHandler: requestHandler, + PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), + SignaturesHandler: pcf.network.NetworkMessenger(), + HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: mainPeerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } interceptorContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(metaInterceptorsContainerFactoryArgs) @@ -1848,6 +1865,7 @@ func (pcf *processComponentsFactory) createExportFactoryHandler( NumConcurrentTrieSyncers: pcf.config.TrieSync.NumConcurrentTrieSyncers, TrieSyncerVersion: pcf.config.TrieSync.TrieSyncerVersion, NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } return updateFactory.NewExportHandlerFactory(argsExporter) } @@ -2045,6 +2063,9 @@ func (pc *processComponents) Close() error { if !check.IfNil(pc.txsSender) { log.LogIfError(pc.txsSender.Close()) } + if !check.IfNil(pc.interceptedDataVerifierFactory) { + log.LogIfError(pc.interceptedDataVerifierFactory.Close()) + } return nil } diff --git a/factory/processing/processComponents_test.go b/factory/processing/processComponents_test.go index a1654ce3ba3..6ddf5ea2d8b 100644 --- a/factory/processing/processComponents_test.go +++ b/factory/processing/processComponents_test.go @@ -17,6 +17,8 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/hashing/keccak" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/factory" disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" @@ -55,7 +57,6 @@ import ( testState "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" updateMocks "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) const ( diff --git a/fallback/headerValidator.go b/fallback/headerValidator.go index 8e2d0eda037..4b9110582b0 100644 --- a/fallback/headerValidator.go +++ b/fallback/headerValidator.go @@ -5,10 +5,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("fallback") @@ -45,28 +46,34 @@ func NewFallbackHeaderValidator( return hv, nil } -// ShouldApplyFallbackValidation returns if for the given header could be applied fallback validation or not -func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool { - if check.IfNil(headerHandler) { - return false - } - if headerHandler.GetShardID() != core.MetachainShardId { +// ShouldApplyFallbackValidationForHeaderWith returns if for the given header data fallback validation could be applied or not +func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool { + if shardID != core.MetachainShardId { return false } - if !headerHandler.IsStartOfEpochBlock() { + if !startOfEpochBlock { return false } - previousHeader, err := process.GetMetaHeader(headerHandler.GetPrevHash(), fhv.headersPool, fhv.marshalizer, fhv.storageService) + previousHeader, err := process.GetMetaHeader(prevHeaderHash, fhv.headersPool, fhv.marshalizer, fhv.storageService) if err != nil { log.Debug("ShouldApplyFallbackValidation", "GetMetaHeader", err.Error()) return false } - isRoundTooOld := int64(headerHandler.GetRound())-int64(previousHeader.GetRound()) >= common.MaxRoundsWithoutCommittedStartInEpochBlock + isRoundTooOld := int64(round)-int64(previousHeader.GetRound()) >= common.MaxRoundsWithoutCommittedStartInEpochBlock return isRoundTooOld } +// ShouldApplyFallbackValidation returns if for the given header could be applied fallback validation or not +func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool { + if check.IfNil(headerHandler) { + return false + } + + return fhv.ShouldApplyFallbackValidationForHeaderWith(headerHandler.GetShardID(), headerHandler.IsStartOfEpochBlock(), headerHandler.GetRound(), headerHandler.GetPrevHash()) +} + // IsInterfaceNil returns true if there is no value under the interface func (fhv *fallbackHeaderValidator) IsInterfaceNil() bool { return fhv == nil diff --git a/go.mod b/go.mod index 1b381e3a86f..895eb3ea982 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/klauspost/cpuid/v2 v2.2.5 github.com/mitchellh/mapstructure v1.5.0 github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e - github.com/multiversx/mx-chain-core-go v1.2.21-0.20240530111258-45870512bfbe + github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030 github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df github.com/multiversx/mx-chain-es-indexer-go v1.7.2-0.20240619122842-05143459c554 github.com/multiversx/mx-chain-logger-go v1.0.15-0.20240508072523-3f00a726af57 @@ -33,6 +33,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/urfave/cli v1.22.10 golang.org/x/crypto v0.10.0 + golang.org/x/exp v0.0.0-20230321023759-10a507213a29 gopkg.in/go-playground/validator.v8 v8.18.2 ) @@ -173,7 +174,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.24.0 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.11.0 // indirect golang.org/x/sync v0.2.0 // indirect diff --git a/go.sum b/go.sum index f7cc76137bf..7391ce4b459 100644 --- a/go.sum +++ b/go.sum @@ -387,8 +387,8 @@ github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUY github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e h1:Tsmwhu+UleE+l3buPuqXSKTqfu5FbPmzQ4MjMoUvCWA= github.com/multiversx/mx-chain-communication-go v1.0.15-0.20240508074652-e128a1c05c8e/go.mod h1:2yXl18wUbuV3cRZr7VHxM1xo73kTaC1WUcu2kx8R034= -github.com/multiversx/mx-chain-core-go v1.2.21-0.20240530111258-45870512bfbe h1:7ccy0nNJkCGDlRrIbAmZfVv5XkZAxXuBFnfUMNuESRA= -github.com/multiversx/mx-chain-core-go v1.2.21-0.20240530111258-45870512bfbe/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= +github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030 h1:4XI4z1ceZC3OUXxTeMQD+6gmTgu9I934nsYlV6P8X4A= +github.com/multiversx/mx-chain-core-go v1.2.21-0.20241204105459-ddd46264c030/go.mod h1:B5zU4MFyJezmEzCsAHE9YNULmGCm2zbPHvl9hazNxmE= github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df h1:clihfi78bMEOWk/qw6WA4uQbCM2e2NGliqswLAvw19k= github.com/multiversx/mx-chain-crypto-go v1.2.12-0.20240508074452-cc21c1b505df/go.mod h1:gtJYB4rR21KBSqJlazn+2z6f9gFSqQP3KvAgL7Qgxw4= github.com/multiversx/mx-chain-es-indexer-go v1.7.2-0.20240619122842-05143459c554 h1:Fv8BfzJSzdovmoh9Jh/by++0uGsOVBlMP3XiN5Svkn4= diff --git a/heartbeat/monitor/monitor_test.go b/heartbeat/monitor/monitor_test.go index 83ae428fbee..02524882220 100644 --- a/heartbeat/monitor/monitor_test.go +++ b/heartbeat/monitor/monitor_test.go @@ -9,19 +9,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/data" "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - "github.com/stretchr/testify/assert" ) func createMockHeartbeatV2MonitorArgs() ArgHeartbeatV2Monitor { return ArgHeartbeatV2Monitor{ - Cache: testscommon.NewCacherMock(), + Cache: cache.NewCacherMock(), PubKeyConverter: &testscommon.PubkeyConverterMock{}, Marshaller: &marshallerMock.MarshalizerMock{}, MaxDurationPeerUnresponsive: time.Second * 3, diff --git a/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go b/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go index 39e21d9eb80..958ee50879b 100644 --- a/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go +++ b/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go @@ -14,18 +14,20 @@ import ( mxAtomic "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/random" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgPeerAuthenticationRequestsProcessor() ArgPeerAuthenticationRequestsProcessor { return ArgPeerAuthenticationRequestsProcessor{ RequestHandler: &testscommon.RequestHandlerStub{}, NodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, - PeerAuthenticationPool: &testscommon.CacherMock{}, + PeerAuthenticationPool: &cache.CacherMock{}, ShardId: 0, Epoch: 0, MinPeersThreshold: 0.8, @@ -200,7 +202,7 @@ func TestPeerAuthenticationRequestsProcessor_startRequestingMessages(t *testing. }, } - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { return providedEligibleKeysMap[0] }, @@ -236,7 +238,7 @@ func TestPeerAuthenticationRequestsProcessor_isThresholdReached(t *testing.T) { args := createMockArgPeerAuthenticationRequestsProcessor() args.MinPeersThreshold = 0.6 counter := uint32(0) - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { var keys = make([][]byte, 0) switch atomic.LoadUint32(&counter) { @@ -323,7 +325,7 @@ func TestPeerAuthenticationRequestsProcessor_goRoutineIsWorkingAndCloseShouldSto }, } keysCalled := &mxAtomic.Flag{} - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { keysCalled.SetValue(true) return make([][]byte, 0) diff --git a/heartbeat/status/metricsUpdater_test.go b/heartbeat/status/metricsUpdater_test.go index 645f4edb0dd..c9cfd4e16df 100644 --- a/heartbeat/status/metricsUpdater_test.go +++ b/heartbeat/status/metricsUpdater_test.go @@ -8,18 +8,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/data" "github.com/multiversx/mx-chain-go/heartbeat/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func createMockArgsMetricsUpdater() ArgsMetricsUpdater { return ArgsMetricsUpdater{ - PeerAuthenticationCacher: testscommon.NewCacherMock(), + PeerAuthenticationCacher: cache.NewCacherMock(), HeartbeatMonitor: &mock.HeartbeatMonitorStub{}, HeartbeatSenderInfoProvider: &mock.HeartbeatSenderInfoProviderStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, diff --git a/integrationTests/chainSimulator/staking/jail/jail_test.go b/integrationTests/chainSimulator/staking/jail/jail_test.go index 42c4e69eaca..f3e920a4dbf 100644 --- a/integrationTests/chainSimulator/staking/jail/jail_test.go +++ b/integrationTests/chainSimulator/staking/jail/jail_test.go @@ -24,8 +24,7 @@ import ( const ( stakingV4JailUnJailStep1EnableEpoch = 5 defaultPathToInitialConfig = "../../../../cmd/node/config/" - - epochWhenNodeIsJailed = 4 + epochWhenNodeIsJailed = 4 ) // Test description @@ -79,6 +78,8 @@ func testChainSimulatorJailAndUnJail(t *testing.T, targetEpoch int32, nodeStatus MetaChainMinNodes: 2, AlterConfigsFunction: func(cfg *config.Configs) { configs.SetStakingV4ActivationEpochs(cfg, stakingV4JailUnJailStep1EnableEpoch) + cfg.EpochConfig.EnableEpochs.FixedOrderInConsensusEnableEpoch = 100 + cfg.EpochConfig.EnableEpochs.EquivalentMessagesEnableEpoch = 100 newNumNodes := cfg.SystemSCConfig.StakingSystemSCConfig.MaxNumberOfNodesForStake + 8 // 8 nodes until new nodes will be placed on queue configs.SetMaxNumberOfNodesInConfigs(cfg, uint32(newNumNodes), 0, numOfShards) configs.SetQuickJailRatingConfig(cfg) diff --git a/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go b/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go index 4c7475701e4..392bce9ff02 100644 --- a/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go +++ b/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/integrationTests" chainSimulatorIntegrationTests "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator/staking" "github.com/multiversx/mx-chain-go/node/chainSimulator" @@ -94,6 +95,10 @@ func TestChainSimulator_MakeNewContractFromValidatorData(t *testing.T) { cfg.EpochConfig.EnableEpochs.StakingV4Step3EnableEpoch = 102 cfg.EpochConfig.EnableEpochs.MaxNodesChangeEnableEpoch[2].EpochEnable = 102 + + // TODO[Sorin]: remove this once all equivalent messages PRs are merged + cfg.EpochConfig.EnableEpochs.EquivalentMessagesEnableEpoch = integrationTests.UnreachableEpoch + cfg.EpochConfig.EnableEpochs.FixedOrderInConsensusEnableEpoch = integrationTests.UnreachableEpoch }, }) require.Nil(t, err) @@ -139,6 +144,10 @@ func TestChainSimulator_MakeNewContractFromValidatorData(t *testing.T) { cfg.EpochConfig.EnableEpochs.StakingV4Step3EnableEpoch = 102 cfg.EpochConfig.EnableEpochs.MaxNodesChangeEnableEpoch[2].EpochEnable = 102 + + // TODO[Sorin]: remove this once all equivalent messages PRs are merged + cfg.EpochConfig.EnableEpochs.EquivalentMessagesEnableEpoch = integrationTests.UnreachableEpoch + cfg.EpochConfig.EnableEpochs.FixedOrderInConsensusEnableEpoch = integrationTests.UnreachableEpoch }, }) require.Nil(t, err) diff --git a/integrationTests/consensus/consensusSigning_test.go b/integrationTests/consensus/consensusSigning_test.go index 68f85cde15c..dfa6966f1f0 100644 --- a/integrationTests/consensus/consensusSigning_test.go +++ b/integrationTests/consensus/consensusSigning_test.go @@ -8,8 +8,9 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/integrationTests" ) func initNodesWithTestSigner( @@ -23,6 +24,7 @@ func initNodesWithTestSigner( fmt.Println("Step 1. Setup nodes...") + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() nodes := integrationTests.CreateNodesWithTestConsensusNode( int(numMetaNodes), int(numNodes), @@ -30,6 +32,7 @@ func initNodesWithTestSigner( roundTime, consensusType, 1, + enableEpochsConfig, ) for shardID, nodesList := range nodes { diff --git a/integrationTests/consensus/consensus_test.go b/integrationTests/consensus/consensus_test.go index a94c5717efe..7a480f3ecc0 100644 --- a/integrationTests/consensus/consensus_test.go +++ b/integrationTests/consensus/consensus_test.go @@ -11,13 +11,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/pubkeyConverter" "github.com/multiversx/mx-chain-core-go/data" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/config" consensusComp "github.com/multiversx/mx-chain-go/factory/consensus" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) const ( @@ -52,6 +53,7 @@ func initNodesAndTest( roundTime uint64, consensusType string, numKeysOnEachNode int, + enableEpochsConfig config.EnableEpochs, ) map[uint32][]*integrationTests.TestConsensusNode { fmt.Println("Step 1. Setup nodes...") @@ -63,6 +65,7 @@ func initNodesAndTest( roundTime, consensusType, numKeysOnEachNode, + enableEpochsConfig, ) for shardID, nodesList := range nodes { @@ -229,7 +232,19 @@ func runFullConsensusTest(t *testing.T, consensusType string, numKeysOnEachNode "consensusSize", consensusSize, ) - nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, numKeysOnEachNode) + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + enableEpochsConfig.EquivalentMessagesEnableEpoch = integrationTests.UnreachableEpoch + enableEpochsConfig.FixedOrderInConsensusEnableEpoch = integrationTests.UnreachableEpoch + nodes := initNodesAndTest( + numMetaNodes, + numNodes, + consensusSize, + numInvalid, + roundTime, + consensusType, + numKeysOnEachNode, + enableEpochsConfig, + ) defer func() { for shardID := range nodes { @@ -292,7 +307,10 @@ func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { consensusSize := uint32(4) numInvalid := uint32(2) roundTime := uint64(1000) - nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, 1) + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + enableEpochsConfig.EquivalentMessagesEnableEpoch = integrationTests.UnreachableEpoch + enableEpochsConfig.FixedOrderInConsensusEnableEpoch = integrationTests.UnreachableEpoch + nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, 1, enableEpochsConfig) defer func() { for shardID := range nodes { diff --git a/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go b/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go index 6c525ff9f12..2e9cb01e72a 100644 --- a/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go +++ b/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test BootstrapComponents -------------------- diff --git a/integrationTests/factory/consensusComponents/consensusComponents_test.go b/integrationTests/factory/consensusComponents/consensusComponents_test.go index 1e32c0c574b..d4b120a9636 100644 --- a/integrationTests/factory/consensusComponents/consensusComponents_test.go +++ b/integrationTests/factory/consensusComponents/consensusComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestConsensusComponents -------------------- diff --git a/integrationTests/factory/dataComponents/dataComponents_test.go b/integrationTests/factory/dataComponents/dataComponents_test.go index c28a41c6543..d26cf7aa01f 100644 --- a/integrationTests/factory/dataComponents/dataComponents_test.go +++ b/integrationTests/factory/dataComponents/dataComponents_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) func TestDataComponents_Create_Close_ShouldWork(t *testing.T) { @@ -36,6 +37,7 @@ func TestDataComponents_Create_Close_ShouldWork(t *testing.T) { require.Nil(t, err) managedNetworkComponents, err := nr.CreateManagedNetworkComponents(managedCoreComponents, managedStatusCoreComponents, managedCryptoComponents) require.Nil(t, err) + managedBootstrapComponents, err := nr.CreateManagedBootstrapComponents(managedStatusCoreComponents, managedCoreComponents, managedCryptoComponents, managedNetworkComponents) require.Nil(t, err) managedDataComponents, err := nr.CreateManagedDataComponents(managedStatusCoreComponents, managedCoreComponents, managedBootstrapComponents, managedCryptoComponents) diff --git a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go index 1c541f524ff..889c4ff38f8 100644 --- a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go +++ b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestHeartbeatComponents -------------------- diff --git a/integrationTests/factory/processComponents/processComponents_test.go b/integrationTests/factory/processComponents/processComponents_test.go index 897a1289d2c..110a8869878 100644 --- a/integrationTests/factory/processComponents/processComponents_test.go +++ b/integrationTests/factory/processComponents/processComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestProcessComponents -------------------- diff --git a/integrationTests/factory/stateComponents/stateComponents_test.go b/integrationTests/factory/stateComponents/stateComponents_test.go index 3c942f54e53..ba93bdf8263 100644 --- a/integrationTests/factory/stateComponents/stateComponents_test.go +++ b/integrationTests/factory/stateComponents/stateComponents_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) func TestStateComponents_Create_Close_ShouldWork(t *testing.T) { diff --git a/integrationTests/factory/statusComponents/statusComponents_test.go b/integrationTests/factory/statusComponents/statusComponents_test.go index 85cfbd155f7..38527da6a41 100644 --- a/integrationTests/factory/statusComponents/statusComponents_test.go +++ b/integrationTests/factory/statusComponents/statusComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test StatusComponents -------------------- diff --git a/integrationTests/frontend/staking/staking_test.go b/integrationTests/frontend/staking/staking_test.go index 8cba29bd032..fa29ea091cd 100644 --- a/integrationTests/frontend/staking/staking_test.go +++ b/integrationTests/frontend/staking/staking_test.go @@ -8,12 +8,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/vm" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/vm" ) var log = logger.GetOrCreate("integrationtests/frontend/staking") @@ -64,11 +65,11 @@ func TestSignatureOnStaking(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -109,7 +110,7 @@ func TestSignatureOnStaking(t *testing.T) { nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) diff --git a/integrationTests/miniNetwork.go b/integrationTests/miniNetwork.go index e9c64f5606d..9424a566c07 100644 --- a/integrationTests/miniNetwork.go +++ b/integrationTests/miniNetwork.go @@ -71,10 +71,10 @@ func (n *MiniNetwork) Start() { // Continue advances processing with a number of rounds func (n *MiniNetwork) Continue(t *testing.T, numRounds int) { - idxProposers := []int{0, 1} + leaders := []*TestProcessorNode{n.Nodes[0], n.Nodes[1]} for i := int64(0); i < int64(numRounds); i++ { - n.Nonce, n.Round = ProposeAndSyncOneBlock(t, n.Nodes, idxProposers, n.Round, n.Nonce) + n.Nonce, n.Round = ProposeAndSyncOneBlock(t, n.Nodes, leaders, n.Round, n.Nonce) } } diff --git a/integrationTests/mock/blockProcessorMock.go b/integrationTests/mock/blockProcessorMock.go index fb83fcfb0af..b3f42dd8e52 100644 --- a/integrationTests/mock/blockProcessorMock.go +++ b/integrationTests/mock/blockProcessorMock.go @@ -24,6 +24,7 @@ type BlockProcessorMock struct { CreateNewHeaderCalled func(round uint64, nonce uint64) (data.HeaderHandler, error) PruneStateOnRollbackCalled func(currHeader data.HeaderHandler, currHeaderHash []byte, prevHeader data.HeaderHandler, prevHeaderHash []byte) RevertStateToBlockCalled func(header data.HeaderHandler, rootHash []byte) error + DecodeBlockHeaderCalled func(dta []byte) data.HeaderHandler } // ProcessBlock mocks processing a block @@ -137,6 +138,10 @@ func (bpm *BlockProcessorMock) DecodeBlockBody(dta []byte) data.BodyHandler { // DecodeBlockHeader method decodes block header from a given byte array func (bpm *BlockProcessorMock) DecodeBlockHeader(dta []byte) data.HeaderHandler { + if bpm.DecodeBlockHeaderCalled != nil { + return bpm.DecodeBlockHeaderCalled(dta) + } + if dta == nil { return nil } diff --git a/integrationTests/mock/headerSigVerifierStub.go b/integrationTests/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/integrationTests/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/integrationTests/multiShard/block/common.go b/integrationTests/multiShard/block/common.go index e4fbd7403cc..481a7cf202a 100644 --- a/integrationTests/multiShard/block/common.go +++ b/integrationTests/multiShard/block/common.go @@ -2,28 +2,7 @@ package block import ( "time" - - "github.com/multiversx/mx-chain-go/integrationTests" ) // StepDelay - var StepDelay = time.Second / 10 - -// GetBlockProposersIndexes - -func GetBlockProposersIndexes( - consensusMap map[uint32][]*integrationTests.TestProcessorNode, - nodesMap map[uint32][]*integrationTests.TestProcessorNode, -) map[uint32]int { - - indexProposer := make(map[uint32]int) - - for sh, testNodeList := range nodesMap { - for k, testNode := range testNodeList { - if consensusMap[sh][0] == testNode { - indexProposer[sh] = k - } - } - } - - return indexProposer -} diff --git a/integrationTests/multiShard/block/edgecases/edgecases_test.go b/integrationTests/multiShard/block/edgecases/edgecases_test.go index 534cea84d31..6f041ee8609 100644 --- a/integrationTests/multiShard/block/edgecases/edgecases_test.go +++ b/integrationTests/multiShard/block/edgecases/edgecases_test.go @@ -9,12 +9,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-crypto-go" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" - "github.com/multiversx/mx-chain-go/state" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" + "github.com/multiversx/mx-chain-go/state" ) var log = logger.GetOrCreate("integrationTests/multishard/block") @@ -23,14 +24,14 @@ var log = logger.GetOrCreate("integrationTests/multishard/block") // A validator from shard 0 receives rewards from shard 1 (where it is assigned) and creates move balance // transactions. All other shard peers can and will sync the blocks containing the move balance transactions. func TestExecutingTransactionsFromRewardsFundsCrossShard(t *testing.T) { - //TODO fix this test + // TODO fix this test t.Skip("TODO fix this test") if testing.Short() { t.Skip("this is not a short test") } - //it is important to have all combinations here as to test more edgecases + // it is important to have all combinations here as to test more edgecases mapAssignements := map[uint32][]uint32{ 0: {1, 0}, 1: {0, 1}, @@ -73,17 +74,14 @@ func TestExecutingTransactionsFromRewardsFundsCrossShard(t *testing.T) { firstNode := nodesMap[senderShardID][0] numBlocksProduced := uint64(13) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < numBlocksProduced; i++ { printAccount(firstNode) for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - - indexesProposers := block.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposalData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposalData, nodesMap, round) time.Sleep(block.StepDelay) round++ @@ -132,7 +130,7 @@ func TestMetaShouldBeAbleToProduceBlockInAVeryHighRoundAndStartOfEpoch(t *testin } } - //edge case on the epoch change + // edge case on the epoch change round := roundsPerEpoch*10 - 1 nonce := uint64(1) round = integrationTests.IncrementAndPrintRound(round) @@ -141,9 +139,8 @@ func TestMetaShouldBeAbleToProduceBlockInAVeryHighRoundAndStartOfEpoch(t *testin integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := block.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, nonce) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, nonce) for _, nodes := range nodesMap { for _, node := range nodes { @@ -163,14 +160,14 @@ func closeNodes(nodesMap map[uint32][]*integrationTests.TestProcessorNode) { } } -//nolint +// nolint func checkSameBlockHeight(t *testing.T, nodesMap map[uint32][]*integrationTests.TestProcessorNode) { for _, nodes := range nodesMap { referenceBlock := nodes[0].BlockChain.GetCurrentBlockHeader() for _, n := range nodes { crtBlock := n.BlockChain.GetCurrentBlockHeader() - //(crtBlock == nil) != (blkc == nil) actually does a XOR operation between the 2 conditions - //as if the reference is nil, the same must be all other nodes. Same if the reference is not nil. + // (crtBlock == nil) != (blkc == nil) actually does a XOR operation between the 2 conditions + // as if the reference is nil, the same must be all other nodes. Same if the reference is not nil. require.False(t, (referenceBlock == nil) != (crtBlock == nil)) if !check.IfNil(referenceBlock) { require.Equal(t, referenceBlock.GetNonce(), crtBlock.GetNonce()) @@ -179,7 +176,7 @@ func checkSameBlockHeight(t *testing.T, nodesMap map[uint32][]*integrationTests. } } -//nolint +// nolint func printAccount(node *integrationTests.TestProcessorNode) { accnt, _ := node.AccntState.GetExistingAccount(node.OwnAccount.Address) if check.IfNil(accnt) { diff --git a/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go b/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go index eec61878296..fcf5ec9178c 100644 --- a/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go +++ b/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go @@ -33,7 +33,6 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nodesPerShard := 3 numMetachainNodes := 1 - idxProposers := []int{0, 3, 6, 9, 12, 15, 18} senderShard := uint32(0) recvShards := []uint32{1, 2} round := uint64(0) @@ -47,6 +46,7 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nodesPerShard, numMetachainNodes, ) + leaders := []*integrationTests.TestProcessorNode{nodes[0], nodes[3], nodes[6], nodes[9], nodes[12], nodes[15], nodes[18]} integrationTests.DisplayAndStartNodes(nodes) defer func() { @@ -97,7 +97,7 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nonce++ roundsToWait := 6 for i := 0; i < roundsToWait; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } gasPricePerTxBigInt := big.NewInt(0).SetUint64(integrationTests.MinTxGasPrice) @@ -163,11 +163,11 @@ func TestSimpleTransactionsWithMoreGasWhichYieldInReceiptsInMultiShardedEnvironm node.EconomicsData.SetMinGasLimit(minGasLimit, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -192,8 +192,8 @@ func TestSimpleTransactionsWithMoreGasWhichYieldInReceiptsInMultiShardedEnvironm nrRoundsToTest := 10 for i := 0; i <= nrRoundsToTest; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -253,11 +253,11 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn node.EconomicsData.SetMinGasLimit(minGasLimit, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -294,8 +294,8 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn time.Sleep(2 * time.Second) integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -320,8 +320,8 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn numRoundsToTest := 6 for i := 0; i < numRoundsToTest; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -420,22 +420,22 @@ func TestShouldSubtractTheCorrectTxFee(t *testing.T) { gasPrice, ) - _, _, consensusNodes := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) shardId0 := uint32(0) _ = integrationTests.IncrementAndPrintRound(round) // test sender account decreased its balance with gasPrice * gasLimit - accnt, err := consensusNodes[shardId0][0].AccntState.GetExistingAccount(ownerPk) + accnt, err := proposeData[shardId0].Leader.AccntState.GetExistingAccount(ownerPk) assert.Nil(t, err) ownerAccnt := accnt.(state.UserAccountHandler) expectedBalance := big.NewInt(0).Set(initialVal) tx := &transaction.Transaction{GasPrice: gasPrice, GasLimit: gasLimit, Data: []byte(txData)} - txCost := consensusNodes[shardId0][0].EconomicsData.ComputeTxFee(tx) + txCost := proposeData[shardId0].Leader.EconomicsData.ComputeTxFee(tx) expectedBalance.Sub(expectedBalance, txCost) assert.Equal(t, expectedBalance, ownerAccnt.GetBalance()) - printContainingTxs(consensusNodes[shardId0][0], consensusNodes[shardId0][0].BlockChain.GetCurrentBlockHeader().(*block.Header)) + printContainingTxs(proposeData[shardId0].Leader, proposeData[shardId0].Leader.BlockChain.GetCurrentBlockHeader().(*block.Header)) } func printContainingTxs(tpn *integrationTests.TestProcessorNode, hdr data.HeaderHandler) { diff --git a/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go b/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go index 645448af81a..787efdcab90 100644 --- a/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go +++ b/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go @@ -10,11 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" - testBlock "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - "github.com/stretchr/testify/assert" ) func getLeaderPercentage(node *integrationTests.TestProcessorNode) float64 { @@ -66,8 +66,6 @@ func TestExecuteBlocksWithTransactionsAndCheckRewards(t *testing.T) { nonce := uint64(1) nbBlocksProduced := 7 - var headers map[uint32]data.HeaderHandler - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) mapRewardsForMetachainAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) @@ -76,21 +74,18 @@ func TestExecuteBlocksWithTransactionsAndCheckRewards(t *testing.T) { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, headers, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range proposeData { addrRewards := make([]string, 0) updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) - nbTxs := getTransactionsFromHeaderInShard(t, headers, shardId) + nbTxs := getTransactionsFromHeaderInShard(t, proposeData[shardId].Header, shardId) if len(addrRewards) > 0 { updateNumberTransactionsProposed(t, nbTxsForLeaderAddress, addrRewards[0], nbTxs) } } - updateRewardsForMetachain(mapRewardsForMetachainAddresses, consensusNodes[0][0]) - - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) time.Sleep(integrationTests.StepDelay) @@ -149,18 +144,16 @@ func TestExecuteBlocksWithTransactionsWhichReachedGasLimitAndCheckRewards(t *tes nonce := uint64(1) nbBlocksProduced := 2 - var headers map[uint32]data.HeaderHandler - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) for i := 0; i < nbBlocksProduced; i++ { - _, headers, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range nodesMap { addrRewards := make([]string, 0) updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) - nbTxs := getTransactionsFromHeaderInShard(t, headers, shardId) + nbTxs := getTransactionsFromHeaderInShard(t, proposeData[shardId].Header, shardId) if len(addrRewards) > 0 { updateNumberTransactionsProposed(t, nbTxsForLeaderAddress, addrRewards[0], nbTxs) } @@ -169,8 +162,7 @@ func TestExecuteBlocksWithTransactionsWhichReachedGasLimitAndCheckRewards(t *tes for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ } @@ -213,15 +205,14 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { nonce := uint64(1) nbBlocksProduced := 7 - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) mapRewardsForMetachainAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) for i := 0; i < nbBlocksProduced; i++ { - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range nodesMap { if shardId == core.MetachainShardId { continue } @@ -231,13 +222,10 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) } - updateRewardsForMetachain(mapRewardsForMetachainAddresses, consensusNodes[0][0]) - for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ } @@ -248,16 +236,11 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { verifyRewardsForMetachain(t, mapRewardsForMetachainAddresses, nodesMap) } -func getTransactionsFromHeaderInShard(t *testing.T, headers map[uint32]data.HeaderHandler, shardId uint32) uint32 { +func getTransactionsFromHeaderInShard(t *testing.T, header data.HeaderHandler, shardId uint32) uint32 { if shardId == core.MetachainShardId { return 0 } - header, ok := headers[shardId] - if !ok { - return 0 - } - hdr, ok := header.(*block.Header) if !ok { assert.Error(t, process.ErrWrongTypeAssertion) diff --git a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go index 82eca349947..099864c1dc8 100644 --- a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go +++ b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go @@ -11,8 +11,9 @@ import ( "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/integrationTests" ) const broadcastDelay = 2 * time.Second @@ -57,12 +58,12 @@ func TestInterceptedShardBlockHeaderVerifiedWithCorrectConsensusGroup(t *testing nonce := uint64(1) var err error - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) - header, err = fillHeaderFields(nodesMap[0][0], header, singleSigner) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) + header, err := fillHeaderFields(proposeBlockData.Leader, proposeBlockData.Header, singleSigner) assert.Nil(t, err) pk := nodesMap[0][0].NodeKeys.MainKey.Pk - nodesMap[0][0].BroadcastBlock(body, header, pk) + nodesMap[0][0].BroadcastBlock(proposeBlockData.Body, header, pk) time.Sleep(broadcastDelay) @@ -122,7 +123,7 @@ func TestInterceptedMetaBlockVerifiedWithCorrectConsensusGroup(t *testing.T) { round := uint64(1) nonce := uint64(1) - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature( + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature( core.MetachainShardId, nodesMap, round, @@ -132,13 +133,13 @@ func TestInterceptedMetaBlockVerifiedWithCorrectConsensusGroup(t *testing.T) { ) pk := nodesMap[core.MetachainShardId][0].NodeKeys.MainKey.Pk - nodesMap[core.MetachainShardId][0].BroadcastBlock(body, header, pk) + nodesMap[core.MetachainShardId][0].BroadcastBlock(proposeBlockData.Body, proposeBlockData.Header, pk) time.Sleep(broadcastDelay) - headerBytes, _ := integrationTests.TestMarshalizer.Marshal(header) + headerBytes, _ := integrationTests.TestMarshalizer.Marshal(proposeBlockData.Header) headerHash := integrationTests.TestHasher.Compute(string(headerBytes)) - hmb := header.(*block.MetaBlock) + hmb := proposeBlockData.Header.(*block.MetaBlock) // all nodes in metachain do not have the block in pool as interceptor does not validate it with a wrong consensus for _, metaNode := range nodesMap[core.MetachainShardId] { @@ -197,16 +198,16 @@ func TestInterceptedShardBlockHeaderWithLeaderSignatureAndRandSeedChecks(t *test round := uint64(1) nonce := uint64(1) - body, header, _, consensusNodes := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) - nodeToSendFrom := consensusNodes[0] - err := header.SetPrevRandSeed(randomness) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) + nodeToSendFrom := proposeBlockData.Leader + err := proposeBlockData.Header.SetPrevRandSeed(randomness) assert.Nil(t, err) - header, err = fillHeaderFields(nodeToSendFrom, header, singleSigner) + header, err := fillHeaderFields(nodeToSendFrom, proposeBlockData.Header, singleSigner) assert.Nil(t, err) pk := nodeToSendFrom.NodeKeys.MainKey.Pk - nodeToSendFrom.BroadcastBlock(body, header, pk) + nodeToSendFrom.BroadcastBlock(proposeBlockData.Body, header, pk) time.Sleep(broadcastDelay) @@ -268,14 +269,14 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted wrongRandomness := []byte("wrong randomness") round := uint64(2) nonce := uint64(2) - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, wrongRandomness, 0) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, wrongRandomness, 0) pk := nodesMap[0][0].NodeKeys.MainKey.Pk - nodesMap[0][0].BroadcastBlock(body, header, pk) + nodesMap[0][0].BroadcastBlock(proposeBlockData.Body, proposeBlockData.Header, pk) time.Sleep(broadcastDelay) - headerBytes, _ := integrationTests.TestMarshalizer.Marshal(header) + headerBytes, _ := integrationTests.TestMarshalizer.Marshal(proposeBlockData.Header) headerHash := integrationTests.TestHasher.Compute(string(headerBytes)) // all nodes in metachain have the block header in pool as interceptor validates it @@ -294,8 +295,11 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted func fillHeaderFields(proposer *integrationTests.TestProcessorNode, hdr data.HeaderHandler, signer crypto.SingleSigner) (data.HeaderHandler, error) { leaderSk := proposer.NodeKeys.MainKey.Sk - randSeed, _ := signer.Sign(leaderSk, hdr.GetPrevRandSeed()) - err := hdr.SetRandSeed(randSeed) + randSeed, err := signer.Sign(leaderSk, hdr.GetPrevRandSeed()) + if err != nil { + return nil, err + } + err = hdr.SetRandSeed(randSeed) if err != nil { return nil, err } diff --git a/integrationTests/multiShard/endOfEpoch/common.go b/integrationTests/multiShard/endOfEpoch/common.go index c416479849d..4d3a6673703 100644 --- a/integrationTests/multiShard/endOfEpoch/common.go +++ b/integrationTests/multiShard/endOfEpoch/common.go @@ -6,9 +6,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/stretchr/testify/assert" ) // CreateAndPropagateBlocks - @@ -18,12 +19,12 @@ func CreateAndPropagateBlocks( currentRound uint64, currentNonce uint64, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, ) (uint64, uint64) { for i := uint64(0); i <= nbRounds; i++ { integrationTests.UpdateRound(nodes, currentRound) - integrationTests.ProposeBlock(nodes, idxProposers, currentRound, currentNonce) - integrationTests.SyncBlock(t, nodes, idxProposers, currentRound) + integrationTests.ProposeBlock(nodes, leaders, currentRound, currentNonce) + integrationTests.SyncBlock(t, nodes, leaders, currentRound) currentRound = integrationTests.IncrementAndPrintRound(currentRound) currentNonce++ } diff --git a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go index a2b5846a759..a3d08fbd755 100644 --- a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go @@ -58,16 +58,14 @@ func TestEpochChangeWithNodesShuffling(t *testing.T) { nonce := uint64(1) nbBlocksToProduce := uint64(20) expectedLastEpoch := uint32(nbBlocksToProduce / roundsPerEpoch) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go index 9c81ff6e97e..59c0abc5156 100644 --- a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" "github.com/multiversx/mx-chain-go/process/rating" - logger "github.com/multiversx/mx-chain-logger-go" ) func TestEpochChangeWithNodesShufflingAndRater(t *testing.T) { @@ -68,16 +69,14 @@ func TestEpochChangeWithNodesShufflingAndRater(t *testing.T) { nonce := uint64(1) nbBlocksToProduce := uint64(20) expectedLastEpoch := uint32(nbBlocksToProduce / roundsPerEpoch) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go index dd964aeb745..92af5c08c28 100644 --- a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go @@ -26,6 +26,8 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -40,11 +42,11 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -71,8 +73,8 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * nrRoundsToPropagateMultiShard := uint64(5) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go index d14eb086de6..e8f6607112f 100644 --- a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go @@ -25,6 +25,8 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -39,11 +41,11 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -61,10 +63,10 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. time.Sleep(time.Second) // ----- wait for epoch end period - round, nonce = endOfEpoch.CreateAndPropagateBlocks(t, roundsPerEpoch, round, nonce, nodes, idxProposers) + round, nonce = endOfEpoch.CreateAndPropagateBlocks(t, roundsPerEpoch, round, nonce, nodes, leaders) nrRoundsToPropagateMultiShard := uint64(5) - _, _ = endOfEpoch.CreateAndPropagateBlocks(t, nrRoundsToPropagateMultiShard, round, nonce, nodes, idxProposers) + _, _ = endOfEpoch.CreateAndPropagateBlocks(t, nrRoundsToPropagateMultiShard, round, nonce, nodes, leaders) epoch := uint32(1) endOfEpoch.VerifyThatNodesHaveCorrectEpoch(t, epoch, nodes) diff --git a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go index 2ee087799e5..27c963a9747 100644 --- a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go +++ b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go @@ -11,6 +11,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" @@ -24,6 +26,7 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" "github.com/multiversx/mx-chain-go/process/block/pendingMb" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/process/smartContract" "github.com/multiversx/mx-chain-go/process/sync/storageBootstrap" "github.com/multiversx/mx-chain-go/sharding" @@ -32,6 +35,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" @@ -41,7 +45,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/scheduledDataSyncer" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func TestStartInEpochForAShardNodeInMultiShardedEnvironment(t *testing.T) { @@ -73,6 +76,8 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -87,11 +92,11 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * numNodesPerShard + leaders[i] = nodes[i*numNodesPerShard] } - idxProposers[numOfShards] = numOfShards * numNodesPerShard + leaders[numOfShards] = nodes[numOfShards*numNodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -118,8 +123,8 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui nrRoundsToPropagateMultiShard := uint64(5) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -241,6 +246,10 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui &marshallerMock.MarshalizerMock{}, 444, ) + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 5, + CacheExpiry: time.Second * 10, + } argsBootstrapHandler := bootstrap.ArgsEpochStartBootstrap{ NodesCoordinatorRegistryFactory: nodesCoordinatorRegistryFactory, CryptoComponentsHolder: cryptoComponents, @@ -279,8 +288,10 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui FlagsConfig: config.ContextFlagsConfig{ ForceStartFromNetwork: false, }, - TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, - StateStatsHandler: disabled.NewStateStatistics(), + TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + StateStatsHandler: disabled.NewStateStatistics(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } epochStartBootstrap, err := bootstrap.NewEpochStartBootstrap(argsBootstrapHandler) diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 61dbada5251..7da61a4dcc3 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -12,6 +12,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -20,6 +25,7 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" vmFactory "github.com/multiversx/mx-chain-go/process/factory" + interceptorFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" @@ -31,10 +37,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/update/factory" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationTests/hardfork") @@ -64,11 +66,11 @@ func TestHardForkWithoutTransactionInMultiShardedEnvironment(t *testing.T) { node.WaitTime = 100 * time.Millisecond } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -89,11 +91,11 @@ func TestHardForkWithoutTransactionInMultiShardedEnvironment(t *testing.T) { nrRoundsToPropagateMultiShard := 5 // ----- wait for epoch end period - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, int(roundsPerEpoch), nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, int(roundsPerEpoch), nonce, round) time.Sleep(time.Second) - nonce, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -135,11 +137,11 @@ func TestHardForkWithContinuousTransactionsInMultiShardedEnvironment(t *testing. node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -189,7 +191,7 @@ func TestHardForkWithContinuousTransactionsInMultiShardedEnvironment(t *testing. epoch := uint32(2) nrRoundsToPropagateMultiShard := uint64(6) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress1, "", integrationTests.AdditionalGasLimit) @@ -253,11 +255,11 @@ func TestHardForkEarlyEndOfEpochWithContinuousTransactionsInMultiShardedEnvironm node.EpochStartTrigger.SetMinRoundsBetweenEpochs(minRoundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = allNodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = allNodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(allNodes) @@ -310,7 +312,7 @@ func TestHardForkEarlyEndOfEpochWithContinuousTransactionsInMultiShardedEnvironm log.LogIfError(err) } - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, consensusNodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, consensusNodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(consensusNodes) for _, node := range consensusNodes { integrationTests.CreateAndSendTransaction(node, allNodes, sendValue, receiverAddress1, "", integrationTests.AdditionalGasLimit) @@ -600,6 +602,11 @@ func createHardForkExporter( networkComponents.PeersRatingHandlerField = node.PeersRatingHandler networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} networkComponents.OutputAntiFlood = &mock.NilAntifloodHandler{} + + interceptorDataVerifierFactoryArgs := interceptorFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 5, + CacheExpiry: time.Second * 10, + } argsExportHandler := factory.ArgsExporter{ CoreComponents: coreComponents, CryptoComponents: cryptoComponents, @@ -649,11 +656,12 @@ func createHardForkExporter( NumResolveFailureThreshold: 3, DebugLineExpiration: 3, }, - MaxHardCapForMissingNodes: 500, - NumConcurrentTrieSyncers: 50, - TrieSyncerVersion: 2, - CheckNodesOnDisk: false, - NodeOperationMode: node.NodeOperationMode, + MaxHardCapForMissingNodes: 500, + NumConcurrentTrieSyncers: 50, + TrieSyncerVersion: 2, + CheckNodesOnDisk: false, + NodeOperationMode: node.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierFactoryArgs), } exportHandler, err := factory.NewExportHandlerFactory(argsExportHandler) diff --git a/integrationTests/multiShard/relayedTx/common.go b/integrationTests/multiShard/relayedTx/common.go index 33a5cedcc53..dec175abb73 100644 --- a/integrationTests/multiShard/relayedTx/common.go +++ b/integrationTests/multiShard/relayedTx/common.go @@ -8,13 +8,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" ) // CreateGeneralSetupForRelayTxTest will create the general setup for relayed transactions -func CreateGeneralSetupForRelayTxTest() ([]*integrationTests.TestProcessorNode, []int, []*integrationTests.TestWalletAccount, *integrationTests.TestWalletAccount) { +func CreateGeneralSetupForRelayTxTest() ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, *integrationTests.TestWalletAccount) { numOfShards := 2 nodesPerShard := 2 numMetachainNodes := 1 @@ -25,11 +26,11 @@ func CreateGeneralSetupForRelayTxTest() ([]*integrationTests.TestProcessorNode, numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -47,7 +48,7 @@ func CreateGeneralSetupForRelayTxTest() ([]*integrationTests.TestProcessorNode, relayerAccount := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, 0) integrationTests.MintAllPlayers(nodes, []*integrationTests.TestWalletAccount{relayerAccount}, initialVal) - return nodes, idxProposers, players, relayerAccount + return nodes, leaders, players, relayerAccount } // CreateAndSendRelayedAndUserTx will create and send a relayed user transaction diff --git a/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go b/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go index 246a81fbe15..405c83d41c4 100644 --- a/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go +++ b/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go @@ -6,9 +6,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/relayedTx" - "github.com/stretchr/testify/assert" ) func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShouldNotIncrementUserAccNonce(t *testing.T) { @@ -16,7 +17,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul t.Skip("this is not a short test") } - nodes, idxProposers, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest() + nodes, leaders, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest() defer func() { for _, n := range nodes { n.Close() @@ -46,7 +47,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul totalFees.Add(totalFees, totalFee) } - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -54,7 +55,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul roundToPropagateMultiShard := int64(20) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) } diff --git a/integrationTests/multiShard/relayedTx/relayedTxV2_test.go b/integrationTests/multiShard/relayedTx/relayedTxV2_test.go index 2795646c359..511bb80f638 100644 --- a/integrationTests/multiShard/relayedTx/relayedTxV2_test.go +++ b/integrationTests/multiShard/relayedTx/relayedTxV2_test.go @@ -7,10 +7,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" vmFactory "github.com/multiversx/mx-chain-go/process/factory" - "github.com/stretchr/testify/assert" ) func TestRelayedTransactionV2InMultiShardEnvironmentWithSmartContractTX(t *testing.T) { @@ -18,7 +19,7 @@ func TestRelayedTransactionV2InMultiShardEnvironmentWithSmartContractTX(t *testi t.Skip("this is not a short test") } - nodes, idxProposers, players, relayer := CreateGeneralSetupForRelayTxTest() + nodes, leaders, players, relayer := CreateGeneralSetupForRelayTxTest() defer func() { for _, n := range nodes { n.Close() @@ -69,13 +70,13 @@ func TestRelayedTransactionV2InMultiShardEnvironmentWithSmartContractTX(t *testi roundToPropagateMultiShard := int64(20) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) } nrRoundsToTest := int64(5) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) for _, player := range players { @@ -89,7 +90,7 @@ func TestRelayedTransactionV2InMultiShardEnvironmentWithSmartContractTX(t *testi } for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) } diff --git a/integrationTests/multiShard/relayedTx/relayedTx_test.go b/integrationTests/multiShard/relayedTx/relayedTx_test.go index 43f713d5d09..412ae4b1dd9 100644 --- a/integrationTests/multiShard/relayedTx/relayedTx_test.go +++ b/integrationTests/multiShard/relayedTx/relayedTx_test.go @@ -9,6 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" "github.com/multiversx/mx-chain-go/process" @@ -16,9 +20,6 @@ import ( "github.com/multiversx/mx-chain-go/process/smartContract/hooks" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestRelayedTransactionInMultiShardEnvironmentWithNormalTx(t *testing.T) { @@ -377,7 +378,7 @@ func checkSCBalance(t *testing.T, node *integrationTests.TestProcessorNode, scAd }) assert.Nil(t, err) actualBalance := big.NewInt(0).SetBytes(vmOutput.ReturnData[0]) - assert.Equal(t, 0, actualBalance.Cmp(balance)) + assert.Equal(t, balance, actualBalance) } func checkPlayerBalances( diff --git a/integrationTests/multiShard/smartContract/dns/dns_test.go b/integrationTests/multiShard/smartContract/dns/dns_test.go index 20135a2bda4..98dc1a1d674 100644 --- a/integrationTests/multiShard/smartContract/dns/dns_test.go +++ b/integrationTests/multiShard/smartContract/dns/dns_test.go @@ -12,14 +12,15 @@ import ( "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/hashing/keccak" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/genesis" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/relayedTx" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestSCCallingDNSUserNames(t *testing.T) { @@ -27,7 +28,7 @@ func TestSCCallingDNSUserNames(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -45,7 +46,7 @@ func TestSCCallingDNSUserNames(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 25 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) } @@ -55,7 +56,7 @@ func TestSCCallingDNSUserNamesTwice(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -73,12 +74,12 @@ func TestSCCallingDNSUserNamesTwice(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) newUserNames := sendRegisterUserNameTxForPlayers(players, nodes, sortedDNSAddresses, dnsRegisterValue) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) checkUserNamesAreDeleted(t, nodes, newUserNames, sortedDNSAddresses) @@ -89,7 +90,7 @@ func TestDNSandRelayedTxNormal(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -108,7 +109,7 @@ func TestDNSandRelayedTxNormal(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 30 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) } @@ -122,7 +123,7 @@ func createAndMintRelayer(nodes []*integrationTests.TestProcessorNode) *integrat return relayer } -func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, []int) { +func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, []*integrationTests.TestProcessorNode) { numOfShards := 2 nodesPerShard := 1 numMetachainNodes := 1 @@ -143,11 +144,11 @@ func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integra node.EconomicsData.SetMaxGasLimitPerBlock(1500000000, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -163,7 +164,7 @@ func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integra integrationTests.MintAllNodes(nodes, initialVal) integrationTests.MintAllPlayers(nodes, players, initialVal) - return nodes, players, idxProposers + return nodes, players, leaders } func getDNSContractsData(node *integrationTests.TestProcessorNode) (*big.Int, []string) { diff --git a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go index b74acc3b392..0f9d559cf3b 100644 --- a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go +++ b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go @@ -7,6 +7,9 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" @@ -14,8 +17,6 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestBridgeSetupAndBurn(t *testing.T) { @@ -31,6 +32,8 @@ func TestBridgeSetupAndBurn(t *testing.T) { GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, SCProcessorV2EnableEpoch: integrationTests.UnreachableEpoch, FixAsyncCallBackArgsListEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } arwenVersion := config.WasmVMVersionByEpoch{Version: "v1.4"} vmConfig := &config.VirtualMachineConfig{ @@ -48,11 +51,11 @@ func TestBridgeSetupAndBurn(t *testing.T) { ownerNode := nodes[0] shard := nodes[0:nodesPerShard] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -73,7 +76,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { nonce++ tokenManagerPath := "../testdata/polynetworkbridge/esdt_token_manager.wasm" - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) blockChainHook := ownerNode.BlockchainHook scAddressBytes, _ := blockChainHook.NewAddress( @@ -100,7 +103,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { deploymentData, 100000, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) txValue := big.NewInt(1000) txData := "performWrappedEgldIssue@05" @@ -112,7 +115,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { txData, integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 8, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 8, nonce, round) scQuery := &process.SCQuery{ CallerAddr: ownerNode.OwnAccount.Address, @@ -140,7 +143,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { integrationTests.AdditionalGasLimit, ) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) checkBurnedOnESDTContract(t, nodes, tokenIdentifier, valueToBurn) } diff --git a/integrationTests/multiShard/smartContract/scCallingSC_test.go b/integrationTests/multiShard/smartContract/scCallingSC_test.go index 52b24371d15..74307489b9c 100644 --- a/integrationTests/multiShard/smartContract/scCallingSC_test.go +++ b/integrationTests/multiShard/smartContract/scCallingSC_test.go @@ -16,16 +16,17 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" vmData "github.com/multiversx/mx-chain-core-go/data/vm" + logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" systemVm "github.com/multiversx/mx-chain-go/vm" - logger "github.com/multiversx/mx-chain-logger-go" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationtests/multishard/smartcontract") @@ -45,11 +46,10 @@ func TestSCCallingIntraShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard integrationTests.DisplayAndStartNodes(nodes) @@ -86,7 +86,7 @@ func TestSCCallingIntraShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -96,10 +96,10 @@ func TestSCCallingIntraShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 // Run two rounds, so the two SmartContracts get deployed. - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) @@ -113,7 +113,7 @@ func TestSCCallingIntraShard(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) // verify how many times was the first SC called for index, node := range nodes { @@ -142,11 +142,11 @@ func TestScDeployAndChangeScOwner(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numShards+1) for i := 0; i < numShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numShards] = numShards * nodesPerShard + leaders[numShards] = nodes[numShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -176,8 +176,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { nonce := uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -195,8 +195,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { for i := 0; i < numRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -222,8 +222,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { for i := 0; i < numRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -252,11 +252,11 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numShards+1) for i := 0; i < numShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numShards] = numShards * nodesPerShard + leaders[numShards] = nodes[numShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -289,8 +289,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { nonce := uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -308,8 +308,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { for i := 0; i < 5; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -346,8 +346,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { for i := 0; i < 3; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -381,11 +381,11 @@ func TestSCCallingInCrossShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -422,7 +422,7 @@ func TestSCCallingInCrossShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -432,9 +432,9 @@ func TestSCCallingInCrossShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // make smart contract call to shard 1 which will do in shard 0 for _, node := range nodes { @@ -452,7 +452,7 @@ func TestSCCallingInCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) // verify how many times was shard 0 and shard 1 called shId := nodes[0].ShardCoordinator.ComputeId(firstSCAddress) @@ -518,11 +518,11 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -557,7 +557,7 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) sender := nodes[0] receiver := nodes[1] @@ -576,7 +576,7 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) testValue1 := vm.GetIntValueFromSC(nil, sender.AccntState, scAddress, "testValue1", nil) require.NotNil(t, testValue1) require.Equal(t, uint64(255), testValue1.Uint64()) @@ -606,18 +606,16 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := getNodeIndex(nodes, nds[0]) - assert.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -652,7 +650,7 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // one node calls to stake all the money from the delegation - that's how the contract is :D node := nodes[0] @@ -665,7 +663,7 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // verify system smart contract has the value @@ -707,18 +705,16 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := getNodeIndex(nodes, nds[0]) - assert.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -761,7 +757,7 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { ) shardNode.OwnAccount.Nonce++ - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // check that the version is the expected one scQueryVersion := &process.SCQuery{ @@ -775,13 +771,13 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { require.True(t, bytes.Contains(vmOutputVersion.ReturnData[0], []byte("0.3."))) log.Info("SC deployed", "version", string(vmOutputVersion.ReturnData[0])) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // set stake per node setStakePerNodeTxData := "setStakePerNode@" + core.ConvertToEvenHexBigInt(nodePrice) integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, setStakePerNodeTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // add node addNodesTxData := fmt.Sprintf("addNodes@%s@%s", @@ -789,25 +785,25 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { hex.EncodeToString(stakerBLSSignature)) integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, addNodesTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // stake some coin! // here the node account fills all the required stake stakeTxData := "stake" integrationTests.CreateAndSendTransaction(shardNode, nodes, totalStake, delegateSCAddress, stakeTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // activate the delegation, this involves an async call to validatorSC stakeAllAvailableTxData := "stakeAllAvailable" integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, stakeAllAvailableTxData, 2*integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -890,11 +886,10 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard integrationTests.DisplayAndStartNodes(nodes) @@ -931,7 +926,7 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -941,10 +936,10 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 // Run two rounds, so the two SmartContracts get deployed. - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) @@ -958,7 +953,7 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) for _, node := range nodes { assert.Equal(t, uint64(5), node.BlockChain.GetCurrentBlockHeader().GetNonce()) diff --git a/integrationTests/multiShard/softfork/scDeploy_test.go b/integrationTests/multiShard/softfork/scDeploy_test.go index 8af125f5797..5b4252b7806 100644 --- a/integrationTests/multiShard/softfork/scDeploy_test.go +++ b/integrationTests/multiShard/softfork/scDeploy_test.go @@ -11,12 +11,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/process/factory" - "github.com/multiversx/mx-chain-go/state" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/process/factory" + "github.com/multiversx/mx-chain-go/state" ) var log = logger.GetOrCreate("integrationtests/singleshard/block/softfork") @@ -67,7 +68,7 @@ func TestScDeploy(t *testing.T) { } integrationTests.ConnectNodes(connectableNodes) - idxProposers := []int{0, 1} + leaders := []*integrationTests.TestProcessorNode{nodes[0], nodes[1]} defer func() { for _, n := range nodes { @@ -93,7 +94,7 @@ func TestScDeploy(t *testing.T) { for i := uint64(0); i < numRounds; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -108,7 +109,7 @@ func TestScDeploy(t *testing.T) { deploySucceeded := deploySc(t, nodes) for i := uint64(0); i < 5; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) round = integrationTests.IncrementAndPrintRound(round) nonce++ diff --git a/integrationTests/multiShard/txScenarios/builtinFunctions_test.go b/integrationTests/multiShard/txScenarios/builtinFunctions_test.go index 1064239cbb0..0285cd0f5fd 100644 --- a/integrationTests/multiShard/txScenarios/builtinFunctions_test.go +++ b/integrationTests/multiShard/txScenarios/builtinFunctions_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/stretchr/testify/assert" ) func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { @@ -19,7 +20,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { } initialBalance := big.NewInt(1000000000000) - nodes, idxProposers, players := createGeneralSetupForTxTest(initialBalance) + nodes, leaders, players := createGeneralSetupForTxTest(initialBalance) defer func() { for _, n := range nodes { n.Close() @@ -50,7 +51,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { nrRoundsToTest := int64(5) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -74,7 +75,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { time.Sleep(time.Millisecond) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -103,7 +104,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { time.Sleep(time.Millisecond) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) } diff --git a/integrationTests/multiShard/txScenarios/common.go b/integrationTests/multiShard/txScenarios/common.go index d720b9d8df5..c5e65d772cf 100644 --- a/integrationTests/multiShard/txScenarios/common.go +++ b/integrationTests/multiShard/txScenarios/common.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/state" @@ -29,7 +30,7 @@ func createGeneralTestnetForTxTest( func createGeneralSetupForTxTest(initialBalance *big.Int) ( []*integrationTests.TestProcessorNode, - []int, + []*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, ) { numOfShards := 2 @@ -40,6 +41,8 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -49,11 +52,11 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -68,7 +71,7 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( integrationTests.MintAllPlayers(nodes, players, initialBalance) - return nodes, idxProposers, players + return nodes, leaders, players } func createAndSendTransaction( diff --git a/integrationTests/multiShard/txScenarios/moveBalance_test.go b/integrationTests/multiShard/txScenarios/moveBalance_test.go index 5df383f7ebb..8599e5a45db 100644 --- a/integrationTests/multiShard/txScenarios/moveBalance_test.go +++ b/integrationTests/multiShard/txScenarios/moveBalance_test.go @@ -6,9 +6,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/pubkeyConverter" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/assert" ) func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { @@ -17,7 +18,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { } initialBalance := big.NewInt(1000000000000) - nodes, idxProposers, players := createGeneralSetupForTxTest(initialBalance) + nodes, leaders, players := createGeneralSetupForTxTest(initialBalance) defer func() { for _, n := range nodes { n.Close() @@ -65,7 +66,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { nrRoundsToTest := int64(7) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(integrationTests.StepDelay) @@ -80,7 +81,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { assert.Equal(t, players[2].Nonce, senderAccount.GetNonce()) assert.Equal(t, expectedBalance, senderAccount.GetBalance()) - //check balance intra shard tx insufficient gas limit + // check balance intra shard tx insufficient gas limit senderAccount = getUserAccount(nodes, players[1].Address) assert.Equal(t, uint64(0), senderAccount.GetNonce()) assert.Equal(t, initialBalance, senderAccount.GetBalance()) @@ -116,7 +117,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { roundToPropagateMultiShard := int64(15) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(integrationTests.StepDelay) } diff --git a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go index b28c5dc054e..06e6d8892c7 100644 --- a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go +++ b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go @@ -8,13 +8,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { @@ -34,11 +35,11 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { stakingWalletAccount := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -71,7 +72,7 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -87,7 +88,7 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, "makeNewContractFromValidatorData", big.NewInt(0), []byte{10}, @@ -124,11 +125,11 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { stakingWalletAccount := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -162,7 +163,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -182,7 +183,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, "createNewDelegationContract", big.NewInt(10000), []byte{0}, @@ -206,7 +207,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) @@ -258,11 +259,11 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { stakingWalletAccount1 := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) stakingWalletAccount2 := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -296,7 +297,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { t, nodes, stakingWalletAccount1, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -312,7 +313,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { t, nodes, stakingWalletAccount2, - idxProposers, + leaders, "createNewDelegationContract", big.NewInt(10000), []byte{0}, @@ -335,7 +336,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 5, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 5, nonce, round) txData = txDataBuilder.NewBuilder().Clear(). Func("mergeValidatorToDelegationWithWhitelist"). @@ -352,7 +353,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) testBLSKeyOwnerIsAddress(t, nodes, scAddressBytes, frontendBLSPubkey) @@ -378,7 +379,7 @@ func generateSendAndWaitToExecuteStakeTransaction( t *testing.T, nodes []*integrationTests.TestProcessorNode, stakingWalletAccount *integrationTests.TestWalletAccount, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nodePrice *big.Int, frontendBLSPubkey []byte, frontendHexSignature string, @@ -398,7 +399,7 @@ func generateSendAndWaitToExecuteStakeTransaction( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 6 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) return nonce, round } @@ -407,7 +408,7 @@ func generateSendAndWaitToExecuteTransaction( t *testing.T, nodes []*integrationTests.TestProcessorNode, stakingWalletAccount *integrationTests.TestWalletAccount, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, function string, value *big.Int, serviceFee []byte, @@ -431,7 +432,7 @@ func generateSendAndWaitToExecuteTransaction( time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) return nonce, round } diff --git a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go index 8d6c00af8ae..6685b5b1433 100644 --- a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go +++ b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go @@ -11,12 +11,13 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/integrationTests" testBlock "github.com/multiversx/mx-chain-go/integrationTests/singleShard/block" "github.com/multiversx/mx-chain-go/process" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) // TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound tests that a shard can not continue building on a @@ -43,6 +44,7 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 + leader := nodes[idxProposer] defer func() { for _, n := range nodes { @@ -57,24 +59,24 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { nonce := uint64(1) round = integrationTests.IncrementAndPrintRound(round) - err := proposeAndCommitBlock(nodes[idxProposer], round, nonce) + err := proposeAndCommitBlock(leader, round, nonce) assert.Nil(t, err) - integrationTests.SyncBlock(t, nodes, []int{idxProposer}, nonce) + integrationTests.SyncBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, nonce) time.Sleep(testBlock.StepDelay) checkCurrentBlockHeight(t, nodes, nonce) - //only nonce increases, round stays the same + // only nonce increases, round stays the same nonce++ err = proposeAndCommitBlock(nodes[idxProposer], round, nonce) assert.Equal(t, process.ErrLowerRoundInBlock, err) - //mockTestingT is used as in normal case SyncBlock would fail as it doesn't find the header with nonce 2 + // mockTestingT is used as in normal case SyncBlock would fail as it doesn't find the header with nonce 2 mockTestingT := &testing.T{} - integrationTests.SyncBlock(mockTestingT, nodes, []int{idxProposer}, nonce) + integrationTests.SyncBlock(mockTestingT, nodes, []*integrationTests.TestProcessorNode{leader}, nonce) time.Sleep(testBlock.StepDelay) @@ -82,12 +84,12 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { } // TestShardShouldProposeBlockContainingInvalidTransactions tests the following scenario: -// 1. generate 3 move balance transactions: one that can be executed, one that can not be executed but the account has -// the balance for the fee and one that is completely invalid (no balance left for it) -// 2. proposer will have those 3 transactions in its pools and will propose a block -// 3. another node will be able to sync the proposed block (and request - receive) the 2 transactions that -// will end up in the block (one valid and one invalid) -// 4. the non-executable transaction will be removed from the proposer's pool +// 1. generate 3 move balance transactions: one that can be executed, one that can not be executed but the account has +// the balance for the fee and one that is completely invalid (no balance left for it) +// 2. proposer will have those 3 transactions in its pools and will propose a block +// 3. another node will be able to sync the proposed block (and request - receive) the 2 transactions that +// will end up in the block (one valid and one invalid) +// 4. the non-executable transaction will be removed from the proposer's pool func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") @@ -110,7 +112,7 @@ func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 - proposer := nodes[idxProposer] + leader := nodes[idxProposer] defer func() { for _, n := range nodes { @@ -128,10 +130,10 @@ func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { transferValue := uint64(1000000) mintAllNodes(nodes, transferValue) - txs, hashes := generateTransferTxs(transferValue, proposer.OwnAccount.SkTxSign, nodes[1].OwnAccount.PkTxSign) - addTxsInDataPool(proposer, txs, hashes) + txs, hashes := generateTransferTxs(transferValue, leader.OwnAccount.SkTxSign, nodes[1].OwnAccount.PkTxSign) + addTxsInDataPool(leader, txs, hashes) - _, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + _, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) fmt.Println(integrationTests.MakeDisplayTable(nodes)) diff --git a/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go b/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go index 81bf80dca55..238503d006a 100644 --- a/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go +++ b/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go @@ -9,10 +9,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/singleShard/block" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/stretchr/testify/assert" ) func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { @@ -40,10 +41,11 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 + leader := nodes[idxProposer] numPlayers := 10 players := make([]*integrationTests.TestWalletAccount, numPlayers) for i := 0; i < numPlayers; i++ { - players[i] = integrationTests.CreateTestWalletAccount(nodes[idxProposer].ShardCoordinator, 0) + players[i] = integrationTests.CreateTestWalletAccount(leader.ShardCoordinator, 0) } defer func() { @@ -62,7 +64,7 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { hardCodedSk, _ := hex.DecodeString("5561d28b0d89fa425bbbf9e49a018b5d1e4a462c03d2efce60faf9ddece2af06") hardCodedScResultingAddress, _ := hex.DecodeString("000000000000000005006c560111a94e434413c1cdaafbc3e1348947d1d5b3a1") - nodes[idxProposer].LoadTxSignSkBytes(hardCodedSk) + leader.LoadTxSignSkBytes(hardCodedSk) initialVal := big.NewInt(100000000000) integrationTests.MintAllNodes(nodes, initialVal) @@ -70,11 +72,11 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { integrationTests.DeployScTx(nodes, idxProposer, hex.EncodeToString(scCode), factory.WasmVirtualMachine, "001000000000") time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) playersDoTopUp(nodes[idxProposer], players, hardCodedScResultingAddress, big.NewInt(10000000)) time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) for i := 0; i < 100; i++ { playersDoTransfer(nodes[idxProposer], players, hardCodedScResultingAddress, big.NewInt(100)) @@ -82,7 +84,7 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { for i := 0; i < 10; i++ { time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) } integrationTests.CheckRootHashes(t, nodes, []int{idxProposer}) diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index 12ec5115d28..68958b7f206 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -24,6 +24,10 @@ import ( dataTx "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/hashing/sha256" crypto "github.com/multiversx/mx-chain-crypto-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" @@ -49,9 +53,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const denomination = "000000000000000000" @@ -1299,7 +1300,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) numNodesPerShard := 1 numNodesMeta := 1 - nodes, idxProposers := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) + nodes, leaders := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) defer integrationTests.CloseProcessorNodes(nodes) integrationTests.BootstrapDelay() @@ -1331,7 +1332,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) round = integrationTests.IncrementAndPrintRound(round) nonce++ - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) rootHashOfFirstBlock, _ := shardNode.AccntState.RootHash() @@ -1340,7 +1341,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) delayRounds := 10 for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } fmt.Println("Generating transactions...") @@ -1357,7 +1358,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) fmt.Println("Delaying for disseminating transactions...") time.Sleep(time.Second * 5) - round, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) time.Sleep(time.Second * 5) rootHashOfRollbackedBlock, _ := shardNode.AccntState.RootHash() @@ -1390,7 +1391,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, numOfRounds, ) @@ -1559,11 +1560,11 @@ func TestStatePruningIsNotBuffered(t *testing.T) { ) shardNode := nodes[0] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1583,21 +1584,21 @@ func TestStatePruningIsNotBuffered(t *testing.T) { time.Sleep(integrationTests.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) delayRounds := 5 for j := 0; j < 8; j++ { // alter the shardNode's state by placing the value0 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value0")) for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } checkTrieCanBeRecreated(t, shardNode) // alter the shardNode's state by placing the value1 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value1")) for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } checkTrieCanBeRecreated(t, shardNode) } @@ -1619,11 +1620,11 @@ func TestStatePruningIsNotBufferedOnConsecutiveBlocks(t *testing.T) { ) shardNode := nodes[0] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1643,17 +1644,17 @@ func TestStatePruningIsNotBufferedOnConsecutiveBlocks(t *testing.T) { time.Sleep(integrationTests.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for j := 0; j < 30; j++ { // alter the shardNode's state by placing the value0 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value0")) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) checkTrieCanBeRecreated(t, shardNode) // alter the shardNode's state by placing the value1 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value1")) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) checkTrieCanBeRecreated(t, shardNode) } } @@ -1733,11 +1734,11 @@ func TestSnapshotOnEpochChange(t *testing.T) { node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1767,7 +1768,7 @@ func TestSnapshotOnEpochChange(t *testing.T) { numRounds := uint32(20) for i := uint64(0); i < uint64(numRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress, "", integrationTests.AdditionalGasLimit) @@ -1786,7 +1787,7 @@ func TestSnapshotOnEpochChange(t *testing.T) { numDelayRounds := uint32(15) for i := uint64(0); i < uint64(numDelayRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress, "", integrationTests.AdditionalGasLimit) @@ -2455,7 +2456,7 @@ func migrateDataTrieBuiltInFunc( migrationAddress []byte, nonce uint64, round uint64, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, ) { require.True(t, nodes[shardId].EnableEpochsHandler.IsFlagEnabled(common.AutoBalanceDataTriesFlag)) isMigrated := getAddressMigrationStatus(t, nodes[shardId].AccntState, migrationAddress) @@ -2465,7 +2466,7 @@ func migrateDataTrieBuiltInFunc( time.Sleep(time.Second) nrRoundsToPropagate := 5 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagate, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagate, nonce, round) isMigrated = getAddressMigrationStatus(t, nodes[shardId].AccntState, migrationAddress) require.True(t, isMigrated) @@ -2475,7 +2476,7 @@ func startNodesAndIssueToken( t *testing.T, numOfShards int, issuerShardId byte, -) ([]*integrationTests.TestProcessorNode, []int, uint64, uint64) { +) (leaders []*integrationTests.TestProcessorNode, nodes []*integrationTests.TestProcessorNode, nonce uint64, round uint64) { nodesPerShard := 1 numMetachainNodes := 1 @@ -2489,9 +2490,11 @@ func startNodesAndIssueToken( StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, AutoBalanceDataTriesEnableEpoch: 1, } - nodes := integrationTests.CreateNodesWithEnableEpochs( + nodes = integrationTests.CreateNodesWithEnableEpochs( numOfShards, nodesPerShard, numMetachainNodes, @@ -2503,19 +2506,19 @@ func startNodesAndIssueToken( node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders = make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) initialVal := int64(10000000000) integrationTests.MintAllNodes(nodes, big.NewInt(initialVal)) - round := uint64(0) - nonce := uint64(0) + round = uint64(0) + nonce = uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -2526,14 +2529,14 @@ func startNodesAndIssueToken( time.Sleep(time.Second) nrRoundsToPropagate := 8 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagate, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagate, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) esdtCommon.CheckAddressHasTokens(t, nodes[issuerShardId].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) - return nodes, idxProposers, nonce, round + return nodes, leaders, nonce, round } func getDestAccountAddress(migrationAddress []byte, shardId byte) []byte { diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 74650d4ce11..7ccc5255cb0 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -10,6 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/throttler" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" @@ -28,10 +33,6 @@ import ( "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationtests/state/statetriesync") @@ -449,11 +450,11 @@ func testSyncMissingSnapshotNodes(t *testing.T, version int) { node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -476,7 +477,7 @@ func testSyncMissingSnapshotNodes(t *testing.T, version int) { nonce++ numDelayRounds := uint32(10) for i := uint64(0); i < uint64(numDelayRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) time.Sleep(integrationTests.StepDelay) } diff --git a/integrationTests/sync/basicSync/basicSync_test.go b/integrationTests/sync/basicSync/basicSync_test.go index 52cc2c7af79..1dfb82dcf80 100644 --- a/integrationTests/sync/basicSync/basicSync_test.go +++ b/integrationTests/sync/basicSync/basicSync_test.go @@ -8,9 +8,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/integrationTests" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/integrationTests" ) var log = logger.GetOrCreate("basicSync") @@ -19,7 +20,6 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - maxShards := uint32(1) shardId := uint32(0) numNodesPerShard := 6 @@ -47,7 +47,7 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { connectableNodes = append(connectableNodes, metachainNode) idxProposerShard0 := 0 - idxProposers := []int{idxProposerShard0, idxProposerMeta} + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0], nodes[idxProposerMeta]} integrationTests.ConnectNodes(connectableNodes) @@ -72,7 +72,7 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { numRoundsToTest := 5 for i := 0; i < numRoundsToTest; i++ { - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) time.Sleep(integrationTests.SyncDelay) @@ -110,7 +110,7 @@ func TestSyncWorksInShard_EmptyBlocksDoubleSign(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposerShard0 := 0 - idxProposers := []int{idxProposerShard0} + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0]} defer func() { for _, n := range nodes { @@ -133,7 +133,7 @@ func TestSyncWorksInShard_EmptyBlocksDoubleSign(t *testing.T) { numRoundsToTest := 2 for i := 0; i < numRoundsToTest; i++ { - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) time.Sleep(integrationTests.SyncDelay) @@ -197,3 +197,89 @@ func testAllNodesHaveSameLastBlock(t *testing.T, nodes []*integrationTests.TestP assert.Equal(t, 1, len(mapBlocksByHash)) } + +func TestSyncWorksInShard_EmptyBlocksNoForks_With_EquivalentProofs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + maxShards := uint32(1) + shardId := uint32(0) + numNodesPerShard := 3 + + enableEpochs := integrationTests.CreateEnableEpochsConfig() + enableEpochs.EquivalentMessagesEnableEpoch = uint32(0) + + nodes := make([]*integrationTests.TestProcessorNode, numNodesPerShard+1) + connectableNodes := make([]integrationTests.Connectable, 0) + for i := 0; i < numNodesPerShard; i++ { + nodes[i] = integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: shardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + EpochsConfig: &enableEpochs, + }) + connectableNodes = append(connectableNodes, nodes[i]) + } + + metachainNode := integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: core.MetachainShardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + }) + idxProposerMeta := numNodesPerShard + nodes[idxProposerMeta] = metachainNode + connectableNodes = append(connectableNodes, metachainNode) + + idxProposerShard0 := 0 + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0], nodes[idxProposerMeta]} + + integrationTests.ConnectNodes(connectableNodes) + + defer func() { + for _, n := range nodes { + n.Close() + } + }() + + for _, n := range nodes { + _ = n.StartSync() + } + + fmt.Println("Delaying for nodes p2p bootstrap...") + time.Sleep(integrationTests.P2pBootstrapDelay) + + round := uint64(0) + nonce := uint64(0) + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + + numRoundsToTest := 5 + for i := 0; i < numRoundsToTest; i++ { + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + + time.Sleep(integrationTests.SyncDelay) + + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + } + + time.Sleep(integrationTests.SyncDelay) + + expectedNonce := nodes[0].BlockChain.GetCurrentBlockHeader().GetNonce() + for i := 1; i < len(nodes); i++ { + if check.IfNil(nodes[i].BlockChain.GetCurrentBlockHeader()) { + assert.Fail(t, fmt.Sprintf("Node with idx %d does not have a current block", i)) + } else { + if i == idxProposerMeta { // metachain node has highest nonce since it's single node and it did not synced the header + assert.Equal(t, expectedNonce, nodes[i].BlockChain.GetCurrentBlockHeader().GetNonce()) + } else { // shard nodes have not managed to sync last header since there is no proof for it; in the complete flow, when nodes will be fully sinced they will get current header directly from consensus, so they will receive the proof for header + assert.Equal(t, expectedNonce-1, nodes[i].BlockChain.GetCurrentBlockHeader().GetNonce()) + } + } + } +} diff --git a/integrationTests/sync/edgeCases/edgeCases_test.go b/integrationTests/sync/edgeCases/edgeCases_test.go index f3167b0528e..285fed4dd8c 100644 --- a/integrationTests/sync/edgeCases/edgeCases_test.go +++ b/integrationTests/sync/edgeCases/edgeCases_test.go @@ -6,9 +6,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" ) // TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard tests the following scenario: @@ -24,8 +25,8 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { numNodesPerShard := 3 numNodesMeta := 3 - nodes, idxProposers := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) - idxProposerMeta := idxProposers[1] + nodes, leaders := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) + leaderMeta := leaders[1] defer integrationTests.CloseProcessorNodes(nodes) integrationTests.BootstrapDelay() @@ -44,7 +45,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, numRoundsBlocksAreProposedCorrectly, ) @@ -54,14 +55,14 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ResetHighestProbableNonce(nodes, shardIdToRollbackLastBlock, 2) integrationTests.EmptyDataPools(nodes, shardIdToRollbackLastBlock) - //revert also the nonce, so the same block nonce will be used when shard will propose the next block + // revert also the nonce, so the same block nonce will be used when shard will propose the next block atomic.AddUint64(nonces[idxNonceShard], ^uint64(0)) numRoundsBlocksAreProposedOnlyByMeta := 2 integrationTests.ProposeBlocks( nodes, &round, - []int{idxProposerMeta}, + []*integrationTests.TestProcessorNode{leaderMeta}, []*uint64{nonces[idxNonceMeta]}, numRoundsBlocksAreProposedOnlyByMeta, ) @@ -70,7 +71,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, secondNumRoundsBlocksAreProposedCorrectly, ) @@ -99,12 +100,12 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.StartSyncingBlocks(syncNodesSlice) - //after joining the network we must propose a new block on the metachain as to be received by the sync - //node and to start the bootstrapping process + // after joining the network we must propose a new block on the metachain as to be received by the sync + // node and to start the bootstrapping process integrationTests.ProposeBlocks( nodes, &round, - []int{idxProposerMeta}, + []*integrationTests.TestProcessorNode{leaderMeta}, []*uint64{nonces[idxNonceMeta]}, 1, ) @@ -115,7 +116,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { time.Sleep(integrationTests.SyncDelay * time.Duration(numOfRoundsToWaitToCatchUp)) integrationTests.UpdateRound(nodes, round) - nonceProposerMeta := nodes[idxProposerMeta].BlockChain.GetCurrentBlockHeader().GetNonce() + nonceProposerMeta := leaderMeta.BlockChain.GetCurrentBlockHeader().GetNonce() nonceSyncNode := syncMetaNode.BlockChain.GetCurrentBlockHeader().GetNonce() assert.Equal(t, nonceProposerMeta, nonceSyncNode) } diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go index 2e297291423..8651045eb7e 100644 --- a/integrationTests/testConsensusNode.go +++ b/integrationTests/testConsensusNode.go @@ -16,6 +16,7 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus/round" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -40,6 +41,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -64,17 +66,18 @@ var testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) // ArgsTestConsensusNode represents the arguments for the test consensus node constructor(s) type ArgsTestConsensusNode struct { - ShardID uint32 - ConsensusSize int - RoundTime uint64 - ConsensusType string - NodeKeys *TestNodeKeys - EligibleMap map[uint32][]nodesCoordinator.Validator - WaitingMap map[uint32][]nodesCoordinator.Validator - KeyGen crypto.KeyGenerator - P2PKeyGen crypto.KeyGenerator - MultiSigner *cryptoMocks.MultisignerMock - StartTime int64 + ShardID uint32 + ConsensusSize int + RoundTime uint64 + ConsensusType string + NodeKeys *TestNodeKeys + EligibleMap map[uint32][]nodesCoordinator.Validator + WaitingMap map[uint32][]nodesCoordinator.Validator + KeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MultiSigner *cryptoMocks.MultisignerMock + StartTime int64 + EnableEpochsConfig config.EnableEpochs } // TestConsensusNode represents a structure used in integration tests used for consensus tests @@ -115,6 +118,7 @@ func CreateNodesWithTestConsensusNode( roundTime uint64, consensusType string, numKeysOnEachNode int, + enableEpochsConfig config.EnableEpochs, ) map[uint32][]*TestConsensusNode { nodes := make(map[uint32][]*TestConsensusNode, nodesPerShard) @@ -134,17 +138,18 @@ func CreateNodesWithTestConsensusNode( multiSignerMock := createCustomMultiSignerMock(multiSigner) args := ArgsTestConsensusNode{ - ShardID: shardID, - ConsensusSize: consensusSize, - RoundTime: roundTime, - ConsensusType: consensusType, - NodeKeys: keysPair, - EligibleMap: eligibleMap, - WaitingMap: waitingMap, - KeyGen: cp.KeyGen, - P2PKeyGen: cp.P2PKeyGen, - MultiSigner: multiSignerMock, - StartTime: startTime, + ShardID: shardID, + ConsensusSize: consensusSize, + RoundTime: roundTime, + ConsensusType: consensusType, + NodeKeys: keysPair, + EligibleMap: eligibleMap, + WaitingMap: waitingMap, + KeyGen: cp.KeyGen, + P2PKeyGen: cp.P2PKeyGen, + MultiSigner: multiSignerMock, + StartTime: startTime, + EnableEpochsConfig: enableEpochsConfig, } tcn := NewTestConsensusNode(args) @@ -188,7 +193,7 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { tcn.MainMessenger = CreateMessengerWithNoDiscovery() tcn.FullArchiveMessenger = &p2pmocks.MessengerStub{} tcn.initBlockChain(testHasher) - tcn.initBlockProcessor() + tcn.initBlockProcessor(tcn.ShardCoordinator.SelfId()) syncer := ntp.NewSyncTime(ntp.NewNTPGoogleConfig(), nil) syncer.StartSyncingTime() @@ -236,7 +241,7 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { tcn.initAccountsDB() - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + coreComponents := GetDefaultCoreComponents(args.EnableEpochsConfig) coreComponents.SyncTimerField = syncer coreComponents.RoundHandlerField = roundHandler coreComponents.InternalMarshalizerField = TestMarshalizer @@ -314,7 +319,7 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { processComponents.EpochNotifier = epochStartRegistrationHandler processComponents.BlackListHdl = &testscommon.TimeCacheStub{} processComponents.BootSore = &mock.BoostrapStorerMock{} - processComponents.HeaderSigVerif = &mock.HeaderSigVerifierStub{} + processComponents.HeaderSigVerif = &consensusMocks.HeaderSigVerifierMock{} processComponents.HeaderIntegrVerif = &mock.HeaderIntegrityVerifierStub{} processComponents.ReqHandler = &testscommon.RequestHandlerStub{} processComponents.MainPeerMapper = mock.NewNetworkShardingCollectorMock() @@ -435,7 +440,7 @@ func (tcn *TestConsensusNode) initBlockChain(hasher hashing.Hasher) { tcn.ChainHandler.SetGenesisHeaderHash(hasher.Compute(string(hdrMarshalized))) } -func (tcn *TestConsensusNode) initBlockProcessor() { +func (tcn *TestConsensusNode) initBlockProcessor(shardId uint32) { tcn.BlockProcessor = &mock.BlockProcessorMock{ Marshalizer: TestMarshalizer, CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { @@ -459,12 +464,38 @@ func (tcn *TestConsensusNode) initBlockProcessor() { return mrsData, mrsTxs, nil }, CreateNewHeaderCalled: func(round uint64, nonce uint64) (data.HeaderHandler, error) { - return &dataBlock.Header{ - Round: round, - Nonce: nonce, - SoftwareVersion: []byte("version"), + if shardId == common.MetachainShardId { + return &dataBlock.MetaBlock{ + Round: round, + Nonce: nonce, + SoftwareVersion: []byte("version"), + ValidatorStatsRootHash: []byte("validator stats root hash"), + AccumulatedFeesInEpoch: big.NewInt(0), + DeveloperFees: big.NewInt(0), + DevFeesInEpoch: big.NewInt(0), + }, nil + } + + return &dataBlock.HeaderV2{ + Header: &dataBlock.Header{ + Round: round, + Nonce: nonce, + SoftwareVersion: []byte("version"), + }, + ScheduledDeveloperFees: big.NewInt(0), + ScheduledAccumulatedFees: big.NewInt(0), }, nil }, + DecodeBlockHeaderCalled: func(dta []byte) data.HeaderHandler { + var header data.HeaderHandler + header = &dataBlock.HeaderV2{} + if shardId == common.MetachainShardId { + header = &dataBlock.MetaBlock{} + } + + _ = TestMarshalizer.Unmarshal(header, dta) + return header + }, } } diff --git a/integrationTests/testHeartbeatNode.go b/integrationTests/testHeartbeatNode.go index b74bfaf01fe..caea2235767 100644 --- a/integrationTests/testHeartbeatNode.go +++ b/integrationTests/testHeartbeatNode.go @@ -21,6 +21,8 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing/mcl" "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -59,7 +61,6 @@ import ( trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" "github.com/multiversx/mx-chain-go/update" - "github.com/stretchr/testify/require" ) // constants used for the hearbeat node & generated messages @@ -716,8 +717,9 @@ func (thn *TestHeartbeatNode) initMultiDataInterceptor(topic string, dataFactory return true }, }, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: thn.MainMessenger.ID(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: thn.MainMessenger.ID(), + InterceptedDataVerifier: &processMock.InterceptedDataVerifierMock{}, }, ) @@ -739,8 +741,9 @@ func (thn *TestHeartbeatNode) initSingleDataInterceptor(topic string, dataFactor return true }, }, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: thn.MainMessenger.ID(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: thn.MainMessenger.ID(), + InterceptedDataVerifier: &processMock.InterceptedDataVerifierMock{}, }, ) diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index a7c6cdac3c3..57af859a8df 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -28,6 +28,11 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/statistics" "github.com/multiversx/mx-chain-go/config" @@ -78,10 +83,6 @@ import ( "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) // StepDelay is used so that transactions can disseminate properly @@ -1142,16 +1143,12 @@ func IncrementAndPrintRound(round uint64) uint64 { } // ProposeBlock proposes a block for every shard -func ProposeBlock(nodes []*TestProcessorNode, idxProposers []int, round uint64, nonce uint64) { +func ProposeBlock(nodes []*TestProcessorNode, leaders []*TestProcessorNode, round uint64, nonce uint64) { log.Info("All shards propose blocks...") stepDelayAdjustment := StepDelay * time.Duration(1+len(nodes)/3) - for idx, n := range nodes { - if !IsIntInSlice(idx, idxProposers) { - continue - } - + for _, n := range leaders { body, header, _ := n.ProposeBlock(round, nonce) n.WhiteListBody(nodes, body) pk := n.NodeKeys.MainKey.Pk @@ -1168,13 +1165,13 @@ func ProposeBlock(nodes []*TestProcessorNode, idxProposers []int, round uint64, func SyncBlock( t *testing.T, nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, round uint64, ) { log.Info("All other shard nodes sync the proposed block...") - for idx, n := range nodes { - if IsIntInSlice(idx, idxProposers) { + for _, n := range nodes { + if IsNodeInSlice(n, leaders) { continue } @@ -1190,10 +1187,9 @@ func SyncBlock( log.Info("Synchronized block\n" + MakeDisplayTable(nodes)) } -// IsIntInSlice returns true if idx is found on any position in the provided slice -func IsIntInSlice(idx int, slice []int) bool { +func IsNodeInSlice(node *TestProcessorNode, slice []*TestProcessorNode) bool { for _, value := range slice { - if value == idx { + if value == node { return true } } @@ -2240,14 +2236,14 @@ func generateValidTx( func ProposeAndSyncOneBlock( t *testing.T, nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, round uint64, nonce uint64, ) (uint64, uint64) { UpdateRound(nodes, round) - ProposeBlock(nodes, idxProposers, round, nonce) - SyncBlock(t, nodes, idxProposers, round) + ProposeBlock(nodes, leaders, round, nonce) + SyncBlock(t, nodes, leaders, round) round = IncrementAndPrintRound(round) nonce++ @@ -2418,7 +2414,7 @@ func BootstrapDelay() { func SetupSyncNodesOneShardAndMeta( numNodesPerShard int, numNodesMeta int, -) ([]*TestProcessorNode, []int) { +) ([]*TestProcessorNode, []*TestProcessorNode) { maxShardsLocal := uint32(1) shardId := uint32(0) @@ -2435,7 +2431,7 @@ func SetupSyncNodesOneShardAndMeta( nodes = append(nodes, shardNode) connectableNodes = append(connectableNodes, shardNode) } - idxProposerShard0 := 0 + leaderShard0 := nodes[0] for i := 0; i < numNodesMeta; i++ { metaNode := NewTestProcessorNode(ArgTestProcessorNode{ @@ -2447,13 +2443,13 @@ func SetupSyncNodesOneShardAndMeta( nodes = append(nodes, metaNode) connectableNodes = append(connectableNodes, metaNode) } - idxProposerMeta := len(nodes) - 1 + leaderMeta := nodes[len(nodes)-1] - idxProposers := []int{idxProposerShard0, idxProposerMeta} + leaders := []*TestProcessorNode{leaderShard0, leaderMeta} ConnectNodes(connectableNodes) - return nodes, idxProposers + return nodes, leaders } // StartSyncingBlocks starts the syncing process of all the nodes @@ -2535,14 +2531,14 @@ func UpdateRound(nodes []*TestProcessorNode, round uint64) { func ProposeBlocks( nodes []*TestProcessorNode, round *uint64, - idxProposers []int, + leaders []*TestProcessorNode, nonces []*uint64, numOfRounds int, ) { for i := 0; i < numOfRounds; i++ { crtRound := atomic.LoadUint64(round) - proposeBlocks(nodes, idxProposers, nonces, crtRound) + proposeBlocks(nodes, leaders, nonces, crtRound) time.Sleep(SyncDelay) @@ -2563,20 +2559,20 @@ func IncrementNonces(nonces []*uint64) { func proposeBlocks( nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, nonces []*uint64, crtRound uint64, ) { - for idx, proposer := range idxProposers { + for idx, proposer := range leaders { crtNonce := atomic.LoadUint64(nonces[idx]) - ProposeBlock(nodes, []int{proposer}, crtRound, crtNonce) + ProposeBlock(nodes, []*TestProcessorNode{proposer}, crtRound, crtNonce) } } // WaitOperationToBeDone - -func WaitOperationToBeDone(t *testing.T, nodes []*TestProcessorNode, nrOfRounds int, nonce uint64, round uint64, idxProposers []int) (uint64, uint64) { +func WaitOperationToBeDone(t *testing.T, leaders []*TestProcessorNode, nodes []*TestProcessorNode, nrOfRounds int, nonce uint64, round uint64) (uint64, uint64) { for i := 0; i < nrOfRounds; i++ { - round, nonce = ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } return nonce, round diff --git a/integrationTests/testNetwork.go b/integrationTests/testNetwork.go index a08b3aa85c7..f5e1e2b9dfd 100644 --- a/integrationTests/testNetwork.go +++ b/integrationTests/testNetwork.go @@ -8,12 +8,13 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) // ShardIdentifier is the numeric index of a shard @@ -44,7 +45,7 @@ type TestNetwork struct { NodesSharded NodesByShardMap Wallets []*TestWalletAccount DeploymentAddress Address - Proposers []int + Proposers []*TestProcessorNode Round uint64 Nonce uint64 T *testing.T @@ -119,11 +120,11 @@ func (net *TestNetwork) Step() { func (net *TestNetwork) Steps(steps int) { net.Nonce, net.Round = WaitOperationToBeDone( net.T, + net.Proposers, net.Nodes, steps, net.Nonce, - net.Round, - net.Proposers) + net.Round) } // Close shuts down the test network. @@ -421,6 +422,8 @@ func (net *TestNetwork) createNodes() { StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, } net.Nodes = CreateNodesWithEnableEpochs( @@ -432,11 +435,11 @@ func (net *TestNetwork) createNodes() { } func (net *TestNetwork) indexProposers() { - net.Proposers = make([]int, net.NumShards+1) + net.Proposers = make([]*TestProcessorNode, net.NumShards+1) for i := 0; i < net.NumShards; i++ { - net.Proposers[i] = i * net.NodesPerShard + net.Proposers[i] = net.Nodes[i*net.NodesPerShard] } - net.Proposers[net.NumShards] = net.NumShards * net.NodesPerShard + net.Proposers[net.NumShards] = net.Nodes[net.NumShards*net.NodesPerShard] } func (net *TestNetwork) mapNodesByShard() { diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index 6eeec34c3c9..6416f8b6c7c 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -31,6 +31,10 @@ import ( ed25519SingleSig "github.com/multiversx/mx-chain-crypto-go/signing/ed25519/singlesig" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" mclsig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" @@ -42,6 +46,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" requesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -78,6 +83,7 @@ import ( "github.com/multiversx/mx-chain-go/process/factory/shard" "github.com/multiversx/mx-chain-go/process/heartbeat/validator" "github.com/multiversx/mx-chain-go/process/interceptors" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/peer" "github.com/multiversx/mx-chain-go/process/rating" @@ -104,7 +110,9 @@ import ( "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + cacheMocks "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" dblookupextMock "github.com/multiversx/mx-chain-go/testscommon/dblookupext" @@ -128,9 +136,6 @@ import ( "github.com/multiversx/mx-chain-go/vm" vmProcess "github.com/multiversx/mx-chain-go/vm/process" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" ) var zero = big.NewInt(0) @@ -304,6 +309,7 @@ type ArgTestProcessorNode struct { StatusMetrics external.StatusMetricsHandler WithPeersRatingHandler bool NodeOperationMode common.NodeOperation + Proofs dataRetriever.ProofsPool } // TestProcessorNode represents a container type of class used in integration tests @@ -330,6 +336,7 @@ type TestProcessorNode struct { TrieContainer common.TriesHolder BlockChain data.ChainHandler GenesisBlocks map[uint32]data.HeaderHandler + ProofsPool dataRetriever.ProofsPool EconomicsData *economics.TestEconomicsData RatingsData *rating.RatingsData @@ -466,8 +473,8 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { var peersRatingMonitor p2p.PeersRatingMonitor peersRatingMonitor = &p2pmocks.PeersRatingMonitorStub{} if args.WithPeersRatingHandler { - topRatedCache := testscommon.NewCacherMock() - badRatedCache := testscommon.NewCacherMock() + topRatedCache := cacheMocks.NewCacherMock() + badRatedCache := cacheMocks.NewCacherMock() peersRatingHandler, _ = p2pFactory.NewPeersRatingHandler( p2pFactory.ArgPeersRatingHandler{ TopRatedCache: topRatedCache, @@ -560,7 +567,7 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { tpn.HeaderSigVerifier = args.HeaderSigVerifier if check.IfNil(tpn.HeaderSigVerifier) { - tpn.HeaderSigVerifier = &mock.HeaderSigVerifierStub{} + tpn.HeaderSigVerifier = &consensusMocks.HeaderSigVerifierMock{} } tpn.HeaderIntegrityVerifier = args.HeaderIntegrityVerifier @@ -848,6 +855,7 @@ func (tpn *TestProcessorNode) initTestNodeWithArgs(args ArgTestProcessorNode) { tpn.NodeKeys.MainKey.Sk, tpn.MainMessenger.ID(), ), + config.ConsensusGradualBroadcastConfig{GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}}, ) if args.WithSync { @@ -882,7 +890,7 @@ func (tpn *TestProcessorNode) createFullSCQueryService(gasMap map[string]map[str argsBuiltIn.AutomaticCrawlerAddresses = GenerateOneAddressPerShard(argsBuiltIn.ShardCoordinator) builtInFuncFactory, _ := builtInFunctions.CreateBuiltInFunctionsFactory(argsBuiltIn) - smartContractsCache := testscommon.NewCacherMock() + smartContractsCache := cacheMocks.NewCacherMock() argsHook := hooks.ArgBlockChainHook{ Accounts: tpn.AccntState, @@ -1073,6 +1081,7 @@ func (tpn *TestProcessorNode) InitializeProcessors(gasMap map[string]map[string] tpn.NodeKeys.MainKey.Sk, tpn.MainMessenger.ID(), ), + config.ConsensusGradualBroadcastConfig{GradualIndexBroadcastDelay: []config.IndexBroadcastDelay{}}, ) tpn.setGenesisBlock() tpn.initNode() @@ -1081,7 +1090,8 @@ func (tpn *TestProcessorNode) InitializeProcessors(gasMap map[string]map[string] } func (tpn *TestProcessorNode) initDataPools() { - tpn.DataPool = dataRetrieverMock.CreatePoolsHolder(1, tpn.ShardCoordinator.SelfId()) + tpn.ProofsPool = proofscache.NewProofsPool() + tpn.DataPool = dataRetrieverMock.CreatePoolsHolderWithProofsPool(1, tpn.ShardCoordinator.SelfId(), tpn.ProofsPool) cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} suCache, _ := storageunit.NewCache(cacherCfg) tpn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(suCache) @@ -1287,6 +1297,11 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenBlockSign cryptoComponents.TxKeyGen = tpn.OwnAccount.KeygenTxSign + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 3, + CacheExpiry: time.Second * 10, + } + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { argsEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ GenesisTime: tpn.RoundHandler.TimeStamp(), @@ -1309,36 +1324,37 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { coreComponents.HardforkTriggerPubKeyField = providedHardforkPk metaInterceptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComponents, - CryptoComponents: cryptoComponents, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - NodesCoordinator: tpn.NodesCoordinator, - MainMessenger: tpn.MainMessenger, - FullArchiveMessenger: tpn.FullArchiveMessenger, - Store: tpn.Storage, - DataPool: tpn.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: tpn.EconomicsData, - BlockBlackList: tpn.BlockBlackListHandler, - HeaderSigVerifier: tpn.HeaderSigVerifier, - HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, - ValidityAttester: tpn.BlockTracker, - EpochStartTrigger: tpn.EpochStartTrigger, - WhiteListHandler: tpn.WhiteListHandler, - WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, - AntifloodHandler: &mock.NilAntifloodHandler{}, - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - SizeCheckDelta: sizeCheckDelta, - RequestHandler: tpn.RequestHandler, - PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, - SignaturesHandler: &processMock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: tpn.MainPeerShardMapper, - FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, - HardforkTrigger: tpn.HardforkTrigger, - NodeOperationMode: tpn.NodeOperationMode, + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + NodesCoordinator: tpn.NodesCoordinator, + MainMessenger: tpn.MainMessenger, + FullArchiveMessenger: tpn.FullArchiveMessenger, + Store: tpn.Storage, + DataPool: tpn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: tpn.EconomicsData, + BlockBlackList: tpn.BlockBlackListHandler, + HeaderSigVerifier: tpn.HeaderSigVerifier, + HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, + ValidityAttester: tpn.BlockTracker, + EpochStartTrigger: tpn.EpochStartTrigger, + WhiteListHandler: tpn.WhiteListHandler, + WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: tpn.RequestHandler, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: tpn.MainPeerShardMapper, + FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, + HardforkTrigger: tpn.HardforkTrigger, + NodeOperationMode: tpn.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } interceptorContainerFactory, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(metaInterceptorContainerFactoryArgs) @@ -1377,37 +1393,39 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { coreComponents.HardforkTriggerPubKeyField = providedHardforkPk shardIntereptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComponents, - CryptoComponents: cryptoComponents, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - NodesCoordinator: tpn.NodesCoordinator, - MainMessenger: tpn.MainMessenger, - FullArchiveMessenger: tpn.FullArchiveMessenger, - Store: tpn.Storage, - DataPool: tpn.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: tpn.EconomicsData, - BlockBlackList: tpn.BlockBlackListHandler, - HeaderSigVerifier: tpn.HeaderSigVerifier, - HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, - ValidityAttester: tpn.BlockTracker, - EpochStartTrigger: tpn.EpochStartTrigger, - WhiteListHandler: tpn.WhiteListHandler, - WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, - AntifloodHandler: &mock.NilAntifloodHandler{}, - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - SizeCheckDelta: sizeCheckDelta, - RequestHandler: tpn.RequestHandler, - PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, - SignaturesHandler: &processMock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: tpn.MainPeerShardMapper, - FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, - HardforkTrigger: tpn.HardforkTrigger, - NodeOperationMode: tpn.NodeOperationMode, + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + NodesCoordinator: tpn.NodesCoordinator, + MainMessenger: tpn.MainMessenger, + FullArchiveMessenger: tpn.FullArchiveMessenger, + Store: tpn.Storage, + DataPool: tpn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: tpn.EconomicsData, + BlockBlackList: tpn.BlockBlackListHandler, + HeaderSigVerifier: tpn.HeaderSigVerifier, + HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, + ValidityAttester: tpn.BlockTracker, + EpochStartTrigger: tpn.EpochStartTrigger, + WhiteListHandler: tpn.WhiteListHandler, + WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: tpn.RequestHandler, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: tpn.MainPeerShardMapper, + FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, + HardforkTrigger: tpn.HardforkTrigger, + NodeOperationMode: tpn.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } + interceptorContainerFactory, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(shardIntereptorContainerFactoryArgs) tpn.MainInterceptorsContainer, tpn.FullArchiveInterceptorsContainer, err = interceptorContainerFactory.Create() @@ -2743,6 +2761,20 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return nil, nil, nil } + previousProof := &dataBlock.HeaderProof{ + PubKeysBitmap: []byte{1}, + AggregatedSignature: sig, + HeaderHash: currHdrHash, + HeaderEpoch: currHdr.GetEpoch(), + HeaderNonce: currHdr.GetNonce(), + HeaderShardId: currHdr.GetShardID(), + } + blockHeader.SetPreviousProof(previousProof) + + _ = tpn.ProofsPool.AddProof(previousProof) + + log.Error("added proof", "currHdrHash", currHdrHash, "node", tpn.OwnAccount.Address) + genesisRound := tpn.BlockChain.GetGenesisHeader().GetRound() err = blockHeader.SetTimeStamp((round - genesisRound) * uint64(tpn.RoundHandler.TimeDuration().Seconds())) if err != nil { @@ -3028,17 +3060,19 @@ func (tpn *TestProcessorNode) initRequestedItemsHandler() { func (tpn *TestProcessorNode) initBlockTracker() { argBaseTracker := track.ArgBaseTracker{ - Hasher: TestHasher, - HeaderValidator: tpn.HeaderValidator, - Marshalizer: TestMarshalizer, - RequestHandler: tpn.RequestHandler, - RoundHandler: tpn.RoundHandler, - ShardCoordinator: tpn.ShardCoordinator, - Store: tpn.Storage, - StartHeaders: tpn.GenesisBlocks, - PoolsHolder: tpn.DataPool, - WhitelistHandler: tpn.WhiteListHandler, - FeeHandler: tpn.EconomicsData, + Hasher: TestHasher, + HeaderValidator: tpn.HeaderValidator, + Marshalizer: TestMarshalizer, + RequestHandler: tpn.RequestHandler, + RoundHandler: tpn.RoundHandler, + ShardCoordinator: tpn.ShardCoordinator, + Store: tpn.Storage, + StartHeaders: tpn.GenesisBlocks, + PoolsHolder: tpn.DataPool, + WhitelistHandler: tpn.WhiteListHandler, + FeeHandler: tpn.EconomicsData, + EnableEpochsHandler: tpn.EnableEpochsHandler, + ProofsPool: tpn.DataPool.Proofs(), } if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { @@ -3066,7 +3100,7 @@ func (tpn *TestProcessorNode) initHeaderValidator() { } func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { - cacher := testscommon.NewCacherMock() + cacher := cacheMocks.NewCacherMock() psh, err := peerSignatureHandler.NewPeerSignatureHandler( cacher, tpn.OwnAccount.BlockSingleSigner, @@ -3253,6 +3287,8 @@ func CreateEnableEpochsConfig() config.EnableEpochs { MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, RefactorPeersMiniBlocksEnableEpoch: UnreachableEpoch, SCProcessorV2EnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, } } @@ -3312,7 +3348,7 @@ func GetDefaultProcessComponents() *mock.ProcessComponentsStub { BlockProcess: &mock.BlockProcessorMock{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BoostrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensusMocks.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, @@ -3413,7 +3449,11 @@ func GetDefaultStatusComponents() *mock.StatusComponentsStub { func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator) *mainFactoryMocks.BootstrapComponentsStub { var versionedHeaderFactory nodeFactory.VersionedHeaderFactory - headerVersionHandler := &testscommon.HeaderVersionHandlerStub{} + headerVersionHandler := &testscommon.HeaderVersionHandlerStub{ + GetVersionCalled: func(epoch uint32) string { + return "2" + }, + } versionedHeaderFactory, _ = hdrFactory.NewShardHeaderFactory(headerVersionHandler) if shardCoordinator.SelfId() == core.MetachainShardId { versionedHeaderFactory, _ = hdrFactory.NewMetaHeaderFactory(headerVersionHandler) @@ -3526,9 +3566,9 @@ func getDefaultNodesSetup(maxShards, numNodes uint32, address []byte, pksBytes m func getDefaultNodesCoordinator(maxShards uint32, pksBytes map[uint32][]byte) nodesCoordinator.NodesCoordinator { return &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pksBytes[shardId], 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, GetAllValidatorsPublicKeysCalled: func() (map[uint32][][]byte, error) { keys := make(map[uint32][][]byte) @@ -3559,5 +3599,18 @@ func GetDefaultEnableEpochsConfig() *config.EnableEpochs { StakingV4Step1EnableEpoch: UnreachableEpoch, StakingV4Step2EnableEpoch: UnreachableEpoch, StakingV4Step3EnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, + } +} + +// GetDefaultRoundsConfig - +func GetDefaultRoundsConfig() config.RoundConfig { + return config.RoundConfig{ + RoundActivations: map[string]config.ActivationRoundByName{ + "DisableAsyncCallV1": { + Round: "18446744073709551615", + }, + }, } } diff --git a/integrationTests/testProcessorNodeWithMultisigner.go b/integrationTests/testProcessorNodeWithMultisigner.go index 80f2a183ad2..7c20b09f349 100644 --- a/integrationTests/testProcessorNodeWithMultisigner.go +++ b/integrationTests/testProcessorNodeWithMultisigner.go @@ -17,6 +17,7 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" mclmultisig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart/notifier" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" @@ -181,6 +182,8 @@ func CreateNodeWithBLSAndTxKeys( ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, RefactorPeersMiniBlocksEnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, } return CreateNode( @@ -242,6 +245,8 @@ func CreateNodesWithNodesCoordinatorFactory( StakingV4Step1EnableEpoch: UnreachableEpoch, StakingV4Step2EnableEpoch: UnreachableEpoch, StakingV4Step3EnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, } nodesMap := make(map[uint32][]*TestProcessorNode) @@ -467,6 +472,8 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( SingleSigVerifier: signer, KeyGen: keyGen, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{}, } headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) @@ -490,6 +497,8 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, }, NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, @@ -608,6 +617,8 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( SingleSigVerifier: singleSigner, KeyGen: keyGenForBlocks, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{}, } headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) @@ -626,6 +637,8 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + EquivalentMessagesEnableEpoch: UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: UnreachableEpoch, }, NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, @@ -648,6 +661,15 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( return nodesMap } +// ProposeBlockData is a struct that holds some context data for the proposed block +type ProposeBlockData struct { + Body data.BodyHandler + Header data.HeaderHandler + Txs [][]byte + Leader *TestProcessorNode + ConsensusGroup []*TestProcessorNode +} + // ProposeBlockWithConsensusSignature proposes func ProposeBlockWithConsensusSignature( shardId uint32, @@ -656,39 +678,48 @@ func ProposeBlockWithConsensusSignature( nonce uint64, randomness []byte, epoch uint32, -) (data.BodyHandler, data.HeaderHandler, [][]byte, []*TestProcessorNode) { +) *ProposeBlockData { nodesCoordinatorInstance := nodesMap[shardId][0].NodesCoordinator - pubKeys, err := nodesCoordinatorInstance.GetConsensusValidatorsPublicKeys(randomness, round, shardId, epoch) + leaderPubKey, pubKeys, err := nodesCoordinatorInstance.GetConsensusValidatorsPublicKeys(randomness, round, shardId, epoch) if err != nil { log.Error("nodesCoordinator.GetConsensusValidatorsPublicKeys", "error", err) } // select nodes from map based on their pub keys - consensusNodes := selectTestNodesForPubKeys(nodesMap[shardId], pubKeys) + leaderNode, consensusNodes := selectTestNodesForPubKeys(nodesMap[shardId], leaderPubKey, pubKeys) // first node is block proposer - body, header, txHashes := consensusNodes[0].ProposeBlock(round, nonce) + body, header, txHashes := leaderNode.ProposeBlock(round, nonce) err = header.SetPrevRandSeed(randomness) if err != nil { log.Error("header.SetPrevRandSeed", "error", err) } - header = DoConsensusSigningOnBlock(header, consensusNodes, pubKeys) + header = DoConsensusSigningOnBlock(header, leaderNode, consensusNodes, pubKeys) - return body, header, txHashes, consensusNodes + return &ProposeBlockData{ + Body: body, + Header: header, + Txs: txHashes, + Leader: leaderNode, + ConsensusGroup: consensusNodes, + } } -func selectTestNodesForPubKeys(nodes []*TestProcessorNode, pubKeys []string) []*TestProcessorNode { +func selectTestNodesForPubKeys(nodes []*TestProcessorNode, leaderPubKey string, pubKeys []string) (*TestProcessorNode, []*TestProcessorNode) { selectedNodes := make([]*TestProcessorNode, len(pubKeys)) cntNodes := 0 - + var leaderNode *TestProcessorNode for i, pk := range pubKeys { - for _, node := range nodes { + for j, node := range nodes { pubKeyBytes, _ := node.NodeKeys.MainKey.Pk.ToByteArray() if bytes.Equal(pubKeyBytes, []byte(pk)) { - selectedNodes[i] = node + selectedNodes[i] = nodes[j] cntNodes++ } + if string(pubKeyBytes) == leaderPubKey { + leaderNode = nodes[j] + } } } @@ -696,12 +727,13 @@ func selectTestNodesForPubKeys(nodes []*TestProcessorNode, pubKeys []string) []* fmt.Println("Error selecting nodes from public keys") } - return selectedNodes + return leaderNode, selectedNodes } -// DoConsensusSigningOnBlock simulates a consensus aggregated signature on the provided block +// DoConsensusSigningOnBlock simulates a ConsensusGroup aggregated signature on the provided block func DoConsensusSigningOnBlock( blockHeader data.HeaderHandler, + leaderNode *TestProcessorNode, consensusNodes []*TestProcessorNode, pubKeys []string, ) data.HeaderHandler { @@ -732,7 +764,7 @@ func DoConsensusSigningOnBlock( pubKeysBytes := make([][]byte, len(consensusNodes)) sigShares := make([][]byte, len(consensusNodes)) - msig := consensusNodes[0].MultiSigner + msig := leaderNode.MultiSigner for i := 0; i < len(consensusNodes); i++ { pubKeysBytes[i] = []byte(pubKeys[i]) @@ -759,20 +791,14 @@ func DoConsensusSigningOnBlock( return blockHeader } -// AllShardsProposeBlock simulates each shard selecting a consensus group and proposing/broadcasting/committing a block +// AllShardsProposeBlock simulates each shard selecting a ConsensusGroup group and proposing/broadcasting/committing a block func AllShardsProposeBlock( round uint64, nonce uint64, nodesMap map[uint32][]*TestProcessorNode, -) ( - map[uint32]data.BodyHandler, - map[uint32]data.HeaderHandler, - map[uint32][]*TestProcessorNode, -) { +) map[uint32]*ProposeBlockData { - body := make(map[uint32]data.BodyHandler) - header := make(map[uint32]data.HeaderHandler) - consensusNodes := make(map[uint32][]*TestProcessorNode) + proposalData := make(map[uint32]*ProposeBlockData) newRandomness := make(map[uint32][]byte) nodesList := make([]*TestProcessorNode, 0) @@ -790,34 +816,36 @@ func AllShardsProposeBlock( // TODO: remove if start of epoch block needs to be validated by the new epoch nodes epoch := currentBlockHeader.GetEpoch() prevRandomness := currentBlockHeader.GetRandSeed() - body[i], header[i], _, consensusNodes[i] = ProposeBlockWithConsensusSignature( + proposalData[i] = ProposeBlockWithConsensusSignature( i, nodesMap, round, nonce, prevRandomness, epoch, ) - nodesMap[i][0].WhiteListBody(nodesList, body[i]) - newRandomness[i] = header[i].GetRandSeed() + proposalData[i].Leader.WhiteListBody(nodesList, proposalData[i].Body) + newRandomness[i] = proposalData[i].Header.GetRandSeed() } // propagate blocks for i := range nodesMap { - pk := consensusNodes[i][0].NodeKeys.MainKey.Pk - consensusNodes[i][0].BroadcastBlock(body[i], header[i], pk) - consensusNodes[i][0].CommitBlock(body[i], header[i]) + leader := proposalData[i].Leader + pk := proposalData[i].Leader.NodeKeys.MainKey.Pk + leader.BroadcastBlock(proposalData[i].Body, proposalData[i].Header, pk) + leader.CommitBlock(proposalData[i].Body, proposalData[i].Header) } time.Sleep(2 * StepDelay) - return body, header, consensusNodes + return proposalData } // SyncAllShardsWithRoundBlock enforces all nodes in each shard synchronizing the block for the given round func SyncAllShardsWithRoundBlock( t *testing.T, + proposalData map[uint32]*ProposeBlockData, nodesMap map[uint32][]*TestProcessorNode, - indexProposers map[uint32]int, round uint64, ) { - for shard, nodeList := range nodesMap { - SyncBlock(t, nodeList, []int{indexProposers[shard]}, round) + for shard, nodesList := range nodesMap { + proposal := proposalData[shard] + SyncBlock(t, nodesList, []*TestProcessorNode{proposal.Leader}, round) } time.Sleep(4 * StepDelay) } diff --git a/integrationTests/testProcessorNodeWithTestWebServer.go b/integrationTests/testProcessorNodeWithTestWebServer.go index 592d7d1bdba..b7d05e76f4c 100644 --- a/integrationTests/testProcessorNodeWithTestWebServer.go +++ b/integrationTests/testProcessorNodeWithTestWebServer.go @@ -7,6 +7,10 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/multiversx/mx-chain-go/api/groups" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/config" @@ -22,13 +26,11 @@ import ( "github.com/multiversx/mx-chain-go/process/transactionEvaluator" "github.com/multiversx/mx-chain-go/process/txstatus" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" ) // TestProcessorNodeWithTestWebServer represents a TestProcessorNode with a test web server @@ -177,7 +179,7 @@ func createFacadeComponents(tpn *TestProcessorNode) nodeFacade.ApiResolver { ShardCoordinator: tpn.ShardCoordinator, Marshalizer: TestMarshalizer, Hasher: TestHasher, - VMOutputCacher: &testscommon.CacherMock{}, + VMOutputCacher: &cache.CacherMock{}, DataFieldParser: dataFieldParser, BlockChainHook: tpn.BlockchainHook, } diff --git a/integrationTests/testSyncNode.go b/integrationTests/testSyncNode.go index b28d5e3f953..31c2ac46111 100644 --- a/integrationTests/testSyncNode.go +++ b/integrationTests/testSyncNode.go @@ -176,6 +176,7 @@ func (tpn *TestProcessorNode) createShardBootstrapper() (TestBootstrapper, error ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: tpn.RoundHandler.TimeDuration(), RepopulateTokensSupplies: false, + EnableEpochsHandler: tpn.EnableEpochsHandler, } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -222,6 +223,7 @@ func (tpn *TestProcessorNode) createMetaChainBootstrapper() (TestBootstrapper, e ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: tpn.RoundHandler.TimeDuration(), RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ diff --git a/integrationTests/vm/delegation/delegation_test.go b/integrationTests/vm/delegation/delegation_test.go index 9bae5235076..3b766314ccc 100644 --- a/integrationTests/vm/delegation/delegation_test.go +++ b/integrationTests/vm/delegation/delegation_test.go @@ -7,16 +7,16 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" integrationTestsVm "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestDelegationSystemSCWithValidatorStatisticsAndStakingPhase3p5(t *testing.T) { @@ -263,17 +263,14 @@ func processBlocks( blockToProduce uint64, nodesMap map[uint32][]*integrationTests.TestProcessorNode, ) (uint64, uint64) { - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode - for i := uint64(0); i < blockToProduce; i++ { for _, nodesSlice := range nodesMap { integrationTests.UpdateRound(nodesSlice, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodesSlice) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/vm/esdt/common.go b/integrationTests/vm/esdt/common.go index 0d3a798d592..a8e13b5e83a 100644 --- a/integrationTests/vm/esdt/common.go +++ b/integrationTests/vm/esdt/common.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" testVm "github.com/multiversx/mx-chain-go/integrationTests/vm" @@ -19,8 +22,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) // GetESDTTokenData - @@ -91,12 +92,12 @@ func SetRolesWithSenderAccount(nodes []*integrationTests.TestProcessorNode, issu func DeployNonPayableSmartContract( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, fileName string, ) []byte { - return DeployNonPayableSmartContractFromNode(t, nodes, 0, idxProposers, nonce, round, fileName) + return DeployNonPayableSmartContractFromNode(t, nodes, 0, leaders, nonce, round, fileName) } // DeployNonPayableSmartContractFromNode - @@ -104,7 +105,7 @@ func DeployNonPayableSmartContractFromNode( t *testing.T, nodes []*integrationTests.TestProcessorNode, idDeployer int, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, fileName string, @@ -121,7 +122,7 @@ func DeployNonPayableSmartContractFromNode( integrationTests.AdditionalGasLimit, ) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, 4, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, *nonce, *round) scShardID := nodes[0].ShardCoordinator.ComputeId(scAddress) for _, node := range nodes { @@ -165,11 +166,13 @@ func CheckAddressHasTokens( } // CreateNodesAndPrepareBalances - -func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestProcessorNode, []int) { +func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode) { enableEpochs := config.EnableEpochs{ OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } roundsConfig := testscommon.GetDefaultRoundsConfig() return CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig( @@ -180,7 +183,11 @@ func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestPro } // CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig - -func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig(numOfShards int, enableEpochs config.EnableEpochs, roundsConfig config.RoundConfig) ([]*integrationTests.TestProcessorNode, []int) { +func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig( + numOfShards int, + enableEpochs config.EnableEpochs, + roundsConfig config.RoundConfig, +) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode) { nodesPerShard := 1 numMetachainNodes := 1 @@ -198,14 +205,14 @@ func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig(numOfShards int, ena }, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) - return nodes, idxProposers + return nodes, leaders } // IssueNFT - @@ -387,7 +394,7 @@ func PrepareFungibleTokensWithLocalBurnAndMint( t *testing.T, nodes []*integrationTests.TestProcessorNode, addressWithRoles []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, round *uint64, nonce *uint64, ) string { @@ -396,7 +403,7 @@ func PrepareFungibleTokensWithLocalBurnAndMint( nodes, nodes[0].OwnAccount, addressWithRoles, - idxProposers, + leaders, round, nonce) } @@ -407,7 +414,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( nodes []*integrationTests.TestProcessorNode, issuerAccount *integrationTests.TestWalletAccount, addressWithRoles []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, round *uint64, nonce *uint64, ) string { @@ -415,7 +422,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKN"))) @@ -424,7 +431,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) return tokenIdentifier diff --git a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go index 742531fb801..a33b882a58c 100644 --- a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go +++ b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go @@ -7,17 +7,18 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" esdtCommon "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" - "github.com/stretchr/testify/assert" ) func TestESDTLocalMintAndBurnFromSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -33,9 +34,9 @@ func TestESDTLocalMintAndBurnFromSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") - esdtLocalMintAndBurnFromSCRunTestsAndAsserts(t, nodes, nodes[0].OwnAccount, scAddress, idxProposers, nonce, round) + esdtLocalMintAndBurnFromSCRunTestsAndAsserts(t, nodes, nodes[0].OwnAccount, scAddress, leaders, nonce, round) } func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( @@ -43,11 +44,11 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( nodes []*integrationTests.TestProcessorNode, ownerWallet *integrationTests.TestWalletAccount, scAddress []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce uint64, round uint64, ) { - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount(t, nodes, ownerWallet, scAddress, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount(t, nodes, ownerWallet, scAddress, leaders, &nonce, &round) txData := []byte("localMint" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(100).Bytes())) @@ -72,7 +73,7 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 200) @@ -99,7 +100,7 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 100) @@ -109,7 +110,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -125,7 +126,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") issuePrice := big.NewInt(1000) txData := []byte("issueFungibleToken" + "@" + hex.EncodeToString([]byte("TOKEN")) + @@ -141,7 +142,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKR"))) @@ -157,7 +158,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData = []byte("localMint" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + @@ -180,7 +181,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 201) @@ -205,7 +206,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 101) @@ -215,7 +216,7 @@ func TestESDTSetTransferRoles(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(2) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(2) defer func() { for _, n := range nodes { @@ -231,14 +232,14 @@ func TestESDTSetTransferRoles(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/use-module.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/use-module.wasm") nrRoundsToPropagateMultiShard := 12 - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddress, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddress, leaders, &nonce, &round) esdtCommon.SetRoles(nodes, scAddress, []byte(tokenIdentifier), [][]byte{[]byte(core.ESDTRoleTransfer)}) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) destAddress := nodes[1].OwnAccount.Address @@ -256,7 +257,7 @@ func TestESDTSetTransferRoles(t *testing.T) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, destAddress, nodes, []byte(tokenIdentifier), 0, amount) @@ -279,7 +280,7 @@ func TestESDTSetTransferRolesForwardAsyncCallFailsCross(t *testing.T) { } func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(numShards) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(numShards) defer func() { for _, n := range nodes { @@ -295,15 +296,15 @@ func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddressA := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/use-module.wasm") - scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, idxProposers, &nonce, &round, "../testdata/use-module.wasm") + scAddressA := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/use-module.wasm") + scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, leaders, &nonce, &round, "../testdata/use-module.wasm") nrRoundsToPropagateMultiShard := 12 - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddressA, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddressA, leaders, &nonce, &round) esdtCommon.SetRoles(nodes, scAddressA, []byte(tokenIdentifier), [][]byte{[]byte(core.ESDTRoleTransfer)}) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) amount := int64(100) @@ -319,7 +320,7 @@ func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddressB, nodes, []byte(tokenIdentifier), 0, 0) @@ -344,7 +345,7 @@ func TestAsyncCallsAndCallBacksArgumentsCross(t *testing.T) { } func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(numShards) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(numShards) defer func() { for _, n := range nodes { n.Close() @@ -359,8 +360,8 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddressA := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 0, idxProposers, &nonce, &round, "forwarder.wasm") - scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, idxProposers, &nonce, &round, "vault.wasm") + scAddressA := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 0, leaders, &nonce, &round, "forwarder.wasm") + scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, leaders, &nonce, &round, "vault.wasm") txData := txDataBuilder.NewBuilder() txData.Clear().Func("echo_args_async").Bytes(scAddressB).Str("AA").Str("BB") @@ -374,7 +375,7 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) callbackArgs := append([]byte("success"), []byte{0}...) @@ -391,7 +392,7 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) checkDataFromAccountAndKey(t, nodes, scAddressA, []byte("callbackStorage"), append([]byte("success"), []byte{0}...)) diff --git a/integrationTests/vm/esdt/multisign/esdtMultisign_test.go b/integrationTests/vm/esdt/multisign/esdtMultisign_test.go index fd8e0b6fbb8..8a82988663a 100644 --- a/integrationTests/vm/esdt/multisign/esdtMultisign_test.go +++ b/integrationTests/vm/esdt/multisign/esdtMultisign_test.go @@ -8,14 +8,15 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" - "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/vm" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/vm" ) var vmType = []byte{5, 0} @@ -37,11 +38,11 @@ func TestESDTTransferWithMultisig(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -63,7 +64,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { time.Sleep(time.Second) numRoundsToPropagateIntraShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) // ----- issue ESDT token @@ -72,7 +73,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { proposeIssueTokenAndTransferFunds(nodes, multisignContractAddress, initalSupply, 0, ticker) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) actionID := getActionID(t, nodes, multisignContractAddress) @@ -82,13 +83,13 @@ func TestESDTTransferWithMultisig(t *testing.T) { time.Sleep(time.Second) numRoundsToPropagateCrossShard := 10 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) performActionID(nodes, multisignContractAddress, actionID, 0) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -102,7 +103,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { proposeTransferToken(nodes, multisignContractAddress, transferValue, 0, destinationAddress, tokenIdentifier) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) actionID = getActionID(t, nodes, multisignContractAddress) @@ -111,13 +112,13 @@ func TestESDTTransferWithMultisig(t *testing.T) { boardMembersSignActionID(nodes, multisignContractAddress, actionID, 1, 2) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) performActionID(nodes, multisignContractAddress, actionID, 0) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) expectedBalance := big.NewInt(0).Set(initalSupply) diff --git a/integrationTests/vm/esdt/nft/common.go b/integrationTests/vm/esdt/nft/common.go index 6df8dc7dd69..23cd837ba3a 100644 --- a/integrationTests/vm/esdt/nft/common.go +++ b/integrationTests/vm/esdt/nft/common.go @@ -8,9 +8,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" - "github.com/stretchr/testify/require" ) // NftArguments - @@ -70,7 +71,7 @@ func CheckNftData( func PrepareNFTWithRoles( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nftCreator *integrationTests.TestProcessorNode, round *uint64, nonce *uint64, @@ -82,7 +83,7 @@ func PrepareNFTWithRoles( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("SFT"))) @@ -91,7 +92,7 @@ func PrepareNFTWithRoles( esdt.SetRoles(nodes, nftCreator.OwnAccount.Address, []byte(tokenIdentifier), roles) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) nftMetaData := NftArguments{ @@ -105,7 +106,7 @@ func PrepareNFTWithRoles( CreateNFT([]byte(tokenIdentifier), nftCreator, nodes, &nftMetaData) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, 3, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, *nonce, *round) time.Sleep(time.Second) CheckNftData( diff --git a/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go b/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go index a1db92372bd..c35e513b357 100644 --- a/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go +++ b/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt/nft" @@ -29,11 +30,11 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -59,7 +60,7 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -85,7 +86,7 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // the token data is removed from trie if the quantity is 0, so we should not find it @@ -116,11 +117,11 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -148,7 +149,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -174,7 +175,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -190,7 +191,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -219,7 +220,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity -= quantityToBurn @@ -249,11 +250,11 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -278,7 +279,7 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -315,7 +316,7 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check that the new address owns the NFT @@ -357,11 +358,11 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -398,7 +399,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeInDifferentShard, &round, &nonce, @@ -424,7 +425,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -440,7 +441,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -469,7 +470,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 11 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - quantityToTransfer @@ -510,11 +511,11 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -542,7 +543,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[0], &round, &nonce, @@ -568,7 +569,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -584,7 +585,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -613,7 +614,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = 0 // make sure that the ESDT SC address didn't receive the token @@ -640,7 +641,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test } func testNFTSendCreateRole(t *testing.T, numOfShards int) { - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(numOfShards) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(numOfShards) defer func() { for _, n := range nodes { @@ -665,7 +666,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nftCreator, &round, &nonce, @@ -698,7 +699,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 20 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CreateNFT( @@ -710,7 +711,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -766,11 +767,11 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -808,7 +809,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeInDifferentShard, &round, &nonce, @@ -834,7 +835,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -850,7 +851,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -879,7 +880,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 11 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - quantityToTransfer @@ -920,11 +921,11 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -954,7 +955,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, tokenIssuer, &round, &nonce, @@ -980,7 +981,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -1013,7 +1014,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - int64(len(nodes)-1)*quantityToTransfer @@ -1056,7 +1057,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd @@ -1101,7 +1102,7 @@ func TestNFTTransferCreateAndSetRolesCrossShard(t *testing.T) { } func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(numOfShards) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(numOfShards) defer func() { for _, n := range nodes { @@ -1126,7 +1127,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nftCreator, &round, &nonce, @@ -1158,7 +1159,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) // stopNFTCreate @@ -1173,7 +1174,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // setCreateRole @@ -1190,7 +1191,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 20, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 20, nonce, round) time.Sleep(time.Second) newNFTMetaData := nft.NftArguments{ @@ -1210,7 +1211,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // we check that old data remains on NONCE 1 - as creation must return failure diff --git a/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go b/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go index 534c1c7435e..a1c3b524c9f 100644 --- a/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go +++ b/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go @@ -7,17 +7,18 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt/nft" - "github.com/stretchr/testify/require" ) func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -33,7 +34,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "nftIssue", "@03@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "nftIssue", "@03@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -65,7 +66,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 3, big.NewInt(1)) @@ -92,7 +93,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) @@ -123,7 +124,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) @@ -136,7 +137,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -152,7 +153,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "sftIssue", "@03@04@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "sftIssue", "@03@04@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -179,7 +180,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(11)) @@ -204,7 +205,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(9)) @@ -234,7 +235,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(9)) @@ -245,7 +246,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -261,7 +262,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "nftIssue", "@03@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "nftIssue", "@03@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -285,13 +286,13 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") txData = []byte("transferNftViaAsyncCall" + "@" + hex.EncodeToString(destinationSCAddress) + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("wrongFunctionToCall"))) @@ -316,7 +317,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(0)) @@ -348,7 +349,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -361,7 +362,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -384,7 +385,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { tokenIdentifier, _ := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[0], &round, &nonce, @@ -395,7 +396,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { nonceArg := hex.EncodeToString(big.NewInt(0).SetUint64(1).Bytes()) quantityToTransfer := hex.EncodeToString(big.NewInt(1).Bytes()) - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") txData := core.BuiltInFunctionESDTNFTTransfer + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + nonceArg + "@" + quantityToTransfer + "@" + hex.EncodeToString(destinationSCAddress) + "@" + hex.EncodeToString([]byte("acceptAndReturnCallData")) integrationTests.CreateAndSendTransaction( @@ -408,7 +409,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodes[0].OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -418,7 +419,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(2) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(2) defer func() { for _, n := range nodes { @@ -434,7 +435,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") destinationSCShardID := nodes[0].ShardCoordinator.ComputeId(destinationSCAddress) @@ -454,7 +455,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { tokenIdentifier, _ := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeFromOtherShard, &round, &nonce, @@ -478,7 +479,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -495,7 +496,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -512,7 +513,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -521,13 +522,13 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { func deployAndIssueNFTSFTThroughSC( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, issueFunc string, rolesEncoded string, ) ([]byte, string) { - scAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, nonce, round, "../../testdata/local-esdt-and-nft.wasm") + scAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, nonce, round, "../../testdata/local-esdt-and-nft.wasm") issuePrice := big.NewInt(1000) txData := []byte(issueFunc + "@" + hex.EncodeToString([]byte("TOKEN")) + @@ -543,7 +544,7 @@ func deployAndIssueNFTSFTThroughSC( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKR"))) @@ -559,7 +560,7 @@ func deployAndIssueNFTSFTThroughSC( ) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) return scAddress, tokenIdentifier diff --git a/integrationTests/vm/esdt/process/esdtProcess_test.go b/integrationTests/vm/esdt/process/esdtProcess_test.go index 8fa9fd04101..76b95987dce 100644 --- a/integrationTests/vm/esdt/process/esdtProcess_test.go +++ b/integrationTests/vm/esdt/process/esdtProcess_test.go @@ -13,6 +13,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" vmData "github.com/multiversx/mx-chain-core-go/data/vm" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" @@ -24,9 +28,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" - "github.com/stretchr/testify/require" ) func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { @@ -43,6 +44,8 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -51,11 +54,11 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -82,7 +85,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -106,7 +109,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vm.ESDTSCAddress, txData.ToString(), core.MinMetaTxExtraGasCost) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply + mintValue @@ -131,7 +134,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtFrozenData := esdtCommon.GetESDTTokenData(t, nodes[1].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0) @@ -175,6 +178,8 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, MultiClaimOnDelegationEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -184,11 +189,11 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -219,7 +224,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -233,7 +238,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply @@ -250,7 +255,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtSCAcc := esdtCommon.GetUserAccountWithAddress(t, vm.ESDTSCAddress, nodes) @@ -279,11 +284,11 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -308,7 +313,7 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -322,7 +327,7 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), nodes[0].OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, nodes[0].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) @@ -398,11 +403,11 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -428,7 +433,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -448,7 +453,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(vaultScAddress) require.Nil(t, err) @@ -461,7 +466,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vaultScAddress, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc) @@ -473,7 +478,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vaultScAddress, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc+valueToRequest) @@ -495,11 +500,11 @@ func TestESDTcallsSC(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -526,7 +531,7 @@ func TestESDTcallsSC(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -541,7 +546,7 @@ func TestESDTcallsSC(t *testing.T) { } time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) numNodesWithoutIssuer := int64(len(nodes) - 1) @@ -567,7 +572,7 @@ func TestESDTcallsSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(scAddress) require.Nil(t, err) @@ -579,7 +584,7 @@ func TestESDTcallsSC(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) scQuery1 := &process.SCQuery{ @@ -613,11 +618,11 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -643,7 +648,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -663,7 +668,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(vault) require.Nil(t, err) @@ -679,7 +684,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(forwarder) require.Nil(t, err) @@ -692,7 +697,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIssuerBalance := initialSupply - valueToSendToSc @@ -711,7 +716,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, tokenIssuerBalance) @@ -735,7 +740,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(5 * time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(5 * time.Second) tokenIssuerBalance -= valueToTransferWithExecSc @@ -750,7 +755,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(5 * time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(5 * time.Second) tokenIssuerBalance -= valueToTransferWithExecSc @@ -774,11 +779,11 @@ func TestCallbackPaymentEgld(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -804,7 +809,7 @@ func TestCallbackPaymentEgld(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -824,7 +829,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -840,7 +845,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(forwarder) require.Nil(t, err) @@ -851,7 +856,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(valueToSendToSc), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) esdtCommon.CheckNumCallBacks(t, forwarder, nodes, 1) @@ -864,7 +869,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) esdtCommon.CheckNumCallBacks(t, forwarder, nodes, 2) @@ -893,11 +898,11 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -923,7 +928,7 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -944,7 +949,7 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -962,12 +967,12 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(firstScAddress) require.Nil(t, err) - nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 2) - _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 2) + nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 2) + _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 2) } func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { @@ -985,11 +990,11 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1015,7 +1020,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1035,7 +1040,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1052,7 +1057,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1073,7 +1078,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1101,7 +1106,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // call caller sc with ESDTTransfer which will call the second sc with execute_on_dest_context @@ -1122,7 +1127,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1151,11 +1156,11 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1194,7 +1199,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1214,7 +1219,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // issue ESDT by calling exec on dest context on child contract @@ -1238,7 +1243,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin nrRoundsToPropagateMultiShard := 12 time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenID := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -1285,6 +1290,8 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, SCProcessorV2EnableEpoch: integrationTests.UnreachableEpoch, FailExecutionOnEveryAPIErrorEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } arwenVersion := config.WasmVMVersionByEpoch{Version: "v1.4"} vmConfig := &config.VirtualMachineConfig{ @@ -1299,11 +1306,11 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn vmConfig, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1328,7 +1335,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1348,7 +1355,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(mapperScAddress) require.Nil(t, err) @@ -1365,7 +1372,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1381,7 +1388,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1400,7 +1407,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1415,7 +1422,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) issueCost := big.NewInt(1000) @@ -1430,7 +1437,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) scQuery := nodes[0].SCQueryService @@ -1457,7 +1464,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToTransfer := int64(1000) @@ -1475,7 +1482,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1501,11 +1508,11 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1534,7 +1541,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1543,7 +1550,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te esdtCommon.IssueTestToken(nodes, initialSupplyWEGLD, tickerWEGLD) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifierWEGLD := string(integrationTests.GetTokenIdentifier(nodes, []byte(tickerWEGLD))) @@ -1563,7 +1570,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(mapperScAddress) require.Nil(t, err) @@ -1580,7 +1587,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1596,7 +1603,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1615,7 +1622,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1634,12 +1641,12 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddressWEGLD) require.Nil(t, err) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) issueCost := big.NewInt(1000) @@ -1654,7 +1661,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 100 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(ticker).Str(tokenIdentifier).Str("B") @@ -1668,7 +1675,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 100 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(tickerWEGLD).Str(tokenIdentifierWEGLD).Str("L") @@ -1682,7 +1689,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(tickerWEGLD).Str(tokenIdentifierWEGLD).Str("B") @@ -1696,7 +1703,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setTicker").Str(tokenIdentifier).Str(string(receiverScAddress)) @@ -1710,7 +1717,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 400, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 400, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setTicker").Str(tokenIdentifierWEGLD).Str(string(receiverScAddressWEGLD)) @@ -1761,7 +1768,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setBorrowTokenRoles").Int(3).Int(4).Int(5) @@ -1813,7 +1820,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setBorrowTokenRoles").Int(3).Int(4).Int(5) @@ -1828,7 +1835,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te // time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToTransfer := int64(1000) @@ -1846,7 +1853,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 40, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 40, nonce, round) time.Sleep(time.Second) valueToTransferWEGLD := int64(1000) @@ -1865,7 +1872,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 40, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 40, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1883,7 +1890,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 25, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 25, nonce, round) time.Sleep(time.Second) esdtBorrowBUSDData := esdtCommon.GetESDTTokenData(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdStrBorrow), 0) @@ -1906,11 +1913,11 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1938,7 +1945,7 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") alice := nodes[0] issuePrice := big.NewInt(1000) @@ -1954,14 +1961,14 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) userAccount := esdtCommon.GetUserAccountWithAddress(t, alice.OwnAccount.Address, nodes) balanceAfterTransfer := userAccount.GetBalance() nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) userAccount = esdtCommon.GetUserAccountWithAddress(t, alice.OwnAccount.Address, nodes) require.Equal(t, userAccount.GetBalance(), big.NewInt(0).Add(balanceAfterTransfer, issuePrice)) @@ -1982,6 +1989,8 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { enableEpochs := config.EnableEpochs{ GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, MaxBlockchainHookCountersEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -1990,11 +1999,11 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2032,7 +2041,7 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -2065,7 +2074,7 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 25, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 25, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-int64(numBurns)) @@ -2106,11 +2115,11 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2137,7 +2146,7 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -2158,7 +2167,7 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -2175,12 +2184,12 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[2].AccntState.GetExistingAccount(firstScAddress) require.Nil(t, err) - nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 20) - _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 20) + nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 20) + _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 20) } func transferRejectedBySecondContract( @@ -2188,7 +2197,7 @@ func transferRejectedBySecondContract( nonce, round uint64, nodes []*integrationTests.TestProcessorNode, tokenIssuer *integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, initialSupply int64, tokenIdentifier string, firstScAddress []byte, @@ -2210,7 +2219,7 @@ func transferRejectedBySecondContract( integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundToPropagate, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundToPropagate, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc) @@ -2250,11 +2259,11 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2297,7 +2306,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -2319,7 +2328,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := ownerNode.AccntState.GetExistingAccount(scAddress) require.Nil(t, err) @@ -2327,7 +2336,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { []byte(core.ESDTRoleLocalMint), } esdtCommon.SetRoles(nodes, scAddress, tokenIdentifier, roles) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) txData := txDataBuilder.NewBuilder() txData.Func("batchTransferEsdtToken") @@ -2349,7 +2358,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { integrationTests.AdditionalGasLimit, ) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) esdtCommon.CheckAddressHasTokens(t, destinationNode.OwnAccount.Address, nodes, tokenIdentifier, 0, 20) } @@ -2366,6 +2375,8 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -2375,11 +2386,11 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2406,14 +2417,14 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) userAcc := esdtCommon.GetUserAccountWithAddress(t, tokenIssuer.OwnAccount.Address, nodes) balanceBefore := userAcc.GetBalance() nrRoundsToPropagateMultiShard := 12 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) require.Equal(t, 0, len(tokenIdentifier)) diff --git a/integrationTests/vm/esdt/roles/esdtRoles_test.go b/integrationTests/vm/esdt/roles/esdtRoles_test.go index 5c117ed4edd..960a3bed393 100644 --- a/integrationTests/vm/esdt/roles/esdtRoles_test.go +++ b/integrationTests/vm/esdt/roles/esdtRoles_test.go @@ -7,12 +7,13 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/require" ) // Test scenario @@ -35,11 +36,11 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -65,7 +66,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 6 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("FTT"))) @@ -75,7 +76,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalBurn)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply.Int64()) @@ -93,7 +94,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -112,7 +113,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -141,11 +142,11 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -171,7 +172,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("FTT"))) @@ -180,14 +181,14 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalMint)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // unset special role unsetRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalMint)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply.Int64()) @@ -207,7 +208,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 7 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -215,7 +216,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalBurn)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // burn local tokens @@ -231,7 +232,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -273,11 +274,11 @@ func TestESDTMintTransferAndExecute(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -295,7 +296,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/egld-esdt-swap.wasm") + scAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/egld-esdt-swap.wasm") // issue ESDT by calling exec on dest context on child contract ticker := "DSN" @@ -316,7 +317,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -329,7 +330,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToWrap := big.NewInt(1000) @@ -346,7 +347,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { } time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for i, n := range nodes { @@ -370,7 +371,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) userAccount := esdt.GetUserAccountWithAddress(t, scAddress, nodes) @@ -387,7 +388,9 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { numMetachainNodes := 2 enableEpochs := config.EnableEpochs{ - ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -396,11 +399,11 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -431,7 +434,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -445,7 +448,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply @@ -460,7 +463,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { txData.Clear().LocalBurnESDT(tokenIdentifier, finalSupply) integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), tokenIssuer.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for _, node := range nodes { @@ -478,7 +481,9 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { numMetachainNodes := 2 enableEpochs := config.EnableEpochs{ - ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -487,11 +492,11 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -522,7 +527,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -530,7 +535,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // send tx to other nodes @@ -540,7 +545,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // send value back to the initial node @@ -550,7 +555,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { integrationTests.CreateAndSendTransaction(node, nodes, big.NewInt(0), tokenIssuer.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for _, node := range nodes[1:] { diff --git a/integrationTests/vm/staking/componentsHolderCreator.go b/integrationTests/vm/staking/componentsHolderCreator.go index e3673b08ec7..2903fb09dba 100644 --- a/integrationTests/vm/staking/componentsHolderCreator.go +++ b/integrationTests/vm/staking/componentsHolderCreator.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" "github.com/multiversx/mx-chain-go/common/forking" @@ -69,6 +70,8 @@ func createCoreComponents() factory.CoreComponentsHolder { StakingV4Step3EnableEpoch: stakingV4Step3EnableEpoch, GovernanceEnableEpoch: integrationTests.UnreachableEpoch, RefactorPeersMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(configEnableEpochs, epochNotifier) diff --git a/integrationTests/vm/systemVM/stakingSC_test.go b/integrationTests/vm/systemVM/stakingSC_test.go index 75e958f926b..c178ee0b5c3 100644 --- a/integrationTests/vm/systemVM/stakingSC_test.go +++ b/integrationTests/vm/systemVM/stakingSC_test.go @@ -10,16 +10,16 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" integrationTestsVm "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { @@ -38,6 +38,8 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + EquivalentMessagesEnableEpoch: integrationTests.UnreachableEpoch, + FixedOrderInConsensusEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -47,11 +49,11 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { enableEpochsConfig, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -87,7 +89,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -109,11 +111,11 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) // ----- wait for unbond period integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) manualSetToInactiveStateStakedPeers(t, nodes) @@ -127,7 +129,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) verifyUnbound(t, nodes) } @@ -152,18 +154,16 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := integrationTestsVm.GetNodeIndex(nodes, nds[0]) - require.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -203,7 +203,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -227,7 +227,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) roundsPerEpoch := uint64(10) for _, node := range nodes { @@ -237,7 +237,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist // ----- wait for unbound period integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) // ----- send unBound for index, node := range nodes { @@ -252,7 +252,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) verifyUnbound(t, nodes) } @@ -322,7 +322,6 @@ func TestStakeWithRewardsAddressAndValidatorStatistics(t *testing.T) { } nbBlocksToProduce := roundsPerEpoch * 3 - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodesSlice := range nodesMap { @@ -330,9 +329,8 @@ func TestStakeWithRewardsAddressAndValidatorStatistics(t *testing.T) { integrationTests.AddSelfNotarizedHeaderByMetachain(nodesSlice) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/keysManagement/managedPeersHolder.go b/keysManagement/managedPeersHolder.go index 8156b64c8eb..39f80f6bbaf 100644 --- a/keysManagement/managedPeersHolder.go +++ b/keysManagement/managedPeersHolder.go @@ -12,10 +12,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/redundancy/common" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("keysManagement") diff --git a/node/chainSimulator/components/coreComponents.go b/node/chainSimulator/components/coreComponents.go index f2bad834ad8..2d92e1dfa3e 100644 --- a/node/chainSimulator/components/coreComponents.go +++ b/node/chainSimulator/components/coreComponents.go @@ -12,7 +12,6 @@ import ( "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/epochStart/notifier" "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/ntp" @@ -148,7 +147,7 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, } instance.watchdog = &watchdog.DisabledWatchdog{} - instance.alarmScheduler = &mock.AlarmSchedulerStub{} + instance.alarmScheduler = &testscommon.AlarmSchedulerStub{} instance.syncTimer = &testscommon.SyncTimerStub{} instance.epochStartNotifierWithConfirm = notifier.NewEpochStartSubscriptionHandler() diff --git a/node/chainSimulator/components/dataComponents_test.go b/node/chainSimulator/components/dataComponents_test.go index a74f0b751f6..9bd27c36eba 100644 --- a/node/chainSimulator/components/dataComponents_test.go +++ b/node/chainSimulator/components/dataComponents_test.go @@ -3,12 +3,14 @@ package components import ( "testing" + "github.com/stretchr/testify/require" + retriever "github.com/multiversx/mx-chain-go/dataRetriever" chainStorage "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/require" ) func createArgsDataComponentsHolder() ArgsDataComponentsHolder { @@ -21,7 +23,7 @@ func createArgsDataComponentsHolder() ArgsDataComponentsHolder { }, DataPool: &dataRetriever.PoolsHolderStub{ MiniBlocksCalled: func() chainStorage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, }, InternalMarshaller: &testscommon.MarshallerStub{}, diff --git a/node/chainSimulator/components/instantBroadcastMessenger_test.go b/node/chainSimulator/components/instantBroadcastMessenger_test.go index 361caa03bbc..84770316337 100644 --- a/node/chainSimulator/components/instantBroadcastMessenger_test.go +++ b/node/chainSimulator/components/instantBroadcastMessenger_test.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/mock" errorsMx "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/require" ) @@ -22,14 +24,14 @@ func TestNewInstantBroadcastMessenger(t *testing.T) { t.Run("nil shardCoordinator should error", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, nil) + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, nil) require.Equal(t, errorsMx.ErrNilShardCoordinator, err) require.Nil(t, mes) }) t.Run("should work", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) require.NoError(t, err) require.NotNil(t, mes) }) @@ -41,7 +43,7 @@ func TestInstantBroadcastMessenger_IsInterfaceNil(t *testing.T) { var mes *instantBroadcastMessenger require.True(t, mes.IsInterfaceNil()) - mes, _ = NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) + mes, _ = NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) require.False(t, mes.IsInterfaceNil()) } @@ -60,7 +62,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { "topic_0": {[]byte("txs topic 0")}, "topic_1": {[]byte("txs topic 1")}, } - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Equal(t, providedMBs, mbs) return expectedErr // for coverage only @@ -94,7 +96,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { expectedTxs := map[string][][]byte{ "topic_0_META": {[]byte("txs topic meta")}, } - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Equal(t, expectedMBs, mbs) return nil @@ -114,7 +116,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { t.Run("shard, empty miniblocks should early exit", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Fail(t, "should have not been called") return nil diff --git a/node/chainSimulator/components/testOnlyProcessingNode.go b/node/chainSimulator/components/testOnlyProcessingNode.go index 93f1beb56da..6c799b203a6 100644 --- a/node/chainSimulator/components/testOnlyProcessingNode.go +++ b/node/chainSimulator/components/testOnlyProcessingNode.go @@ -228,7 +228,7 @@ func NewTestOnlyProcessingNode(args ArgsTestOnlyProcessingNode) (*testOnlyProces return nil, err } - err = instance.createBroadcastMessenger() + err = instance.createBroadcastMessenger(args.Configs.GeneralConfig.ConsensusGradualBroadcast) if err != nil { return nil, err } @@ -326,7 +326,7 @@ func (node *testOnlyProcessingNode) createNodesCoordinator(pref config.Preferenc return nil } -func (node *testOnlyProcessingNode) createBroadcastMessenger() error { +func (node *testOnlyProcessingNode) createBroadcastMessenger(gradualBroadcastConfig config.ConsensusGradualBroadcastConfig) error { broadcastMessenger, err := sposFactory.GetBroadcastMessenger( node.CoreComponentsHolder.InternalMarshalizer(), node.CoreComponentsHolder.Hasher(), @@ -337,6 +337,7 @@ func (node *testOnlyProcessingNode) createBroadcastMessenger() error { node.ProcessComponentsHolder.InterceptorsContainer(), node.CoreComponentsHolder.AlarmScheduler(), node.CryptoComponentsHolder.KeysHandler(), + gradualBroadcastConfig, ) if err != nil { return err diff --git a/node/chainSimulator/configs/configs.go b/node/chainSimulator/configs/configs.go index ce2cdf6b5d4..718329381e3 100644 --- a/node/chainSimulator/configs/configs.go +++ b/node/chainSimulator/configs/configs.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/common/factory" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/genesis/data" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" "github.com/multiversx/mx-chain-go/sharding" @@ -140,6 +141,10 @@ func CreateChainSimulatorConfigs(args ArgsChainSimulatorConfigs) (*ArgsConfigsSi configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[0].RoundDuration = args.RoundDurationInMillis configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[0].Hysteresis = args.Hysteresis + // TODO[Sorin]: remove this once all equivalent messages PRs are merged + configs.EpochConfig.EnableEpochs.EquivalentMessagesEnableEpoch = integrationTests.UnreachableEpoch + configs.EpochConfig.EnableEpochs.FixedOrderInConsensusEnableEpoch = integrationTests.UnreachableEpoch + node.ApplyArchCustomConfigs(configs) if args.AlterConfigsFunction != nil { diff --git a/node/chainSimulator/process/processor.go b/node/chainSimulator/process/processor.go index d8f225bfde8..1c32a1fe0c9 100644 --- a/node/chainSimulator/process/processor.go +++ b/node/chainSimulator/process/processor.go @@ -3,10 +3,10 @@ package process import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" - "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process-block") @@ -82,11 +82,11 @@ func (creator *blocksCreator) CreateNewBlock() error { return err } - validatorsGroup, err := creator.nodeHandler.GetProcessComponents().NodesCoordinator().ComputeConsensusGroup(prevRandSeed, newHeader.GetRound(), shardID, epoch) + leader, _, err := creator.nodeHandler.GetProcessComponents().NodesCoordinator().ComputeConsensusGroup(prevRandSeed, newHeader.GetRound(), shardID, epoch) if err != nil { return err } - blsKey := validatorsGroup[spos.IndexOfLeaderInConsensusGroup] + blsKey := leader isManaged := creator.nodeHandler.GetCryptoComponents().KeysHandler().IsKeyManagedByCurrentNode(blsKey.PubKey()) if !isManaged { diff --git a/node/chainSimulator/process/processor_test.go b/node/chainSimulator/process/processor_test.go index 80ffd568134..c7be21ff9cf 100644 --- a/node/chainSimulator/process/processor_test.go +++ b/node/chainSimulator/process/processor_test.go @@ -9,9 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" - mockConsensus "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/integrationTests/mock" chainSimulatorProcess "github.com/multiversx/mx-chain-go/node/chainSimulator/process" @@ -24,7 +25,6 @@ import ( testsFactory "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/require" ) var expectedErr = errors.New("expected error") @@ -221,8 +221,8 @@ func TestBlocksCreator_CreateNewBlock(t *testing.T) { }, }, NodesCoord: &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, expectedErr + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, expectedErr }, }, } @@ -515,7 +515,7 @@ func TestBlocksCreator_CreateNewBlock(t *testing.T) { nodeHandler := getNodeHandler() nodeHandler.GetBroadcastMessengerCalled = func() consensus.BroadcastMessenger { - return &mockConsensus.BroadcastMessengerMock{ + return &testsConsensus.BroadcastMessengerMock{ BroadcastHeaderCalled: func(handler data.HeaderHandler, bytes []byte) error { return expectedErr }, @@ -596,10 +596,9 @@ func getNodeHandler() *chainSimulator.NodeHandlerMock { }, }, NodesCoord: &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, 1), - }, nil + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + v := shardingMocks.NewValidatorMock([]byte("A"), 1, 1) + return v, []nodesCoordinator.Validator{v}, nil }, }, } @@ -625,7 +624,7 @@ func getNodeHandler() *chainSimulator.NodeHandlerMock { } }, GetBroadcastMessengerCalled: func() consensus.BroadcastMessenger { - return &mockConsensus.BroadcastMessengerMock{} + return &testsConsensus.BroadcastMessengerMock{} }, } } diff --git a/node/interface.go b/node/interface.go index 236e7a131e3..05330285fb6 100644 --- a/node/interface.go +++ b/node/interface.go @@ -4,8 +4,9 @@ import ( "io" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-go/update" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + + "github.com/multiversx/mx-chain-go/update" ) // NetworkShardingCollector defines the updating methods used by the network sharding component diff --git a/node/mock/headerSigVerifierStub.go b/node/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/node/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/node/nodeRunner.go b/node/nodeRunner.go index 1378007ad64..f6fa53a660e 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -20,6 +20,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/data/endProcess" outportCore "github.com/multiversx/mx-chain-core-go/data/outport" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/api/gin" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/common" @@ -61,7 +63,6 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" trieStatistics "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/update/trigger" - logger "github.com/multiversx/mx-chain-logger-go" ) type nextOperationForNode int diff --git a/node/node_test.go b/node/node_test.go index 3279683a476..37efcdd4f50 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -50,6 +50,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/consensus" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -5227,7 +5228,7 @@ func getDefaultProcessComponents() *factoryMock.ProcessComponentsMock { BlockProcess: &testscommon.BlockProcessorStub{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BootstrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensus.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, diff --git a/outport/process/outportDataProvider.go b/outport/process/outportDataProvider.go index a99e0bc4827..3c80b1db990 100644 --- a/outport/process/outportDataProvider.go +++ b/outport/process/outportDataProvider.go @@ -16,12 +16,13 @@ import ( "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("outport/process/outportDataProvider") @@ -292,7 +293,7 @@ func (odp *outportDataProvider) computeEpoch(header data.HeaderHandler) uint32 { func (odp *outportDataProvider) getSignersIndexes(header data.HeaderHandler) ([]uint64, error) { epoch := odp.computeEpoch(header) - pubKeys, err := odp.nodesCoordinator.GetConsensusValidatorsPublicKeys( + _, pubKeys, err := odp.nodesCoordinator.GetConsensusValidatorsPublicKeys( header.GetPrevRandSeed(), header.GetRound(), odp.shardID, diff --git a/outport/process/outportDataProvider_test.go b/outport/process/outportDataProvider_test.go index c240fe50ab7..3b048eadf8e 100644 --- a/outport/process/outportDataProvider_test.go +++ b/outport/process/outportDataProvider_test.go @@ -12,6 +12,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/outport/mock" "github.com/multiversx/mx-chain-go/outport/process/transactionsfee" "github.com/multiversx/mx-chain-go/testscommon" @@ -20,7 +22,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/require" ) func createArgOutportDataProvider() ArgOutportDataProvider { @@ -81,8 +82,8 @@ func TestPrepareOutportSaveBlockData(t *testing.T) { arg := createArgOutportDataProvider() arg.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) { - return nil, nil + GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) { + return "", nil, nil }, GetValidatorsIndexesCalled: func(publicKeys []string, epoch uint32) ([]uint64, error) { return []uint64{0, 1}, nil @@ -125,8 +126,8 @@ func TestOutportDataProvider_GetIntraShardMiniBlocks(t *testing.T) { arg := createArgOutportDataProvider() arg.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) { - return nil, nil + GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) { + return "", nil, nil }, GetValidatorsIndexesCalled: func(publicKeys []string, epoch uint32) ([]uint64, error) { return []uint64{0, 1}, nil diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 0e3c573b23d..4f2a3661ece 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -123,6 +123,8 @@ type baseProcessor struct { mutNonceOfFirstCommittedBlock sync.RWMutex nonceOfFirstCommittedBlock core.OptionalUint64 extraDelayRequestBlockInfo time.Duration + + proofsPool dataRetriever.ProofsPool } type bootStorerDataArgs struct { @@ -358,6 +360,14 @@ func displayHeader(headerHandler data.HeaderHandler) []*display.LineData { if !check.IfNil(additionalData) { scheduledRootHash = additionalData.GetScheduledRootHash() } + + proof := headerHandler.GetPreviousProof() + + var prevAggregatedSig, prevBitmap []byte + if !check.IfNilReflect(proof) { + prevAggregatedSig, prevBitmap = proof.GetAggregatedSignature(), proof.GetPubKeysBitmap() + } + return []*display.LineData{ display.NewLineData(false, []string{ "", @@ -419,10 +429,18 @@ func displayHeader(headerHandler data.HeaderHandler) []*display.LineData { "", "Receipts hash", logger.DisplayByteSlice(headerHandler.GetReceiptsHash())}), - display.NewLineData(true, []string{ + display.NewLineData(false, []string{ "", "Epoch start meta hash", logger.DisplayByteSlice(epochStartMetaHash)}), + display.NewLineData(false, []string{ + "Previous proof", + "Aggregated signature", + logger.DisplayByteSlice(prevAggregatedSig)}), + display.NewLineData(true, []string{ + "", + "Pub keys bitmap", + logger.DisplayByteSlice(prevBitmap)}), } } @@ -597,15 +615,19 @@ func (bp *baseProcessor) verifyFees(header data.HeaderHandler) error { } // TODO: remove bool parameter and give instead the set to sort -func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) map[uint32][]data.HeaderHandler { +func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) (map[uint32][]data.HeaderHandler, error) { hdrsForCurrentBlock := make(map[uint32][]data.HeaderHandler) bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() - for _, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { + for hdrHash, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { if headerInfo.usedInBlock != usedInBlock { continue } + if bp.hasMissingProof(headerInfo, hdrHash) { + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hdrHash) + } + hdrsForCurrentBlock[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlock[headerInfo.hdr.GetShardID()], headerInfo.hdr) } bp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() @@ -615,10 +637,10 @@ func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) map process.SortHeadersByNonce(hdrsForShard) } - return hdrsForCurrentBlock + return hdrsForCurrentBlock, nil } -func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool) map[uint32][][]byte { +func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool) (map[uint32][][]byte, error) { hdrsForCurrentBlockInfo := make(map[uint32][]*nonceAndHashInfo) bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() @@ -627,6 +649,10 @@ func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool continue } + if bp.hasMissingProof(headerInfo, metaBlockHash) { + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, metaBlockHash) + } + hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()], &nonceAndHashInfo{nonce: headerInfo.hdr.GetNonce(), hash: []byte(metaBlockHash)}) } @@ -647,7 +673,13 @@ func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool } } - return hdrsHashesForCurrentBlock + return hdrsHashesForCurrentBlock, nil +} + +func (bp *baseProcessor) hasMissingProof(headerInfo *hdrInfo, hdrHash string) bool { + isFlagEnabledForHeader := bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) + hasProof := bp.proofsPool.HasProof(headerInfo.hdr.GetShardID(), []byte(hdrHash)) + return isFlagEnabledForHeader && !hasProof } func (bp *baseProcessor) createMiniBlockHeaderHandlers( @@ -957,7 +989,18 @@ func (bp *baseProcessor) cleanupPools(headerHandler data.HeaderHandler) { bp.removeHeadersBehindNonceFromPools( true, bp.shardCoordinator.SelfId(), - highestPrevFinalBlockNonce) + highestPrevFinalBlockNonce, + ) + + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerHandler.GetEpoch()) { + err := bp.dataPool.Proofs().CleanupProofsBehindNonce(bp.shardCoordinator.SelfId(), highestPrevFinalBlockNonce) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", bp.shardCoordinator.SelfId(), + "error", err) + } + } if bp.shardCoordinator.SelfId() == core.MetachainShardId { for shardID := uint32(0); shardID < bp.shardCoordinator.NumberOfShards(); shardID++ { @@ -966,6 +1009,7 @@ func (bp *baseProcessor) cleanupPools(headerHandler data.HeaderHandler) { } else { bp.cleanupPoolsForCrossShard(core.MetachainShardId, noncesToPrevFinal) } + } func (bp *baseProcessor) cleanupPoolsForCrossShard( @@ -986,6 +1030,16 @@ func (bp *baseProcessor) cleanupPoolsForCrossShard( shardID, crossNotarizedHeader.GetNonce(), ) + + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, crossNotarizedHeader.GetEpoch()) { + err = bp.dataPool.Proofs().CleanupProofsBehindNonce(shardID, noncesToPrevFinal) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", shardID, + "error", err) + } + } } func (bp *baseProcessor) removeHeadersBehindNonceFromPools( @@ -2119,7 +2173,7 @@ func (bp *baseProcessor) setNonceOfFirstCommittedBlock(nonce uint64) { } func (bp *baseProcessor) checkSentSignaturesAtCommitTime(header data.HeaderHandler) error { - validatorsGroup, err := headerCheck.ComputeConsensusGroup(header, bp.nodesCoordinator) + _, validatorsGroup, err := headerCheck.ComputeConsensusGroup(header, bp.nodesCoordinator) if err != nil { return err } diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index f88d7e8d667..017f7b3e1d0 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -24,10 +24,14 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" blproc "github.com/multiversx/mx-chain-go/process/block" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -40,6 +44,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" @@ -55,8 +60,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -160,7 +163,7 @@ func createShardedDataChacherNotifier( return func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, testHash) { return handler, true @@ -207,7 +210,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { UnsignedTransactionsCalled: unsignedTxCalled, RewardTransactionsCalled: rewardTransactionsCalled, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -234,7 +237,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -284,6 +287,9 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } return cs }, + ProofsCalled: func() dataRetriever.ProofsPool { + return proofscache.NewProofsPool() + }, } return sdp @@ -3121,8 +3127,8 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { expectedErr := errors.New("expected error") t.Run("nodes coordinator errors, should return error", func(t *testing.T) { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, expectedErr + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, expectedErr } arguments := CreateMockArguments(createComponentHolderMocks()) @@ -3134,7 +3140,10 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { arguments.NodesCoordinator = nodesCoordinatorInstance bp, _ := blproc.NewShardProcessor(arguments) - err := bp.CheckSentSignaturesAtCommitTime(&block.Header{}) + err := bp.CheckSentSignaturesAtCommitTime(&block.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + }) assert.Equal(t, expectedErr, err) }) t.Run("should work with bitmap", func(t *testing.T) { @@ -3143,8 +3152,8 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { validator2, _ := nodesCoordinator.NewValidator([]byte("pk2"), 2, 2) nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{validator0, validator1, validator2}, nil + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator0, []nodesCoordinator.Validator{validator0, validator1, validator2}, nil } resetCountersCalled := make([][]byte, 0) @@ -3158,6 +3167,8 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { bp, _ := blproc.NewShardProcessor(arguments) err := bp.CheckSentSignaturesAtCommitTime(&block.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), PubKeysBitmap: []byte{0b00000101}, }) assert.Nil(t, err) diff --git a/process/block/headerValidator.go b/process/block/headerValidator.go index b39787c7a96..9459280c847 100644 --- a/process/block/headerValidator.go +++ b/process/block/headerValidator.go @@ -87,6 +87,8 @@ func (h *headerValidator) IsHeaderConstructionValid(currHeader, prevHeader data. return process.ErrRandSeedDoesNotMatch } + // TODO: check here if proof from currHeader is valid for prevHeader + return nil } diff --git a/process/block/interceptedBlocks/argInterceptedBlockHeader.go b/process/block/interceptedBlocks/argInterceptedBlockHeader.go index 50d5b2be82f..7e493d8b311 100644 --- a/process/block/interceptedBlocks/argInterceptedBlockHeader.go +++ b/process/block/interceptedBlocks/argInterceptedBlockHeader.go @@ -3,6 +3,7 @@ package interceptedBlocks import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -17,4 +18,5 @@ type ArgInterceptedBlockHeader struct { HeaderIntegrityVerifier process.HeaderIntegrityVerifier ValidityAttester process.ValidityAttester EpochStartTrigger process.EpochStartTriggerHandler + EnableEpochsHandler common.EnableEpochsHandler } diff --git a/process/block/interceptedBlocks/common.go b/process/block/interceptedBlocks/common.go index f3d3f1e393f..90a604dba23 100644 --- a/process/block/interceptedBlocks/common.go +++ b/process/block/interceptedBlocks/common.go @@ -1,9 +1,13 @@ package interceptedBlocks import ( + "sync" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -39,6 +43,9 @@ func checkBlockHeaderArgument(arg *ArgInterceptedBlockHeader) error { if check.IfNil(arg.ValidityAttester) { return process.ErrNilValidityAttester } + if check.IfNil(arg.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } return nil } @@ -63,14 +70,16 @@ func checkMiniblockArgument(arg *ArgInterceptedMiniblock) error { return nil } -func checkHeaderHandler(hdr data.HeaderHandler) error { - if len(hdr.GetPubKeysBitmap()) == 0 { +func checkHeaderHandler(hdr data.HeaderHandler, enableEpochsHandler common.EnableEpochsHandler) error { + equivalentMessagesEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.GetEpoch()) + + if len(hdr.GetPubKeysBitmap()) == 0 && !equivalentMessagesEnabled { return process.ErrNilPubKeysBitmap } if len(hdr.GetPrevHash()) == 0 { return process.ErrNilPreviousBlockHash } - if len(hdr.GetSignature()) == 0 { + if len(hdr.GetSignature()) == 0 && !equivalentMessagesEnabled { return process.ErrNilSignature } if len(hdr.GetRootHash()) == 0 { @@ -83,10 +92,43 @@ func checkHeaderHandler(hdr data.HeaderHandler) error { return process.ErrNilPrevRandSeed } + err := checkProofIntegrity(hdr, enableEpochsHandler) + if err != nil { + return err + } + return hdr.CheckFieldsForNil() } -func checkMetaShardInfo(shardInfo []data.ShardDataHandler, coordinator sharding.Coordinator) error { +func checkProofIntegrity(hdr data.HeaderHandler, enableEpochsHandler common.EnableEpochsHandler) error { + equivalentMessagesEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.GetEpoch()) + + prevHeaderProof := hdr.GetPreviousProof() + nilPreviousProof := check.IfNilReflect(prevHeaderProof) + missingProof := nilPreviousProof && equivalentMessagesEnabled + unexpectedProof := !nilPreviousProof && !equivalentMessagesEnabled + hasProof := !nilPreviousProof && equivalentMessagesEnabled + + if missingProof { + return process.ErrMissingHeaderProof + } + if unexpectedProof { + return process.ErrUnexpectedHeaderProof + } + if hasProof && isIncompleteProof(prevHeaderProof) { + return process.ErrInvalidHeaderProof + } + + return nil +} + +func checkMetaShardInfo( + shardInfo []data.ShardDataHandler, + coordinator sharding.Coordinator, + headerSigVerifier process.InterceptedHeaderSigVerifier, +) error { + wgProofsVerification := sync.WaitGroup{} + errChan := make(chan error, len(shardInfo)) for _, sd := range shardInfo { if sd.GetShardID() >= coordinator.NumberOfShards() && sd.GetShardID() != core.MetachainShardId { return process.ErrInvalidShardId @@ -96,9 +138,49 @@ func checkMetaShardInfo(shardInfo []data.ShardDataHandler, coordinator sharding. if err != nil { return err } + + wgProofsVerification.Add(1) + checkProofAsync(sd.GetPreviousProof(), headerSigVerifier, &wgProofsVerification, errChan) } - return nil + wgProofsVerification.Wait() + close(errChan) + + return <-errChan +} + +func checkProofAsync( + proof data.HeaderProofHandler, + headerSigVerifier process.InterceptedHeaderSigVerifier, + wg *sync.WaitGroup, + errChan chan error, +) { + go func(proof data.HeaderProofHandler) { + errCheckProof := checkProof(proof, headerSigVerifier) + if errCheckProof != nil { + errChan <- errCheckProof + } + + wg.Done() + }(proof) +} + +func checkProof(proof data.HeaderProofHandler, headerSigVerifier process.InterceptedHeaderSigVerifier) error { + if check.IfNilReflect(proof) { + return nil + } + + if isIncompleteProof(proof) { + return process.ErrInvalidHeaderProof + } + + return headerSigVerifier.VerifyHeaderProof(proof) +} + +func isIncompleteProof(proof data.HeaderProofHandler) bool { + return len(proof.GetAggregatedSignature()) == 0 || + len(proof.GetPubKeysBitmap()) == 0 || + len(proof.GetHeaderHash()) == 0 } func checkShardData(sd data.ShardDataHandler, coordinator sharding.Coordinator) error { diff --git a/process/block/interceptedBlocks/common_test.go b/process/block/interceptedBlocks/common_test.go index 02be37e9bde..321a41b6217 100644 --- a/process/block/interceptedBlocks/common_test.go +++ b/process/block/interceptedBlocks/common_test.go @@ -2,6 +2,8 @@ package interceptedBlocks import ( "errors" + "strconv" + "sync" "testing" "github.com/multiversx/mx-chain-core-go/data" @@ -9,6 +11,8 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/stretchr/testify/assert" ) @@ -19,10 +23,11 @@ func createDefaultBlockHeaderArgument() *ArgInterceptedBlockHeader { Hasher: &hashingMocks.HasherMock{}, Marshalizer: &mock.MarshalizerMock{}, HdrBuff: []byte("test buffer"), - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, ValidityAttester: &mock.ValidityAttesterStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return arg @@ -138,6 +143,17 @@ func TestCheckBlockHeaderArgument_NilValidityAttesterShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilValidityAttester, err) } +func TestCheckBlockHeaderArgument_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + arg := createDefaultBlockHeaderArgument() + arg.EnableEpochsHandler = nil + + err := checkBlockHeaderArgument(arg) + + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + func TestCheckBlockHeaderArgument_ShouldWork(t *testing.T) { t.Parallel() @@ -222,7 +238,7 @@ func TestCheckHeaderHandler_NilPubKeysBitmapShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPubKeysBitmap, err) } @@ -235,7 +251,7 @@ func TestCheckHeaderHandler_NilPrevHashShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPreviousBlockHash, err) } @@ -248,7 +264,7 @@ func TestCheckHeaderHandler_NilSignatureShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilSignature, err) } @@ -261,7 +277,7 @@ func TestCheckHeaderHandler_NilRootHashErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilRootHash, err) } @@ -274,7 +290,7 @@ func TestCheckHeaderHandler_NilRandSeedErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilRandSeed, err) } @@ -287,7 +303,7 @@ func TestCheckHeaderHandler_NilPrevRandSeedErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPrevRandSeed, err) } @@ -301,7 +317,7 @@ func TestCheckHeaderHandler_CheckFieldsForNilErrors(t *testing.T) { return expectedErr } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, expectedErr, err) } @@ -311,7 +327,7 @@ func TestCheckHeaderHandler_ShouldWork(t *testing.T) { hdr := createDefaultHeaderHandler() - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Nil(t, err) } @@ -323,8 +339,8 @@ func TestCheckMetaShardInfo_WithNilOrEmptyShouldReturnNil(t *testing.T) { shardCoordinator := mock.NewOneShardCoordinatorMock() - err1 := checkMetaShardInfo(nil, shardCoordinator) - err2 := checkMetaShardInfo(make([]data.ShardDataHandler, 0), shardCoordinator) + err1 := checkMetaShardInfo(nil, shardCoordinator, &consensus.HeaderSigVerifierMock{}) + err2 := checkMetaShardInfo(make([]data.ShardDataHandler, 0), shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Nil(t, err1) assert.Nil(t, err2) @@ -342,7 +358,7 @@ func TestCheckMetaShardInfo_WrongShardIdShouldErr(t *testing.T) { TxCount: 0, } - err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Equal(t, process.ErrInvalidShardId, err) } @@ -366,7 +382,7 @@ func TestCheckMetaShardInfo_WrongMiniblockSenderShardIdShouldErr(t *testing.T) { TxCount: 0, } - err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Equal(t, process.ErrInvalidShardId, err) } @@ -390,7 +406,7 @@ func TestCheckMetaShardInfo_WrongMiniblockReceiverShardIdShouldErr(t *testing.T) TxCount: 0, } - err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Equal(t, process.ErrInvalidShardId, err) } @@ -414,7 +430,7 @@ func TestCheckMetaShardInfo_ReservedPopulatedShouldErr(t *testing.T) { TxCount: 0, } - err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Equal(t, process.ErrReservedFieldInvalid, err) } @@ -437,15 +453,62 @@ func TestCheckMetaShardInfo_OkValsShouldWork(t *testing.T) { TxCount: 0, } - err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Nil(t, err) miniBlock.Reserved = []byte("r") sd.ShardMiniBlockHeaders = []block.MiniBlockHeader{miniBlock} - err = checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + err = checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator, &consensus.HeaderSigVerifierMock{}) assert.Nil(t, err) } +func TestCheckMetaShardInfo_FewShardDataErrorShouldReturnError(t *testing.T) { + t.Parallel() + + shardCoordinator := mock.NewOneShardCoordinatorMock() + miniBlock := block.MiniBlockHeader{ + Hash: make([]byte, 0), + ReceiverShardID: shardCoordinator.SelfId(), + SenderShardID: shardCoordinator.SelfId(), + TxCount: 0, + } + + calledCnt := 0 + mutCalled := sync.Mutex{} + providedRandomError := errors.New("random error") + sigVerifier := &consensus.HeaderSigVerifierMock{ + VerifyHeaderProofCalled: func(proofHandler data.HeaderProofHandler) error { + mutCalled.Lock() + defer mutCalled.Unlock() + + calledCnt++ + if calledCnt%5 == 0 { + return providedRandomError + } + + return nil + }, + } + + numShardData := 1000 + shardData := make([]data.ShardDataHandler, numShardData) + for i := 0; i < numShardData; i++ { + shardData[i] = &block.ShardData{ + ShardID: shardCoordinator.SelfId(), + HeaderHash: []byte("hash" + strconv.Itoa(i)), + ShardMiniBlockHeaders: []block.MiniBlockHeader{miniBlock}, + PreviousShardHeaderProof: &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig" + strconv.Itoa(i)), + HeaderHash: []byte("hash" + strconv.Itoa(i)), + }, + } + } + + err := checkMetaShardInfo(shardData, shardCoordinator, sigVerifier) + assert.Equal(t, providedRandomError, err) +} + //------- checkMiniBlocksHeaders func TestCheckMiniBlocksHeaders_WithNilOrEmptyShouldReturnNil(t *testing.T) { diff --git a/process/block/interceptedBlocks/errors.go b/process/block/interceptedBlocks/errors.go new file mode 100644 index 00000000000..afd3f50cf03 --- /dev/null +++ b/process/block/interceptedBlocks/errors.go @@ -0,0 +1,8 @@ +package interceptedBlocks + +import "errors" + +var ( + // ErrInvalidProof signals that an invalid proof has been provided + ErrInvalidProof = errors.New("invalid proof") +) diff --git a/process/block/interceptedBlocks/interceptedBlockHeader.go b/process/block/interceptedBlocks/interceptedBlockHeader.go index 81d78bef5c0..cde4be46170 100644 --- a/process/block/interceptedBlocks/interceptedBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedBlockHeader.go @@ -6,9 +6,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.HdrValidatorHandler = (*InterceptedHeader)(nil) @@ -17,15 +19,16 @@ var _ process.InterceptedData = (*InterceptedHeader)(nil) // InterceptedHeader represents the wrapper over HeaderWrapper struct. // It implements Newer and Hashed interfaces type InterceptedHeader struct { - hdr data.HeaderHandler - sigVerifier process.InterceptedHeaderSigVerifier - integrityVerifier process.HeaderIntegrityVerifier - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - hash []byte - isForCurrentShard bool - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + hdr data.HeaderHandler + sigVerifier process.InterceptedHeaderSigVerifier + integrityVerifier process.HeaderIntegrityVerifier + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + hash []byte + isForCurrentShard bool + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler } // NewInterceptedHeader creates a new instance of InterceptedHeader struct @@ -41,13 +44,14 @@ func NewInterceptedHeader(arg *ArgInterceptedBlockHeader) (*InterceptedHeader, e } inHdr := &InterceptedHeader{ - hdr: hdr, - hasher: arg.Hasher, - sigVerifier: arg.HeaderSigVerifier, - integrityVerifier: arg.HeaderIntegrityVerifier, - shardCoordinator: arg.ShardCoordinator, - validityAttester: arg.ValidityAttester, - epochStartTrigger: arg.EpochStartTrigger, + hdr: hdr, + hasher: arg.Hasher, + sigVerifier: arg.HeaderSigVerifier, + integrityVerifier: arg.HeaderIntegrityVerifier, + shardCoordinator: arg.ShardCoordinator, + validityAttester: arg.ValidityAttester, + epochStartTrigger: arg.EpochStartTrigger, + enableEpochsHandler: arg.EnableEpochsHandler, } inHdr.processFields(arg.HdrBuff) @@ -74,7 +78,11 @@ func (inHdr *InterceptedHeader) CheckValidity() error { return err } - err = inHdr.sigVerifier.VerifyRandSeedAndLeaderSignature(inHdr.hdr) + return inHdr.verifySignatures() +} + +func (inHdr *InterceptedHeader) verifySignatures() error { + err := inHdr.sigVerifier.VerifyRandSeedAndLeaderSignature(inHdr.hdr) if err != nil { return err } @@ -121,7 +129,7 @@ func (inHdr *InterceptedHeader) integrity() error { inHdr.epochStartTrigger.EpochFinalityAttestingRound()) } - err := checkHeaderHandler(inHdr.HeaderHandler()) + err := checkHeaderHandler(inHdr.HeaderHandler(), inHdr.enableEpochsHandler) if err != nil { return err } diff --git a/process/block/interceptedBlocks/interceptedBlockHeader_test.go b/process/block/interceptedBlocks/interceptedBlockHeader_test.go index a107e01dc3e..bb58691cd38 100644 --- a/process/block/interceptedBlocks/interceptedBlockHeader_test.go +++ b/process/block/interceptedBlocks/interceptedBlockHeader_test.go @@ -8,14 +8,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var testMarshalizer = &mock.MarshalizerMock{} @@ -30,10 +35,11 @@ func createDefaultShardArgument() *interceptedBlocks.ArgInterceptedBlockHeader { ShardCoordinator: mock.NewOneShardCoordinatorMock(), Hasher: testHasher, Marshalizer: testMarshalizer, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, ValidityAttester: &mock.ValidityAttesterStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } hdr := createMockShardHeader() @@ -47,10 +53,11 @@ func createDefaultShardArgumentWithV2Support() *interceptedBlocks.ArgIntercepted ShardCoordinator: mock.NewOneShardCoordinatorMock(), Hasher: testHasher, Marshalizer: &marshal.GogoProtoMarshalizer{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, ValidityAttester: &mock.ValidityAttesterStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } hdr := createMockShardHeader() arg.HdrBuff, _ = arg.Marshalizer.Marshal(hdr) @@ -83,7 +90,7 @@ func createMockShardHeader() *dataBlock.Header { } } -//------- TestNewInterceptedHeader +// ------- TestNewInterceptedHeader func TestNewInterceptedHeader_NilArgumentShouldErr(t *testing.T) { t.Parallel() @@ -167,7 +174,7 @@ func TestNewInterceptedHeader_MetachainForThisShardShouldWork(t *testing.T) { assert.True(t, inHdr.IsForCurrentShard()) } -//------- CheckValidity +// ------- Verify func TestInterceptedHeader_CheckValidityNilPubKeyBitmapShouldErr(t *testing.T) { t.Parallel() @@ -194,7 +201,7 @@ func TestInterceptedHeader_CheckValidityLeaderSignatureNotCorrectShouldErr(t *te expectedErr := errors.New("expected err") buff, _ := marshaller.Marshal(hdr) - arg.HeaderSigVerifier = &mock.HeaderSigVerifierStub{ + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ VerifyRandSeedAndLeaderSignatureCalled: func(header data.HeaderHandler) error { return expectedErr }, @@ -226,6 +233,53 @@ func TestInterceptedHeader_CheckValidityLeaderSignatureOkShouldWork(t *testing.T assert.Nil(t, err) } +func TestInterceptedHeader_CheckValidityLeaderSignatureOkWithFlagActiveShouldWork(t *testing.T) { + t.Parallel() + + headerHash := []byte("header hash") + arg := createDefaultShardArgumentWithV2Support() + arg.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + wasVerifySignatureCalled := false + providedPrevBitmap := []byte{1, 1, 1, 1} + providedPrevSig := []byte("provided sig") + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ + VerifySignatureCalled: func(header data.HeaderHandler) error { + wasVerifySignatureCalled = true + proof := header.GetPreviousProof() + prevSig, prevBitmap := proof.GetAggregatedSignature(), proof.GetPubKeysBitmap() + assert.Equal(t, providedPrevBitmap, prevBitmap) + assert.Equal(t, providedPrevSig, prevSig) + return nil + }, + } + marshaller := arg.Marshalizer + hdr := &dataBlock.HeaderV2{ + Header: createMockShardHeader(), + ScheduledRootHash: []byte("root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + PreviousHeaderProof: &block.HeaderProof{ + PubKeysBitmap: providedPrevBitmap, + AggregatedSignature: providedPrevSig, + HeaderHash: headerHash, + }, + } + buff, _ := marshaller.Marshal(hdr) + + arg.HdrBuff = buff + inHdr, err := interceptedBlocks.NewInterceptedHeader(arg) + require.Nil(t, err) + require.NotNil(t, inHdr) + + err = inHdr.CheckValidity() + assert.Nil(t, err) + assert.True(t, wasVerifySignatureCalled) +} + func TestInterceptedHeader_ErrorInMiniBlockShouldErr(t *testing.T) { t.Parallel() @@ -305,7 +359,7 @@ func TestInterceptedHeader_CheckAgainstFinalHeaderErrorsShouldErr(t *testing.T) assert.Equal(t, expectedErr, err) } -//------- getters +// ------- getters func TestInterceptedHeader_Getters(t *testing.T) { t.Parallel() @@ -318,7 +372,7 @@ func TestInterceptedHeader_Getters(t *testing.T) { assert.Equal(t, hash, inHdr.Hash()) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestInterceptedHeader_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof.go b/process/block/interceptedBlocks/interceptedEquivalentProof.go new file mode 100644 index 00000000000..a7937a5aef2 --- /dev/null +++ b/process/block/interceptedBlocks/interceptedEquivalentProof.go @@ -0,0 +1,171 @@ +package interceptedBlocks + +import ( + "fmt" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/dataRetriever" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" + logger "github.com/multiversx/mx-chain-logger-go" +) + +const interceptedEquivalentProofType = "intercepted equivalent proof" + +// ArgInterceptedEquivalentProof is the argument used in the intercepted equivalent proof constructor +type ArgInterceptedEquivalentProof struct { + DataBuff []byte + Marshaller marshal.Marshalizer + ShardCoordinator sharding.Coordinator + HeaderSigVerifier consensus.HeaderSigVerifier + Proofs dataRetriever.ProofsPool +} + +type interceptedEquivalentProof struct { + proof *block.HeaderProof + isForCurrentShard bool + headerSigVerifier consensus.HeaderSigVerifier + proofsPool dataRetriever.ProofsPool +} + +// NewInterceptedEquivalentProof returns a new instance of interceptedEquivalentProof +func NewInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) (*interceptedEquivalentProof, error) { + err := checkArgInterceptedEquivalentProof(args) + if err != nil { + return nil, err + } + + equivalentProof, err := createEquivalentProof(args.Marshaller, args.DataBuff) + if err != nil { + return nil, err + } + + return &interceptedEquivalentProof{ + proof: equivalentProof, + isForCurrentShard: extractIsForCurrentShard(args.ShardCoordinator, equivalentProof), + headerSigVerifier: args.HeaderSigVerifier, + proofsPool: args.Proofs, + }, nil +} + +func checkArgInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) error { + if len(args.DataBuff) == 0 { + return process.ErrNilBuffer + } + if check.IfNil(args.Marshaller) { + return process.ErrNilMarshalizer + } + if check.IfNil(args.ShardCoordinator) { + return process.ErrNilShardCoordinator + } + if check.IfNil(args.HeaderSigVerifier) { + return process.ErrNilHeaderSigVerifier + } + if check.IfNil(args.Proofs) { + return process.ErrNilProofsPool + } + + return nil +} + +func createEquivalentProof(marshaller marshal.Marshalizer, buff []byte) (*block.HeaderProof, error) { + headerProof := &block.HeaderProof{} + err := marshaller.Unmarshal(headerProof, buff) + if err != nil { + return nil, err + } + + log.Trace("interceptedEquivalentProof successfully created", + "header hash", logger.DisplayByteSlice(headerProof.HeaderHash), + "header shard", headerProof.HeaderShardId, + "header epoch", headerProof.HeaderEpoch, + "header nonce", headerProof.HeaderNonce, + "bitmap", logger.DisplayByteSlice(headerProof.PubKeysBitmap), + "signature", logger.DisplayByteSlice(headerProof.AggregatedSignature), + ) + + return headerProof, nil +} + +func extractIsForCurrentShard(shardCoordinator sharding.Coordinator, equivalentProof *block.HeaderProof) bool { + proofShardId := equivalentProof.GetHeaderShardId() + if proofShardId == core.MetachainShardId { + return true + } + + return proofShardId == shardCoordinator.SelfId() +} + +// CheckValidity checks if the received proof is valid +func (iep *interceptedEquivalentProof) CheckValidity() error { + err := iep.integrity() + if err != nil { + return err + } + + ok := iep.proofsPool.HasProof(iep.proof.GetHeaderShardId(), iep.proof.GetHeaderHash()) + if ok { + return proofscache.ErrAlreadyExistingEquivalentProof + } + + return iep.headerSigVerifier.VerifyHeaderProof(iep.proof) +} + +func (iep *interceptedEquivalentProof) integrity() error { + isProofValid := len(iep.proof.AggregatedSignature) > 0 && + len(iep.proof.PubKeysBitmap) > 0 && + len(iep.proof.HeaderHash) > 0 + if !isProofValid { + return ErrInvalidProof + } + + return nil +} + +// GetProof returns the underlying intercepted header proof +func (iep *interceptedEquivalentProof) GetProof() data.HeaderProofHandler { + return iep.proof +} + +// IsForCurrentShard returns true if the equivalent proof should be processed by the current shard +func (iep *interceptedEquivalentProof) IsForCurrentShard() bool { + return iep.isForCurrentShard +} + +// Hash returns the header hash the proof belongs to +func (iep *interceptedEquivalentProof) Hash() []byte { + return iep.proof.HeaderHash +} + +// Type returns the type of this intercepted data +func (iep *interceptedEquivalentProof) Type() string { + return interceptedEquivalentProofType +} + +// Identifiers returns the identifiers used in requests +func (iep *interceptedEquivalentProof) Identifiers() [][]byte { + return [][]byte{iep.proof.HeaderHash} +} + +// String returns the proof's most important fields as string +func (iep *interceptedEquivalentProof) String() string { + return fmt.Sprintf("bitmap=%s, signature=%s, hash=%s, epoch=%d, shard=%d, nonce=%d", + logger.DisplayByteSlice(iep.proof.PubKeysBitmap), + logger.DisplayByteSlice(iep.proof.AggregatedSignature), + logger.DisplayByteSlice(iep.proof.HeaderHash), + iep.proof.HeaderEpoch, + iep.proof.HeaderShardId, + iep.proof.HeaderNonce, + ) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (iep *interceptedEquivalentProof) IsInterfaceNil() bool { + return iep == nil +} diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go new file mode 100644 index 00000000000..b0a8cd6c9c9 --- /dev/null +++ b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go @@ -0,0 +1,260 @@ +package interceptedBlocks + +import ( + "bytes" + "errors" + "fmt" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/consensus/mock" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" +) + +var ( + expectedErr = errors.New("expected error") + testMarshaller = &marshallerMock.MarshalizerMock{} +) + +func createMockDataBuff() []byte { + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + } + + dataBuff, _ := testMarshaller.Marshal(proof) + return dataBuff +} + +func createMockArgInterceptedEquivalentProof() ArgInterceptedEquivalentProof { + return ArgInterceptedEquivalentProof{ + DataBuff: createMockDataBuff(), + Marshaller: testMarshaller, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + Proofs: &dataRetriever.ProofsPoolMock{}, + } +} + +func TestInterceptedEquivalentProof_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var iep *interceptedEquivalentProof + require.True(t, iep.IsInterfaceNil()) + + iep, _ = NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.False(t, iep.IsInterfaceNil()) +} + +func TestNewInterceptedEquivalentProof(t *testing.T) { + t.Parallel() + + t.Run("nil DataBuff should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.DataBuff = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilBuffer, err) + require.Nil(t, iep) + }) + t.Run("nil Marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Marshaller = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilMarshalizer, err) + require.Nil(t, iep) + }) + t.Run("nil ShardCoordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.ShardCoordinator = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilShardCoordinator, err) + require.Nil(t, iep) + }) + t.Run("nil HeaderSigVerifier should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.HeaderSigVerifier = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHeaderSigVerifier, err) + require.Nil(t, iep) + }) + t.Run("nil proofs pool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Proofs = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilProofsPool, err) + require.Nil(t, iep) + }) + t.Run("unmarshal error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Marshaller = &marshallerMock.MarshalizerStub{ + UnmarshalCalled: func(obj interface{}, buff []byte) error { + return expectedErr + }, + } + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, expectedErr, err) + require.Nil(t, iep) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + iep, err := NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.NoError(t, err) + require.NotNil(t, iep) + }) +} + +func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) { + t.Parallel() + + t.Run("invalid proof should error", func(t *testing.T) { + t.Parallel() + + // no header hash + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.Equal(t, ErrInvalidProof, err) + }) + + t.Run("already exiting proof should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Proofs = &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.Equal(t, proofscache.ErrAlreadyExistingEquivalentProof, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + iep, err := NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.NoError(t, err) + + err = iep.CheckValidity() + require.NoError(t, err) + }) +} + +func TestInterceptedEquivalentProof_IsForCurrentShard(t *testing.T) { + t.Parallel() + + t.Run("meta should return true", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: core.MetachainShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ShardCoordinator = &mock.ShardCoordinatorMock{ShardID: core.MetachainShardId} + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.True(t, iep.IsForCurrentShard()) + }) + t.Run("self shard id return true", func(t *testing.T) { + t.Parallel() + + selfShardId := uint32(1234) + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: selfShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ShardCoordinator = &mock.ShardCoordinatorMock{ShardID: selfShardId} + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.True(t, iep.IsForCurrentShard()) + }) + t.Run("other shard id return true", func(t *testing.T) { + t.Parallel() + + selfShardId := uint32(1234) + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: selfShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.False(t, iep.IsForCurrentShard()) + }) +} + +func TestInterceptedEquivalentProof_Getters(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.Equal(t, proof, iep.GetProof()) // pointer testing + require.True(t, bytes.Equal(proof.HeaderHash, iep.Hash())) + require.Equal(t, [][]byte{proof.HeaderHash}, iep.Identifiers()) + require.Equal(t, interceptedEquivalentProofType, iep.Type()) + expectedStr := fmt.Sprintf("bitmap=%s, signature=%s, hash=%s, epoch=123, shard=0, nonce=345", + logger.DisplayByteSlice(proof.PubKeysBitmap), + logger.DisplayByteSlice(proof.AggregatedSignature), + logger.DisplayByteSlice(proof.HeaderHash)) + require.Equal(t, expectedStr, iep.String()) +} diff --git a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go index 415e2da3967..c3f92781e7e 100644 --- a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go @@ -8,9 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.HdrValidatorHandler = (*InterceptedMetaHeader)(nil) @@ -20,14 +22,15 @@ var log = logger.GetOrCreate("process/block/interceptedBlocks") // InterceptedMetaHeader represents the wrapper over the meta block header struct type InterceptedMetaHeader struct { - hdr data.MetaHeaderHandler - sigVerifier process.InterceptedHeaderSigVerifier - integrityVerifier process.HeaderIntegrityVerifier - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - hash []byte - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + hdr data.MetaHeaderHandler + sigVerifier process.InterceptedHeaderSigVerifier + integrityVerifier process.HeaderIntegrityVerifier + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + hash []byte + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler } // NewInterceptedMetaHeader creates a new instance of InterceptedMetaHeader struct @@ -43,13 +46,14 @@ func NewInterceptedMetaHeader(arg *ArgInterceptedBlockHeader) (*InterceptedMetaH } inHdr := &InterceptedMetaHeader{ - hdr: hdr, - hasher: arg.Hasher, - sigVerifier: arg.HeaderSigVerifier, - integrityVerifier: arg.HeaderIntegrityVerifier, - shardCoordinator: arg.ShardCoordinator, - validityAttester: arg.ValidityAttester, - epochStartTrigger: arg.EpochStartTrigger, + hdr: hdr, + hasher: arg.Hasher, + sigVerifier: arg.HeaderSigVerifier, + integrityVerifier: arg.HeaderIntegrityVerifier, + shardCoordinator: arg.ShardCoordinator, + validityAttester: arg.ValidityAttester, + epochStartTrigger: arg.EpochStartTrigger, + enableEpochsHandler: arg.EnableEpochsHandler, } inHdr.processFields(arg.HdrBuff) @@ -137,12 +141,12 @@ func (imh *InterceptedMetaHeader) isMetaHeaderEpochOutOfRange() bool { // integrity checks the integrity of the meta header block wrapper func (imh *InterceptedMetaHeader) integrity() error { - err := checkHeaderHandler(imh.HeaderHandler()) + err := checkHeaderHandler(imh.HeaderHandler(), imh.enableEpochsHandler) if err != nil { return err } - err = checkMetaShardInfo(imh.hdr.GetShardInfoHandlers(), imh.shardCoordinator) + err = checkMetaShardInfo(imh.hdr.GetShardInfoHandlers(), imh.shardCoordinator, imh.sigVerifier) if err != nil { return err } diff --git a/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go b/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go index 99fc49d1dd3..b895a6a81cc 100644 --- a/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go +++ b/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go @@ -8,11 +8,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" ) func createDefaultMetaArgument() *interceptedBlocks.ArgInterceptedBlockHeader { @@ -20,7 +23,7 @@ func createDefaultMetaArgument() *interceptedBlocks.ArgInterceptedBlockHeader { ShardCoordinator: mock.NewOneShardCoordinatorMock(), Hasher: testHasher, Marshalizer: testMarshalizer, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, ValidityAttester: &mock.ValidityAttesterStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{ @@ -28,6 +31,7 @@ func createDefaultMetaArgument() *interceptedBlocks.ArgInterceptedBlockHeader { return hdrEpoch }, }, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } hdr := createMockMetaHeader() @@ -204,7 +208,7 @@ func TestInterceptedMetaHeader_CheckValidityLeaderSignatureNotCorrectShouldErr(t buff, _ := testMarshalizer.Marshal(hdr) arg := createDefaultMetaArgument() - arg.HeaderSigVerifier = &mock.HeaderSigVerifierStub{ + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ VerifyRandSeedAndLeaderSignatureCalled: func(header data.HeaderHandler) error { return expectedErr }, diff --git a/process/block/interceptedBlocks/interceptedMiniblock_test.go b/process/block/interceptedBlocks/interceptedMiniblock_test.go index 57d53ec251d..46b489b259d 100644 --- a/process/block/interceptedBlocks/interceptedMiniblock_test.go +++ b/process/block/interceptedBlocks/interceptedMiniblock_test.go @@ -5,10 +5,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/stretchr/testify/assert" ) func createDefaultMiniblockArgument() *interceptedBlocks.ArgInterceptedMiniblock { @@ -69,7 +70,7 @@ func TestNewInterceptedMiniblock_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- CheckValidity +//------- Verify func TestInterceptedMiniblock_InvalidReceiverShardIdShouldErr(t *testing.T) { t.Parallel() diff --git a/process/block/metablock.go b/process/block/metablock.go index fb53f1207d7..fbd963f4da4 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -139,6 +139,7 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { managedPeersHolder: arguments.ManagedPeersHolder, sentSignaturesTracker: arguments.SentSignaturesTracker, extraDelayRequestBlockInfo: time.Duration(arguments.Config.EpochStartConfig.ExtraDelayForRequestBlockInfoInMilliseconds) * time.Millisecond, + proofsPool: arguments.DataComponents.Datapool().Proofs(), } mp := metaProcessor{ @@ -340,6 +341,11 @@ func (mp *metaProcessor) ProcessBlock( } } + err = mp.checkProofsForShardData(header) + if err != nil { + return err + } + defer func() { go mp.checkAndRequestIfShardHeadersMissing() }() @@ -410,6 +416,23 @@ func (mp *metaProcessor) ProcessBlock( return nil } +func (mp *metaProcessor) checkProofsForShardData(header *block.MetaBlock) error { + if !mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.Epoch) { + return nil + } + + for _, shardData := range header.ShardInfo { + // TODO: consider the validation of the proof: + // compare the one from proofsPool with what shardData.CurrentSignature and shardData.CurrentPubKeysBitmap hold + // if they are different, verify the proof received on header + if !mp.proofsPool.HasProof(shardData.ShardID, shardData.HeaderHash) { + return fmt.Errorf("%w for header hash %s", process.ErrMissingHeaderProof, hex.EncodeToString(shardData.HeaderHash)) + } + } + + return nil +} + func (mp *metaProcessor) processEpochStartMetaBlock( header *block.MetaBlock, body *block.Body, @@ -1081,8 +1104,23 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( continue } + shouldCheckProof := mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, currShardHdr.GetEpoch()) + if shouldCheckProof { + hasProofForHdr := mp.proofsPool.HasProof(currShardHdr.GetShardID(), orderedHdrsHashes[i]) + if !hasProofForHdr { + log.Trace("no proof for shard header", + "shard", currShardHdr.GetShardID(), + "hash", logger.DisplayByteSlice(orderedHdrsHashes[i]), + ) + continue + } + } + if len(currShardHdr.GetMiniBlockHeadersWithDst(mp.shardCoordinator.SelfId())) == 0 { - mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{hdr: currShardHdr, usedInBlock: true} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{ + hdr: currShardHdr, + usedInBlock: true, + } hdrsAdded++ hdrsAddedForShard[currShardHdr.GetShardID()]++ lastShardHdr[currShardHdr.GetShardID()] = currShardHdr @@ -1121,7 +1159,10 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( miniBlocks = append(miniBlocks, currMBProcessed...) txsAdded += currTxsAdded - mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{hdr: currShardHdr, usedInBlock: true} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{ + hdr: currShardHdr, + usedInBlock: true, + } hdrsAdded++ hdrsAddedForShard[currShardHdr.GetShardID()]++ @@ -1284,7 +1325,20 @@ func (mp *metaProcessor) CommitBlock( mp.lastRestartNonce = header.GetNonce() } - mp.updateState(lastMetaBlock, lastMetaBlockHash) + finalMetaBlock := lastMetaBlock + finalMetaBlockHash := lastMetaBlockHash + isBlockAfterEquivalentMessagesFlag := !check.IfNil(finalMetaBlock) && + mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, finalMetaBlock.GetEpoch()) + if isBlockAfterEquivalentMessagesFlag { + // for the first block we need to update both the state of the previous one and for current + if common.IsEpochChangeBlockForFlagActivation(header, mp.enableEpochsHandler, common.EquivalentMessagesFlag) { + mp.updateState(lastMetaBlock, lastMetaBlockHash) + } + + finalMetaBlock = header + finalMetaBlockHash = headerHash + } + mp.updateState(finalMetaBlock, finalMetaBlockHash) committedRootHash, err := mp.accountsDB[state.UserAccountsState].RootHash() if err != nil { @@ -1298,12 +1352,12 @@ func (mp *metaProcessor) CommitBlock( mp.blockChain.SetCurrentBlockHeaderHash(headerHash) - if !check.IfNil(lastMetaBlock) && lastMetaBlock.IsStartOfEpochBlock() { + if !check.IfNil(finalMetaBlock) && finalMetaBlock.IsStartOfEpochBlock() { mp.blockTracker.CleanupInvalidCrossHeaders(header.Epoch, header.Round) } // TODO: Should be sent also validatorInfoTxs alongside rewardsTxs -> mp.validatorInfoCreator.GetValidatorInfoTxs(body) ? - mp.indexBlock(header, headerHash, body, lastMetaBlock, notarizedHeadersHashes, rewardsTxs) + mp.indexBlock(header, headerHash, body, finalMetaBlock, notarizedHeadersHashes, rewardsTxs) mp.recordBlockInHistory(headerHash, headerHandler, bodyHandler) highestFinalBlockNonce := mp.forkDetector.GetHighestFinalBlockNonce() @@ -1723,7 +1777,10 @@ func (mp *metaProcessor) getLastCrossNotarizedShardHdrs() (map[uint32]data.Heade log.Debug("lastCrossNotarizedHeader for shard", "shardID", shardID, "hash", hash) lastCrossNotarizedHeader[shardID] = lastCrossNotarizedHeaderForShard usedInBlock := mp.isGenesisShardBlockAndFirstMeta(lastCrossNotarizedHeaderForShard.GetNonce()) - mp.hdrsForCurrBlock.hdrHashAndInfo[string(hash)] = &hdrInfo{hdr: lastCrossNotarizedHeaderForShard, usedInBlock: usedInBlock} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(hash)] = &hdrInfo{ + hdr: lastCrossNotarizedHeaderForShard, + usedInBlock: usedInBlock, + } } return lastCrossNotarizedHeader, nil @@ -1737,7 +1794,10 @@ func (mp *metaProcessor) checkShardHeadersValidity(metaHdr *block.MetaBlock) (ma return nil, err } - usedShardHdrs := mp.sortHeadersForCurrentBlockByNonce(true) + usedShardHdrs, err := mp.sortHeadersForCurrentBlockByNonce(true) + if err != nil { + return nil, err + } highestNonceHdrs := make(map[uint32]data.HeaderHandler, len(usedShardHdrs)) if len(usedShardHdrs) == 0 { @@ -1785,6 +1845,11 @@ func (mp *metaProcessor) checkShardHeadersValidity(metaHdr *block.MetaBlock) (ma return nil, process.ErrDeveloperFeesDoNotMatch } + err = verifyProof(shardData.GetPreviousProof()) + if err != nil { + return nil, err + } + mapMiniBlockHeadersInMetaBlock := make(map[string]struct{}) for _, shardMiniBlockHdr := range shardData.ShardMiniBlockHeaders { mapMiniBlockHeadersInMetaBlock[string(shardMiniBlockHdr.Hash)] = struct{}{} @@ -1800,6 +1865,24 @@ func (mp *metaProcessor) checkShardHeadersValidity(metaHdr *block.MetaBlock) (ma return highestNonceHdrs, nil } +func verifyProof(proof data.HeaderProofHandler) error { + if check.IfNilReflect(proof) { + return nil + } + + if isIncompleteProof(proof) { + return process.ErrInvalidHeaderProof + } + + return nil +} + +func isIncompleteProof(proof data.HeaderProofHandler) bool { + return len(proof.GetAggregatedSignature()) == 0 || + len(proof.GetPubKeysBitmap()) == 0 || + len(proof.GetHeaderHash()) == 0 +} + func (mp *metaProcessor) getFinalMiniBlockHeaders(miniBlockHeaderHandlers []data.MiniBlockHeaderHandler) []data.MiniBlockHeaderHandler { if !mp.enableEpochsHandler.IsFlagEnabled(common.ScheduledMiniBlocksFlag) { return miniBlockHeaderHandlers @@ -1822,7 +1905,10 @@ func (mp *metaProcessor) getFinalMiniBlockHeaders(miniBlockHeaderHandlers []data func (mp *metaProcessor) checkShardHeadersFinality( highestNonceHdrs map[uint32]data.HeaderHandler, ) error { - finalityAttestingShardHdrs := mp.sortHeadersForCurrentBlockByNonce(false) + finalityAttestingShardHdrs, err := mp.sortHeadersForCurrentBlockByNonce(false) + if err != nil { + return err + } var errFinal error @@ -1840,7 +1926,20 @@ func (mp *metaProcessor) checkShardHeadersFinality( // verify if there are "K" block after current to make this one final nextBlocksVerified := uint32(0) + isNotarizedBasedOnProofs := false for _, shardHdr := range finalityAttestingShardHdrs[shardId] { + var errCheckProof error + isNotarizedBasedOnProofs, errCheckProof = mp.checkShardHeaderFinalityBasedOnProofs(shardHdr, shardId) + // if the header is notarized based on proofs and there is no error, break the loop + // if there is any error, forward it, header is not final + if isNotarizedBasedOnProofs { + if errCheckProof != nil { + return err + } + + break + } + if nextBlocksVerified >= mp.shardBlockFinality { break } @@ -1859,7 +1958,7 @@ func (mp *metaProcessor) checkShardHeadersFinality( } } - if nextBlocksVerified < mp.shardBlockFinality { + if nextBlocksVerified < mp.shardBlockFinality && !isNotarizedBasedOnProofs { go mp.requestHandler.RequestShardHeaderByNonce(lastVerifiedHdr.GetShardID(), lastVerifiedHdr.GetNonce()) go mp.requestHandler.RequestShardHeaderByNonce(lastVerifiedHdr.GetShardID(), lastVerifiedHdr.GetNonce()+1) errFinal = process.ErrHeaderNotFinal @@ -1869,6 +1968,24 @@ func (mp *metaProcessor) checkShardHeadersFinality( return errFinal } +func (mp *metaProcessor) checkShardHeaderFinalityBasedOnProofs(shardHdr data.HeaderHandler, shardId uint32) (bool, error) { + if !mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, shardHdr.GetEpoch()) { + return false, nil + } + + marshalledHeader, err := mp.marshalizer.Marshal(shardHdr) + if err != nil { + return true, err + } + + headerHash := mp.hasher.Compute(string(marshalledHeader)) + if !mp.proofsPool.HasProof(shardId, headerHash) { + return true, process.ErrHeaderNotFinal + } + + return true, nil +} + // receivedShardHeader is a call back function which is called when a new header // is added in the headers pool func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, shardHeaderHash []byte) { @@ -1900,7 +2017,8 @@ func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, s } } - if mp.hdrsForCurrBlock.missingHdrs == 0 { + shouldConsiderProofsForNotarization := mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, shardHeader.GetEpoch()) + if mp.hdrsForCurrBlock.missingHdrs == 0 && !shouldConsiderProofsForNotarization { mp.hdrsForCurrBlock.missingFinalityAttestingHdrs = mp.requestMissingFinalityAttestingShardHeaders() if mp.hdrsForCurrBlock.missingFinalityAttestingHdrs == 0 { log.Debug("received all missing finality attesting shard headers") @@ -1955,6 +2073,7 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() + notarizedShardHdrsBasedOnProofs := 0 for _, shardData := range metaBlock.ShardInfo { if shardData.Nonce == mp.genesisNonce { lastCrossNotarizedHeaderForShard, hash, err := mp.blockTracker.GetLastCrossNotarizedHeader(shardData.ShardID) @@ -1982,6 +2101,7 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock hdr: nil, usedInBlock: true, } + go mp.requestHandler.RequestShardHeader(shardData.ShardID, shardData.HeaderHash) continue } @@ -1994,9 +2114,14 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock if hdr.GetNonce() > mp.hdrsForCurrBlock.highestHdrNonce[shardData.ShardID] { mp.hdrsForCurrBlock.highestHdrNonce[shardData.ShardID] = hdr.GetNonce() } + + if mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.GetEpoch()) { + notarizedShardHdrsBasedOnProofs++ + } } - if mp.hdrsForCurrBlock.missingHdrs == 0 { + shouldRequestMissingFinalityAttestingShardHeaders := notarizedShardHdrsBasedOnProofs != len(metaBlock.ShardInfo) + if mp.hdrsForCurrBlock.missingHdrs == 0 && shouldRequestMissingFinalityAttestingShardHeaders { mp.hdrsForCurrBlock.missingFinalityAttestingHdrs = mp.requestMissingFinalityAttestingShardHeaders() } @@ -2017,6 +2142,13 @@ func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { continue } + isBlockAfterEquivalentMessagesFlag := check.IfNil(headerInfo.hdr) && + mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerInfo.hdr.GetEpoch()) + hasMissingShardHdrProof := isBlockAfterEquivalentMessagesFlag && !mp.proofsPool.HasProof(headerInfo.hdr.GetShardID(), []byte(hdrHash)) + if hasMissingShardHdrProof { + return nil, fmt.Errorf("%w for shard header with hash %s", process.ErrMissingHeaderProof, hdrHash) + } + shardHdr, ok := headerInfo.hdr.(data.ShardHeaderHandler) if !ok { return nil, process.ErrWrongTypeAssertion @@ -2031,6 +2163,13 @@ func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { shardData.Nonce = shardHdr.GetNonce() shardData.PrevRandSeed = shardHdr.GetPrevRandSeed() shardData.PubKeysBitmap = shardHdr.GetPubKeysBitmap() + if mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, shardHdr.GetEpoch()) { + prevProof := shardHdr.GetPreviousProof() + err := shardData.SetPreviousProof(prevProof) + if err != nil { + return nil, err + } + } shardData.NumPendingMiniBlocks = uint32(len(mp.pendingMiniBlocksHandler.GetPendingMiniBlocks(shardData.ShardID))) header, _, err := mp.blockTracker.GetLastSelfNotarizedHeader(shardHdr.GetShardID()) if err != nil { @@ -2268,7 +2407,10 @@ func (mp *metaProcessor) prepareBlockHeaderInternalMapForValidatorProcessor() { } mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() - mp.hdrsForCurrBlock.hdrHashAndInfo[string(currentBlockHeaderHash)] = &hdrInfo{false, currentBlockHeader} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(currentBlockHeaderHash)] = &hdrInfo{ + usedInBlock: false, + hdr: currentBlockHeader, + } mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() } diff --git a/process/block/metrics.go b/process/block/metrics.go index ce29ddb23f8..94ab2e00276 100644 --- a/process/block/metrics.go +++ b/process/block/metrics.go @@ -12,11 +12,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) const leaderIndex = 0 @@ -129,7 +130,7 @@ func incrementMetricCountConsensusAcceptedBlocks( appStatusHandler core.AppStatusHandler, managedPeersHolder common.ManagedPeersHolder, ) { - pubKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys( + _, pubKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys( header.GetPrevRandSeed(), header.GetRound(), header.GetShardID(), @@ -184,7 +185,7 @@ func indexRoundInfo( roundsInfo := make([]*outportcore.RoundInfo, 0) roundsInfo = append(roundsInfo, roundInfo) for i := lastBlockRound + 1; i < currentBlockRound; i++ { - publicKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys(lastHeader.GetRandSeed(), i, shardId, lastHeader.GetEpoch()) + _, publicKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys(lastHeader.GetRandSeed(), i, shardId, lastHeader.GetEpoch()) if err != nil { continue } diff --git a/process/block/metrics_test.go b/process/block/metrics_test.go index 2457bd67ac1..eff2950f371 100644 --- a/process/block/metrics_test.go +++ b/process/block/metrics_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func TestMetrics_CalculateRoundDuration(t *testing.T) { @@ -32,8 +33,8 @@ func TestMetrics_IncrementMetricCountConsensusAcceptedBlocks(t *testing.T) { t.Parallel() nodesCoord := &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return nil, expectedErr + GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + return "", nil, expectedErr }, } statusHandler := &statusHandlerMock.AppStatusHandlerStub{ @@ -54,9 +55,10 @@ func TestMetrics_IncrementMetricCountConsensusAcceptedBlocks(t *testing.T) { GetOwnPublicKeyCalled: func() []byte { return []byte(mainKey) }, - GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return []string{ - "some leader", + GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + leader := "some leader" + return leader, []string{ + leader, mainKey, managedKeyInConsensus, "some other key", diff --git a/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go b/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go index b590009bdf7..ba16c9dadbb 100644 --- a/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go +++ b/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go @@ -6,9 +6,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -19,7 +21,7 @@ func createMockArgMiniBlocksPoolsCleaner() ArgMiniBlocksPoolsCleaner { ShardCoordinator: &mock.CoordinatorStub{}, MaxRoundsToKeepUnprocessedData: 1, }, - MiniblocksPool: testscommon.NewCacherStub(), + MiniblocksPool: cache.NewCacherStub(), } } @@ -103,7 +105,7 @@ func TestCleanMiniblocksPoolsIfNeeded_MiniblockNotInPoolShouldBeRemovedFromMap(t t.Parallel() args := createMockArgMiniBlocksPoolsCleaner() - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -122,7 +124,7 @@ func TestCleanMiniblocksPoolsIfNeeded_RoundDiffTooSmallMiniblockShouldRemainInMa t.Parallel() args := createMockArgMiniBlocksPoolsCleaner() - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -142,7 +144,7 @@ func TestCleanMiniblocksPoolsIfNeeded_MbShouldBeRemovedFromPoolAndMap(t *testing args := createMockArgMiniBlocksPoolsCleaner() called := false - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, diff --git a/process/block/poolsCleaner/txsPoolsCleaner_test.go b/process/block/poolsCleaner/txsPoolsCleaner_test.go index 125f44e1870..cbcab2aae85 100644 --- a/process/block/poolsCleaner/txsPoolsCleaner_test.go +++ b/process/block/poolsCleaner/txsPoolsCleaner_test.go @@ -6,14 +6,16 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" - "github.com/stretchr/testify/assert" ) func createMockArgTxsPoolsCleaner() ArgTxsPoolsCleaner { @@ -174,7 +176,7 @@ func TestReceivedBlockTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -199,7 +201,7 @@ func TestReceivedRewardTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -223,7 +225,7 @@ func TestReceivedUnsignedTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -252,7 +254,7 @@ func TestCleanTxsPoolsIfNeeded_CannotFindTxInPoolShouldBeRemovedFromMap(t *testi UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -283,7 +285,7 @@ func TestCleanTxsPoolsIfNeeded_RoundDiffTooSmallShouldNotBeRemoved(t *testing.T) UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -323,7 +325,7 @@ func TestCleanTxsPoolsIfNeeded_RoundDiffTooBigShouldBeRemoved(t *testing.T) { UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, diff --git a/process/block/preprocess/rewardTxPreProcessor_test.go b/process/block/preprocess/rewardTxPreProcessor_test.go index ad0d0952569..836a85d8652 100644 --- a/process/block/preprocess/rewardTxPreProcessor_test.go +++ b/process/block/preprocess/rewardTxPreProcessor_test.go @@ -9,17 +9,19 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/rewardTx" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" ) const testTxHash = "tx1_hash" @@ -904,7 +906,7 @@ func TestRewardTxPreprocessor_RestoreBlockDataIntoPools(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() numRestoredTxs, err := rtp.RestoreBlockDataIntoPools(blockBody, miniBlockPool) assert.Equal(t, 1, numRestoredTxs) diff --git a/process/block/preprocess/smartContractResults_test.go b/process/block/preprocess/smartContractResults_test.go index 6f56571c7d7..37a03255c66 100644 --- a/process/block/preprocess/smartContractResults_test.go +++ b/process/block/preprocess/smartContractResults_test.go @@ -13,20 +13,22 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonTests "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" ) func haveTime() time.Duration { @@ -691,7 +693,7 @@ func TestScrsPreprocessor_ReceivedTransactionShouldEraseRequested(t *testing.T) shardedDataStub := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return &smartContractResult.SmartContractResult{}, true }, @@ -1430,7 +1432,7 @@ func TestScrsPreprocessor_ProcessMiniBlock(t *testing.T) { tdp.TransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &smartContractResult.SmartContractResult{Nonce: 10}, true @@ -1589,7 +1591,7 @@ func TestScrsPreprocessor_RestoreBlockDataIntoPools(t *testing.T) { } body.MiniBlocks = append(body.MiniBlocks, &miniblock) - miniblockPool := testscommon.NewCacherMock() + miniblockPool := cache.NewCacherMock() scrRestored, err := scr.RestoreBlockDataIntoPools(body, miniblockPool) assert.Equal(t, scrRestored, 1) diff --git a/process/block/preprocess/transactions_test.go b/process/block/preprocess/transactions_test.go index 67a5b312994..ba1f0dd8601 100644 --- a/process/block/preprocess/transactions_test.go +++ b/process/block/preprocess/transactions_test.go @@ -21,6 +21,10 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -29,6 +33,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -39,9 +44,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const MaxGasLimitPerBlock = uint64(100000) @@ -78,7 +80,7 @@ func feeHandlerMock() *economicsmocks.EconomicsHandlerStub { func shardedDataCacherNotifier() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &smartContractResult.SmartContractResult{Nonce: 10}, true @@ -123,7 +125,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &rewardTx.RewardTx{Value: big.NewInt(100)}, true @@ -155,7 +157,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { } }, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -178,7 +180,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -512,7 +514,7 @@ func TestTransactionPreprocessor_ReceivedTransactionShouldEraseRequested(t *test shardedDataStub := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return &transaction.Transaction{}, true }, @@ -1214,7 +1216,7 @@ func TestTransactionsPreprocessor_ProcessMiniBlockShouldWork(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx_hash1")) { return &transaction.Transaction{Nonce: 10}, true @@ -1300,7 +1302,7 @@ func TestTransactionsPreprocessor_ProcessMiniBlockShouldErrMaxGasLimitUsedForDes TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx_hash1")) { return &transaction.Transaction{}, true @@ -2012,7 +2014,7 @@ func TestTransactions_RestoreBlockDataIntoPools(t *testing.T) { args.Store = genericMocks.NewChainStorerMock(0) txs, _ := NewTransactionPreprocessor(args) - mbPool := testscommon.NewCacherMock() + mbPool := cache.NewCacherMock() body, allTxs := createMockBlockBody() storer, _ := args.Store.GetStorer(dataRetriever.TransactionUnit) diff --git a/process/block/preprocess/validatorInfoPreProcessor_test.go b/process/block/preprocess/validatorInfoPreProcessor_test.go index 059c6c3d0b1..59cf03baa6c 100644 --- a/process/block/preprocess/validatorInfoPreProcessor_test.go +++ b/process/block/preprocess/validatorInfoPreProcessor_test.go @@ -8,17 +8,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/rewardTx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewValidatorInfoPreprocessor_NilHasherShouldErr(t *testing.T) { @@ -289,7 +291,7 @@ func TestNewValidatorInfoPreprocessor_RestorePeerBlockIntoPools(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() marshalizedMb, _ := marshalizer.Marshal(mb1) mbHash := hasher.Compute(string(marshalizedMb)) @@ -334,7 +336,7 @@ func TestNewValidatorInfoPreprocessor_RestoreOtherBlockTypeIntoPoolsShouldNotRes blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() marshalizedMb, _ := marshalizer.Marshal(mb1) mbHash := hasher.Compute(string(marshalizedMb)) @@ -382,7 +384,7 @@ func TestNewValidatorInfoPreprocessor_RemovePeerBlockFromPool(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() miniBlockPool.Put(mbHash, marshalizedMb, len(marshalizedMb)) foundMb, ok := miniBlockPool.Get(mbHash) @@ -427,7 +429,7 @@ func TestNewValidatorInfoPreprocessor_RemoveOtherBlockTypeFromPoolShouldNotRemov blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() miniBlockPool.Put(mbHash, marshalizedMb, len(marshalizedMb)) foundMb, ok := miniBlockPool.Get(mbHash) diff --git a/process/block/shardblock.go b/process/block/shardblock.go index a73b4e0be6d..d35ed73aa6b 100644 --- a/process/block/shardblock.go +++ b/process/block/shardblock.go @@ -2,6 +2,8 @@ package block import ( "bytes" + "encoding/hex" + "errors" "fmt" "math/big" "time" @@ -11,6 +13,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/headerVersionData" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -20,7 +24,6 @@ import ( "github.com/multiversx/mx-chain-go/process/block/helpers" "github.com/multiversx/mx-chain-go/process/block/processedMb" "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.BlockProcessor = (*shardProcessor)(nil) @@ -124,6 +127,7 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { managedPeersHolder: arguments.ManagedPeersHolder, sentSignaturesTracker: arguments.SentSignaturesTracker, extraDelayRequestBlockInfo: time.Duration(arguments.Config.EpochStartConfig.ExtraDelayForRequestBlockInfoInMilliseconds) * time.Millisecond, + proofsPool: arguments.DataComponents.Datapool().Proofs(), } sp := shardProcessor{ @@ -171,7 +175,7 @@ func (sp *shardProcessor) ProcessBlock( err := sp.checkBlockValidity(headerHandler, bodyHandler) if err != nil { - if err == process.ErrBlockHashDoesNotMatch { + if errors.Is(err, process.ErrBlockHashDoesNotMatch) { log.Debug("requested missing shard header", "hash", headerHandler.GetPrevHash(), "for shard", headerHandler.GetShardID(), @@ -289,6 +293,15 @@ func (sp *shardProcessor) ProcessBlock( return process.ErrAccountStateDirty } + if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + // check proofs for cross notarized metablocks + for _, metaBlockHash := range header.GetMetaBlockHashes() { + if !sp.proofsPool.HasProof(core.MetachainShardId, metaBlockHash) { + return fmt.Errorf("%w for header hash %s", process.ErrMissingHeaderProof, hex.EncodeToString(metaBlockHash)) + } + } + } + defer func() { go sp.checkAndRequestIfMetaHeadersMissing() }() @@ -524,7 +537,10 @@ func (sp *shardProcessor) checkMetaHeadersValidityAndFinality() error { } log.Trace("checkMetaHeadersValidityAndFinality", "lastCrossNotarizedHeader nonce", lastCrossNotarizedHeader.GetNonce()) - usedMetaHdrs := sp.sortHeadersForCurrentBlockByNonce(true) + usedMetaHdrs, err := sp.sortHeadersForCurrentBlockByNonce(true) + if err != nil { + return err + } if len(usedMetaHdrs[core.MetachainShardId]) == 0 { return nil } @@ -553,7 +569,24 @@ func (sp *shardProcessor) checkMetaHdrFinality(header data.HeaderHandler) error return process.ErrNilBlockHeader } - finalityAttestingMetaHdrs := sp.sortHeadersForCurrentBlockByNonce(false) + if common.IsFlagEnabledAfterEpochsStartBlock(header, sp.enableEpochsHandler, common.EquivalentMessagesFlag) { + marshalledHeader, err := sp.marshalizer.Marshal(header) + if err != nil { + return err + } + + headerHash := sp.hasher.Compute(string(marshalledHeader)) + if !sp.proofsPool.HasProof(header.GetShardID(), headerHash) { + return fmt.Errorf("%w, missing proof for header %s", process.ErrHeaderNotFinal, hex.EncodeToString(headerHash)) + } + + return nil + } + + finalityAttestingMetaHdrs, err := sp.sortHeadersForCurrentBlockByNonce(false) + if err != nil { + return err + } lastVerifiedHdr := header // verify if there are "K" block after current to make this one final @@ -1708,12 +1741,11 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me } } + hasProofForMetablock := false // attesting something if sp.hdrsForCurrBlock.missingHdrs == 0 { - sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( - core.MetachainShardId, - sp.metaBlockFinality, - ) + hasProofForMetablock = sp.hasProofForMetablock(metaBlockHash, metaBlock) + if sp.hdrsForCurrBlock.missingFinalityAttestingHdrs == 0 { log.Debug("received all missing finality attesting meta headers") } @@ -1723,7 +1755,7 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me missingFinalityAttestingMetaHdrs := sp.hdrsForCurrBlock.missingFinalityAttestingHdrs sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() - allMissingMetaHeadersReceived := missingMetaHdrs == 0 && missingFinalityAttestingMetaHdrs == 0 + allMissingMetaHeadersReceived := missingMetaHdrs == 0 && missingFinalityAttestingMetaHdrs == 0 && hasProofForMetablock if allMissingMetaHeadersReceived { sp.chRcvAllMetaHdrs <- true } @@ -1734,6 +1766,20 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me go sp.requestMiniBlocksIfNeeded(headerHandler) } +func (sp *shardProcessor) hasProofForMetablock(metaBlockHash []byte, metaBlock *block.MetaBlock) bool { + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, metaBlock.Epoch) + if !shouldConsiderProofsForNotarization { + sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( + core.MetachainShardId, + sp.metaBlockFinality, + ) + + return true // no proof needed + } + + return sp.proofsPool.HasProof(core.MetachainShardId, metaBlockHash) +} + func (sp *shardProcessor) requestMetaHeaders(shardHeader data.ShardHeaderHandler) (uint32, uint32) { _ = core.EmptyChannel(sp.chRcvAllMetaHdrs) @@ -1748,6 +1794,7 @@ func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header dat sp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() + notarizedMetaHdrsBasedOnProofs := 0 metaBlockHashes := header.GetMetaBlockHashes() for i := 0; i < len(metaBlockHashes); i++ { hdr, err := process.GetMetaHeaderFromPool( @@ -1760,6 +1807,7 @@ func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header dat hdr: nil, usedInBlock: true, } + go sp.requestHandler.RequestMetaHeader(metaBlockHashes[i]) continue } @@ -1772,9 +1820,14 @@ func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header dat if hdr.Nonce > sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] { sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] = hdr.Nonce } + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, hdr.Epoch) + if shouldConsiderProofsForNotarization { + notarizedMetaHdrsBasedOnProofs++ + } } - if sp.hdrsForCurrBlock.missingHdrs == 0 { + shouldRequestMissingFinalityAttestingMetaHeaders := notarizedMetaHdrsBasedOnProofs != len(metaBlockHashes) + if sp.hdrsForCurrBlock.missingHdrs == 0 && shouldRequestMissingFinalityAttestingMetaHeaders { sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( core.MetachainShardId, sp.metaBlockFinality, @@ -1902,9 +1955,21 @@ func (sp *shardProcessor) createAndProcessMiniBlocksDstMe(haveTime func() bool) break } + hasProofForHdr := sp.proofsPool.HasProof(core.MetachainShardId, orderedMetaBlocksHashes[i]) + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, orderedMetaBlocks[i].GetEpoch()) + if !hasProofForHdr && shouldConsiderProofsForNotarization { + log.Trace("no proof for meta header", + "hash", logger.DisplayByteSlice(orderedMetaBlocksHashes[i]), + ) + break + } + createAndProcessInfo.currMetaHdrHash = orderedMetaBlocksHashes[i] if len(createAndProcessInfo.currMetaHdr.GetMiniBlockHeadersWithDst(sp.shardCoordinator.SelfId())) == 0 { - sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{hdr: createAndProcessInfo.currMetaHdr, usedInBlock: true} + sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{ + hdr: createAndProcessInfo.currMetaHdr, + usedInBlock: true, + } createAndProcessInfo.numHdrsAdded++ lastMetaHdr = createAndProcessInfo.currMetaHdr continue @@ -1968,7 +2033,10 @@ func (sp *shardProcessor) createMbsAndProcessCrossShardTransactionsDstMe( createAndProcessInfo.numTxsAdded += currNumTxsAdded if !createAndProcessInfo.hdrAdded && currNumTxsAdded > 0 { - sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{hdr: createAndProcessInfo.currMetaHdr, usedInBlock: true} + sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{ + hdr: createAndProcessInfo.currMetaHdr, + usedInBlock: true, + } createAndProcessInfo.numHdrsAdded++ createAndProcessInfo.hdrAdded = true } @@ -2182,8 +2250,12 @@ func (sp *shardProcessor) applyBodyToHeader( } sw.Start("sortHeaderHashesForCurrentBlockByNonce") - metaBlockHashes := sp.sortHeaderHashesForCurrentBlockByNonce(true) + metaBlockHashes, err := sp.sortHeaderHashesForCurrentBlockByNonce(true) sw.Stop("sortHeaderHashesForCurrentBlockByNonce") + if err != nil { + return nil, err + } + err = shardHeader.SetMetaBlockHashes(metaBlockHashes[core.MetachainShardId]) if err != nil { return nil, err diff --git a/process/common_test.go b/process/common_test.go index a79e2fd5c32..b6e308ec3ab 100644 --- a/process/common_test.go +++ b/process/common_test.go @@ -12,14 +12,16 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetShardHeaderShouldErrNilCacher(t *testing.T) { @@ -1800,7 +1802,7 @@ func TestGetTransactionHandlerShouldGetTransactionFromPool(t *testing.T) { storageService := &storageStubs.ChainStorerStub{} shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return txFromPool, true }, @@ -1843,7 +1845,7 @@ func TestGetTransactionHandlerShouldGetTransactionFromStorage(t *testing.T) { } shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -1871,7 +1873,7 @@ func TestGetTransactionHandlerFromPool_Errors(t *testing.T) { shardedDataCacherNotifier := testscommon.NewShardedDataStub() shardedDataCacherNotifier.ShardDataStoreCalled = func(cacheID string) storage.Cacher { - return testscommon.NewCacherMock() + return cache.NewCacherMock() } t.Run("nil sharded cache", func(t *testing.T) { @@ -1922,7 +1924,7 @@ func TestGetTransactionHandlerFromPoolShouldErrTxNotFound(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -1948,7 +1950,7 @@ func TestGetTransactionHandlerFromPoolShouldErrInvalidTxInPool(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1975,7 +1977,7 @@ func TestGetTransactionHandlerFromPoolShouldWorkWithPeek(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return txFromPool, true }, @@ -2026,7 +2028,7 @@ func TestGetTransactionHandlerFromPoolShouldWorkWithPeekFallbackToSearchFirst(t peekCalled := false shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { peekCalled = true return nil, false diff --git a/process/coordinator/process_test.go b/process/coordinator/process_test.go index d1dff667cb7..80e26980e81 100644 --- a/process/coordinator/process_test.go +++ b/process/coordinator/process_test.go @@ -36,6 +36,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMock "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -81,7 +82,7 @@ func createShardedDataChacherNotifier( return &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, testHash) { return handler, true @@ -126,7 +127,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { UnsignedTransactionsCalled: unsignedTxHandler, RewardTransactionsCalled: rewardTxCalled, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -149,7 +150,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -1162,7 +1163,7 @@ func TestTransactionCoordinator_CreateMbsAndProcessTransactionsFromMeNothingToPr shardedCacheMock := &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -2317,7 +2318,7 @@ func TestTransactionCoordinator_VerifyCreatedBlockTransactionsOk(t *testing.T) { return &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, scrHash) { return scr, true @@ -4455,7 +4456,7 @@ func TestTransactionCoordinator_requestMissingMiniBlocksAndTransactionsShouldWor t.Parallel() args := createMockTransactionCoordinatorArguments() - args.MiniBlockPool = &testscommon.CacherStub{ + args.MiniBlockPool = &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, []byte("hash0")) || bytes.Equal(key, []byte("hash1")) || bytes.Equal(key, []byte("hash2")) { if bytes.Equal(key, []byte("hash0")) { diff --git a/process/errors.go b/process/errors.go index c2ebf0299a8..395ebf17620 100644 --- a/process/errors.go +++ b/process/errors.go @@ -239,6 +239,9 @@ var ErrNilMiniBlockPool = errors.New("nil mini block pool") // ErrNilMetaBlocksPool signals that a nil meta blocks pool was used var ErrNilMetaBlocksPool = errors.New("nil meta blocks pool") +// ErrNilProofsPool signals that a nil proofs pool was used +var ErrNilProofsPool = errors.New("nil proofs pool") + // ErrNilTxProcessor signals that a nil transactions processor was used var ErrNilTxProcessor = errors.New("nil transactions processor") @@ -696,6 +699,9 @@ var ErrNilWhiteListHandler = errors.New("nil whitelist handler") // ErrNilPreferredPeersHolder signals that preferred peers holder is nil var ErrNilPreferredPeersHolder = errors.New("nil preferred peers holder") +// ErrNilInterceptedDataVerifier signals that intercepted data verifier is nil +var ErrNilInterceptedDataVerifier = errors.New("nil intercepted data verifier") + // ErrMiniBlocksInWrongOrder signals the miniblocks are in wrong order var ErrMiniBlocksInWrongOrder = errors.New("miniblocks in wrong order, should have been only from me") @@ -1095,6 +1101,9 @@ var ErrInvalidExpiryTimespan = errors.New("invalid expiry timespan") // ErrNilPeerSignatureHandler signals that a nil peer signature handler was provided var ErrNilPeerSignatureHandler = errors.New("nil peer signature handler") +// ErrNilInterceptedDataVerifierFactory signals that a nil intercepted data verifier factory was provided +var ErrNilInterceptedDataVerifierFactory = errors.New("nil intercepted data verifier factory") + // ErrNilPeerAuthenticationCacher signals that a nil peer authentication cacher was provided var ErrNilPeerAuthenticationCacher = errors.New("nil peer authentication cacher") @@ -1241,3 +1250,27 @@ var ErrEmptyChainParametersConfiguration = errors.New("empty chain parameters co // ErrNoMatchingConfigForProvidedEpoch signals that there is no matching configuration for the provided epoch var ErrNoMatchingConfigForProvidedEpoch = errors.New("no matching configuration") + +// ErrInvalidHeader is raised when header is invalid +var ErrInvalidHeader = errors.New("header is invalid") + +// ErrNilHeaderProof signals that a nil header proof has been provided +var ErrNilHeaderProof = errors.New("nil header proof") + +// ErrNilInterceptedDataCache signals that a nil cacher was provided for intercepted data verifier +var ErrNilInterceptedDataCache = errors.New("nil cache for intercepted data") + +// ErrFlagNotActive signals that a flag is not active +var ErrFlagNotActive = errors.New("flag not active") + +// ErrInvalidInterceptedData signals that an invalid data has been intercepted +var ErrInvalidInterceptedData = errors.New("invalid intercepted data") + +// ErrMissingHeaderProof signals that the proof for the header is missing +var ErrMissingHeaderProof = errors.New("missing header proof") + +// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided +var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") + +// ErrUnexpectedHeaderProof signals that a header proof has been provided unexpectedly +var ErrUnexpectedHeaderProof = errors.New("unexpected header proof") diff --git a/process/factory/interceptorscontainer/args.go b/process/factory/interceptorscontainer/args.go index 294e66290b3..8e98c7c18ab 100644 --- a/process/factory/interceptorscontainer/args.go +++ b/process/factory/interceptorscontainer/args.go @@ -2,6 +2,7 @@ package interceptorscontainer import ( crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/heartbeat" @@ -13,34 +14,35 @@ import ( // CommonInterceptorsContainerFactoryArgs holds the arguments needed for the metachain/shard interceptors factories type CommonInterceptorsContainerFactoryArgs struct { - CoreComponents process.CoreComponentsHolder - CryptoComponents process.CryptoComponentsHolder - Accounts state.AccountsAdapter - ShardCoordinator sharding.Coordinator - NodesCoordinator nodesCoordinator.NodesCoordinator - MainMessenger process.TopicHandler - FullArchiveMessenger process.TopicHandler - Store dataRetriever.StorageService - DataPool dataRetriever.PoolsHolder - MaxTxNonceDeltaAllowed int - TxFeeHandler process.FeeHandler - BlockBlackList process.TimeCacher - HeaderSigVerifier process.InterceptedHeaderSigVerifier - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - ValidityAttester process.ValidityAttester - EpochStartTrigger process.EpochStartTriggerHandler - WhiteListHandler process.WhiteListHandler - WhiteListerVerifiedTxs process.WhiteListHandler - AntifloodHandler process.P2PAntifloodHandler - ArgumentsParser process.ArgumentsParser - PreferredPeersHolder process.PreferredPeersHolderHandler - SizeCheckDelta uint32 - RequestHandler process.RequestHandler - PeerSignatureHandler crypto.PeerSignatureHandler - SignaturesHandler process.SignaturesHandler - HeartbeatExpiryTimespanInSec int64 - MainPeerShardMapper process.PeerShardMapper - FullArchivePeerShardMapper process.PeerShardMapper - HardforkTrigger heartbeat.HardforkTrigger - NodeOperationMode common.NodeOperation + CoreComponents process.CoreComponentsHolder + CryptoComponents process.CryptoComponentsHolder + Accounts state.AccountsAdapter + ShardCoordinator sharding.Coordinator + NodesCoordinator nodesCoordinator.NodesCoordinator + MainMessenger process.TopicHandler + FullArchiveMessenger process.TopicHandler + Store dataRetriever.StorageService + DataPool dataRetriever.PoolsHolder + MaxTxNonceDeltaAllowed int + TxFeeHandler process.FeeHandler + BlockBlackList process.TimeCacher + HeaderSigVerifier process.InterceptedHeaderSigVerifier + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + ValidityAttester process.ValidityAttester + EpochStartTrigger process.EpochStartTriggerHandler + WhiteListHandler process.WhiteListHandler + WhiteListerVerifiedTxs process.WhiteListHandler + AntifloodHandler process.P2PAntifloodHandler + ArgumentsParser process.ArgumentsParser + PreferredPeersHolder process.PreferredPeersHolderHandler + SizeCheckDelta uint32 + RequestHandler process.RequestHandler + PeerSignatureHandler crypto.PeerSignatureHandler + SignaturesHandler process.SignaturesHandler + HeartbeatExpiryTimespanInSec int64 + MainPeerShardMapper process.PeerShardMapper + FullArchivePeerShardMapper process.PeerShardMapper + HardforkTrigger heartbeat.HardforkTrigger + NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } diff --git a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go index cfed22b39c9..bc167e0dab5 100644 --- a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/heartbeat" @@ -31,29 +32,31 @@ const ( ) type baseInterceptorsContainerFactory struct { - mainContainer process.InterceptorsContainer - fullArchiveContainer process.InterceptorsContainer - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - store dataRetriever.StorageService - dataPool dataRetriever.PoolsHolder - mainMessenger process.TopicHandler - fullArchiveMessenger process.TopicHandler - nodesCoordinator nodesCoordinator.NodesCoordinator - blockBlackList process.TimeCacher - argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory - globalThrottler process.InterceptorThrottler - maxTxNonceDeltaAllowed int - antifloodHandler process.P2PAntifloodHandler - whiteListHandler process.WhiteListHandler - whiteListerVerifiedTxs process.WhiteListHandler - preferredPeersHolder process.PreferredPeersHolderHandler - hasher hashing.Hasher - requestHandler process.RequestHandler - mainPeerShardMapper process.PeerShardMapper - fullArchivePeerShardMapper process.PeerShardMapper - hardforkTrigger heartbeat.HardforkTrigger - nodeOperationMode common.NodeOperation + mainContainer process.InterceptorsContainer + fullArchiveContainer process.InterceptorsContainer + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + store dataRetriever.StorageService + dataPool dataRetriever.PoolsHolder + mainMessenger process.TopicHandler + fullArchiveMessenger process.TopicHandler + nodesCoordinator nodesCoordinator.NodesCoordinator + blockBlackList process.TimeCacher + argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory + globalThrottler process.InterceptorThrottler + maxTxNonceDeltaAllowed int + antifloodHandler process.P2PAntifloodHandler + whiteListHandler process.WhiteListHandler + whiteListerVerifiedTxs process.WhiteListHandler + preferredPeersHolder process.PreferredPeersHolderHandler + hasher hashing.Hasher + requestHandler process.RequestHandler + mainPeerShardMapper process.PeerShardMapper + fullArchivePeerShardMapper process.PeerShardMapper + hardforkTrigger heartbeat.HardforkTrigger + nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory + enableEpochsHandler common.EnableEpochsHandler } func checkBaseParams( @@ -285,18 +288,24 @@ func (bicf *baseInterceptorsContainerFactory) createOneTxInterceptor(topic strin return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -328,18 +337,24 @@ func (bicf *baseInterceptorsContainerFactory) createOneUnsignedTxInterceptor(top return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -371,18 +386,24 @@ func (bicf *baseInterceptorsContainerFactory) createOneRewardTxInterceptor(topic return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -403,8 +424,10 @@ func (bicf *baseInterceptorsContainerFactory) generateHeaderInterceptors() error } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: bicf.dataPool.Headers(), - BlockBlackList: bicf.blockBlackList, + Headers: bicf.dataPool.Headers(), + BlockBlackList: bicf.blockBlackList, + Proofs: bicf.dataPool.Proofs(), + EnableEpochsHandler: bicf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { @@ -414,17 +437,23 @@ func (bicf *baseInterceptorsContainerFactory) generateHeaderInterceptors() error // compose header shard topic, for example: shardBlocks_0_META identifierHdr := factory.ShardBlocksTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + // only one intrashard header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -502,17 +531,23 @@ func (bicf *baseInterceptorsContainerFactory) createOneMiniBlocksInterceptor(top return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: miniblockFactory, - Processor: miniblockProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + DataFactory: miniblockFactory, + Processor: miniblockProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -533,25 +568,33 @@ func (bicf *baseInterceptorsContainerFactory) generateMetachainHeaderInterceptor } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: bicf.dataPool.Headers(), - BlockBlackList: bicf.blockBlackList, + Headers: bicf.dataPool.Headers(), + BlockBlackList: bicf.blockBlackList, + Proofs: bicf.dataPool.Proofs(), + EnableEpochsHandler: bicf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + // only one metachain header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -577,18 +620,24 @@ func (bicf *baseInterceptorsContainerFactory) createOneTrieNodesInterceptor(topi return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: trieNodesFactory, - Processor: trieNodesProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + DataFactory: trieNodesFactory, + Processor: trieNodesProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -669,17 +718,23 @@ func (bicf *baseInterceptorsContainerFactory) generatePeerAuthenticationIntercep return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierPeerAuthentication) + if err != nil { + return err + } + mdInterceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: identifierPeerAuthentication, - Marshalizer: internalMarshaller, - DataFactory: peerAuthenticationFactory, - Processor: peerAuthenticationProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifierPeerAuthentication, + Marshalizer: internalMarshaller, + DataFactory: peerAuthenticationFactory, + Processor: peerAuthenticationProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -728,16 +783,22 @@ func (bicf *baseInterceptorsContainerFactory) createHeartbeatV2Interceptor( return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifier, - DataFactory: heartbeatFactory, - Processor: heartbeatProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifier, + DataFactory: heartbeatFactory, + Processor: heartbeatProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -777,16 +838,22 @@ func (bicf *baseInterceptorsContainerFactory) createPeerShardInterceptor( return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifier, - DataFactory: interceptedPeerShardFactory, - Processor: psiProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifier, + DataFactory: interceptedPeerShardFactory, + Processor: psiProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -814,17 +881,23 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return err + } + mdInterceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: identifier, - Marshalizer: internalMarshaller, - DataFactory: interceptedValidatorInfoFactory, - Processor: validatorInfoProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifier, + Marshalizer: internalMarshaller, + DataFactory: interceptedValidatorInfoFactory, + Processor: validatorInfoProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -839,6 +912,44 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() return bicf.addInterceptorsToContainers([]string{identifier}, []process.Interceptor{interceptor}) } +func (bicf *baseInterceptorsContainerFactory) createOneShardEquivalentProofsInterceptor(topic string) (process.Interceptor, error) { + equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(*bicf.argInterceptorFactory, bicf.dataPool.Proofs()) + + marshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() + argProcessor := processor.ArgEquivalentProofsInterceptorProcessor{ + EquivalentProofsPool: bicf.dataPool.Proofs(), + Marshaller: marshaller, + } + equivalentProofsProcessor, err := processor.NewEquivalentProofsInterceptorProcessor(argProcessor) + if err != nil { + return nil, err + } + + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + + interceptor, err := interceptors.NewSingleDataInterceptor( + interceptors.ArgSingleDataInterceptor{ + Topic: topic, + DataFactory: equivalentProofsFactory, + Processor: equivalentProofsProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, + }, + ) + if err != nil { + return nil, err + } + + return bicf.createTopicAndAssignHandler(topic, interceptor, true) +} + func (bicf *baseInterceptorsContainerFactory) addInterceptorsToContainers(keys []string, interceptors []process.Interceptor) error { err := bicf.mainContainer.AddMultiple(keys, interceptors) if err != nil { diff --git a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go index 38d3e460bce..e3c304b3f83 100644 --- a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go @@ -5,6 +5,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/process/factory/containers" @@ -79,6 +81,9 @@ func NewMetaInterceptorsContainerFactory( if check.IfNil(args.PeerSignatureHandler) { return nil, process.ErrNilPeerSignatureHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } if args.HeartbeatExpiryTimespanInSec < minTimespanDurationInSec { return nil, process.ErrInvalidExpiryTimespan } @@ -102,28 +107,30 @@ func NewMetaInterceptorsContainerFactory( } base := &baseInterceptorsContainerFactory{ - mainContainer: containers.NewInterceptorsContainer(), - fullArchiveContainer: containers.NewInterceptorsContainer(), - shardCoordinator: args.ShardCoordinator, - mainMessenger: args.MainMessenger, - fullArchiveMessenger: args.FullArchiveMessenger, - store: args.Store, - dataPool: args.DataPool, - nodesCoordinator: args.NodesCoordinator, - blockBlackList: args.BlockBlackList, - argInterceptorFactory: argInterceptorFactory, - maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, - accounts: args.Accounts, - antifloodHandler: args.AntifloodHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - preferredPeersHolder: args.PreferredPeersHolder, - hasher: args.CoreComponents.Hasher(), - requestHandler: args.RequestHandler, - mainPeerShardMapper: args.MainPeerShardMapper, - fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, - hardforkTrigger: args.HardforkTrigger, - nodeOperationMode: args.NodeOperationMode, + mainContainer: containers.NewInterceptorsContainer(), + fullArchiveContainer: containers.NewInterceptorsContainer(), + shardCoordinator: args.ShardCoordinator, + mainMessenger: args.MainMessenger, + fullArchiveMessenger: args.FullArchiveMessenger, + store: args.Store, + dataPool: args.DataPool, + nodesCoordinator: args.NodesCoordinator, + blockBlackList: args.BlockBlackList, + argInterceptorFactory: argInterceptorFactory, + maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, + accounts: args.Accounts, + antifloodHandler: args.AntifloodHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + preferredPeersHolder: args.PreferredPeersHolder, + hasher: args.CoreComponents.Hasher(), + requestHandler: args.RequestHandler, + mainPeerShardMapper: args.MainPeerShardMapper, + fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, + hardforkTrigger: args.HardforkTrigger, + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, + enableEpochsHandler: args.CoreComponents.EnableEpochsHandler(), } icf := &metaInterceptorsContainerFactory{ @@ -195,6 +202,11 @@ func (micf *metaInterceptorsContainerFactory) Create() (process.InterceptorsCont return nil, nil, err } + err = micf.generateEquivalentProofsInterceptors() + if err != nil { + return nil, nil, err + } + return micf.mainContainer, micf.fullArchiveContainer, nil } @@ -253,24 +265,32 @@ func (micf *metaInterceptorsContainerFactory) createOneShardHeaderInterceptor(to } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: micf.dataPool.Headers(), - BlockBlackList: micf.blockBlackList, + Headers: micf.dataPool.Headers(), + BlockBlackList: micf.blockBlackList, + Proofs: micf.dataPool.Proofs(), + EnableEpochsHandler: micf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return nil, err } + interceptedDataVerifier, err := micf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := processInterceptors.NewSingleDataInterceptor( processInterceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: micf.globalThrottler, - AntifloodHandler: micf.antifloodHandler, - WhiteListRequest: micf.whiteListHandler, - CurrentPeerId: micf.mainMessenger.ID(), - PreferredPeersHolder: micf.preferredPeersHolder, + Topic: topic, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: micf.globalThrottler, + AntifloodHandler: micf.antifloodHandler, + WhiteListRequest: micf.whiteListHandler, + CurrentPeerId: micf.mainMessenger.ID(), + PreferredPeersHolder: micf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -329,6 +349,39 @@ func (micf *metaInterceptorsContainerFactory) generateRewardTxInterceptors() err return micf.addInterceptorsToContainers(keys, interceptorSlice) } +func (micf *metaInterceptorsContainerFactory) generateEquivalentProofsInterceptors() error { + shardC := micf.shardCoordinator + noOfShards := shardC.NumberOfShards() + + keys := make([]string, noOfShards+1) + interceptorSlice := make([]process.Interceptor, noOfShards+1) + + for idx := uint32(0); idx < noOfShards; idx++ { + // equivalent proofs shard topic, to listen for shard proofs, for example: equivalentProofs_0_META + identifierEquivalentProofs := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(idx) + interceptor, err := micf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofs) + if err != nil { + return err + } + + keys[int(idx)] = identifierEquivalentProofs + interceptorSlice[int(idx)] = interceptor + } + + // equivalent proofs meta all topic, equivalentProofs_META_ALL + identifierEquivalentProofs := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.AllShardId) + + interceptor, err := micf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofs) + if err != nil { + return err + } + + keys[noOfShards] = identifierEquivalentProofs + interceptorSlice[noOfShards] = interceptor + + return micf.addInterceptorsToContainers(keys, interceptorSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (micf *metaInterceptorsContainerFactory) IsInterfaceNil() bool { return micf == nil diff --git a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go index c8ed20b5fad..ec699e5803b 100644 --- a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go @@ -5,6 +5,9 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" @@ -14,6 +17,8 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -21,8 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const maxTxNonceDeltaAllowed = 100 @@ -63,7 +66,7 @@ func createMetaDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -72,11 +75,14 @@ func createMetaDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } return pools @@ -397,6 +403,18 @@ func TestNewMetaInterceptorsContainerFactory_NilPeerSignatureHandler(t *testing. assert.Equal(t, process.ErrNilPeerSignatureHandler, err) } +func TestNewMetaInterceptorsContainerFactory_NilInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + coreComp, cryptoComp := createMockComponentHolders() + args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = nil + icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) + + assert.Nil(t, icf) + assert.Equal(t, process.ErrNilInterceptedDataVerifierFactory, err) +} + func TestNewMetaInterceptorsContainerFactory_InvalidExpiryTimespan(t *testing.T) { t.Parallel() @@ -521,6 +539,8 @@ func TestMetaInterceptorsContainerFactory_CreateTopicsAndRegisterFailure(t *test testCreateMetaTopicShouldFailOnAllMessenger(t, "generatePeerShardInterceptor", common.ConnectionTopic, "") + testCreateMetaTopicShouldFailOnAllMessenger(t, "generateEquivalentProofsInterceptors", common.EquivalentProofsTopic, "") + t.Run("generatePeerAuthenticationInterceptor_main", testCreateMetaTopicShouldFail(common.PeerAuthenticationTopic, "")) } @@ -541,6 +561,7 @@ func testCreateMetaTopicShouldFail(matchStrToErrOnCreate string, matchStrToErrOn } else { args.MainMessenger = createMetaStubTopicHandler(matchStrToErrOnCreate, matchStrToErrOnRegister) } + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) mainContainer, fullArchiveConatiner, err := icf.Create() @@ -556,13 +577,15 @@ func TestMetaInterceptorsContainerFactory_CreateShouldWork(t *testing.T) { coreComp, cryptoComp := createMockComponentHolders() args := getArgumentsMeta(coreComp, cryptoComp) + + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) mainContainer, fullArchiveContainer, err := icf.Create() + require.Nil(t, err) assert.NotNil(t, mainContainer) assert.NotNil(t, fullArchiveContainer) - assert.Nil(t, err) } func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { @@ -588,6 +611,8 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args := getArgumentsMeta(coreComp, cryptoComp) args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} + icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) require.Nil(t, err) @@ -604,10 +629,11 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorsHeartbeatForMetachain := 1 numInterceptorsShardValidatorInfoForMetachain := 1 numInterceptorValidatorInfo := 1 + numInterceptorsEquivalentProofs := noOfShards + 1 totalInterceptors := numInterceptorsMetablock + numInterceptorsShardHeadersForMetachain + numInterceptorsTrieNodes + numInterceptorsTransactionsForMetachain + numInterceptorsUnsignedTxsForMetachain + numInterceptorsMiniBlocksForMetachain + numInterceptorsRewardsTxsForMetachain + numInterceptorsPeerAuthForMetachain + numInterceptorsHeartbeatForMetachain + - numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -637,6 +663,7 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.NodeOperationMode = common.FullArchiveMode args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) require.Nil(t, err) @@ -654,10 +681,11 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorsHeartbeatForMetachain := 1 numInterceptorsShardValidatorInfoForMetachain := 1 numInterceptorValidatorInfo := 1 + numInterceptorsEquivalentProofs := noOfShards + 1 totalInterceptors := numInterceptorsMetablock + numInterceptorsShardHeadersForMetachain + numInterceptorsTrieNodes + numInterceptorsTransactionsForMetachain + numInterceptorsUnsignedTxsForMetachain + numInterceptorsMiniBlocksForMetachain + numInterceptorsRewardsTxsForMetachain + numInterceptorsPeerAuthForMetachain + numInterceptorsHeartbeatForMetachain + - numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -678,34 +706,35 @@ func getArgumentsMeta( cryptoComp *mock.CryptoComponentsMock, ) interceptorscontainer.CommonInterceptorsContainerFactoryArgs { return interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComp, - CryptoComponents: cryptoComp, - Accounts: &stateMock.AccountsStub{}, - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), - MainMessenger: &mock.TopicHandlerStub{}, - FullArchiveMessenger: &mock.TopicHandlerStub{}, - Store: createMetaStore(), - DataPool: createMetaDataPools(), - MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, - TxFeeHandler: &economicsmocks.EconomicsHandlerStub{}, - BlockBlackList: &testscommon.TimeCacheStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, - WhiteListHandler: &testscommon.WhiteListHandlerStub{}, - WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - ArgumentsParser: &mock.ArgumentParserMock{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, - SignaturesHandler: &mock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - HardforkTrigger: &testscommon.HardforkTriggerStub{}, - NodeOperationMode: common.NormalOperation, + CoreComponents: coreComp, + CryptoComponents: cryptoComp, + Accounts: &stateMock.AccountsStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), + MainMessenger: &mock.TopicHandlerStub{}, + FullArchiveMessenger: &mock.TopicHandlerStub{}, + Store: createMetaStore(), + DataPool: createMetaDataPools(), + MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerStub{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + WhiteListHandler: &testscommon.WhiteListHandlerStub{}, + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + ArgumentsParser: &mock.ArgumentParserMock{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, + SignaturesHandler: &mock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + NodeOperationMode: common.NormalOperation, + InterceptedDataVerifierFactory: &mock.InterceptedDataVerifierFactoryMock{}, } } diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go index beef288c54c..e3a4e639d5b 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go @@ -5,6 +5,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/process/factory/containers" @@ -78,6 +81,9 @@ func NewShardInterceptorsContainerFactory( if check.IfNil(args.PeerSignatureHandler) { return nil, process.ErrNilPeerSignatureHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } if args.HeartbeatExpiryTimespanInSec < minTimespanDurationInSec { return nil, process.ErrInvalidExpiryTimespan } @@ -101,28 +107,30 @@ func NewShardInterceptorsContainerFactory( } base := &baseInterceptorsContainerFactory{ - mainContainer: containers.NewInterceptorsContainer(), - fullArchiveContainer: containers.NewInterceptorsContainer(), - accounts: args.Accounts, - shardCoordinator: args.ShardCoordinator, - mainMessenger: args.MainMessenger, - fullArchiveMessenger: args.FullArchiveMessenger, - store: args.Store, - dataPool: args.DataPool, - nodesCoordinator: args.NodesCoordinator, - argInterceptorFactory: argInterceptorFactory, - blockBlackList: args.BlockBlackList, - maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, - antifloodHandler: args.AntifloodHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - preferredPeersHolder: args.PreferredPeersHolder, - hasher: args.CoreComponents.Hasher(), - requestHandler: args.RequestHandler, - mainPeerShardMapper: args.MainPeerShardMapper, - fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, - hardforkTrigger: args.HardforkTrigger, - nodeOperationMode: args.NodeOperationMode, + mainContainer: containers.NewInterceptorsContainer(), + fullArchiveContainer: containers.NewInterceptorsContainer(), + accounts: args.Accounts, + shardCoordinator: args.ShardCoordinator, + mainMessenger: args.MainMessenger, + fullArchiveMessenger: args.FullArchiveMessenger, + store: args.Store, + dataPool: args.DataPool, + nodesCoordinator: args.NodesCoordinator, + argInterceptorFactory: argInterceptorFactory, + blockBlackList: args.BlockBlackList, + maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, + antifloodHandler: args.AntifloodHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + preferredPeersHolder: args.PreferredPeersHolder, + hasher: args.CoreComponents.Hasher(), + requestHandler: args.RequestHandler, + mainPeerShardMapper: args.MainPeerShardMapper, + fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, + hardforkTrigger: args.HardforkTrigger, + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, + enableEpochsHandler: args.CoreComponents.EnableEpochsHandler(), } icf := &shardInterceptorsContainerFactory{ @@ -194,6 +202,11 @@ func (sicf *shardInterceptorsContainerFactory) Create() (process.InterceptorsCon return nil, nil, err } + err = sicf.generateEquivalentProofsInterceptor() + if err != nil { + return nil, nil, err + } + return sicf.mainContainer, sicf.fullArchiveContainer, nil } @@ -235,6 +248,28 @@ func (sicf *shardInterceptorsContainerFactory) generateRewardTxInterceptor() err return sicf.addInterceptorsToContainers(keys, interceptorSlice) } +func (sicf *shardInterceptorsContainerFactory) generateEquivalentProofsInterceptor() error { + shardC := sicf.shardCoordinator + + // equivalent proofs shard topic, for example: equivalentProofs_0_META + identifierEquivalentProofsShardMeta := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + + interceptorShardMeta, err := sicf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofsShardMeta) + if err != nil { + return err + } + + // equivalent proofs _ALL topic, to listen for meta proofs, example: equivalentProofs_META_ALL + identifierEquivalentProofsMetaAll := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + + interceptorMetaAll, err := sicf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofsMetaAll) + if err != nil { + return err + } + + return sicf.addInterceptorsToContainers([]string{identifierEquivalentProofsShardMeta, identifierEquivalentProofsMetaAll}, []process.Interceptor{interceptorShardMeta, interceptorMetaAll}) +} + // IsInterfaceNil returns true if there is no value under the interface func (sicf *shardInterceptorsContainerFactory) IsInterfaceNil() bool { return sicf == nil diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go index 24472c24f32..d05099299d5 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go @@ -6,6 +6,9 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/versioning" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" @@ -15,6 +18,8 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -25,7 +30,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) var providedHardforkPubKey = []byte("provided hardfork pub key") @@ -64,13 +68,13 @@ func createShardDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.MetaBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -79,14 +83,18 @@ func createShardDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() } pools.TrieNodesCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.TrieNodesChunksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.CurrBlockTxsCalled = func() dataRetriever.TransactionCacher { return &mock.TxForCurrentBlockStub{} } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + return pools } @@ -353,6 +361,18 @@ func TestNewShardInterceptorsContainerFactory_NilValidityAttesterShouldErr(t *te assert.Equal(t, process.ErrNilValidityAttester, err) } +func TestNewShardInterceptorsContainerFactory_NilInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + coreComp, cryptoComp := createMockComponentHolders() + args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = nil + icf, err := interceptorscontainer.NewShardInterceptorsContainerFactory(args) + + assert.Nil(t, icf) + assert.Equal(t, process.ErrNilInterceptedDataVerifierFactory, err) +} + func TestNewShardInterceptorsContainerFactory_InvalidChainIDShouldErr(t *testing.T) { t.Parallel() @@ -479,6 +499,8 @@ func TestShardInterceptorsContainerFactory_CreateTopicsAndRegisterFailure(t *tes testCreateShardTopicShouldFailOnAllMessenger(t, "generatePeerShardIntercepto", common.ConnectionTopic, "") + testCreateShardTopicShouldFailOnAllMessenger(t, "generateEquivalentProofsInterceptor", common.EquivalentProofsTopic, "") + t.Run("generatePeerAuthenticationInterceptor_main", testCreateShardTopicShouldFail(common.PeerAuthenticationTopic, "")) } func testCreateShardTopicShouldFailOnAllMessenger(t *testing.T, testNamePrefix string, matchStrToErrOnCreate string, matchStrToErrOnRegister string) { @@ -492,6 +514,7 @@ func testCreateShardTopicShouldFail(matchStrToErrOnCreate string, matchStrToErrO coreComp, cryptoComp := createMockComponentHolders() args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} if strings.Contains(t.Name(), "full_archive") { args.NodeOperationMode = common.FullArchiveMode args.FullArchiveMessenger = createShardStubTopicHandler(matchStrToErrOnCreate, matchStrToErrOnRegister) @@ -558,14 +581,15 @@ func TestShardInterceptorsContainerFactory_CreateShouldWork(t *testing.T) { }, } args.WhiteListerVerifiedTxs = &testscommon.WhiteListHandlerStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) mainContainer, fullArchiveContainer, err := icf.Create() + require.Nil(t, err) assert.NotNil(t, mainContainer) assert.NotNil(t, fullArchiveContainer) - assert.Nil(t, err) } func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { @@ -593,6 +617,7 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator args.PreferredPeersHolder = &p2pmocks.PeersHolderStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) @@ -609,9 +634,11 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorHeartbeat := 1 numInterceptorsShardValidatorInfo := 1 numInterceptorValidatorInfo := 1 + numInterceptorEquivalentProofs := 2 totalInterceptors := numInterceptorTxs + numInterceptorsUnsignedTxs + numInterceptorsRewardTxs + numInterceptorHeaders + numInterceptorMiniBlocks + numInterceptorMetachainHeaders + numInterceptorTrieNodes + - numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + + numInterceptorEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -641,6 +668,7 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator args.PreferredPeersHolder = &p2pmocks.PeersHolderStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) @@ -657,9 +685,11 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorHeartbeat := 1 numInterceptorsShardValidatorInfo := 1 numInterceptorValidatorInfo := 1 + numInterceptorEquivalentProofs := 2 totalInterceptors := numInterceptorTxs + numInterceptorsUnsignedTxs + numInterceptorsRewardTxs + numInterceptorHeaders + numInterceptorMiniBlocks + numInterceptorMetachainHeaders + numInterceptorTrieNodes + - numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + + numInterceptorEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -703,34 +733,35 @@ func getArgumentsShard( cryptoComp *mock.CryptoComponentsMock, ) interceptorscontainer.CommonInterceptorsContainerFactoryArgs { return interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComp, - CryptoComponents: cryptoComp, - Accounts: &stateMock.AccountsStub{}, - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), - MainMessenger: &mock.TopicHandlerStub{}, - FullArchiveMessenger: &mock.TopicHandlerStub{}, - Store: createShardStore(), - DataPool: createShardDataPools(), - MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, - TxFeeHandler: &economicsmocks.EconomicsHandlerStub{}, - BlockBlackList: &testscommon.TimeCacheStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - SizeCheckDelta: 0, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListHandler: &testscommon.WhiteListHandlerStub{}, - WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - ArgumentsParser: &mock.ArgumentParserMock{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, - SignaturesHandler: &mock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - HardforkTrigger: &testscommon.HardforkTriggerStub{}, + CoreComponents: coreComp, + CryptoComponents: cryptoComp, + Accounts: &stateMock.AccountsStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), + MainMessenger: &mock.TopicHandlerStub{}, + FullArchiveMessenger: &mock.TopicHandlerStub{}, + Store: createShardStore(), + DataPool: createShardDataPools(), + MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerStub{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + SizeCheckDelta: 0, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListHandler: &testscommon.WhiteListHandlerStub{}, + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + ArgumentsParser: &mock.ArgumentParserMock{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, + SignaturesHandler: &mock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + InterceptedDataVerifierFactory: &mock.InterceptedDataVerifierFactoryMock{}, } } diff --git a/process/factory/shard/intermediateProcessorsContainerFactory_test.go b/process/factory/shard/intermediateProcessorsContainerFactory_test.go index 5835a7361ac..a1a39c28402 100644 --- a/process/factory/shard/intermediateProcessorsContainerFactory_test.go +++ b/process/factory/shard/intermediateProcessorsContainerFactory_test.go @@ -3,6 +3,8 @@ package shard_test import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -10,13 +12,13 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" txExecOrderStub "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) func createDataPools() dataRetriever.PoolsHolder { @@ -28,13 +30,13 @@ func createDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.MetaBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -43,7 +45,7 @@ func createDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() } pools.TrieNodesCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.CurrBlockTxsCalled = func() dataRetriever.TransactionCacher { return &mock.TxForCurrentBlockStub{} diff --git a/process/headerCheck/common.go b/process/headerCheck/common.go index 01946580d87..353c112e501 100644 --- a/process/headerCheck/common.go +++ b/process/headerCheck/common.go @@ -3,20 +3,24 @@ package headerCheck import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) // ComputeConsensusGroup will compute the consensus group that assembled the provided block -func ComputeConsensusGroup(header data.HeaderHandler, nodesCoordinator nodesCoordinator.NodesCoordinator) (validatorsGroup []nodesCoordinator.Validator, err error) { +func ComputeConsensusGroup(header data.HeaderHandler, nodesCoordinator nodesCoordinator.NodesCoordinator) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { if check.IfNil(header) { - return nil, process.ErrNilHeaderHandler + return nil, nil, process.ErrNilHeaderHandler } if check.IfNil(nodesCoordinator) { - return nil, process.ErrNilNodesCoordinator + return nil, nil, process.ErrNilNodesCoordinator } prevRandSeed := header.GetPrevRandSeed() + if prevRandSeed == nil { + return nil, nil, process.ErrNilPrevRandSeed + } // TODO: change here with an activation flag if start of epoch block needs to be validated by the new epoch nodes epoch := header.GetEpoch() diff --git a/process/headerCheck/common_test.go b/process/headerCheck/common_test.go index 0961b7f2a20..8924327fcbd 100644 --- a/process/headerCheck/common_test.go +++ b/process/headerCheck/common_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) func TestComputeConsensusGroup(t *testing.T) { @@ -16,14 +17,15 @@ func TestComputeConsensusGroup(t *testing.T) { t.Run("nil header should error", func(t *testing.T) { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Fail(t, "should have not called ComputeValidatorsGroupCalled") - return nil, nil + return nil, nil, nil } - vGroup, err := ComputeConsensusGroup(nil, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(nil, nodesCoordinatorInstance) assert.Equal(t, process.ErrNilHeaderHandler, err) assert.Nil(t, vGroup) + assert.Nil(t, leader) }) t.Run("nil nodes coordinator should error", func(t *testing.T) { header := &block.Header{ @@ -34,9 +36,10 @@ func TestComputeConsensusGroup(t *testing.T) { PrevRandSeed: []byte("prev rand seed"), } - vGroup, err := ComputeConsensusGroup(header, nil) + leader, vGroup, err := ComputeConsensusGroup(header, nil) assert.Equal(t, process.ErrNilNodesCoordinator, err) assert.Nil(t, vGroup) + assert.Nil(t, leader) }) t.Run("should work for a random block", func(t *testing.T) { header := &block.Header{ @@ -52,18 +55,19 @@ func TestComputeConsensusGroup(t *testing.T) { validatorGroup := []nodesCoordinator.Validator{validator1, validator2} nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Equal(t, header.PrevRandSeed, randomness) assert.Equal(t, header.Round, round) assert.Equal(t, header.ShardID, shardId) assert.Equal(t, header.Epoch, epoch) - return validatorGroup, nil + return validator1, validatorGroup, nil } - vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) assert.Nil(t, err) assert.Equal(t, validatorGroup, vGroup) + assert.Equal(t, validator1, leader) }) t.Run("should work for a start of epoch block", func(t *testing.T) { header := &block.Header{ @@ -80,18 +84,19 @@ func TestComputeConsensusGroup(t *testing.T) { validatorGroup := []nodesCoordinator.Validator{validator1, validator2} nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Equal(t, header.PrevRandSeed, randomness) assert.Equal(t, header.Round, round) assert.Equal(t, header.ShardID, shardId) assert.Equal(t, header.Epoch-1, epoch) - return validatorGroup, nil + return validator1, validatorGroup, nil } - vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) assert.Nil(t, err) assert.Equal(t, validatorGroup, vGroup) + assert.Equal(t, validator1, leader) }) } diff --git a/process/headerCheck/errors.go b/process/headerCheck/errors.go index e0d4123ae2b..b808de98518 100644 --- a/process/headerCheck/errors.go +++ b/process/headerCheck/errors.go @@ -23,3 +23,9 @@ var ErrIndexOutOfBounds = errors.New("index is out of bounds") // ErrIndexNotSelected signals that the given index is not selected var ErrIndexNotSelected = errors.New("index is not selected") + +// ErrProofShardMismatch signals that the proof shard does not match the header shard +var ErrProofShardMismatch = errors.New("proof shard mismatch") + +// ErrProofHeaderHashMismatch signals that the proof header hash does not match the header hash +var ErrProofHeaderHashMismatch = errors.New("proof header hash mismatch") diff --git a/process/headerCheck/headerSignatureVerify.go b/process/headerCheck/headerSignatureVerify.go index 308af919366..50bc3ff42ac 100644 --- a/process/headerCheck/headerSignatureVerify.go +++ b/process/headerCheck/headerSignatureVerify.go @@ -1,6 +1,8 @@ package headerCheck import ( + "bytes" + "fmt" "math/bits" "github.com/multiversx/mx-chain-core-go/core" @@ -9,10 +11,13 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.InterceptedHeaderSigVerifier = (*HeaderSigVerifier)(nil) @@ -28,6 +33,8 @@ type ArgsHeaderSigVerifier struct { SingleSigVerifier crypto.SingleSigner KeyGen crypto.KeyGenerator FallbackHeaderValidator process.FallbackHeaderValidator + EnableEpochsHandler common.EnableEpochsHandler + HeadersPool dataRetriever.HeadersPool } // HeaderSigVerifier is component used to check if a header is valid @@ -39,6 +46,8 @@ type HeaderSigVerifier struct { singleSigVerifier crypto.SingleSigner keyGen crypto.KeyGenerator fallbackHeaderValidator process.FallbackHeaderValidator + enableEpochsHandler common.EnableEpochsHandler + headersPool dataRetriever.HeadersPool } // NewHeaderSigVerifier will create a new instance of HeaderSigVerifier @@ -56,6 +65,8 @@ func NewHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) (*HeaderSigVerifier, singleSigVerifier: arguments.SingleSigVerifier, keyGen: arguments.KeyGen, fallbackHeaderValidator: arguments.FallbackHeaderValidator, + enableEpochsHandler: arguments.EnableEpochsHandler, + headersPool: arguments.HeadersPool, }, nil } @@ -91,6 +102,12 @@ func checkArgsHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) error { if check.IfNil(arguments.FallbackHeaderValidator) { return process.ErrNilFallbackHeaderValidator } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arguments.HeadersPool) { + return process.ErrNilHeadersDataPool + } return nil } @@ -109,77 +126,192 @@ func isIndexInBitmap(index uint16, bitmap []byte) error { return nil } -func (hsv *HeaderSigVerifier) getConsensusSigners(header data.HeaderHandler) ([][]byte, error) { - randSeed := header.GetPrevRandSeed() - bitmap := header.GetPubKeysBitmap() - if len(bitmap) == 0 { +func (hsv *HeaderSigVerifier) getConsensusSigners( + randSeed []byte, + shardID uint32, + epoch uint32, + startOfEpochBlock bool, + round uint64, + prevHash []byte, + pubKeysBitmap []byte, +) ([][]byte, error) { + if len(pubKeysBitmap) == 0 { return nil, process.ErrNilPubKeysBitmap } - if bitmap[0]&1 == 0 { - return nil, process.ErrBlockProposerSignatureMissing + + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, epoch) { + if pubKeysBitmap[0]&1 == 0 { + return nil, process.ErrBlockProposerSignatureMissing + } } // TODO: remove if start of epochForConsensus block needs to be validated by the new epochForConsensus nodes - epochForConsensus := header.GetEpoch() - if header.IsStartOfEpochBlock() && epochForConsensus > 0 { + epochForConsensus := epoch + if startOfEpochBlock && epochForConsensus > 0 { epochForConsensus = epochForConsensus - 1 } - consensusPubKeys, err := hsv.nodesCoordinator.GetConsensusValidatorsPublicKeys( + _, consensusPubKeys, err := hsv.nodesCoordinator.GetConsensusValidatorsPublicKeys( randSeed, - header.GetRound(), - header.GetShardID(), + round, + shardID, epochForConsensus, ) if err != nil { return nil, err } - err = hsv.verifyConsensusSize(consensusPubKeys, header) + err = hsv.verifyConsensusSize( + consensusPubKeys, + pubKeysBitmap, + shardID, + startOfEpochBlock, + round, + prevHash) if err != nil { return nil, err } + return getPubKeySigners(consensusPubKeys, pubKeysBitmap), nil +} + +func getPubKeySigners(consensusPubKeys []string, pubKeysBitmap []byte) [][]byte { pubKeysSigners := make([][]byte, 0, len(consensusPubKeys)) for i := range consensusPubKeys { - err = isIndexInBitmap(uint16(i), bitmap) + err := isIndexInBitmap(uint16(i), pubKeysBitmap) if err != nil { continue } pubKeysSigners = append(pubKeysSigners, []byte(consensusPubKeys[i])) } - return pubKeysSigners, nil + return pubKeysSigners } // VerifySignature will check if signature is correct func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { + if hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + return hsv.VerifyHeaderWithProof(header) + } + + headerCopy, err := hsv.copyHeaderWithoutSig(header) + if err != nil { + return err + } + + hash, err := core.CalculateHash(hsv.marshalizer, hsv.hasher, headerCopy) + if err != nil { + return err + } + + bitmap := header.GetPubKeysBitmap() + sig := header.GetSignature() + return hsv.VerifySignatureForHash(headerCopy, hash, bitmap, sig) +} + +func verifyPrevProofForHeader(header data.HeaderHandler) error { + prevProof := header.GetPreviousProof() + if check.IfNilReflect(prevProof) { + return process.ErrNilHeaderProof + } + + if header.GetShardID() != prevProof.GetHeaderShardId() { + return ErrProofShardMismatch + } + + if !bytes.Equal(header.GetPrevHash(), prevProof.GetHeaderHash()) { + return ErrProofHeaderHashMismatch + } + + return nil +} + +// VerifySignatureForHash will check if signature is correct for the provided hash +func (hsv *HeaderSigVerifier) VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error { multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(header.GetEpoch()) if err != nil { return err } - headerCopy, err := hsv.copyHeaderWithoutSig(header) + randSeed := header.GetPrevRandSeed() + if randSeed == nil { + return process.ErrNilPrevRandSeed + } + pubKeysSigners, err := hsv.getConsensusSigners( + randSeed, + header.GetShardID(), + header.GetEpoch(), + header.IsStartOfEpochBlock(), + header.GetRound(), + header.GetPrevHash(), + pubkeysBitmap, + ) if err != nil { return err } - hash, err := core.CalculateHash(hsv.marshalizer, hsv.hasher, headerCopy) + return multiSigVerifier.VerifyAggregatedSig(pubKeysSigners, hash, signature) +} + +// VerifyHeaderWithProof checks if the proof on the header is correct +func (hsv *HeaderSigVerifier) VerifyHeaderWithProof(header data.HeaderHandler) error { + err := verifyPrevProofForHeader(header) + if err != nil { + return err + } + + prevProof := header.GetPreviousProof() + return hsv.VerifyHeaderProof(prevProof) +} + +// VerifyHeaderProof checks if the proof is correct for the header +func (hsv *HeaderSigVerifier) VerifyHeaderProof(proofHandler data.HeaderProofHandler) error { + if check.IfNilReflect(proofHandler) { + return process.ErrNilHeaderProof + } + + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, proofHandler.GetHeaderEpoch()) { + return fmt.Errorf("%w for flag %s", process.ErrFlagNotActive, common.EquivalentMessagesFlag) + } + + // for the start of epoch block the consensus is taken from the previous epoch + header, err := hsv.headersPool.GetHeaderByHash(proofHandler.GetHeaderHash()) if err != nil { return err } - pubKeysSigners, err := hsv.getConsensusSigners(header) + multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(proofHandler.GetHeaderEpoch()) if err != nil { return err } - return multiSigVerifier.VerifyAggregatedSig(pubKeysSigners, hash, header.GetSignature()) + // round, prevHash and prevRandSeed could be removed when we remove fallback validation and we don't need backwards compatibility + // (e.g new binary from epoch x forward) + consensusPubKeys, err := hsv.getConsensusSigners( + header.GetPrevRandSeed(), + proofHandler.GetHeaderShardId(), + proofHandler.GetHeaderEpoch(), + header.IsStartOfEpochBlock(), + header.GetRound(), + header.GetPrevHash(), + proofHandler.GetPubKeysBitmap(), + ) + if err != nil { + return err + } + + return multiSigVerifier.VerifyAggregatedSig(consensusPubKeys, proofHandler.GetHeaderHash(), proofHandler.GetAggregatedSignature()) } -func (hsv *HeaderSigVerifier) verifyConsensusSize(consensusPubKeys []string, header data.HeaderHandler) error { +func (hsv *HeaderSigVerifier) verifyConsensusSize( + consensusPubKeys []string, + bitmap []byte, + shardID uint32, + startOfEpochBlock bool, + round uint64, + prevHash []byte, +) error { consensusSize := len(consensusPubKeys) - bitmap := header.GetPubKeysBitmap() expectedBitmapSize := consensusSize / 8 if consensusSize%8 != 0 { @@ -198,7 +330,12 @@ func (hsv *HeaderSigVerifier) verifyConsensusSize(consensusPubKeys []string, hea } minNumRequiredSignatures := core.GetPBFTThreshold(consensusSize) - if hsv.fallbackHeaderValidator.ShouldApplyFallbackValidation(header) { + if hsv.fallbackHeaderValidator.ShouldApplyFallbackValidationForHeaderWith( + shardID, + startOfEpochBlock, + round, + prevHash, + ) { minNumRequiredSignatures = core.GetPBFTFallbackThreshold(consensusSize) log.Warn("HeaderSigVerifier.verifyConsensusSize: fallback validation has been applied", "minimum number of signatures required", minNumRequiredSignatures, @@ -282,7 +419,15 @@ func (hsv *HeaderSigVerifier) IsInterfaceNil() bool { func (hsv *HeaderSigVerifier) verifyRandSeed(leaderPubKey crypto.PublicKey, header data.HeaderHandler) error { prevRandSeed := header.GetPrevRandSeed() + if prevRandSeed == nil { + return process.ErrNilPrevRandSeed + } + randSeed := header.GetRandSeed() + if randSeed == nil { + return process.ErrNilRandSeed + } + return hsv.singleSigVerifier.Verify(leaderPubKey, prevRandSeed, randSeed) } @@ -301,13 +446,11 @@ func (hsv *HeaderSigVerifier) verifyLeaderSignature(leaderPubKey crypto.PublicKe } func (hsv *HeaderSigVerifier) getLeader(header data.HeaderHandler) (crypto.PublicKey, error) { - headerConsensusGroup, err := ComputeConsensusGroup(header, hsv.nodesCoordinator) + leader, _, err := ComputeConsensusGroup(header, hsv.nodesCoordinator) if err != nil { return nil, err } - - leaderPubKeyValidator := headerConsensusGroup[0] - return hsv.keyGen.PublicKeyFromByteArray(leaderPubKeyValidator.PubKey()) + return hsv.keyGen.PublicKeyFromByteArray(leader.PubKey()) } func (hsv *HeaderSigVerifier) copyHeaderWithoutSig(header data.HeaderHandler) (data.HeaderHandler, error) { @@ -322,9 +465,11 @@ func (hsv *HeaderSigVerifier) copyHeaderWithoutSig(header data.HeaderHandler) (d return nil, err } - err = headerCopy.SetLeaderSignature(nil) - if err != nil { - return nil, err + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + err = headerCopy.SetLeaderSignature(nil) + if err != nil { + return nil, err + } } return headerCopy, nil diff --git a/process/headerCheck/headerSignatureVerify_test.go b/process/headerCheck/headerSignatureVerify_test.go index f89b8cf90ca..adb372ba15c 100644 --- a/process/headerCheck/headerSignatureVerify_test.go +++ b/process/headerCheck/headerSignatureVerify_test.go @@ -3,32 +3,57 @@ package headerCheck import ( "bytes" "errors" + "strings" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/require" ) const defaultChancesSelection = 1 +var expectedErr = errors.New("expected error") + func createHeaderSigVerifierArgs() *ArgsHeaderSigVerifier { + v1, _ := nodesCoordinator.NewValidator([]byte("pubKey1"), 1, defaultChancesSelection) + v2, _ := nodesCoordinator.NewValidator([]byte("pubKey1"), 1, defaultChancesSelection) return &ArgsHeaderSigVerifier{ - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, - MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), - SingleSigVerifier: &mock.SignerMock{}, - KeyGen: &mock.SingleSignKeyGenMock{}, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + NodesCoordinator: &shardingMocks.NodesCoordinatorMock{ + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { + return v1, []nodesCoordinator.Validator{v1, v2}, nil + }, + }, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), + SingleSigVerifier: &mock.SignerMock{}, + KeyGen: &mock.SingleSignKeyGenMock{ + PublicKeyFromByteArrayCalled: func(b []byte) (key crypto.PublicKey, err error) { + return &mock.SingleSignPublicKey{}, nil + }, + }, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &dataBlock.Header{ + PrevRandSeed: []byte("prevRandSeed"), + }, nil + }, + }, } } @@ -107,6 +132,17 @@ func TestNewHeaderSigVerifier_NilSingleSigShouldErr(t *testing.T) { require.Equal(t, process.ErrNilSingleSigner, err) } +func TestNewHeaderSigVerifier_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + func TestNewHeaderSigVerifier_OkValsShouldWork(t *testing.T) { t.Parallel() @@ -123,10 +159,13 @@ func TestHeaderSigVerifier_VerifySignatureNilPrevRandSeedShouldErr(t *testing.T) args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PrevRandSeed: nil, + RandSeed: []byte("rand seed"), + } err := hdrSigVerifier.VerifyRandSeed(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifyRandSeedOk(t *testing.T) { @@ -149,14 +188,17 @@ func TestHeaderSigVerifier_VerifyRandSeedOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + } err := hdrSigVerifier.VerifyRandSeed(header) require.Nil(t, err) @@ -184,14 +226,17 @@ func TestHeaderSigVerifier_VerifyRandSeedShouldErrWhenVerificationFails(t *testi pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeed(header) require.Equal(t, localError, err) @@ -203,10 +248,13 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureNilRandomnessShouldEr args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: nil, + PrevRandSeed: []byte("prev rand seed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilRandSeed, err) } func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyShouldErrWhenValidationFails(t *testing.T) { @@ -230,14 +278,17 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyShouldErrWhenVa pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) require.Equal(t, localErr, err) @@ -269,14 +320,16 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyLeaderSigShould pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), LeaderSignature: leaderSig, } @@ -305,29 +358,35 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) require.Nil(t, err) require.Equal(t, 2, count) } -func TestHeaderSigVerifier_VerifyLeaderSignatureNilRandomnessShouldErr(t *testing.T) { +func TestHeaderSigVerifier_VerifyLeaderSignatureNilPrevRandomnessShouldErr(t *testing.T) { t.Parallel() args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("rand seed "), + PrevRandSeed: nil, + } err := hdrSigVerifier.VerifyLeaderSignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyShouldErrWhenValidationFails(t *testing.T) { @@ -351,14 +410,17 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyShouldErrWhenValidationFai pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyLeaderSignature(header) require.Equal(t, localErr, err) @@ -390,14 +452,16 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyLeaderSigShouldErr(t *test pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), LeaderSignature: leaderSig, } @@ -426,14 +490,17 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyLeaderSignature(header) require.Nil(t, err) @@ -445,7 +512,11 @@ func TestHeaderSigVerifier_VerifySignatureNilBitmapShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PubKeysBitmap: nil, + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifySignature(header) require.Equal(t, process.ErrNilPubKeysBitmap, err) @@ -458,6 +529,8 @@ func TestHeaderSigVerifier_VerifySignatureBlockProposerSigMissingShouldErr(t *te hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("0"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -470,11 +543,12 @@ func TestHeaderSigVerifier_VerifySignatureNilRandomnessShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + PrevRandSeed: nil, PubKeysBitmap: []byte("1"), } err := hdrSigVerifier.VerifySignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) { @@ -483,9 +557,9 @@ func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc @@ -493,6 +567,8 @@ func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("11"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -505,9 +581,9 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } args.NodesCoordinator = nc @@ -515,6 +591,8 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("A"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -528,9 +606,9 @@ func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc @@ -544,6 +622,7 @@ func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("1"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -558,9 +637,9 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{ @@ -582,6 +661,7 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ PubKeysBitmap: []byte("C"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -596,9 +676,9 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{ @@ -618,10 +698,110 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ - PubKeysBitmap: []byte("C"), + PubKeysBitmap: []byte{15}, + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) require.Nil(t, err) require.True(t, wasCalled) } + +func getFilledHeader() data.HeaderHandler { + return &dataBlock.Header{ + PrevHash: []byte("prev hash"), + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + PubKeysBitmap: []byte{0xFF}, + LeaderSignature: []byte("leader signature"), + } +} + +func TestHeaderSigVerifier_VerifyHeaderProof(t *testing.T) { + t.Parallel() + + t.Run("nil proof should error", func(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = enableEpochsHandlerMock.NewEnableEpochsHandlerStub(common.FixedOrderInConsensusFlag) + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(nil) + require.Equal(t, process.ErrNilHeaderProof, err) + }) + t.Run("flag not active should error", func(t *testing.T) { + t.Parallel() + + hdrSigVerifier, err := NewHeaderSigVerifier(createHeaderSigVerifierArgs()) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{3}, + }) + require.True(t, errors.Is(err, process.ErrFlagNotActive)) + require.True(t, strings.Contains(err.Error(), string(common.EquivalentMessagesFlag))) + }) + t.Run("GetMultiSigner error should error", func(t *testing.T) { + t.Parallel() + + cnt := 0 + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + cnt++ + if cnt > 1 { + return nil, expectedErr + } + return &cryptoMocks.MultiSignerStub{}, nil + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{}) + require.Equal(t, expectedErr, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + headerHash := []byte("header hash") + wasVerifyAggregatedSigCalled := false + args := createHeaderSigVerifierArgs() + args.HeadersPool = &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return getFilledHeader(), nil + }, + } + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.FixedOrderInConsensusFlag || flag == common.EquivalentMessagesFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultiSignerStub{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasVerifyAggregatedSigCalled = true + return nil + }, + }, nil + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{0x3}, + AggregatedSignature: make([]byte, 10), + HeaderHash: headerHash, + }) + require.NoError(t, err) + require.True(t, wasVerifyAggregatedSigCalled) + }) +} diff --git a/process/interceptors/baseDataInterceptor.go b/process/interceptors/baseDataInterceptor.go index 64efb852238..cec00abd756 100644 --- a/process/interceptors/baseDataInterceptor.go +++ b/process/interceptors/baseDataInterceptor.go @@ -6,19 +6,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" ) type baseDataInterceptor struct { - throttler process.InterceptorThrottler - antifloodHandler process.P2PAntifloodHandler - topic string - currentPeerId core.PeerID - processor process.InterceptorProcessor - mutDebugHandler sync.RWMutex - debugHandler process.InterceptedDebugger - preferredPeersHolder process.PreferredPeersHolderHandler + throttler process.InterceptorThrottler + antifloodHandler process.P2PAntifloodHandler + topic string + currentPeerId core.PeerID + processor process.InterceptorProcessor + mutDebugHandler sync.RWMutex + debugHandler process.InterceptedDebugger + preferredPeersHolder process.PreferredPeersHolderHandler + interceptedDataVerifier process.InterceptedDataVerifier } func (bdi *baseDataInterceptor) preProcessMesage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { diff --git a/process/interceptors/factory/argInterceptedDataFactory.go b/process/interceptors/factory/argInterceptedDataFactory.go index 37701a92f7a..4cc4cb846a8 100644 --- a/process/interceptors/factory/argInterceptedDataFactory.go +++ b/process/interceptors/factory/argInterceptedDataFactory.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" diff --git a/process/interceptors/factory/interceptedDataVerifierFactory.go b/process/interceptors/factory/interceptedDataVerifierFactory.go new file mode 100644 index 00000000000..2775bbdc61a --- /dev/null +++ b/process/interceptors/factory/interceptedDataVerifierFactory.go @@ -0,0 +1,72 @@ +package factory + +import ( + "fmt" + "sync" + "time" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/interceptors" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/cache" +) + +// InterceptedDataVerifierFactoryArgs holds the required arguments for interceptedDataVerifierFactory +type InterceptedDataVerifierFactoryArgs struct { + CacheSpan time.Duration + CacheExpiry time.Duration +} + +// interceptedDataVerifierFactory encapsulates the required arguments to create InterceptedDataVerifier +// Furthermore it will hold all such instances in an internal map. +type interceptedDataVerifierFactory struct { + cacheSpan time.Duration + cacheExpiry time.Duration + + interceptedDataVerifierMap map[string]storage.Cacher + mutex sync.Mutex +} + +// NewInterceptedDataVerifierFactory will create a factory instance that will create instance of InterceptedDataVerifiers +func NewInterceptedDataVerifierFactory(args InterceptedDataVerifierFactoryArgs) *interceptedDataVerifierFactory { + return &interceptedDataVerifierFactory{ + cacheSpan: args.CacheSpan, + cacheExpiry: args.CacheExpiry, + interceptedDataVerifierMap: make(map[string]storage.Cacher), + mutex: sync.Mutex{}, + } +} + +// Create will return an instance of InterceptedDataVerifier +func (idvf *interceptedDataVerifierFactory) Create(topic string) (process.InterceptedDataVerifier, error) { + internalCache, err := cache.NewTimeCacher(cache.ArgTimeCacher{ + DefaultSpan: idvf.cacheSpan, + CacheExpiry: idvf.cacheExpiry, + }) + if err != nil { + return nil, err + } + + idvf.mutex.Lock() + idvf.interceptedDataVerifierMap[topic] = internalCache + idvf.mutex.Unlock() + + return interceptors.NewInterceptedDataVerifier(internalCache) +} + +// Close will close all the sweeping routines created by the cache. +func (idvf *interceptedDataVerifierFactory) Close() error { + for topic, cacher := range idvf.interceptedDataVerifierMap { + err := cacher.Close() + if err != nil { + return fmt.Errorf("failed to close cacher on topic %q: %w", topic, err) + } + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (idvf *interceptedDataVerifierFactory) IsInterfaceNil() bool { + return idvf == nil +} diff --git a/process/interceptors/factory/interceptedDataVerifierFactory_test.go b/process/interceptors/factory/interceptedDataVerifierFactory_test.go new file mode 100644 index 00000000000..45f42ec05fd --- /dev/null +++ b/process/interceptors/factory/interceptedDataVerifierFactory_test.go @@ -0,0 +1,44 @@ +package factory + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func createMockArgInterceptedDataVerifierFactory() InterceptedDataVerifierFactoryArgs { + return InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second, + CacheExpiry: time.Second, + } +} + +func TestInterceptedDataVerifierFactory_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var factory *interceptedDataVerifierFactory + require.True(t, factory.IsInterfaceNil()) + + factory = NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.False(t, factory.IsInterfaceNil()) +} + +func TestNewInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + factory := NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.NotNil(t, factory) +} + +func TestInterceptedDataVerifierFactory_Create(t *testing.T) { + t.Parallel() + + factory := NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.NotNil(t, factory) + + interceptedDataVerifier, err := factory.Create("mockTopic") + require.NoError(t, err) + + require.False(t, interceptedDataVerifier.IsInterfaceNil()) +} diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory.go b/process/interceptors/factory/interceptedEquivalentProofsFactory.go new file mode 100644 index 00000000000..4c5694d1e4d --- /dev/null +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory.go @@ -0,0 +1,44 @@ +package factory + +import ( + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" + "github.com/multiversx/mx-chain-go/sharding" +) + +type interceptedEquivalentProofsFactory struct { + marshaller marshal.Marshalizer + shardCoordinator sharding.Coordinator + headerSigVerifier consensus.HeaderSigVerifier + proofsPool dataRetriever.ProofsPool +} + +// NewInterceptedEquivalentProofsFactory creates a new instance of interceptedEquivalentProofsFactory +func NewInterceptedEquivalentProofsFactory(args ArgInterceptedDataFactory, proofsPool dataRetriever.ProofsPool) *interceptedEquivalentProofsFactory { + return &interceptedEquivalentProofsFactory{ + marshaller: args.CoreComponents.InternalMarshalizer(), + shardCoordinator: args.ShardCoordinator, + headerSigVerifier: args.HeaderSigVerifier, + proofsPool: proofsPool, + } +} + +// Create creates instances of InterceptedData by unmarshalling provided buffer +func (factory *interceptedEquivalentProofsFactory) Create(buff []byte) (process.InterceptedData, error) { + args := interceptedBlocks.ArgInterceptedEquivalentProof{ + DataBuff: buff, + Marshaller: factory.marshaller, + ShardCoordinator: factory.shardCoordinator, + HeaderSigVerifier: factory.headerSigVerifier, + Proofs: factory.proofsPool, + } + return interceptedBlocks.NewInterceptedEquivalentProof(args) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (factory *interceptedEquivalentProofsFactory) IsInterfaceNil() bool { + return factory == nil +} diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go new file mode 100644 index 00000000000..c96ade9528b --- /dev/null +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go @@ -0,0 +1,76 @@ +package factory + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/consensus/mock" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/require" +) + +func createMockArgInterceptedDataFactory() ArgInterceptedDataFactory { + return ArgInterceptedDataFactory{ + CoreComponents: &processMock.CoreComponentsMock{ + IntMarsh: &mock.MarshalizerMock{}, + }, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + } +} + +func TestInterceptedEquivalentProofsFactory_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var factory *interceptedEquivalentProofsFactory + require.True(t, factory.IsInterfaceNil()) + + factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + require.False(t, factory.IsInterfaceNil()) +} + +func TestNewInterceptedEquivalentProofsFactory(t *testing.T) { + t.Parallel() + + factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedDataFactory(), &dataRetriever.ProofsPoolMock{}) + require.NotNil(t, factory) +} + +func TestInterceptedEquivalentProofsFactory_Create(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedDataFactory() + factory := NewInterceptedEquivalentProofsFactory(args, &dataRetriever.ProofsPoolMock{}) + require.NotNil(t, factory) + + providedProof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + } + providedDataBuff, _ := args.CoreComponents.InternalMarshalizer().Marshal(providedProof) + interceptedData, err := factory.Create(providedDataBuff) + require.NoError(t, err) + require.NotNil(t, interceptedData) + + type interceptedEquivalentProof interface { + GetProof() data.HeaderProofHandler + } + interceptedHeaderProof, ok := interceptedData.(interceptedEquivalentProof) + require.True(t, ok) + + proof := interceptedHeaderProof.GetProof() + require.NotNil(t, proof) + require.Equal(t, providedProof.GetPubKeysBitmap(), proof.GetPubKeysBitmap()) + require.Equal(t, providedProof.GetAggregatedSignature(), proof.GetAggregatedSignature()) + require.Equal(t, providedProof.GetHeaderHash(), proof.GetHeaderHash()) + require.Equal(t, providedProof.GetHeaderEpoch(), proof.GetHeaderEpoch()) + require.Equal(t, providedProof.GetHeaderNonce(), proof.GetHeaderNonce()) + require.Equal(t, providedProof.GetHeaderShardId(), proof.GetHeaderShardId()) +} diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory.go index 7567727571d..54c66ad687c 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/sharding" @@ -19,6 +20,7 @@ type interceptedMetaHeaderDataFactory struct { headerIntegrityVerifier process.HeaderIntegrityVerifier validityAttester process.ValidityAttester epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler } // NewInterceptedMetaHeaderDataFactory creates an instance of interceptedMetaHeaderDataFactory @@ -65,6 +67,7 @@ func NewInterceptedMetaHeaderDataFactory(argument *ArgInterceptedDataFactory) (* headerIntegrityVerifier: argument.HeaderIntegrityVerifier, validityAttester: argument.ValidityAttester, epochStartTrigger: argument.EpochStartTrigger, + enableEpochsHandler: argument.CoreComponents.EnableEpochsHandler(), }, nil } @@ -79,6 +82,7 @@ func (imhdf *interceptedMetaHeaderDataFactory) Create(buff []byte) (process.Inte HeaderIntegrityVerifier: imhdf.headerIntegrityVerifier, ValidityAttester: imhdf.validityAttester, EpochStartTrigger: imhdf.epochStartTrigger, + EnableEpochsHandler: imhdf.enableEpochsHandler, } return interceptedBlocks.NewInterceptedMetaHeader(arg) diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go index 0912de698c1..03859b63cb9 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" processMocks "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -97,7 +98,7 @@ func createMockArgument( NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), FeeHandler: createMockFeeHandler(), WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, ValidityAttester: &mock.ValidityAttesterStub{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{}, diff --git a/process/interceptors/factory/interceptedShardHeaderDataFactory.go b/process/interceptors/factory/interceptedShardHeaderDataFactory.go index fd19194dbd0..1a8b7518e63 100644 --- a/process/interceptors/factory/interceptedShardHeaderDataFactory.go +++ b/process/interceptors/factory/interceptedShardHeaderDataFactory.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/sharding" @@ -19,6 +20,7 @@ type interceptedShardHeaderDataFactory struct { headerIntegrityVerifier process.HeaderIntegrityVerifier validityAttester process.ValidityAttester epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler } // NewInterceptedShardHeaderDataFactory creates an instance of interceptedShardHeaderDataFactory @@ -65,6 +67,7 @@ func NewInterceptedShardHeaderDataFactory(argument *ArgInterceptedDataFactory) ( headerIntegrityVerifier: argument.HeaderIntegrityVerifier, validityAttester: argument.ValidityAttester, epochStartTrigger: argument.EpochStartTrigger, + enableEpochsHandler: argument.CoreComponents.EnableEpochsHandler(), }, nil } @@ -79,6 +82,7 @@ func (ishdf *interceptedShardHeaderDataFactory) Create(buff []byte) (process.Int HeaderIntegrityVerifier: ishdf.headerIntegrityVerifier, ValidityAttester: ishdf.validityAttester, EpochStartTrigger: ishdf.epochStartTrigger, + EnableEpochsHandler: ishdf.enableEpochsHandler, } return interceptedBlocks.NewInterceptedHeader(arg) diff --git a/process/interceptors/interceptedDataVerifier.go b/process/interceptors/interceptedDataVerifier.go new file mode 100644 index 00000000000..0accf41d3fc --- /dev/null +++ b/process/interceptors/interceptedDataVerifier.go @@ -0,0 +1,70 @@ +package interceptors + +import ( + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/sync" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/storage" +) + +type interceptedDataStatus int8 + +const ( + validInterceptedData interceptedDataStatus = iota + invalidInterceptedData + + interceptedDataStatusBytesSize = 8 +) + +type interceptedDataVerifier struct { + km sync.KeyRWMutexHandler + cache storage.Cacher +} + +// NewInterceptedDataVerifier creates a new instance of intercepted data verifier +func NewInterceptedDataVerifier(cache storage.Cacher) (*interceptedDataVerifier, error) { + if check.IfNil(cache) { + return nil, process.ErrNilInterceptedDataCache + } + + return &interceptedDataVerifier{ + km: sync.NewKeyRWMutex(), + cache: cache, + }, nil +} + +// Verify will check if the intercepted data has been validated before and put in the time cache. +// It will retrieve the status in the cache if it exists, otherwise it will validate it and store the status of the +// validation in the cache. Note that the entries are stored for a set period of time +func (idv *interceptedDataVerifier) Verify(interceptedData process.InterceptedData) error { + if len(interceptedData.Hash()) == 0 { + return interceptedData.CheckValidity() + } + + hash := string(interceptedData.Hash()) + idv.km.Lock(hash) + defer idv.km.Unlock(hash) + + if val, ok := idv.cache.Get(interceptedData.Hash()); ok { + if val == validInterceptedData { + return nil + } + + return process.ErrInvalidInterceptedData + } + + err := interceptedData.CheckValidity() + if err != nil { + idv.cache.Put(interceptedData.Hash(), invalidInterceptedData, interceptedDataStatusBytesSize) + return process.ErrInvalidInterceptedData + } + + idv.cache.Put(interceptedData.Hash(), validInterceptedData, interceptedDataStatusBytesSize) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (idv *interceptedDataVerifier) IsInterfaceNil() bool { + return idv == nil +} diff --git a/process/interceptors/interceptedDataVerifier_test.go b/process/interceptors/interceptedDataVerifier_test.go new file mode 100644 index 00000000000..8913f5828d8 --- /dev/null +++ b/process/interceptors/interceptedDataVerifier_test.go @@ -0,0 +1,237 @@ +package interceptors + +import ( + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/cache" + "github.com/multiversx/mx-chain-go/testscommon" +) + +const defaultSpan = 1 * time.Second + +func defaultInterceptedDataVerifier(span time.Duration) *interceptedDataVerifier { + c, _ := cache.NewTimeCacher(cache.ArgTimeCacher{ + DefaultSpan: span, + CacheExpiry: span, + }) + + verifier, _ := NewInterceptedDataVerifier(c) + return verifier +} + +func TestNewInterceptedDataVerifier(t *testing.T) { + t.Parallel() + + var c storage.Cacher + verifier, err := NewInterceptedDataVerifier(c) + require.Equal(t, process.ErrNilInterceptedDataCache, err) + require.Nil(t, verifier) +} + +func TestInterceptedDataVerifier_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var verifier *interceptedDataVerifier + require.True(t, verifier.IsInterfaceNil()) + + verifier = defaultInterceptedDataVerifier(defaultSpan) + require.False(t, verifier.IsInterfaceNil()) +} + +func TestInterceptedDataVerifier_EmptyHash(t *testing.T) { + t.Parallel() + + var checkValidityCounter int + verifier := defaultInterceptedDataVerifier(defaultSpan) + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter++ + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return nil + }, + } + + err := verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, 1, checkValidityCounter) + + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, 2, checkValidityCounter) +} + +func TestInterceptedDataVerifier_CheckValidityShouldWork(t *testing.T) { + t.Parallel() + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(defaultSpan) + + err := verifier.Verify(interceptedData) + require.NoError(t, err) + + errCount := atomic.Counter{} + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := verifier.Verify(interceptedData) + if err != nil { + errCount.Add(1) + } + }() + } + wg.Wait() + + require.Equal(t, int64(0), errCount.Get()) + require.Equal(t, int64(1), checkValidityCounter.Get()) +} + +func TestInterceptedDataVerifier_CheckValidityShouldNotWork(t *testing.T) { + t.Parallel() + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return process.ErrInvalidInterceptedData + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(defaultSpan) + + err := verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + + errCount := atomic.Counter{} + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := verifier.Verify(interceptedData) + if err != nil { + errCount.Add(1) + } + }() + } + wg.Wait() + + require.Equal(t, int64(100), errCount.Get()) + require.Equal(t, int64(1), checkValidityCounter.Get()) +} + +func TestInterceptedDataVerifier_CheckExpiryTime(t *testing.T) { + t.Parallel() + + t.Run("expiry on valid data", func(t *testing.T) { + expiryTestDuration := defaultSpan * 2 + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(expiryTestDuration) + + // First retrieval, check validity is reached. + err := verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Second retrieval should be from the cache. + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Wait for the cache expiry + <-time.After(expiryTestDuration + 100*time.Millisecond) + + // Third retrieval should reach validity check again. + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(2), checkValidityCounter.Get()) + }) + + t.Run("expiry on invalid data", func(t *testing.T) { + expiryTestDuration := defaultSpan * 2 + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return process.ErrInvalidInterceptedData + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(expiryTestDuration) + + // First retrieval, check validity is reached. + err := verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Second retrieval should be from the cache. + err = verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Wait for the cache expiry + <-time.After(expiryTestDuration + 100*time.Millisecond) + + // Third retrieval should reach validity check again. + err = verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(2), checkValidityCounter.Get()) + }) +} diff --git a/process/interceptors/multiDataInterceptor.go b/process/interceptors/multiDataInterceptor.go index 9e0197ea741..923c9b360e9 100644 --- a/process/interceptors/multiDataInterceptor.go +++ b/process/interceptors/multiDataInterceptor.go @@ -7,27 +7,30 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/pkg/errors" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/debug/handler" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/disabled" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process/interceptors") // ArgMultiDataInterceptor is the argument for the multi-data interceptor type ArgMultiDataInterceptor struct { - Topic string - Marshalizer marshal.Marshalizer - DataFactory process.InterceptedDataFactory - Processor process.InterceptorProcessor - Throttler process.InterceptorThrottler - AntifloodHandler process.P2PAntifloodHandler - WhiteListRequest process.WhiteListHandler - PreferredPeersHolder process.PreferredPeersHolderHandler - CurrentPeerId core.PeerID + Topic string + Marshalizer marshal.Marshalizer + DataFactory process.InterceptedDataFactory + Processor process.InterceptorProcessor + Throttler process.InterceptorThrottler + AntifloodHandler process.P2PAntifloodHandler + WhiteListRequest process.WhiteListHandler + PreferredPeersHolder process.PreferredPeersHolderHandler + CurrentPeerId core.PeerID + InterceptedDataVerifier process.InterceptedDataVerifier } // MultiDataInterceptor is used for intercepting packed multi data @@ -66,19 +69,23 @@ func NewMultiDataInterceptor(arg ArgMultiDataInterceptor) (*MultiDataInterceptor if check.IfNil(arg.PreferredPeersHolder) { return nil, process.ErrNilPreferredPeersHolder } + if check.IfNil(arg.InterceptedDataVerifier) { + return nil, process.ErrNilInterceptedDataVerifier + } if len(arg.CurrentPeerId) == 0 { return nil, process.ErrEmptyPeerID } multiDataIntercept := &MultiDataInterceptor{ baseDataInterceptor: &baseDataInterceptor{ - throttler: arg.Throttler, - antifloodHandler: arg.AntifloodHandler, - topic: arg.Topic, - currentPeerId: arg.CurrentPeerId, - processor: arg.Processor, - preferredPeersHolder: arg.PreferredPeersHolder, - debugHandler: handler.NewDisabledInterceptorDebugHandler(), + throttler: arg.Throttler, + antifloodHandler: arg.AntifloodHandler, + topic: arg.Topic, + currentPeerId: arg.CurrentPeerId, + processor: arg.Processor, + preferredPeersHolder: arg.PreferredPeersHolder, + debugHandler: handler.NewDisabledInterceptorDebugHandler(), + interceptedDataVerifier: arg.InterceptedDataVerifier, }, marshalizer: arg.Marshalizer, factory: arg.DataFactory, @@ -153,6 +160,7 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, var interceptedData process.InterceptedData interceptedData, err = mdi.interceptedData(dataBuff, message.Peer(), fromConnectedPeer) listInterceptedData[index] = interceptedData + if err != nil { mdi.throttler.EndProcessing() return err @@ -207,11 +215,11 @@ func (mdi *MultiDataInterceptor) interceptedData(dataBuff []byte, originator cor mdi.receivedDebugInterceptedData(interceptedData) - err = interceptedData.CheckValidity() + err = mdi.interceptedDataVerifier.Verify(interceptedData) if err != nil { mdi.processDebugInterceptedData(interceptedData, err) - isWrongVersion := err == process.ErrInvalidTransactionVersion || err == process.ErrInvalidChainID + isWrongVersion := errors.Is(err, process.ErrInvalidTransactionVersion) || errors.Is(err, process.ErrInvalidChainID) if isWrongVersion { // this situation is so severe that we need to black list de peers reason := "wrong version of received intercepted data, topic " + mdi.topic + ", error " + err.Error() diff --git a/process/interceptors/multiDataInterceptor_test.go b/process/interceptors/multiDataInterceptor_test.go index 6ca244409b7..ede867dba07 100644 --- a/process/interceptors/multiDataInterceptor_test.go +++ b/process/interceptors/multiDataInterceptor_test.go @@ -10,28 +10,30 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var fromConnectedPeerId = core.PeerID("from connected peer Id") func createMockArgMultiDataInterceptor() interceptors.ArgMultiDataInterceptor { return interceptors.ArgMultiDataInterceptor{ - Topic: "test topic", - Marshalizer: &mock.MarshalizerMock{}, - DataFactory: &mock.InterceptedDataFactoryStub{}, - Processor: &mock.InterceptorProcessorStub{}, - Throttler: createMockThrottler(), - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListRequest: &testscommon.WhiteListHandlerStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: "pid", + Topic: "test topic", + Marshalizer: &mock.MarshalizerMock{}, + DataFactory: &mock.InterceptedDataFactoryStub{}, + Processor: &mock.InterceptorProcessorStub{}, + Throttler: createMockThrottler(), + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListRequest: &testscommon.WhiteListHandlerStub{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: "pid", + InterceptedDataVerifier: &mock.InterceptedDataVerifierMock{}, } } @@ -68,6 +70,17 @@ func TestNewMultiDataInterceptor_NilInterceptedDataFactoryShouldErr(t *testing.T assert.Equal(t, process.ErrNilInterceptedDataFactory, err) } +func TestNewMultiDataInterceptor_NilInterceptedDataVerifierShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockArgMultiDataInterceptor() + arg.InterceptedDataVerifier = nil + mdi, err := interceptors.NewMultiDataInterceptor(arg) + + assert.True(t, check.IfNil(mdi)) + assert.Equal(t, process.ErrNilInterceptedDataVerifier, err) +} + func TestNewMultiDataInterceptor_NilInterceptedDataProcessorShouldErr(t *testing.T) { t.Parallel() @@ -282,6 +295,7 @@ func TestMultiDataInterceptor_ProcessReceivedPartiallyCorrectDataShouldErr(t *te IsForCurrentShardCalled: func() bool { return true }, + HashCalled: func() []byte { return []byte("hash") }, } arg := createMockArgMultiDataInterceptor() arg.DataFactory = &mock.InterceptedDataFactoryStub{ @@ -354,6 +368,11 @@ func testProcessReceiveMessageMultiData(t *testing.T, isForCurrentShard bool, ex } arg.Processor = createMockInterceptorStub(&checkCalledNum, &processCalledNum) arg.Throttler = throttler + arg.InterceptedDataVerifier = &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } mdi, _ := interceptors.NewMultiDataInterceptor(arg) dataField, _ := marshalizer.Marshal(&batch.Batch{Data: buffData}) @@ -570,6 +589,9 @@ func processReceivedMessageMultiDataInvalidVersion(t *testing.T, expectedErr err checkCalledNum := int32(0) processCalledNum := int32(0) interceptedData := &testscommon.InterceptedDataStub{ + HashCalled: func() []byte { + return []byte("hash") + }, CheckValidityCalled: func() error { return expectedErr }, @@ -603,6 +625,11 @@ func processReceivedMessageMultiDataInvalidVersion(t *testing.T, expectedErr err return true }, } + arg.InterceptedDataVerifier = &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } mdi, _ := interceptors.NewMultiDataInterceptor(arg) dataField, _ := marshalizer.Marshal(&batch.Batch{Data: buffData}) @@ -658,6 +685,9 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis IsForCurrentShardCalled: func() bool { return false }, + HashCalled: func() []byte { + return []byte("hash") + }, } whiteListHandler := &testscommon.WhiteListHandlerStub{ diff --git a/process/interceptors/processor/argHdrInterceptorProcessor.go b/process/interceptors/processor/argHdrInterceptorProcessor.go index 53e79b731b7..0f9616fb2cf 100644 --- a/process/interceptors/processor/argHdrInterceptorProcessor.go +++ b/process/interceptors/processor/argHdrInterceptorProcessor.go @@ -1,12 +1,15 @@ package processor import ( + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" ) // ArgHdrInterceptorProcessor is the argument for the interceptor processor used for headers (shard, meta and so on) type ArgHdrInterceptorProcessor struct { - Headers dataRetriever.HeadersPool - BlockBlackList process.TimeCacher + Headers dataRetriever.HeadersPool + Proofs dataRetriever.ProofsPool + BlockBlackList process.TimeCacher + EnableEpochsHandler common.EnableEpochsHandler } diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor.go new file mode 100644 index 00000000000..ef8beff12af --- /dev/null +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor.go @@ -0,0 +1,70 @@ +package processor + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/process" +) + +// ArgEquivalentProofsInterceptorProcessor is the argument for the interceptor processor used for equivalent proofs +type ArgEquivalentProofsInterceptorProcessor struct { + EquivalentProofsPool EquivalentProofsPool + Marshaller marshal.Marshalizer +} + +// equivalentProofsInterceptorProcessor is the processor used when intercepting equivalent proofs +type equivalentProofsInterceptorProcessor struct { + equivalentProofsPool EquivalentProofsPool + marshaller marshal.Marshalizer +} + +// NewEquivalentProofsInterceptorProcessor creates a new equivalentProofsInterceptorProcessor +func NewEquivalentProofsInterceptorProcessor(args ArgEquivalentProofsInterceptorProcessor) (*equivalentProofsInterceptorProcessor, error) { + err := checkArgsEquivalentProofs(args) + if err != nil { + return nil, err + } + + return &equivalentProofsInterceptorProcessor{ + equivalentProofsPool: args.EquivalentProofsPool, + marshaller: args.Marshaller, + }, nil +} + +func checkArgsEquivalentProofs(args ArgEquivalentProofsInterceptorProcessor) error { + if check.IfNil(args.EquivalentProofsPool) { + return process.ErrNilProofsPool + } + if check.IfNil(args.Marshaller) { + return process.ErrNilMarshalizer + } + + return nil +} + +// Validate checks if the intercepted data can be processed +// returns nil as proper validity checks are done at intercepted data level +func (epip *equivalentProofsInterceptorProcessor) Validate(_ process.InterceptedData, _ core.PeerID) error { + return nil +} + +// Save will save the intercepted equivalent proof inside the proofs tracker +func (epip *equivalentProofsInterceptorProcessor) Save(data process.InterceptedData, _ core.PeerID, _ string) error { + interceptedProof, ok := data.(interceptedEquivalentProof) + if !ok { + return process.ErrWrongTypeAssertion + } + + return epip.equivalentProofsPool.AddProof(interceptedProof.GetProof()) +} + +// RegisterHandler registers a callback function to be notified of incoming equivalent proofs +func (epip *equivalentProofsInterceptorProcessor) RegisterHandler(_ func(topic string, hash []byte, data interface{})) { + log.Error("equivalentProofsInterceptorProcessor.RegisterHandler", "error", "not implemented") +} + +// IsInterfaceNil returns true if there is no value under the interface +func (epip *equivalentProofsInterceptorProcessor) IsInterfaceNil() bool { + return epip == nil +} diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go new file mode 100644 index 00000000000..b11eca03aec --- /dev/null +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go @@ -0,0 +1,133 @@ +package processor + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" + "github.com/multiversx/mx-chain-go/process/transaction" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/stretchr/testify/require" +) + +func createMockArgEquivalentProofsInterceptorProcessor() ArgEquivalentProofsInterceptorProcessor { + return ArgEquivalentProofsInterceptorProcessor{ + EquivalentProofsPool: &dataRetriever.ProofsPoolMock{}, + Marshaller: &marshallerMock.MarshalizerMock{}, + } +} + +func TestEquivalentProofsInterceptorProcessor_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var epip *equivalentProofsInterceptorProcessor + require.True(t, epip.IsInterfaceNil()) + + epip, _ = NewEquivalentProofsInterceptorProcessor(createMockArgEquivalentProofsInterceptorProcessor()) + require.False(t, epip.IsInterfaceNil()) +} + +func TestNewEquivalentProofsInterceptorProcessor(t *testing.T) { + t.Parallel() + + t.Run("nil EquivalentProofsPool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsInterceptorProcessor() + args.EquivalentProofsPool = nil + + epip, err := NewEquivalentProofsInterceptorProcessor(args) + require.Equal(t, process.ErrNilProofsPool, err) + require.Nil(t, epip) + }) + t.Run("nil Marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsInterceptorProcessor() + args.Marshaller = nil + + epip, err := NewEquivalentProofsInterceptorProcessor(args) + require.Equal(t, process.ErrNilMarshalizer, err) + require.Nil(t, epip) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + epip, err := NewEquivalentProofsInterceptorProcessor(createMockArgEquivalentProofsInterceptorProcessor()) + require.NoError(t, err) + require.NotNil(t, epip) + }) +} + +func TestEquivalentProofsInterceptorProcessor_Validate(t *testing.T) { + t.Parallel() + + epip, err := NewEquivalentProofsInterceptorProcessor(createMockArgEquivalentProofsInterceptorProcessor()) + require.NoError(t, err) + + // coverage only + require.Nil(t, epip.Validate(nil, "")) +} + +func TestEquivalentProofsInterceptorProcessor_Save(t *testing.T) { + t.Parallel() + + t.Run("wrong assertion should error", func(t *testing.T) { + t.Parallel() + + epip, err := NewEquivalentProofsInterceptorProcessor(createMockArgEquivalentProofsInterceptorProcessor()) + require.NoError(t, err) + + err = epip.Save(&transaction.InterceptedTransaction{}, "", "") + require.Equal(t, process.ErrWrongTypeAssertion, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + wasCalled := false + args := createMockArgEquivalentProofsInterceptorProcessor() + args.EquivalentProofsPool = &dataRetriever.ProofsPoolMock{ + AddProofCalled: func(notarizedProof data.HeaderProofHandler) error { + wasCalled = true + return nil + }, + } + epip, err := NewEquivalentProofsInterceptorProcessor(args) + require.NoError(t, err) + + argInterceptedEquivalentProof := interceptedBlocks.ArgInterceptedEquivalentProof{ + Marshaller: args.Marshaller, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + Proofs: &dataRetriever.ProofsPoolMock{}, + } + argInterceptedEquivalentProof.DataBuff, _ = argInterceptedEquivalentProof.Marshaller.Marshal(&block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + }) + iep, _ := interceptedBlocks.NewInterceptedEquivalentProof(argInterceptedEquivalentProof) + + err = epip.Save(iep, "", "") + require.NoError(t, err) + require.True(t, wasCalled) + }) +} + +func TestEquivalentProofsInterceptorProcessor_RegisterHandler(t *testing.T) { + t.Parallel() + + epip, err := NewEquivalentProofsInterceptorProcessor(createMockArgEquivalentProofsInterceptorProcessor()) + require.NoError(t, err) + + // coverage only + epip.RegisterHandler(nil) +} diff --git a/process/interceptors/processor/hdrInterceptorProcessor.go b/process/interceptors/processor/hdrInterceptorProcessor.go index b71d5b73e59..e60489c2ae5 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor.go +++ b/process/interceptors/processor/hdrInterceptorProcessor.go @@ -1,11 +1,13 @@ package processor import ( + "reflect" "sync" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" ) @@ -15,10 +17,12 @@ var _ process.InterceptorProcessor = (*HdrInterceptorProcessor)(nil) // HdrInterceptorProcessor is the processor used when intercepting headers // (shard headers, meta headers) structs which satisfy HeaderHandler interface. type HdrInterceptorProcessor struct { - headers dataRetriever.HeadersPool - blackList process.TimeCacher - registeredHandlers []func(topic string, hash []byte, data interface{}) - mutHandlers sync.RWMutex + headers dataRetriever.HeadersPool + proofs dataRetriever.ProofsPool + blackList process.TimeCacher + enableEpochsHandler common.EnableEpochsHandler + registeredHandlers []func(topic string, hash []byte, data interface{}) + mutHandlers sync.RWMutex } // NewHdrInterceptorProcessor creates a new TxInterceptorProcessor instance @@ -29,14 +33,22 @@ func NewHdrInterceptorProcessor(argument *ArgHdrInterceptorProcessor) (*HdrInter if check.IfNil(argument.Headers) { return nil, process.ErrNilCacher } + if check.IfNil(argument.Proofs) { + return nil, process.ErrNilProofsPool + } if check.IfNil(argument.BlockBlackList) { return nil, process.ErrNilBlackListCacher } + if check.IfNil(argument.EnableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } return &HdrInterceptorProcessor{ - headers: argument.Headers, - blackList: argument.BlockBlackList, - registeredHandlers: make([]func(topic string, hash []byte, data interface{}), 0), + headers: argument.Headers, + proofs: argument.Proofs, + blackList: argument.BlockBlackList, + enableEpochsHandler: argument.EnableEpochsHandler, + registeredHandlers: make([]func(topic string, hash []byte, data interface{}), 0), }, nil } @@ -68,6 +80,13 @@ func (hip *HdrInterceptorProcessor) Save(data process.InterceptedData, _ core.Pe hip.headers.AddHeader(interceptedHdr.Hash(), interceptedHdr.HeaderHandler()) + if common.IsFlagEnabledAfterEpochsStartBlock(interceptedHdr.HeaderHandler(), hip.enableEpochsHandler, common.EquivalentMessagesFlag) { + err := hip.proofs.AddProof(interceptedHdr.HeaderHandler().GetPreviousProof()) + if err != nil { + log.Error("failed to add proof", "error", err, "intercepted header hash", interceptedHdr.Hash(), "header type", reflect.TypeOf(interceptedHdr.HeaderHandler())) + } + } + return nil } diff --git a/process/interceptors/processor/hdrInterceptorProcessor_test.go b/process/interceptors/processor/hdrInterceptorProcessor_test.go index 87fe3521ff7..cc35b04d06b 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor_test.go +++ b/process/interceptors/processor/hdrInterceptorProcessor_test.go @@ -4,19 +4,25 @@ import ( "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/stretchr/testify/assert" ) func createMockHdrArgument() *processor.ArgHdrInterceptorProcessor { arg := &processor.ArgHdrInterceptorProcessor{ - Headers: &mock.HeadersCacherStub{}, - BlockBlackList: &testscommon.TimeCacheStub{}, + Headers: &mock.HeadersCacherStub{}, + Proofs: &dataRetriever.ProofsPoolMock{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return arg @@ -55,6 +61,28 @@ func TestNewHdrInterceptorProcessor_NilBlackListHandlerShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilBlackListCacher, err) } +func TestNewHdrInterceptorProcessor_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockHdrArgument() + arg.Proofs = nil + hip, err := processor.NewHdrInterceptorProcessor(arg) + + assert.Nil(t, hip) + assert.Equal(t, process.ErrNilProofsPool, err) +} + +func TestNewHdrInterceptorProcessor_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockHdrArgument() + arg.EnableEpochsHandler = nil + hip, err := processor.NewHdrInterceptorProcessor(arg) + + assert.Nil(t, hip) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + func TestNewHdrInterceptorProcessor_ShouldWork(t *testing.T) { t.Parallel() @@ -165,6 +193,19 @@ func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { wasAddedHeaders = true }, } + arg.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + wasAddedProofs := false + arg.Proofs = &dataRetriever.ProofsPoolMock{ + AddProofCalled: func(headerProof data.HeaderProofHandler) error { + wasAddedProofs = true + return nil + }, + } hip, _ := processor.NewHdrInterceptorProcessor(arg) chanCalled := make(chan struct{}, 1) @@ -176,6 +217,7 @@ func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { assert.Nil(t, err) assert.True(t, wasAddedHeaders) + assert.True(t, wasAddedProofs) timeout := time.Second * 2 select { diff --git a/process/interceptors/processor/heartbeatInterceptorProcessor_test.go b/process/interceptors/processor/heartbeatInterceptorProcessor_test.go index 3a2c3a03aff..1667e35abc6 100644 --- a/process/interceptors/processor/heartbeatInterceptorProcessor_test.go +++ b/process/interceptors/processor/heartbeatInterceptorProcessor_test.go @@ -6,19 +6,21 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + heartbeatMessages "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/heartbeat" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) func createHeartbeatInterceptorProcessArg() processor.ArgHeartbeatInterceptorProcessor { return processor.ArgHeartbeatInterceptorProcessor{ - HeartbeatCacher: testscommon.NewCacherStub(), + HeartbeatCacher: cache.NewCacherStub(), ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, PeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, } @@ -133,7 +135,7 @@ func TestHeartbeatInterceptorProcessor_Save(t *testing.T) { wasCalled := false providedPid := core.PeerID("pid") arg := createHeartbeatInterceptorProcessArg() - arg.HeartbeatCacher = &testscommon.CacherStub{ + arg.HeartbeatCacher = &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { assert.True(t, bytes.Equal(providedPid.Bytes(), key)) ihb := value.(*heartbeatMessages.HeartbeatV2) diff --git a/process/interceptors/processor/interface.go b/process/interceptors/processor/interface.go index 147d8f30270..14c0ae73bd6 100644 --- a/process/interceptors/processor/interface.go +++ b/process/interceptors/processor/interface.go @@ -1,6 +1,7 @@ package processor import ( + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/state" ) @@ -24,3 +25,17 @@ type interceptedValidatorInfo interface { Hash() []byte ValidatorInfo() *state.ShardValidatorInfo } + +type interceptedEquivalentProof interface { + Hash() []byte + GetProof() data.HeaderProofHandler +} + +// EquivalentProofsPool defines the behaviour of a proofs pool components +type EquivalentProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) error + CleanupProofsBehindNonce(shardID uint32, nonce uint64) error + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} diff --git a/process/interceptors/processor/miniblockInterceptorProcessor_test.go b/process/interceptors/processor/miniblockInterceptorProcessor_test.go index eff36ae8281..149befd1a98 100644 --- a/process/interceptors/processor/miniblockInterceptorProcessor_test.go +++ b/process/interceptors/processor/miniblockInterceptorProcessor_test.go @@ -6,13 +6,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" ) var testMarshalizer = &mock.MarshalizerMock{} @@ -20,7 +22,7 @@ var testHasher = &hashingMocks.HasherMock{} func createMockMiniblockArgument() *processor.ArgMiniblockInterceptorProcessor { return &processor.ArgMiniblockInterceptorProcessor{ - MiniblockCache: testscommon.NewCacherStub(), + MiniblockCache: cache.NewCacherStub(), Marshalizer: testMarshalizer, Hasher: testHasher, ShardCoordinator: mock.NewOneShardCoordinatorMock(), @@ -103,7 +105,7 @@ func TestNewMiniblockInterceptorProcessor_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- Validate +// ------- Validate func TestMiniblockInterceptorProcessor_ValidateShouldWork(t *testing.T) { t.Parallel() @@ -113,7 +115,7 @@ func TestMiniblockInterceptorProcessor_ValidateShouldWork(t *testing.T) { assert.Nil(t, mip.Validate(nil, "")) } -//------- Save +// ------- Save func TestMiniblockInterceptorProcessor_SaveWrongTypeAssertion(t *testing.T) { t.Parallel() @@ -129,7 +131,7 @@ func TestMiniblockInterceptorProcessor_NilMiniblockShouldNotAdd(t *testing.T) { t.Parallel() arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -152,7 +154,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockNotForCurrentShardShouldNotA } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -174,7 +176,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockWithSenderInSameShardShouldA } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { _, ok := value.(*block.MiniBlock) if !ok { @@ -204,7 +206,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblocksWithReceiverInSameShardShou } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { _, ok := value.(*block.MiniBlock) if !ok { @@ -248,7 +250,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockCrossShardForMeNotWhiteListe return false } - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -277,7 +279,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockCrossShardForMeWhiteListedSh } addedInPool := false - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { addedInPool = true return false, true diff --git a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go index 38a56751f05..3a1db0b6b66 100644 --- a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go +++ b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + heartbeatMessages "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/heartbeat" @@ -13,9 +15,9 @@ import ( "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) type interceptedDataHandler interface { @@ -25,7 +27,7 @@ type interceptedDataHandler interface { func createPeerAuthenticationInterceptorProcessArg() processor.ArgPeerAuthenticationInterceptorProcessor { return processor.ArgPeerAuthenticationInterceptorProcessor{ - PeerAuthenticationCacher: testscommon.NewCacherStub(), + PeerAuthenticationCacher: cache.NewCacherStub(), PeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, Marshaller: marshallerMock.MarshalizerMock{}, HardforkTrigger: &testscommon.HardforkTriggerStub{}, @@ -188,7 +190,7 @@ func TestPeerAuthenticationInterceptorProcessor_Save(t *testing.T) { wasPutCalled := false providedPid := core.PeerID("pid") arg := createPeerAuthenticationInterceptorProcessArg() - arg.PeerAuthenticationCacher = &testscommon.CacherStub{ + arg.PeerAuthenticationCacher = &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { assert.Equal(t, providedIPAMessage.Pubkey, key) ipa := value.(*heartbeatMessages.PeerAuthentication) diff --git a/process/interceptors/processor/trieNodeChunksProcessor_test.go b/process/interceptors/processor/trieNodeChunksProcessor_test.go index f6602cddf67..ad63ca7adc6 100644 --- a/process/interceptors/processor/trieNodeChunksProcessor_test.go +++ b/process/interceptors/processor/trieNodeChunksProcessor_test.go @@ -9,8 +9,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -31,7 +34,7 @@ func createMockTrieNodesChunksProcessorArgs() TrieNodesChunksProcessorArgs { return 32 }, }, - ChunksCacher: testscommon.NewCacherMock(), + ChunksCacher: cache.NewCacherMock(), RequestInterval: time.Second, RequestHandler: &testscommon.RequestHandlerStub{}, Topic: "topic", diff --git a/process/interceptors/processor/trieNodeInterceptorProcessor_test.go b/process/interceptors/processor/trieNodeInterceptorProcessor_test.go index d0bf3f66c27..b580f4ab65a 100644 --- a/process/interceptors/processor/trieNodeInterceptorProcessor_test.go +++ b/process/interceptors/processor/trieNodeInterceptorProcessor_test.go @@ -4,10 +4,12 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -22,27 +24,27 @@ func TestNewTrieNodesInterceptorProcessor_NilCacherShouldErr(t *testing.T) { func TestNewTrieNodesInterceptorProcessor_OkValsShouldWork(t *testing.T) { t.Parallel() - tnip, err := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, err := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) assert.Nil(t, err) assert.NotNil(t, tnip) } -//------- Validate +// ------- Validate func TestTrieNodesInterceptorProcessor_ValidateShouldWork(t *testing.T) { t.Parallel() - tnip, _ := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, _ := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) assert.Nil(t, tnip.Validate(nil, "")) } -//------- Save +// ------- Save func TestTrieNodesInterceptorProcessor_SaveWrongTypeAssertion(t *testing.T) { t.Parallel() - tnip, _ := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, _ := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) err := tnip.Save(nil, "", "") assert.Equal(t, process.ErrWrongTypeAssertion, err) @@ -61,7 +63,7 @@ func TestTrieNodesInterceptorProcessor_SaveShouldPutInCacher(t *testing.T) { } putCalled := false - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { putCalled = true assert.Equal(t, len(nodeHash)+nodeSize, sizeInBytes) @@ -75,7 +77,7 @@ func TestTrieNodesInterceptorProcessor_SaveShouldPutInCacher(t *testing.T) { assert.True(t, putCalled) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestTrieNodesInterceptorProcessor_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/interceptors/singleDataInterceptor.go b/process/interceptors/singleDataInterceptor.go index 84f3296acd7..7e5a4257fd6 100644 --- a/process/interceptors/singleDataInterceptor.go +++ b/process/interceptors/singleDataInterceptor.go @@ -1,8 +1,11 @@ package interceptors import ( + "errors" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/debug/handler" "github.com/multiversx/mx-chain-go/p2p" @@ -11,14 +14,15 @@ import ( // ArgSingleDataInterceptor is the argument for the single-data interceptor type ArgSingleDataInterceptor struct { - Topic string - DataFactory process.InterceptedDataFactory - Processor process.InterceptorProcessor - Throttler process.InterceptorThrottler - AntifloodHandler process.P2PAntifloodHandler - WhiteListRequest process.WhiteListHandler - PreferredPeersHolder process.PreferredPeersHolderHandler - CurrentPeerId core.PeerID + Topic string + DataFactory process.InterceptedDataFactory + Processor process.InterceptorProcessor + Throttler process.InterceptorThrottler + AntifloodHandler process.P2PAntifloodHandler + WhiteListRequest process.WhiteListHandler + PreferredPeersHolder process.PreferredPeersHolderHandler + CurrentPeerId core.PeerID + InterceptedDataVerifier process.InterceptedDataVerifier } // SingleDataInterceptor is used for intercepting packed multi data @@ -51,19 +55,23 @@ func NewSingleDataInterceptor(arg ArgSingleDataInterceptor) (*SingleDataIntercep if check.IfNil(arg.PreferredPeersHolder) { return nil, process.ErrNilPreferredPeersHolder } + if check.IfNil(arg.InterceptedDataVerifier) { + return nil, process.ErrNilInterceptedDataVerifier + } if len(arg.CurrentPeerId) == 0 { return nil, process.ErrEmptyPeerID } singleDataIntercept := &SingleDataInterceptor{ baseDataInterceptor: &baseDataInterceptor{ - throttler: arg.Throttler, - antifloodHandler: arg.AntifloodHandler, - topic: arg.Topic, - currentPeerId: arg.CurrentPeerId, - processor: arg.Processor, - preferredPeersHolder: arg.PreferredPeersHolder, - debugHandler: handler.NewDisabledInterceptorDebugHandler(), + throttler: arg.Throttler, + antifloodHandler: arg.AntifloodHandler, + topic: arg.Topic, + currentPeerId: arg.CurrentPeerId, + processor: arg.Processor, + preferredPeersHolder: arg.PreferredPeersHolder, + debugHandler: handler.NewDisabledInterceptorDebugHandler(), + interceptedDataVerifier: arg.InterceptedDataVerifier, }, factory: arg.DataFactory, whiteListRequest: arg.WhiteListRequest, @@ -93,13 +101,12 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, } sdi.receivedDebugInterceptedData(interceptedData) - - err = interceptedData.CheckValidity() + err = sdi.interceptedDataVerifier.Verify(interceptedData) if err != nil { sdi.throttler.EndProcessing() sdi.processDebugInterceptedData(interceptedData, err) - isWrongVersion := err == process.ErrInvalidTransactionVersion || err == process.ErrInvalidChainID + isWrongVersion := errors.Is(err, process.ErrInvalidTransactionVersion) || errors.Is(err, process.ErrInvalidChainID) if isWrongVersion { // this situation is so severe that we need to black list de peers reason := "wrong version of received intercepted data, topic " + sdi.topic + ", error " + err.Error() diff --git a/process/interceptors/singleDataInterceptor_test.go b/process/interceptors/singleDataInterceptor_test.go index 515c2a8724c..9b1fad0a840 100644 --- a/process/interceptors/singleDataInterceptor_test.go +++ b/process/interceptors/singleDataInterceptor_test.go @@ -8,25 +8,27 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgSingleDataInterceptor() interceptors.ArgSingleDataInterceptor { return interceptors.ArgSingleDataInterceptor{ - Topic: "test topic", - DataFactory: &mock.InterceptedDataFactoryStub{}, - Processor: &mock.InterceptorProcessorStub{}, - Throttler: createMockThrottler(), - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListRequest: &testscommon.WhiteListHandlerStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: "pid", + Topic: "test topic", + DataFactory: &mock.InterceptedDataFactoryStub{}, + Processor: &mock.InterceptorProcessorStub{}, + Throttler: createMockThrottler(), + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListRequest: &testscommon.WhiteListHandlerStub{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: "pid", + InterceptedDataVerifier: createMockInterceptedDataVerifier(), } } @@ -57,6 +59,14 @@ func createMockThrottler() *mock.InterceptorThrottlerStub { } } +func createMockInterceptedDataVerifier() *mock.InterceptedDataVerifierMock { + return &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } +} + func TestNewSingleDataInterceptor_EmptyTopicShouldErr(t *testing.T) { t.Parallel() @@ -145,6 +155,17 @@ func TestNewSingleDataInterceptor_EmptyPeerIDShouldErr(t *testing.T) { assert.Equal(t, process.ErrEmptyPeerID, err) } +func TestNewSingleDataInterceptor_NilInterceptedDataVerifierShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockArgMultiDataInterceptor() + arg.InterceptedDataVerifier = nil + mdi, err := interceptors.NewMultiDataInterceptor(arg) + + assert.True(t, check.IfNil(mdi)) + assert.Equal(t, process.ErrNilInterceptedDataVerifier, err) +} + func TestNewSingleDataInterceptor(t *testing.T) { t.Parallel() diff --git a/process/interceptors/whiteListDataVerifier_test.go b/process/interceptors/whiteListDataVerifier_test.go index c1567465fcc..f974f2f2c02 100644 --- a/process/interceptors/whiteListDataVerifier_test.go +++ b/process/interceptors/whiteListDataVerifier_test.go @@ -6,8 +6,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -23,7 +26,7 @@ func TestNewWhiteListDataVerifier_NilCacherShouldErr(t *testing.T) { func TestNewWhiteListDataVerifier_ShouldWork(t *testing.T) { t.Parallel() - wldv, err := NewWhiteListDataVerifier(testscommon.NewCacherStub()) + wldv, err := NewWhiteListDataVerifier(cache.NewCacherStub()) assert.False(t, check.IfNil(wldv)) assert.Nil(t, err) @@ -34,7 +37,7 @@ func TestWhiteListDataVerifier_Add(t *testing.T) { keys := [][]byte{[]byte("key1"), []byte("key2")} added := map[string]struct{}{} - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { added[string(key)] = struct{}{} return false @@ -55,7 +58,7 @@ func TestWhiteListDataVerifier_Remove(t *testing.T) { keys := [][]byte{[]byte("key1"), []byte("key2")} removed := map[string]struct{}{} - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ RemoveCalled: func(key []byte) { removed[string(key)] = struct{}{} }, @@ -73,7 +76,7 @@ func TestWhiteListDataVerifier_Remove(t *testing.T) { func TestWhiteListDataVerifier_IsWhiteListedNilInterceptedDataShouldRetFalse(t *testing.T) { t.Parallel() - wldv, _ := NewWhiteListDataVerifier(testscommon.NewCacherStub()) + wldv, _ := NewWhiteListDataVerifier(cache.NewCacherStub()) assert.False(t, wldv.IsWhiteListed(nil)) } @@ -83,7 +86,7 @@ func TestWhiteListDataVerifier_IsWhiteListedNotFoundShouldRetFalse(t *testing.T) keyCheck := []byte("key") wldv, _ := NewWhiteListDataVerifier( - &testscommon.CacherStub{ + &cache.CacherStub{ HasCalled: func(key []byte) bool { return !bytes.Equal(key, keyCheck) }, @@ -104,7 +107,7 @@ func TestWhiteListDataVerifier_IsWhiteListedFoundShouldRetTrue(t *testing.T) { keyCheck := []byte("key") wldv, _ := NewWhiteListDataVerifier( - &testscommon.CacherStub{ + &cache.CacherStub{ HasCalled: func(key []byte) bool { return bytes.Equal(key, keyCheck) }, diff --git a/process/interface.go b/process/interface.go index fd022f5f8a9..d7cbe87825b 100644 --- a/process/interface.go +++ b/process/interface.go @@ -849,6 +849,9 @@ type InterceptedHeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error + VerifyHeaderWithProof(header data.HeaderHandler) error IsInterfaceNil() bool } @@ -1201,6 +1204,7 @@ type PayableHandler interface { // FallbackHeaderValidator defines the behaviour of a component able to signal when a fallback header validation could be applied type FallbackHeaderValidator interface { + ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool IsInterfaceNil() bool } @@ -1398,3 +1402,22 @@ type SentSignaturesTracker interface { ResetCountersForManagedBlockSigner(signerPk []byte) IsInterfaceNil() bool } + +// InterceptedDataVerifier defines a component able to verify intercepted data validity +type InterceptedDataVerifier interface { + Verify(interceptedData InterceptedData) error + IsInterfaceNil() bool +} + +// InterceptedDataVerifierFactory defines a component that is able to create intercepted data verifiers +type InterceptedDataVerifierFactory interface { + Create(topic string) (InterceptedDataVerifier, error) + Close() error + IsInterfaceNil() bool +} + +// ProofsPool defines the behaviour of a proofs pool components +type ProofsPool interface { + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} diff --git a/process/mock/forkDetectorMock.go b/process/mock/forkDetectorMock.go index a574e4724b1..51e79af246f 100644 --- a/process/mock/forkDetectorMock.go +++ b/process/mock/forkDetectorMock.go @@ -28,17 +28,27 @@ func (fdm *ForkDetectorMock) RestoreToGenesis() { // AddHeader - func (fdm *ForkDetectorMock) AddHeader(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) + if fdm.AddHeaderCalled != nil { + return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) + } + + return nil } // RemoveHeader - func (fdm *ForkDetectorMock) RemoveHeader(nonce uint64, hash []byte) { - fdm.RemoveHeaderCalled(nonce, hash) + if fdm.RemoveHeaderCalled != nil { + fdm.RemoveHeaderCalled(nonce, hash) + } } // CheckFork - func (fdm *ForkDetectorMock) CheckFork() *process.ForkInfo { - return fdm.CheckForkCalled() + if fdm.CheckForkCalled != nil { + return fdm.CheckForkCalled() + } + + return nil } // GetHighestFinalBlockNonce - @@ -51,12 +61,20 @@ func (fdm *ForkDetectorMock) GetHighestFinalBlockNonce() uint64 { // GetHighestFinalBlockHash - func (fdm *ForkDetectorMock) GetHighestFinalBlockHash() []byte { - return fdm.GetHighestFinalBlockHashCalled() + if fdm.GetHighestFinalBlockHashCalled != nil { + return fdm.GetHighestFinalBlockHashCalled() + } + + return nil } // ProbableHighestNonce - func (fdm *ForkDetectorMock) ProbableHighestNonce() uint64 { - return fdm.ProbableHighestNonceCalled() + if fdm.ProbableHighestNonceCalled != nil { + return fdm.ProbableHighestNonceCalled() + } + + return 0 } // SetRollBackNonce - @@ -68,12 +86,18 @@ func (fdm *ForkDetectorMock) SetRollBackNonce(nonce uint64) { // ResetFork - func (fdm *ForkDetectorMock) ResetFork() { - fdm.ResetForkCalled() + if fdm.ResetForkCalled != nil { + fdm.ResetForkCalled() + } } // GetNotarizedHeaderHash - func (fdm *ForkDetectorMock) GetNotarizedHeaderHash(nonce uint64) []byte { - return fdm.GetNotarizedHeaderHashCalled(nonce) + if fdm.GetNotarizedHeaderHashCalled != nil { + return fdm.GetNotarizedHeaderHashCalled(nonce) + } + + return nil } // ResetProbableHighestNonce - diff --git a/process/mock/headerSigVerifierStub.go b/process/mock/headerSigVerifierStub.go deleted file mode 100644 index efc83c06e18..00000000000 --- a/process/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/process/mock/interceptedDataVerifierFactoryMock.go b/process/mock/interceptedDataVerifierFactoryMock.go new file mode 100644 index 00000000000..245be014b15 --- /dev/null +++ b/process/mock/interceptedDataVerifierFactoryMock.go @@ -0,0 +1,29 @@ +package mock + +import ( + "github.com/multiversx/mx-chain-go/process" +) + +// InterceptedDataVerifierFactoryMock - +type InterceptedDataVerifierFactoryMock struct { + CreateCalled func(topic string) (process.InterceptedDataVerifier, error) +} + +// Create - +func (idvfs *InterceptedDataVerifierFactoryMock) Create(topic string) (process.InterceptedDataVerifier, error) { + if idvfs.CreateCalled != nil { + return idvfs.CreateCalled(topic) + } + + return &InterceptedDataVerifierMock{}, nil +} + +// Close - +func (idvfs *InterceptedDataVerifierFactoryMock) Close() error { + return nil +} + +// IsInterfaceNil - +func (idvfs *InterceptedDataVerifierFactoryMock) IsInterfaceNil() bool { + return idvfs == nil +} diff --git a/process/mock/interceptedDataVerifierMock.go b/process/mock/interceptedDataVerifierMock.go index c8d4d14392b..6668a6ea625 100644 --- a/process/mock/interceptedDataVerifierMock.go +++ b/process/mock/interceptedDataVerifierMock.go @@ -1,17 +1,24 @@ package mock -import "github.com/multiversx/mx-chain-go/process" +import ( + "github.com/multiversx/mx-chain-go/process" +) // InterceptedDataVerifierMock - type InterceptedDataVerifierMock struct { + VerifyCalled func(interceptedData process.InterceptedData) error } -// IsForCurrentShard - -func (i *InterceptedDataVerifierMock) IsForCurrentShard(_ process.InterceptedData) bool { - return true +// Verify - +func (idv *InterceptedDataVerifierMock) Verify(interceptedData process.InterceptedData) error { + if idv.VerifyCalled != nil { + return idv.VerifyCalled(interceptedData) + } + + return nil } -// IsInterfaceNil returns true if underlying object is -func (i *InterceptedDataVerifierMock) IsInterfaceNil() bool { - return i == nil +// IsInterfaceNil - +func (idv *InterceptedDataVerifierMock) IsInterfaceNil() bool { + return idv == nil } diff --git a/process/mock/peerShardResolverStub.go b/process/mock/peerShardResolverStub.go index 4239fbeaee4..a5bd8a66d98 100644 --- a/process/mock/peerShardResolverStub.go +++ b/process/mock/peerShardResolverStub.go @@ -11,7 +11,11 @@ type PeerShardResolverStub struct { // GetPeerInfo - func (psrs *PeerShardResolverStub) GetPeerInfo(pid core.PeerID) core.P2PPeerInfo { - return psrs.GetPeerInfoCalled(pid) + if psrs.GetPeerInfoCalled != nil { + return psrs.GetPeerInfoCalled(pid) + } + + return core.P2PPeerInfo{} } // IsInterfaceNil - diff --git a/process/peer/process.go b/process/peer/process.go index 7cb50db55be..c5ebb890d8a 100644 --- a/process/peer/process.go +++ b/process/peer/process.go @@ -1,6 +1,7 @@ package peer import ( + "bytes" "context" "encoding/hex" "fmt" @@ -388,7 +389,7 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca log.Trace("Increasing", "round", previousHeader.GetRound(), "prevRandSeed", previousHeader.GetPrevRandSeed()) consensusGroupEpoch := computeEpoch(previousHeader) - consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup( + leader, consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup( previousHeader.GetPrevRandSeed(), previousHeader.GetRound(), previousHeader.GetShardID(), @@ -397,15 +398,23 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca return nil, err } - encodedLeaderPk := vs.pubkeyConv.SilentEncode(consensusGroup[0].PubKey(), log) + encodedLeaderPk := vs.pubkeyConv.SilentEncode(leader.PubKey(), log) leaderPK := core.GetTrimmedPk(encodedLeaderPk) log.Trace("Increasing for leader", "leader", leaderPK, "round", previousHeader.GetRound()) log.Debug("UpdatePeerState - registering meta previous leader fees", "metaNonce", previousHeader.GetNonce()) + bitmap := previousHeader.GetPubKeysBitmap() + if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, previousHeader.GetEpoch()) { + proof := previousHeader.GetPreviousProof() + if !check.IfNilReflect(proof) { + bitmap = proof.GetPubKeysBitmap() + } + } err = vs.updateValidatorInfoOnSuccessfulBlock( + leader, consensusGroup, - previousHeader.GetPubKeysBitmap(), + bitmap, big.NewInt(0).Sub(previousHeader.GetAccumulatedFees(), previousHeader.GetDeveloperFees()), previousHeader.GetShardID(), ) @@ -799,16 +808,16 @@ func (vs *validatorStatistics) computeDecrease( swInner.Start("ComputeValidatorsGroup") log.Debug("decreasing", "round", i, "prevRandSeed", prevRandSeed, "shardId", shardID) - consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup(prevRandSeed, i, shardID, epoch) + leader, consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup(prevRandSeed, i, shardID, epoch) swInner.Stop("ComputeValidatorsGroup") if err != nil { return err } swInner.Start("loadPeerAccount") - leaderPeerAcc, err := vs.loadPeerAccount(consensusGroup[0].PubKey()) + leaderPeerAcc, err := vs.loadPeerAccount(leader.PubKey()) - encodedLeaderPk := vs.pubkeyConv.SilentEncode(consensusGroup[0].PubKey(), log) + encodedLeaderPk := vs.pubkeyConv.SilentEncode(leader.PubKey(), log) leaderPK := core.GetTrimmedPk(encodedLeaderPk) swInner.Stop("loadPeerAccount") if err != nil { @@ -816,7 +825,7 @@ func (vs *validatorStatistics) computeDecrease( } vs.mutValidatorStatistics.Lock() - vs.missedBlocksCounters.decreaseLeader(consensusGroup[0].PubKey()) + vs.missedBlocksCounters.decreaseLeader(leader.PubKey()) vs.mutValidatorStatistics.Unlock() swInner.Start("ComputeDecreaseProposer") @@ -920,13 +929,14 @@ func (vs *validatorStatistics) updateShardDataPeerState( epoch := computeEpoch(currentHeader) - shardConsensus, shardInfoErr := vs.nodesCoordinator.ComputeConsensusGroup(h.PrevRandSeed, h.Round, h.ShardID, epoch) + leader, shardConsensus, shardInfoErr := vs.nodesCoordinator.ComputeConsensusGroup(h.PrevRandSeed, h.Round, h.ShardID, epoch) if shardInfoErr != nil { return shardInfoErr } log.Debug("updateShardDataPeerState - registering shard leader fees", "shard headerHash", h.HeaderHash, "accumulatedFees", h.AccumulatedFees.String(), "developerFees", h.DeveloperFees.String()) shardInfoErr = vs.updateValidatorInfoOnSuccessfulBlock( + leader, shardConsensus, h.PubKeysBitmap, big.NewInt(0).Sub(h.AccumulatedFees, h.DeveloperFees), @@ -1014,6 +1024,7 @@ func (vs *validatorStatistics) savePeerAccountData( } func (vs *validatorStatistics) updateValidatorInfoOnSuccessfulBlock( + leader nodesCoordinator.Validator, validatorList []nodesCoordinator.Validator, signingBitmap []byte, accumulatedFees *big.Int, @@ -1033,7 +1044,7 @@ func (vs *validatorStatistics) updateValidatorInfoOnSuccessfulBlock( peerAcc.IncreaseNumSelectedInSuccessBlocks() newRating := peerAcc.GetRating() - isLeader := i == 0 + isLeader := bytes.Equal(leader.PubKey(), validatorList[i].PubKey()) validatorSigned := (signingBitmap[i/8] & (1 << (uint16(i) % 8))) != 0 actionType := vs.computeValidatorActionType(isLeader, validatorSigned) @@ -1164,6 +1175,11 @@ func (vs *validatorStatistics) getTempRating(s string) uint32 { } func (vs *validatorStatistics) display(validatorKey string) { + if log.GetLevel() != logger.LogTrace { + // do not need to load peer account if not log level trace + return + } + peerAcc, err := vs.loadPeerAccount([]byte(validatorKey)) if err != nil { log.Trace("display peer acc", "error", err) diff --git a/process/peer/process_test.go b/process/peer/process_test.go index 4a3bd5a212b..fde64825452 100644 --- a/process/peer/process_test.go +++ b/process/peer/process_test.go @@ -467,8 +467,8 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateComputeValidatorErrShouldEr arguments := createMockArguments() arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, computeValidatorsErr + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, computeValidatorsErr }, } validatorStatistics, _ := peer.NewValidatorStatisticsProcessor(arguments) @@ -492,9 +492,10 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetExistingAccountErr(t *te } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{validator}, nil }, } arguments.PeerAdapter = adapter @@ -517,9 +518,10 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetExistingAccountInvalidTy } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{validator}, nil }, } arguments.PeerAdapter = adapter @@ -561,9 +563,11 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetHeaderError(t *testing.T }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -617,9 +621,15 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCallsIncrease(t *testing.T) }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + return []byte("pk1") + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { return []byte("pk2") }}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1289,9 +1299,11 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCheckForMissedBlocksErr(t * }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1357,9 +1369,9 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksNoMissedBlocks(t *test arguments.DataPool = dataRetrieverMock.NewPoolsHolderStub() arguments.StorageService = &storageStubs.ChainStorerStub{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { computeValidatorGroupCalled = true - return nil, nil + return nil, nil, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1443,8 +1455,8 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksErrOnComputeValidatorL arguments.DataPool = dataRetrieverMock.NewPoolsHolderStub() arguments.StorageService = &storageStubs.ChainStorerStub{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, computeErr + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, computeErr }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1470,10 +1482,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksErrOnDecrease(t *testi } arguments := createMockArguments() + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{ + validator1, }, nil }, } @@ -1504,14 +1517,15 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksCallsDecrease(t *testi } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + return pubKey + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{ - PubKeyCalled: func() []byte { - return pubKey - }, - }, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, } @@ -1555,10 +1569,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksWithRoundDifferenceGre } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { @@ -1614,10 +1629,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksWithRoundDifferenceGre } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { @@ -1816,8 +1832,8 @@ func DoComputeMissingBlocks( arguments := createMockArguments() arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return consensus, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return consensus[0], consensus, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { return validatorPublicKeys, nil @@ -1891,14 +1907,18 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCallsPubKeyForValidator(t * pubKeyCalled := false arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + pubKeyCalled = true + return make([]byte, 0) + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{ - PubKeyCalled: func() []byte { - pubKeyCalled = true - return make([]byte, 0) - }, - }, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, + &shardingMocks.ValidatorMock{}, + }, nil }, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -2603,13 +2623,13 @@ func createUpdateTestArgs(consensusGroup map[string][]nodesCoordinator.Validator arguments.PeerAdapter = adapter arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { key := fmt.Sprintf(consensusGroupFormat, string(randomness), round, shardId, epoch) validatorsArray, ok := consensusGroup[key] if !ok { - return nil, process.ErrEmptyConsensusGroup + return nil, nil, process.ErrEmptyConsensusGroup } - return validatorsArray, nil + return validatorsArray[0], validatorsArray, nil }, } return arguments diff --git a/process/rating/peerHonesty/peerHonesty_test.go b/process/rating/peerHonesty/peerHonesty_test.go index 73ca45e2623..0d7cf263ca6 100644 --- a/process/rating/peerHonesty/peerHonesty_test.go +++ b/process/rating/peerHonesty/peerHonesty_test.go @@ -7,9 +7,12 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -44,7 +47,7 @@ func TestNewP2pPeerHonesty_NilBlacklistedPkCacheShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( createMockPeerHonestyConfig(), nil, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -59,7 +62,7 @@ func TestNewP2pPeerHonesty_InvalidDecayCoefficientShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -74,7 +77,7 @@ func TestNewP2pPeerHonesty_InvalidDecayUpdateIntervalShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -89,7 +92,7 @@ func TestNewP2pPeerHonesty_InvalidMinScoreShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -104,7 +107,7 @@ func TestNewP2pPeerHonesty_InvalidMaxScoreShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -119,7 +122,7 @@ func TestNewP2pPeerHonesty_InvalidUnitValueShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -134,7 +137,7 @@ func TestNewP2pPeerHonesty_InvalidBadPeerThresholdShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -148,7 +151,7 @@ func TestNewP2pPeerHonesty_ShouldWork(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.False(t, check.IfNil(pph)) @@ -167,7 +170,7 @@ func TestP2pPeerHonesty_Close(t *testing.T) { pph, _ := NewP2pPeerHonestyWithCustomExecuteDelayFunction( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, handler, ) @@ -189,7 +192,7 @@ func TestP2pPeerHonesty_ChangeScoreShouldWork(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -210,7 +213,7 @@ func TestP2pPeerHonesty_DoubleChangeScoreShouldWork(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -243,7 +246,7 @@ func TestP2pPeerHonesty_CheckBlacklistNotBlacklisted(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -275,7 +278,7 @@ func TestP2pPeerHonesty_CheckBlacklistMaxScoreReached(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -310,7 +313,7 @@ func TestP2pPeerHonesty_CheckBlacklistMinScoreReached(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -345,7 +348,7 @@ func TestP2pPeerHonesty_CheckBlacklistHasShouldNotCallUpsert(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -374,7 +377,7 @@ func TestP2pPeerHonesty_CheckBlacklistUpsertErrorsShouldWork(t *testing.T) { return errors.New("expected error") }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -392,7 +395,7 @@ func TestP2pPeerHonesty_ApplyDecay(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pks := []string{"pkMin", "pkMax", "pkNearZero", "pkZero", "pkValue"} @@ -422,7 +425,7 @@ func TestP2pPeerHonesty_ApplyDecayWillEventuallyGoTheScoreToZero(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" diff --git a/process/smartContract/hooks/blockChainHook_test.go b/process/smartContract/hooks/blockChainHook_test.go index 92636c1baf0..fd46e206498 100644 --- a/process/smartContract/hooks/blockChainHook_test.go +++ b/process/smartContract/hooks/blockChainHook_test.go @@ -15,6 +15,13 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -26,6 +33,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" @@ -33,12 +41,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockBlockChainHookArgs() hooks.ArgBlockChainHook { @@ -1258,7 +1260,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args := createMockBlockChainHookArgs() wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return code, true @@ -1280,7 +1282,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args.NilCompiledSCStore = true wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return struct{}{}, true @@ -1313,7 +1315,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { }, } wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return nil, false @@ -1350,7 +1352,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { }, } args.NilCompiledSCStore = false - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return nil, false @@ -2213,7 +2215,7 @@ func TestBlockChainHookImpl_ClearCompiledCodes(t *testing.T) { args.EnableEpochs.IsPayableBySCEnableEpoch = 11 clearCalled := 0 - args.CompiledSCPool = &testscommon.CacherStub{ClearCalled: func() { + args.CompiledSCPool = &cache.CacherStub{ClearCalled: func() { clearCalled++ }} diff --git a/process/sync/argBootstrapper.go b/process/sync/argBootstrapper.go index ec3f64a58d8..587ecedd258 100644 --- a/process/sync/argBootstrapper.go +++ b/process/sync/argBootstrapper.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dblookupext" @@ -48,6 +49,7 @@ type ArgBaseBootstrapper struct { ScheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler ProcessWaitTime time.Duration RepopulateTokensSupplies bool + EnableEpochsHandler common.EnableEpochsHandler } // ArgShardBootstrapper holds all dependencies required by the bootstrap data factory in order to create diff --git a/process/sync/baseSync.go b/process/sync/baseSync.go index aa43d8cecc1..cf13638912f 100644 --- a/process/sync/baseSync.go +++ b/process/sync/baseSync.go @@ -3,6 +3,7 @@ package sync import ( "bytes" "context" + "encoding/hex" "fmt" "math" "sync" @@ -57,21 +58,23 @@ type notarizedInfo struct { type baseBootstrap struct { historyRepo dblookupext.HistoryRepository headers dataRetriever.HeadersPool + proofs dataRetriever.ProofsPool chainHandler data.ChainHandler blockProcessor process.BlockProcessor store dataRetriever.StorageService - roundHandler consensus.RoundHandler - hasher hashing.Hasher - marshalizer marshal.Marshalizer - epochHandler dataRetriever.EpochHandler - forkDetector process.ForkDetector - requestHandler process.RequestHandler - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - blockBootstrapper blockBootstrapper - blackListHandler process.TimeCacher + roundHandler consensus.RoundHandler + hasher hashing.Hasher + marshalizer marshal.Marshalizer + epochHandler dataRetriever.EpochHandler + forkDetector process.ForkDetector + requestHandler process.RequestHandler + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + blockBootstrapper blockBootstrapper + blackListHandler process.TimeCacher + enableEpochsHandler common.EnableEpochsHandler mutHeader sync.RWMutex headerNonce *uint64 @@ -491,6 +494,9 @@ func checkBaseBootstrapParameters(arguments ArgBaseBootstrapper) error { if arguments.ProcessWaitTime < minimumProcessWaitTime { return fmt.Errorf("%w, minimum is %v, provided is %v", process.ErrInvalidProcessWaitTime, minimumProcessWaitTime, arguments.ProcessWaitTime) } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } return nil } @@ -630,13 +636,18 @@ func (boot *baseBootstrap) syncBlock() error { } }() - header, err = boot.getNextHeaderRequestingIfMissing() + header, headerHash, err := boot.getNextHeaderRequestingIfMissing() if err != nil { return err } go boot.requestHeadersFromNonceIfMissing(header.GetNonce() + 1) + err = boot.handleEquivalentProof(header, headerHash) + if err != nil { + return err + } + body, err = boot.blockBootstrapper.getBlockBodyRequestingIfMissing(header) if err != nil { return err @@ -687,6 +698,47 @@ func (boot *baseBootstrap) syncBlock() error { ) boot.cleanNoncesSyncedWithErrorsBehindFinal() + boot.cleanProofsBehindFinal(header) + + return nil +} + +func (boot *baseBootstrap) handleEquivalentProof( + header data.HeaderHandler, + headerHash []byte, +) error { + if !boot.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + return nil + } + + prevHeader, err := boot.blockBootstrapper.getHeaderWithHashRequestingIfMissing(header.GetPrevHash()) + if err != nil { + return err + } + + if !boot.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, prevHeader.GetEpoch()) { + // no need to check proof for first block after activation + log.Info("handleEquivalentProof: no need to check equivalent proof for first activation block") + return nil + } + + // process block only if there is a proof for it + hasProof := boot.proofs.HasProof(header.GetShardID(), headerHash) + if hasProof { + return nil + } + + log.Trace("baseBootstrap.handleEquivalentProof: did not have proof for header, will try again", "headerHash", headerHash) + + _, _, err = boot.blockBootstrapper.getHeaderWithNonceRequestingIfMissing(header.GetNonce() + 1) + if err != nil { + return err + } + + hasProof = boot.proofs.HasProof(header.GetShardID(), headerHash) + if !hasProof { + return fmt.Errorf("baseBootstrap.handleEquivalentProof: did not have proof for header, headerHash %s", hex.EncodeToString(headerHash)) + } return nil } @@ -715,6 +767,25 @@ func (boot *baseBootstrap) cleanNoncesSyncedWithErrorsBehindFinal() { } } +func (boot *baseBootstrap) cleanProofsBehindFinal(header data.HeaderHandler) { + if !boot.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, header.GetEpoch()) { + return + } + + // TODO: analyse fork detection by proofs + finalNonce := boot.forkDetector.GetHighestFinalBlockNonce() + + err := boot.proofs.CleanupProofsBehindNonce(header.GetShardID(), finalNonce) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", finalNonce, + "shardID", header.GetShardID(), + "error", err) + } + + log.Trace("baseBootstrap.cleanProofsBehindFinal cleanup successfully", "finalNonce", finalNonce) +} + // rollBack decides if rollBackOneBlock must be called func (boot *baseBootstrap) rollBack(revertUsingForkNonce bool) error { var roleBackOneBlockExecuted bool @@ -935,7 +1006,7 @@ func (boot *baseBootstrap) getRootHashFromBlock(hdr data.HeaderHandler, hdrHash return hdrRootHash } -func (boot *baseBootstrap) getNextHeaderRequestingIfMissing() (data.HeaderHandler, error) { +func (boot *baseBootstrap) getNextHeaderRequestingIfMissing() (data.HeaderHandler, []byte, error) { nonce := boot.getNonceForNextBlock() boot.setRequestedHeaderHash(nil) @@ -947,7 +1018,8 @@ func (boot *baseBootstrap) getNextHeaderRequestingIfMissing() (data.HeaderHandle } if hash != nil { - return boot.blockBootstrapper.getHeaderWithHashRequestingIfMissing(hash) + header, err := boot.blockBootstrapper.getHeaderWithHashRequestingIfMissing(hash) + return header, hash, err } return boot.blockBootstrapper.getHeaderWithNonceRequestingIfMissing(nonce) diff --git a/process/sync/export_test.go b/process/sync/export_test.go index 719e7599f9f..16a91ead8b3 100644 --- a/process/sync/export_test.go +++ b/process/sync/export_test.go @@ -288,3 +288,11 @@ func (boot *baseBootstrap) IsInImportMode() bool { func (boot *baseBootstrap) ProcessWaitTime() time.Duration { return boot.processWaitTime } + +// HandleEquivalentProof - +func (boot *baseBootstrap) HandleEquivalentProof( + header data.HeaderHandler, + headerHash []byte, +) error { + return boot.handleEquivalentProof(header, headerHash) +} diff --git a/process/sync/interface.go b/process/sync/interface.go index 88f644df160..d672cafb88b 100644 --- a/process/sync/interface.go +++ b/process/sync/interface.go @@ -13,7 +13,7 @@ type blockBootstrapper interface { getPrevHeader(data.HeaderHandler, storage.Storer) (data.HeaderHandler, error) getBlockBody(headerHandler data.HeaderHandler) (data.BodyHandler, error) getHeaderWithHashRequestingIfMissing(hash []byte) (data.HeaderHandler, error) - getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) + getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, []byte, error) haveHeaderInPoolWithNonce(nonce uint64) bool getBlockBodyRequestingIfMissing(headerHandler data.HeaderHandler) (data.BodyHandler, error) isForkTriggeredByMeta() bool diff --git a/process/sync/metablock.go b/process/sync/metablock.go index 1b3c69c7386..72fc8a8688b 100644 --- a/process/sync/metablock.go +++ b/process/sync/metablock.go @@ -31,6 +31,9 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { if check.IfNil(arguments.PoolsHolder.Headers()) { return nil, process.ErrNilMetaBlocksPool } + if check.IfNil(arguments.PoolsHolder.Proofs()) { + return nil, process.ErrNilProofsPool + } if check.IfNil(arguments.EpochBootstrapper) { return nil, process.ErrNilEpochStartTrigger } @@ -54,6 +57,7 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { blockProcessor: arguments.BlockProcessor, store: arguments.Store, headers: arguments.PoolsHolder.Headers(), + proofs: arguments.PoolsHolder.Proofs(), roundHandler: arguments.RoundHandler, waitTime: arguments.WaitTime, hasher: arguments.Hasher, @@ -78,6 +82,7 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { historyRepo: arguments.HistoryRepo, scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, processWaitTime: arguments.ProcessWaitTime, + enableEpochsHandler: arguments.EnableEpochsHandler, } if base.isInImportMode { @@ -243,8 +248,8 @@ func (boot *MetaBootstrap) requestHeaderWithHash(hash []byte) { // getHeaderWithNonceRequestingIfMissing method gets the header with a given nonce from pool. If it is not found there, it will // be requested from network -func (boot *MetaBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) { - hdr, _, err := process.GetMetaHeaderFromPoolWithNonce( +func (boot *MetaBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, []byte, error) { + hdr, hash, err := process.GetMetaHeaderFromPoolWithNonce( nonce, boot.headers) if err != nil { @@ -252,18 +257,18 @@ func (boot *MetaBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) ( boot.requestHeaderWithNonce(nonce) err = boot.waitForHeaderNonce() if err != nil { - return nil, err + return nil, nil, err } - hdr, _, err = process.GetMetaHeaderFromPoolWithNonce( + hdr, hash, err = process.GetMetaHeaderFromPoolWithNonce( nonce, boot.headers) if err != nil { - return nil, err + return nil, nil, err } } - return hdr, nil + return hdr, hash, nil } // getHeaderWithHashRequestingIfMissing method gets the header with a given hash from pool. If it is not found there, diff --git a/process/sync/metablock_test.go b/process/sync/metablock_test.go index 6d183fbf821..73386a021f1 100644 --- a/process/sync/metablock_test.go +++ b/process/sync/metablock_test.go @@ -15,6 +15,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/round" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -24,14 +27,15 @@ import ( "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/outport" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMetaBlockProcessor(blk data.ChainHandler) *testscommon.BlockProcessorStub { @@ -92,6 +96,7 @@ func CreateMetaBootstrapMockArguments() sync.ArgMetaBootstrapper { ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: testProcessWaitTime, RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ @@ -170,6 +175,22 @@ func TestNewMetaBootstrap_PoolsHolderRetNilOnHeadersShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilMetaBlocksPool, err) } +func TestNewMetaBootstrap_NilProofsPool(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaBootstrap_NilStoreShouldErr(t *testing.T) { t.Parallel() @@ -386,6 +407,34 @@ func TestNewMetaBootstrap_InvalidProcessTimeShouldErr(t *testing.T) { assert.True(t, errors.Is(err, process.ErrInvalidProcessWaitTime)) } +func TestNewMetaBootstrap_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = nil + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.True(t, errors.Is(err, process.ErrNilEnableEpochsHandler)) +} + +func TestNewMetaBootstrap_PoolsHolderRetNilOnProofsShouldErr(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaBootstrap_MissingStorer(t *testing.T) { t.Parallel() @@ -652,7 +701,7 @@ func TestMetaBootstrap_ShouldReturnNilErr(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - sds := &testscommon.CacherStub{ + sds := &cache.CacherStub{ HasOrAddCalled: func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { return false, true }, @@ -1810,3 +1859,333 @@ func TestMetaBootstrap_SyncAccountsDBs(t *testing.T) { require.True(t, accountsSyncCalled) }) } + +func TestMetaBootstrap_HandleEquivalentProof(t *testing.T) { + t.Parallel() + + prevHeaderHash1 := []byte("prevHeaderHash") + headerHash1 := []byte("headerHash") + + t.Run("flag not activated, should return direclty", func(t *testing.T) { + t.Parallel() + + header := &block.MetaBlock{ + Nonce: 11, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Nil(t, err) + }) + + t.Run("should return nil if first block after activation", func(t *testing.T) { + t.Parallel() + + prevHeader := &block.MetaBlock{ + Epoch: 3, + Nonce: 10, + } + + header := &block.MetaBlock{ + Epoch: 4, + Nonce: 11, + PrevHash: prevHeaderHash1, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + if epoch == 4 { + return flag == common.EquivalentMessagesFlag + } + + return false + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, prevHeaderHash1) { + return prevHeader, nil + } + + return nil, sync.ErrHeaderNotFound + } + + return sds + } + + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Nil(t, err) + }) + + t.Run("should work, proof already in pool", func(t *testing.T) { + t.Parallel() + + prevHeader := &block.MetaBlock{ + Nonce: 10, + } + + header := &block.MetaBlock{ + Nonce: 11, + PrevHash: prevHeaderHash1, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, prevHeaderHash1) { + return prevHeader, nil + } + + return nil, sync.ErrHeaderNotFound + } + + return sds + } + + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header, headerHash1) + require.Nil(t, err) + }) + + t.Run("should work, by checking for next header", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.MetaBlock{ + Nonce: 10, + } + + header2 := &block.MetaBlock{ + Nonce: 11, + PrevHash: headerHash1, + } + + header3 := &block.MetaBlock{ + Nonce: 12, + PrevHash: headerHash2, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + if hdrNonce == header2.GetNonce()+1 { + return []data.HeaderHandler{header3}, [][]byte{headerHash2}, nil + } + + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled == 0 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Nil(t, err) + }) + + t.Run("should return err if failing to get proof after second request", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.MetaBlock{ + Nonce: 10, + } + + header2 := &block.MetaBlock{ + Nonce: 11, + PrevHash: headerHash1, + } + + header3 := &block.MetaBlock{ + Nonce: 12, + PrevHash: headerHash2, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + if hdrNonce == header2.GetNonce()+1 { + return []data.HeaderHandler{header3}, [][]byte{headerHash2}, nil + } + + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled < 2 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Error(t, err) + }) + + t.Run("should return err if failing to request next header", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash1") + headerHash2 := []byte("headerHash2") + + header1 := &block.MetaBlock{ + Nonce: 10, + } + + header2 := &block.MetaBlock{ + Nonce: 11, + PrevHash: headerHash1, + } + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.EquivalentMessagesFlag + }, + } + + pools := createMockPools() + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, headerHash1) { + return header1, nil + } + + return nil, sync.ErrHeaderNotFound + } + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + return nil, nil, process.ErrMissingHeader + } + + return sds + } + + hasProofCalled := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if hasProofCalled < 2 { + hasProofCalled++ + return false + } + + return true + }, + } + } + + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + + err = bs.HandleEquivalentProof(header2, headerHash2) + require.Error(t, err) + }) +} diff --git a/process/sync/shardblock.go b/process/sync/shardblock.go index 8cca3954ef0..10a3492d024 100644 --- a/process/sync/shardblock.go +++ b/process/sync/shardblock.go @@ -27,6 +27,9 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) if check.IfNil(arguments.PoolsHolder.Headers()) { return nil, process.ErrNilHeadersDataPool } + if check.IfNil(arguments.PoolsHolder.Proofs()) { + return nil, process.ErrNilProofsPool + } if check.IfNil(arguments.PoolsHolder.MiniBlocks()) { return nil, process.ErrNilTxBlockBody } @@ -41,6 +44,7 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) blockProcessor: arguments.BlockProcessor, store: arguments.Store, headers: arguments.PoolsHolder.Headers(), + proofs: arguments.PoolsHolder.Proofs(), roundHandler: arguments.RoundHandler, waitTime: arguments.WaitTime, hasher: arguments.Hasher, @@ -66,6 +70,7 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, processWaitTime: arguments.ProcessWaitTime, repopulateTokensSupplies: arguments.RepopulateTokensSupplies, + enableEpochsHandler: arguments.EnableEpochsHandler, } if base.isInImportMode { @@ -196,8 +201,8 @@ func (boot *ShardBootstrap) requestHeaderWithHash(hash []byte) { // getHeaderWithNonceRequestingIfMissing method gets the header with a given nonce from pool. If it is not found there, it will // be requested from network -func (boot *ShardBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) { - hdr, _, err := process.GetShardHeaderFromPoolWithNonce( +func (boot *ShardBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, []byte, error) { + hdr, hash, err := process.GetShardHeaderFromPoolWithNonce( nonce, boot.shardCoordinator.SelfId(), boot.headers) @@ -206,19 +211,19 @@ func (boot *ShardBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) boot.requestHeaderWithNonce(nonce) err = boot.waitForHeaderNonce() if err != nil { - return nil, err + return nil, nil, err } - hdr, _, err = process.GetShardHeaderFromPoolWithNonce( + hdr, hash, err = process.GetShardHeaderFromPoolWithNonce( nonce, boot.shardCoordinator.SelfId(), boot.headers) if err != nil { - return nil, err + return nil, nil, err } } - return hdr, nil + return hdr, hash, nil } // getHeaderWithHashRequestingIfMissing method gets the header with a given hash from pool. If it is not found there, diff --git a/process/sync/shardblock_test.go b/process/sync/shardblock_test.go index 070b926df0f..fbf974c1ee4 100644 --- a/process/sync/shardblock_test.go +++ b/process/sync/shardblock_test.go @@ -16,6 +16,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/round" @@ -28,15 +31,15 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/outport" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // waitTime defines the time in milliseconds until node waits the requested info from the network @@ -55,7 +58,7 @@ func createMockPools() *dataRetrieverMock.PoolsHolderStub { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - cs := &testscommon.CacherStub{ + cs := &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -63,6 +66,9 @@ func createMockPools() *dataRetrieverMock.PoolsHolderStub { } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } return pools } @@ -219,6 +225,7 @@ func CreateShardBootstrapMockArguments() sync.ArgShardBootstrapper { ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: testProcessWaitTime, RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -270,6 +277,22 @@ func TestNewShardBootstrap_PoolsHolderRetNilOnHeadersShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilHeadersDataPool, err) } +func TestNewShardBootstrap_NilProofsPool(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardBootstrap_PoolsHolderRetNilOnTxBlockBodyShouldErr(t *testing.T) { t.Parallel() @@ -442,6 +465,34 @@ func TestNewShardBootstrap_InvalidProcessTimeShouldErr(t *testing.T) { assert.True(t, errors.Is(err, process.ErrInvalidProcessWaitTime)) } +func TestNewShardBootstrap_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = nil + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.True(t, errors.Is(err, process.ErrNilEnableEpochsHandler)) +} + +func TestNewShardBootstrap_PoolsHolderRetNilOnProofsShouldErr(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardBootstrap_MissingStorer(t *testing.T) { t.Parallel() @@ -491,13 +542,17 @@ func TestNewShardBootstrap_OkValsShouldWork(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { wasCalled++ } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + args.PoolsHolder = pools args.IsInImportMode = true bs, err := sync.NewShardBootstrap(args) @@ -708,7 +763,7 @@ func TestBootstrap_SyncShouldSyncOneBlock(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -721,6 +776,10 @@ func TestBootstrap_SyncShouldSyncOneBlock(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -803,7 +862,7 @@ func TestBootstrap_ShouldReturnNilErr(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -816,6 +875,9 @@ func TestBootstrap_ShouldReturnNilErr(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -885,7 +947,7 @@ func TestBootstrap_SyncBlockShouldReturnErrorWhenProcessBlockFailed(t *testing.T return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -898,6 +960,9 @@ func TestBootstrap_SyncBlockShouldReturnErrorWhenProcessBlockFailed(t *testing.T return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -1874,12 +1939,15 @@ func TestShardBootstrap_RequestMiniBlocksFromHeaderWithNonceIfMissing(t *testing return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools blkc := initBlockchain() @@ -2093,7 +2161,7 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -2106,6 +2174,9 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -2144,9 +2215,10 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return []byte("roothash"), nil }} - bs, _ := sync.NewShardBootstrap(args) + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) - err := bs.SyncBlock(context.Background()) + err = bs.SyncBlock(context.Background()) assert.Equal(t, errGetNodeFromDB, err) assert.True(t, syncCalled) } diff --git a/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go b/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go index 0d5eee28a06..686b49031d1 100644 --- a/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go +++ b/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go @@ -7,16 +7,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/throttle/antiflood/blackList" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) const selfPid = "current pid" -//-------- NewP2PQuotaBlacklistProcessor +// -------- NewP2PQuotaBlacklistProcessor func TestNewP2PQuotaBlacklistProcessor_NilCacherShouldErr(t *testing.T) { t.Parallel() @@ -40,7 +42,7 @@ func TestNewP2PQuotaBlacklistProcessor_NilBlackListHandlerShouldErr(t *testing.T t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), nil, 1, 1, @@ -58,7 +60,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidThresholdNumReceivedFloodShouldErr t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 0, 1, @@ -76,7 +78,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidThresholdSizeReceivedFloodShouldEr t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 0, @@ -94,7 +96,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidNumFloodingRoundsShouldErr(t *test t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -112,7 +114,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidBanDurationShouldErr(t *testing.T) t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -130,7 +132,7 @@ func TestNewP2PQuotaBlacklistProcessor_ShouldWork(t *testing.T) { t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -144,7 +146,7 @@ func TestNewP2PQuotaBlacklistProcessor_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- AddQuota +// ------- AddQuota func TestP2PQuotaBlacklistProcessor_AddQuotaUnderThresholdShouldNotCallGetOrPut(t *testing.T) { t.Parallel() @@ -153,7 +155,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaUnderThresholdShouldNotCallGetOrPut( thresholdSize := uint64(20) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { assert.Fail(t, "should not have called get") return nil, false @@ -184,7 +186,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaOverThresholdInexistentDataOnGetShou putCalled := false identifier := core.PeerID("identifier") pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return nil, false }, @@ -219,7 +221,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaOverThresholdDataNotValidOnGetShould putCalled := false identifier := core.PeerID("identifier") pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return "invalid data", true }, @@ -255,7 +257,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaShouldIncrement(t *testing.T) { identifier := core.PeerID("identifier") existingValue := uint32(445) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return existingValue, true }, @@ -290,7 +292,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaForSelfShouldNotIncrement(t *testing putCalled := false existingValue := uint32(445) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return existingValue, true }, @@ -313,7 +315,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaForSelfShouldNotIncrement(t *testing assert.False(t, putCalled) } -//------- ResetStatistics +// ------- ResetStatistics func TestP2PQuotaBlacklistProcessor_ResetStatisticsRemoveNilValueKey(t *testing.T) { t.Parallel() @@ -324,7 +326,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsRemoveNilValueKey(t *testing. nilValKey := "nil val key" removedCalled := false pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(nilValKey)} }, @@ -360,7 +362,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsShouldRemoveInvalidValueKey(t invalidValKey := "invalid val key" removedCalled := false pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(invalidValKey)} }, @@ -399,7 +401,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsUnderNumFloodingRoundsShouldN upsertCalled := false duration := time.Second * 3892 pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(key)} }, @@ -444,7 +446,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsOverNumFloodingRoundsShouldBl upsertCalled := false duration := time.Second * 3892 pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(key)} }, diff --git a/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go b/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go index 068ba97591d..5dc21b68e35 100644 --- a/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go +++ b/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go @@ -9,16 +9,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) func createDefaultArgument() ArgQuotaFloodPreventer { return ArgQuotaFloodPreventer{ Name: "test", - Cacher: testscommon.NewCacherStub(), + Cacher: cache.NewCacherStub(), StatusHandlers: []QuotaStatusHandler{&mock.QuotaStatusHandlerStub{}}, BaseMaxNumMessagesPerPeer: minMessages, MaxTotalSizePerPeer: minTotalSize, @@ -28,7 +30,7 @@ func createDefaultArgument() ArgQuotaFloodPreventer { } } -//------- NewQuotaFloodPreventer +// ------- NewQuotaFloodPreventer func TestNewQuotaFloodPreventer_NilCacherShouldErr(t *testing.T) { t.Parallel() @@ -128,7 +130,7 @@ func TestNewQuotaFloodPreventer_NilListShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- IncreaseLoad +// ------- IncreaseLoad func TestNewQuotaFloodPreventer_IncreaseLoadIdentifierNotPresentPutQuotaAndReturnTrue(t *testing.T) { t.Parallel() @@ -136,7 +138,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadIdentifierNotPresentPutQuotaAndRetur putWasCalled := false size := uint64(minTotalSize * 5) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -168,7 +170,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadNotQuotaSavedInCacheShouldPutQuotaAn putWasCalled := false size := uint64(minTotalSize * 5) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return "bad value", true }, @@ -205,7 +207,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadUnderMaxValuesShouldIncrementAndRetu } size := uint64(minTotalSize * 2) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -219,7 +221,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadUnderMaxValuesShouldIncrementAndRetu assert.Nil(t, err) } -//------- IncreaseLoad per peer +// ------- IncreaseLoad per peer func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerNumMessagesShouldNotPutAndReturnFalse(t *testing.T) { t.Parallel() @@ -231,7 +233,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerNumMessagesShouldNotPutAn sizeReceivedMessages: existingSize, } arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -260,7 +262,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerSizeShouldNotPutAndReturn sizeReceivedMessages: existingSize, } arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -284,7 +286,7 @@ func TestCountersMap_IncreaseLoadShouldWorkConcurrently(t *testing.T) { numIterations := 1000 arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() qfp, _ := NewQuotaFloodPreventer(arg) wg := sync.WaitGroup{} wg.Add(numIterations) @@ -299,14 +301,14 @@ func TestCountersMap_IncreaseLoadShouldWorkConcurrently(t *testing.T) { wg.Wait() } -//------- Reset +// ------- Reset func TestCountersMap_ResetShouldCallCacherClear(t *testing.T) { t.Parallel() clearCalled := false arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ ClearCalled: func() { clearCalled = true }, @@ -324,7 +326,7 @@ func TestCountersMap_ResetShouldCallCacherClear(t *testing.T) { func TestCountersMap_ResetShouldCallQuotaStatus(t *testing.T) { t.Parallel() - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() key1 := core.PeerID("key1") quota1 := "a{ numReceivedMessages: 1, @@ -391,7 +393,7 @@ func TestCountersMap_IncrementAndResetShouldWorkConcurrently(t *testing.T) { numIterations := 1000 arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() qfp, _ := NewQuotaFloodPreventer(arg) wg := sync.WaitGroup{} wg.Add(numIterations + numIterations/10) @@ -418,7 +420,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadWithMockCacherShouldWork(t *testing. numMessages := uint32(100) arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() arg.BaseMaxNumMessagesPerPeer = numMessages arg.MaxTotalSizePerPeer = math.MaxUint64 arg.PercentReserved = float32(17) @@ -437,7 +439,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadWithMockCacherShouldWork(t *testing. } } -//------- ApplyConsensusSize +// ------- ApplyConsensusSize func TestQuotaFloodPreventer_ApplyConsensusSizeInvalidConsensusSize(t *testing.T) { t.Parallel() @@ -468,7 +470,7 @@ func TestQuotaFloodPreventer_ApplyConsensusShouldWork(t *testing.T) { t.Parallel() arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() arg.BaseMaxNumMessagesPerPeer = 2000 arg.IncreaseThreshold = 1000 arg.IncreaseFactor = 0.25 diff --git a/process/track/argBlockProcessor.go b/process/track/argBlockProcessor.go index 0b7b02b20c9..60a4b17edf5 100644 --- a/process/track/argBlockProcessor.go +++ b/process/track/argBlockProcessor.go @@ -1,6 +1,9 @@ package track import ( + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -19,4 +22,8 @@ type ArgBlockProcessor struct { SelfNotarizedHeadersNotifier blockNotifierHandler FinalMetachainHeadersNotifier blockNotifierHandler RoundHandler process.RoundHandler + EnableEpochsHandler common.EnableEpochsHandler + ProofsPool process.ProofsPool + Marshaller marshal.Marshalizer + Hasher hashing.Hasher } diff --git a/process/track/argBlockTrack.go b/process/track/argBlockTrack.go index ea655d3937b..c44bb6254b7 100644 --- a/process/track/argBlockTrack.go +++ b/process/track/argBlockTrack.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" @@ -12,17 +13,19 @@ import ( // ArgBaseTracker holds all dependencies required by the process data factory in order to create // new instances of shard/meta block tracker type ArgBaseTracker struct { - Hasher hashing.Hasher - HeaderValidator process.HeaderConstructionValidator - Marshalizer marshal.Marshalizer - RequestHandler process.RequestHandler - RoundHandler process.RoundHandler - ShardCoordinator sharding.Coordinator - Store dataRetriever.StorageService - StartHeaders map[uint32]data.HeaderHandler - PoolsHolder dataRetriever.PoolsHolder - WhitelistHandler process.WhiteListHandler - FeeHandler process.FeeHandler + Hasher hashing.Hasher + HeaderValidator process.HeaderConstructionValidator + Marshalizer marshal.Marshalizer + RequestHandler process.RequestHandler + RoundHandler process.RoundHandler + ShardCoordinator sharding.Coordinator + Store dataRetriever.StorageService + StartHeaders map[uint32]data.HeaderHandler + PoolsHolder dataRetriever.PoolsHolder + WhitelistHandler process.WhiteListHandler + FeeHandler process.FeeHandler + EnableEpochsHandler common.EnableEpochsHandler + ProofsPool process.ProofsPool } // ArgShardTracker holds all dependencies required by the process data factory in order to create diff --git a/process/track/baseBlockTrack_test.go b/process/track/baseBlockTrack_test.go index 8c919cd9ee7..b32b943faf9 100644 --- a/process/track/baseBlockTrack_test.go +++ b/process/track/baseBlockTrack_test.go @@ -22,6 +22,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" @@ -122,17 +123,19 @@ func CreateShardTrackerMockArguments() track.ArgShardTracker { arguments := track.ArgShardTracker{ ArgBaseTracker: track.ArgBaseTracker{ - Hasher: &hashingMocks.HasherMock{}, - HeaderValidator: headerValidator, - Marshalizer: &mock.MarshalizerMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - RoundHandler: &mock.RoundHandlerMock{}, - ShardCoordinator: shardCoordinatorMock, - Store: initStore(), - StartHeaders: genesisBlocks, - PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), - WhitelistHandler: whitelistHandler, - FeeHandler: feeHandler, + Hasher: &hashingMocks.HasherMock{}, + HeaderValidator: headerValidator, + Marshalizer: &mock.MarshalizerMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + RoundHandler: &mock.RoundHandlerMock{}, + ShardCoordinator: shardCoordinatorMock, + Store: initStore(), + StartHeaders: genesisBlocks, + PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), + WhitelistHandler: whitelistHandler, + FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, }, } @@ -160,17 +163,19 @@ func CreateMetaTrackerMockArguments() track.ArgMetaTracker { arguments := track.ArgMetaTracker{ ArgBaseTracker: track.ArgBaseTracker{ - Hasher: &hashingMocks.HasherMock{}, - HeaderValidator: headerValidator, - Marshalizer: &mock.MarshalizerMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - RoundHandler: &mock.RoundHandlerMock{}, - ShardCoordinator: shardCoordinatorMock, - Store: initStore(), - StartHeaders: genesisBlocks, - PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), - WhitelistHandler: whitelistHandler, - FeeHandler: feeHandler, + Hasher: &hashingMocks.HasherMock{}, + HeaderValidator: headerValidator, + Marshalizer: &mock.MarshalizerMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + RoundHandler: &mock.RoundHandlerMock{}, + ShardCoordinator: shardCoordinatorMock, + Store: initStore(), + StartHeaders: genesisBlocks, + PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), + WhitelistHandler: whitelistHandler, + FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, }, } @@ -195,15 +200,17 @@ func CreateBaseTrackerMockArguments() track.ArgBaseTracker { } arguments := track.ArgBaseTracker{ - Hasher: &hashingMocks.HasherMock{}, - HeaderValidator: headerValidator, - Marshalizer: &mock.MarshalizerMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - RoundHandler: &mock.RoundHandlerMock{}, - ShardCoordinator: shardCoordinatorMock, - Store: initStore(), - StartHeaders: genesisBlocks, - FeeHandler: feeHandler, + Hasher: &hashingMocks.HasherMock{}, + HeaderValidator: headerValidator, + Marshalizer: &mock.MarshalizerMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + RoundHandler: &mock.RoundHandlerMock{}, + ShardCoordinator: shardCoordinatorMock, + Store: initStore(), + StartHeaders: genesisBlocks, + FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, } return arguments diff --git a/process/track/blockProcessor.go b/process/track/blockProcessor.go index e24ff02e35d..11b1d9aef3f 100644 --- a/process/track/blockProcessor.go +++ b/process/track/blockProcessor.go @@ -4,9 +4,11 @@ import ( "sort" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -25,6 +27,11 @@ type blockProcessor struct { finalMetachainHeadersNotifier blockNotifierHandler roundHandler process.RoundHandler + enableEpochsHandler common.EnableEpochsHandler + proofsPool process.ProofsPool + marshaller marshal.Marshalizer + hasher hashing.Hasher + blockFinality uint64 } @@ -47,6 +54,10 @@ func NewBlockProcessor(arguments ArgBlockProcessor) (*blockProcessor, error) { selfNotarizedHeadersNotifier: arguments.SelfNotarizedHeadersNotifier, finalMetachainHeadersNotifier: arguments.FinalMetachainHeadersNotifier, roundHandler: arguments.RoundHandler, + enableEpochsHandler: arguments.EnableEpochsHandler, + proofsPool: arguments.ProofsPool, + marshaller: arguments.Marshaller, + hasher: arguments.Hasher, } bp.blockFinality = process.BlockFinality @@ -154,7 +165,7 @@ func (bp *blockProcessor) doJobOnReceivedMetachainHeader() { } } - sortedHeaders, _ := bp.blockTracker.SortHeadersFromNonce(core.MetachainShardId, header.GetNonce()+1) + sortedHeaders, sortedHeadersHashes := bp.blockTracker.SortHeadersFromNonce(core.MetachainShardId, header.GetNonce()+1) if len(sortedHeaders) == 0 { return } @@ -162,7 +173,7 @@ func (bp *blockProcessor) doJobOnReceivedMetachainHeader() { finalMetachainHeaders := make([]data.HeaderHandler, 0) finalMetachainHeadersHashes := make([][]byte, 0) - err = bp.checkHeaderFinality(header, sortedHeaders, 0) + err = bp.checkHeaderFinality(header, sortedHeaders, sortedHeadersHashes, 0) if err == nil { finalMetachainHeaders = append(finalMetachainHeaders, header) finalMetachainHeadersHashes = append(finalMetachainHeadersHashes, headerHash) @@ -234,14 +245,15 @@ func (bp *blockProcessor) ComputeLongestChain(shardID uint32, header data.Header go bp.requestHeadersIfNeeded(header, sortedHeaders, headers) }() - sortedHeaders, sortedHeadersHashes = bp.blockTracker.SortHeadersFromNonce(shardID, header.GetNonce()+1) + startingNonce := header.GetNonce() + 1 + sortedHeaders, sortedHeadersHashes = bp.blockTracker.SortHeadersFromNonce(shardID, startingNonce) if len(sortedHeaders) == 0 { return headers, headersHashes } longestChainHeadersIndexes := make([]int, 0) headersIndexes := make([]int, 0) - bp.getNextHeader(&longestChainHeadersIndexes, headersIndexes, header, sortedHeaders, 0) + bp.getNextHeader(&longestChainHeadersIndexes, headersIndexes, header, sortedHeaders, sortedHeadersHashes, 0) for _, index := range longestChainHeadersIndexes { headers = append(headers, sortedHeaders[index]) @@ -256,6 +268,7 @@ func (bp *blockProcessor) getNextHeader( headersIndexes []int, prevHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, + sortedHeadersHashes [][]byte, index int, ) { defer func() { @@ -279,13 +292,13 @@ func (bp *blockProcessor) getNextHeader( continue } - err = bp.checkHeaderFinality(currHeader, sortedHeaders, i+1) + err = bp.checkHeaderFinality(currHeader, sortedHeaders, sortedHeadersHashes, i+1) if err != nil { continue } headersIndexes = append(headersIndexes, i) - bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, currHeader, sortedHeaders, i+1) + bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, currHeader, sortedHeaders, sortedHeadersHashes, i+1) headersIndexes = headersIndexes[:len(headersIndexes)-1] } } @@ -293,6 +306,7 @@ func (bp *blockProcessor) getNextHeader( func (bp *blockProcessor) checkHeaderFinality( header data.HeaderHandler, sortedHeaders []data.HeaderHandler, + sortedHeadersHashes [][]byte, index int, ) error { @@ -300,6 +314,14 @@ func (bp *blockProcessor) checkHeaderFinality( return process.ErrNilBlockHeader } + if common.IsFlagEnabledAfterEpochsStartBlock(header, bp.enableEpochsHandler, common.EquivalentMessagesFlag) { + if bp.proofsPool.HasProof(header.GetShardID(), sortedHeadersHashes[index]) { + return nil + } + + return process.ErrHeaderNotFinal + } + prevHeader := header numFinalityAttestingHeaders := uint64(0) @@ -484,6 +506,18 @@ func checkBlockProcessorNilParameters(arguments ArgBlockProcessor) error { if check.IfNil(arguments.RoundHandler) { return ErrNilRoundHandler } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arguments.ProofsPool) { + return ErrNilProofsPool + } + if check.IfNil(arguments.Marshaller) { + return process.ErrNilMarshalizer + } + if check.IfNilReflect(arguments.Hasher) { + return process.ErrNilHasher + } return nil } diff --git a/process/track/blockProcessor_test.go b/process/track/blockProcessor_test.go index ad30bd35e06..05d6275047f 100644 --- a/process/track/blockProcessor_test.go +++ b/process/track/blockProcessor_test.go @@ -8,6 +8,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-core-go/data" @@ -56,7 +58,11 @@ func CreateBlockProcessorMockArguments() track.ArgBlockProcessor { return 1 }, }, - RoundHandler: &mock.RoundHandlerMock{}, + RoundHandler: &mock.RoundHandlerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + Marshaller: &testscommon.MarshallerStub{}, + Hasher: &hashingMocks.HasherMock{}, } return arguments @@ -172,6 +178,50 @@ func TestNewBlockProcessor_ShouldErrFinalMetachainHeadersNotifier(t *testing.T) assert.Nil(t, bp) } +func TestNewBlockProcessor_ShouldErrNilEnableEpochsHandler(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.EnableEpochsHandler = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilProofsPool(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.ProofsPool = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, track.ErrNilProofsPool, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilMarshaller(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.Marshaller = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilMarshalizer, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilHasher(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.Hasher = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilHasher, err) + assert.Nil(t, bp) +} + func TestNewBlockProcessor_ShouldErrNilRoundHandler(t *testing.T) { t.Parallel() @@ -553,7 +603,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenPrevHeaderIsNil(t *testing.T) { longestChainHeadersIndexes := make([]int, 0) headersIndexes := make([]int, 0) sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, nil, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, nil, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -568,7 +618,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenSortedHeadersHaveHigherNonces(t headersIndexes := make([]int, 0) prevHeader := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 2}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -583,7 +633,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenHeaderConstructionIsNotValid(t headersIndexes := make([]int, 0) prevHeader := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -614,7 +664,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenHeaderFinalityIsNotChecked(t *t } sortedHeaders := []data.HeaderHandler{header2} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -653,7 +703,7 @@ func TestGetNextHeader_ShouldWork(t *testing.T) { } sortedHeaders := []data.HeaderHandler{header2, header3} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, [][]byte{}, 0) require.Equal(t, 1, len(longestChainHeadersIndexes)) assert.Equal(t, 0, longestChainHeadersIndexes[0]) @@ -666,7 +716,7 @@ func TestCheckHeaderFinality_ShouldErrNilBlockHeader(t *testing.T) { bp, _ := track.NewBlockProcessor(blockProcessorArguments) sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - err := bp.CheckHeaderFinality(nil, sortedHeaders, 0) + err := bp.CheckHeaderFinality(nil, sortedHeaders, [][]byte{}, 0) assert.Equal(t, process.ErrNilBlockHeader, err) } @@ -679,7 +729,7 @@ func TestCheckHeaderFinality_ShouldErrHeaderNotFinal(t *testing.T) { header := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - err := bp.CheckHeaderFinality(header, sortedHeaders, 0) + err := bp.CheckHeaderFinality(header, sortedHeaders, [][]byte{}, 0) assert.Equal(t, process.ErrHeaderNotFinal, err) } @@ -707,7 +757,7 @@ func TestCheckHeaderFinality_ShouldWork(t *testing.T) { } sortedHeaders := []data.HeaderHandler{header2} - err := bp.CheckHeaderFinality(header1, sortedHeaders, 0) + err := bp.CheckHeaderFinality(header1, sortedHeaders, [][]byte{}, 0) assert.Nil(t, err) } diff --git a/process/track/errors.go b/process/track/errors.go index 2c9a3a5c297..220863ce86e 100644 --- a/process/track/errors.go +++ b/process/track/errors.go @@ -33,3 +33,6 @@ var ErrNilRoundHandler = errors.New("nil roundHandler") // ErrNilKeysHandler signals that a nil keys handler was provided var ErrNilKeysHandler = errors.New("nil keys handler") + +// ErrNilProofsPool signals that a nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") diff --git a/process/track/export_test.go b/process/track/export_test.go index 8a2752afb2c..8cbcccb2919 100644 --- a/process/track/export_test.go +++ b/process/track/export_test.go @@ -11,70 +11,86 @@ import ( // shardBlockTrack +// SetNumPendingMiniBlocks - func (sbt *shardBlockTrack) SetNumPendingMiniBlocks(shardID uint32, numPendingMiniBlocks uint32) { sbt.blockBalancer.SetNumPendingMiniBlocks(shardID, numPendingMiniBlocks) } +// GetNumPendingMiniBlocks - func (sbt *shardBlockTrack) GetNumPendingMiniBlocks(shardID uint32) uint32 { return sbt.blockBalancer.GetNumPendingMiniBlocks(shardID) } +// SetLastShardProcessedMetaNonce - func (sbt *shardBlockTrack) SetLastShardProcessedMetaNonce(shardID uint32, nonce uint64) { sbt.blockBalancer.SetLastShardProcessedMetaNonce(shardID, nonce) } +// GetLastShardProcessedMetaNonce - func (sbt *shardBlockTrack) GetLastShardProcessedMetaNonce(shardID uint32) uint64 { return sbt.blockBalancer.GetLastShardProcessedMetaNonce(shardID) } +// GetTrackedShardHeaderWithNonceAndHash - func (sbt *shardBlockTrack) GetTrackedShardHeaderWithNonceAndHash(shardID uint32, nonce uint64, hash []byte) (data.HeaderHandler, error) { return sbt.getTrackedShardHeaderWithNonceAndHash(shardID, nonce, hash) } // metaBlockTrack +// GetTrackedMetaBlockWithHash - func (mbt *metaBlockTrack) GetTrackedMetaBlockWithHash(hash []byte) (*block.MetaBlock, error) { return mbt.getTrackedMetaBlockWithHash(hash) } // baseBlockTrack +// ReceivedHeader - func (bbt *baseBlockTrack) ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) { bbt.receivedHeader(headerHandler, headerHash) } +// CheckTrackerNilParameters - func CheckTrackerNilParameters(arguments ArgBaseTracker) error { return checkTrackerNilParameters(arguments) } +// InitNotarizedHeaders - func (bbt *baseBlockTrack) InitNotarizedHeaders(startHeaders map[uint32]data.HeaderHandler) error { return bbt.initNotarizedHeaders(startHeaders) } +// ReceivedShardHeader - func (bbt *baseBlockTrack) ReceivedShardHeader(headerHandler data.HeaderHandler, shardHeaderHash []byte) { bbt.receivedShardHeader(headerHandler, shardHeaderHash) } +// ReceivedMetaBlock - func (bbt *baseBlockTrack) ReceivedMetaBlock(headerHandler data.HeaderHandler, metaBlockHash []byte) { bbt.receivedMetaBlock(headerHandler, metaBlockHash) } +// GetMaxNumHeadersToKeepPerShard - func (bbt *baseBlockTrack) GetMaxNumHeadersToKeepPerShard() int { return bbt.maxNumHeadersToKeepPerShard } +// ShouldAddHeaderForCrossShard - func (bbt *baseBlockTrack) ShouldAddHeaderForCrossShard(headerHandler data.HeaderHandler) bool { return bbt.shouldAddHeaderForShard(headerHandler, bbt.crossNotarizer, headerHandler.GetShardID()) } +// ShouldAddHeaderForSelfShard - func (bbt *baseBlockTrack) ShouldAddHeaderForSelfShard(headerHandler data.HeaderHandler) bool { return bbt.shouldAddHeaderForShard(headerHandler, bbt.selfNotarizer, core.MetachainShardId) } +// AddHeader - func (bbt *baseBlockTrack) AddHeader(header data.HeaderHandler, hash []byte) bool { return bbt.addHeader(header, hash) } +// AppendTrackedHeader - func (bbt *baseBlockTrack) AppendTrackedHeader(headerHandler data.HeaderHandler) { bbt.mutHeaders.Lock() headersForShard, ok := bbt.headers[headerHandler.GetShardID()] @@ -87,48 +103,59 @@ func (bbt *baseBlockTrack) AppendTrackedHeader(headerHandler data.HeaderHandler) bbt.mutHeaders.Unlock() } +// CleanupTrackedHeadersBehindNonce - func (bbt *baseBlockTrack) CleanupTrackedHeadersBehindNonce(shardID uint32, nonce uint64) { bbt.cleanupTrackedHeadersBehindNonce(shardID, nonce) } +// DisplayTrackedHeadersForShard - func (bbt *baseBlockTrack) DisplayTrackedHeadersForShard(shardID uint32, message string) { bbt.displayTrackedHeadersForShard(shardID, message) } +// SetRoundHandler - func (bbt *baseBlockTrack) SetRoundHandler(roundHandler process.RoundHandler) { bbt.roundHandler = roundHandler } +// SetCrossNotarizer - func (bbt *baseBlockTrack) SetCrossNotarizer(notarizer blockNotarizerHandler) { bbt.crossNotarizer = notarizer } +// SetSelfNotarizer - func (bbt *baseBlockTrack) SetSelfNotarizer(notarizer blockNotarizerHandler) { bbt.selfNotarizer = notarizer } +// SetShardCoordinator - func (bbt *baseBlockTrack) SetShardCoordinator(coordinator sharding.Coordinator) { bbt.shardCoordinator = coordinator } +// NewBaseBlockTrack - func NewBaseBlockTrack() *baseBlockTrack { return &baseBlockTrack{} } +// DoWhitelistWithMetaBlockIfNeeded - func (bbt *baseBlockTrack) DoWhitelistWithMetaBlockIfNeeded(metaBlock *block.MetaBlock) { bbt.doWhitelistWithMetaBlockIfNeeded(metaBlock) } +// DoWhitelistWithShardHeaderIfNeeded - func (bbt *baseBlockTrack) DoWhitelistWithShardHeaderIfNeeded(shardHeader *block.Header) { bbt.doWhitelistWithShardHeaderIfNeeded(shardHeader) } +// IsHeaderOutOfRange - func (bbt *baseBlockTrack) IsHeaderOutOfRange(headerHandler data.HeaderHandler) bool { return bbt.isHeaderOutOfRange(headerHandler) } // blockNotifier +// GetNotarizedHeadersHandlers - func (bn *blockNotifier) GetNotarizedHeadersHandlers() []func(shardID uint32, headers []data.HeaderHandler, headersHashes [][]byte) { bn.mutNotarizedHeadersHandlers.RLock() notarizedHeadersHandlers := bn.notarizedHeadersHandlers @@ -139,12 +166,14 @@ func (bn *blockNotifier) GetNotarizedHeadersHandlers() []func(shardID uint32, he // blockNotarizer +// AppendNotarizedHeader - func (bn *blockNotarizer) AppendNotarizedHeader(headerHandler data.HeaderHandler) { bn.mutNotarizedHeaders.Lock() bn.notarizedHeaders[headerHandler.GetShardID()] = append(bn.notarizedHeaders[headerHandler.GetShardID()], &HeaderInfo{Header: headerHandler}) bn.mutNotarizedHeaders.Unlock() } +// GetNotarizedHeaders - func (bn *blockNotarizer) GetNotarizedHeaders() map[uint32][]*HeaderInfo { bn.mutNotarizedHeaders.RLock() notarizedHeaders := bn.notarizedHeaders @@ -153,6 +182,7 @@ func (bn *blockNotarizer) GetNotarizedHeaders() map[uint32][]*HeaderInfo { return notarizedHeaders } +// GetNotarizedHeaderWithIndex - func (bn *blockNotarizer) GetNotarizedHeaderWithIndex(shardID uint32, index int) data.HeaderHandler { bn.mutNotarizedHeaders.RLock() notarizedHeader := bn.notarizedHeaders[shardID][index].Header @@ -161,70 +191,98 @@ func (bn *blockNotarizer) GetNotarizedHeaderWithIndex(shardID uint32, index int) return notarizedHeader } +// LastNotarizedHeaderInfo - func (bn *blockNotarizer) LastNotarizedHeaderInfo(shardID uint32) *HeaderInfo { return bn.lastNotarizedHeaderInfo(shardID) } // blockProcessor +// DoJobOnReceivedHeader - func (bp *blockProcessor) DoJobOnReceivedHeader(shardID uint32) { bp.doJobOnReceivedHeader(shardID) } +// DoJobOnReceivedCrossNotarizedHeader - func (bp *blockProcessor) DoJobOnReceivedCrossNotarizedHeader(shardID uint32) { bp.doJobOnReceivedCrossNotarizedHeader(shardID) } +// ComputeLongestChainFromLastCrossNotarized - func (bp *blockProcessor) ComputeLongestChainFromLastCrossNotarized(shardID uint32) (data.HeaderHandler, []byte, []data.HeaderHandler, [][]byte) { return bp.computeLongestChainFromLastCrossNotarized(shardID) } +// ComputeSelfNotarizedHeaders - func (bp *blockProcessor) ComputeSelfNotarizedHeaders(headers []data.HeaderHandler) ([]data.HeaderHandler, [][]byte) { return bp.computeSelfNotarizedHeaders(headers) } -func (bp *blockProcessor) GetNextHeader(longestChainHeadersIndexes *[]int, headersIndexes []int, prevHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, index int) { - bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, index) +// GetNextHeader - +func (bp *blockProcessor) GetNextHeader( + longestChainHeadersIndexes *[]int, + headersIndexes []int, + prevHeader data.HeaderHandler, + sortedHeaders []data.HeaderHandler, + sortedHashes [][]byte, + index int, +) { + bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, sortedHashes, index) } -func (bp *blockProcessor) CheckHeaderFinality(header data.HeaderHandler, sortedHeaders []data.HeaderHandler, index int) error { - return bp.checkHeaderFinality(header, sortedHeaders, index) +// CheckHeaderFinality - +func (bp *blockProcessor) CheckHeaderFinality( + header data.HeaderHandler, + sortedHeaders []data.HeaderHandler, + sortedHashes [][]byte, + index int, +) error { + return bp.checkHeaderFinality(header, sortedHeaders, sortedHashes, index) } +// RequestHeadersIfNeeded - func (bp *blockProcessor) RequestHeadersIfNeeded(lastNotarizedHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, longestChainHeaders []data.HeaderHandler) { bp.requestHeadersIfNeeded(lastNotarizedHeader, sortedHeaders, longestChainHeaders) } +// GetLatestValidHeader - func (bp *blockProcessor) GetLatestValidHeader(lastNotarizedHeader data.HeaderHandler, longestChainHeaders []data.HeaderHandler) data.HeaderHandler { return bp.getLatestValidHeader(lastNotarizedHeader, longestChainHeaders) } +// GetHighestRoundInReceivedHeaders - func (bp *blockProcessor) GetHighestRoundInReceivedHeaders(latestValidHeader data.HeaderHandler, sortedReceivedHeaders []data.HeaderHandler) uint64 { return bp.getHighestRoundInReceivedHeaders(latestValidHeader, sortedReceivedHeaders) } +// RequestHeadersIfNothingNewIsReceived - func (bp *blockProcessor) RequestHeadersIfNothingNewIsReceived(lastNotarizedHeaderNonce uint64, latestValidHeader data.HeaderHandler, highestRoundInReceivedHeaders uint64) { bp.requestHeadersIfNothingNewIsReceived(lastNotarizedHeaderNonce, latestValidHeader, highestRoundInReceivedHeaders) } +// RequestHeaders - func (bp *blockProcessor) RequestHeaders(shardID uint32, fromNonce uint64) { bp.requestHeaders(shardID, fromNonce) } +// ShouldProcessReceivedHeader - func (bp *blockProcessor) ShouldProcessReceivedHeader(headerHandler data.HeaderHandler) bool { return bp.shouldProcessReceivedHeader(headerHandler) } // miniBlockTrack +// ReceivedMiniBlock - func (mbt *miniBlockTrack) ReceivedMiniBlock(key []byte, value interface{}) { mbt.receivedMiniBlock(key, value) } +// GetTransactionPool - func (mbt *miniBlockTrack) GetTransactionPool(mbType block.Type) dataRetriever.ShardedDataCacherNotifier { return mbt.getTransactionPool(mbType) } +// SetBlockTransactionsPool - func (mbt *miniBlockTrack) SetBlockTransactionsPool(blockTransactionsPool dataRetriever.ShardedDataCacherNotifier) { mbt.blockTransactionsPool = blockTransactionsPool } diff --git a/process/track/metaBlockTrack.go b/process/track/metaBlockTrack.go index 26e13d58e1c..392c85eaeaf 100644 --- a/process/track/metaBlockTrack.go +++ b/process/track/metaBlockTrack.go @@ -46,6 +46,10 @@ func NewMetaBlockTrack(arguments ArgMetaTracker) (*metaBlockTrack, error) { SelfNotarizedHeadersNotifier: bbt.selfNotarizedHeadersNotifier, FinalMetachainHeadersNotifier: bbt.finalMetachainHeadersNotifier, RoundHandler: arguments.RoundHandler, + EnableEpochsHandler: arguments.EnableEpochsHandler, + ProofsPool: arguments.ProofsPool, + Marshaller: arguments.Marshalizer, + Hasher: arguments.Hasher, } blockProcessorObject, err := NewBlockProcessor(argBlockProcessor) diff --git a/process/track/miniBlockTrack_test.go b/process/track/miniBlockTrack_test.go index 123c3813052..6a72d7ad9d0 100644 --- a/process/track/miniBlockTrack_test.go +++ b/process/track/miniBlockTrack_test.go @@ -4,14 +4,16 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/track" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" - "github.com/stretchr/testify/assert" ) func TestNewMiniBlockTrack_NilDataPoolHolderErr(t *testing.T) { @@ -256,7 +258,7 @@ func TestGetTransactionPool_ShouldWork(t *testing.T) { return unsignedTransactionsPool }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } mbt, _ := track.NewMiniBlockTrack(dataPool, mock.NewMultipleShardsCoordinatorMock(), &testscommon.WhiteListHandlerStub{}) @@ -286,7 +288,7 @@ func createDataPool() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } } diff --git a/process/track/shardBlockTrack.go b/process/track/shardBlockTrack.go index 327282725bc..72b918713d2 100644 --- a/process/track/shardBlockTrack.go +++ b/process/track/shardBlockTrack.go @@ -46,6 +46,10 @@ func NewShardBlockTrack(arguments ArgShardTracker) (*shardBlockTrack, error) { SelfNotarizedHeadersNotifier: bbt.selfNotarizedHeadersNotifier, FinalMetachainHeadersNotifier: bbt.finalMetachainHeadersNotifier, RoundHandler: arguments.RoundHandler, + EnableEpochsHandler: arguments.EnableEpochsHandler, + ProofsPool: arguments.ProofsPool, + Marshaller: arguments.Marshalizer, + Hasher: arguments.Hasher, } blockProcessorObject, err := NewBlockProcessor(argBlockProcessor) diff --git a/process/transaction/interceptedTransaction_test.go b/process/transaction/interceptedTransaction_test.go index b2aa2e81526..1312f5cba4f 100644 --- a/process/transaction/interceptedTransaction_test.go +++ b/process/transaction/interceptedTransaction_test.go @@ -14,18 +14,20 @@ import ( "github.com/multiversx/mx-chain-core-go/data" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/smartContract" "github.com/multiversx/mx-chain-go/process/transaction" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var errSingleSignKeyGenMock = errors.New("errSingleSignKeyGenMock") @@ -1309,7 +1311,7 @@ func TestInterceptedTransaction_CheckValiditySecondTimeDoesNotVerifySig(t *testi return shardCoordinator.CurrentShard } - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() whiteListerVerifiedTxs, err := interceptors.NewWhiteListDataVerifier(cache) require.Nil(t, err) @@ -1510,7 +1512,7 @@ func TestRelayTransaction_NotAddedToWhitelistUntilIntegrityChecked(t *testing.T) t.Parallel() marshalizer := &mock.MarshalizerMock{} - whiteListHandler, _ := interceptors.NewWhiteListDataVerifier(testscommon.NewCacherMock()) + whiteListHandler, _ := interceptors.NewWhiteListDataVerifier(cache.NewCacherMock()) userTx := &dataTransaction.Transaction{ SndAddr: recvAddress, diff --git a/process/unsigned/interceptedUnsignedTransaction_test.go b/process/unsigned/interceptedUnsignedTransaction_test.go index b0c00e4982e..102b76c0975 100644 --- a/process/unsigned/interceptedUnsignedTransaction_test.go +++ b/process/unsigned/interceptedUnsignedTransaction_test.go @@ -11,13 +11,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/unsigned" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) var senderShard = uint32(2) @@ -170,7 +171,7 @@ func TestNewInterceptedUnsignedTransaction_ShouldWork(t *testing.T) { assert.Nil(t, err) } -// ------- CheckValidity +// ------- Verify func TestInterceptedUnsignedTransaction_CheckValidityNilTxHashShouldErr(t *testing.T) { t.Parallel() diff --git a/sharding/chainParametersHolder_test.go b/sharding/chainParametersHolder_test.go index f2a9b33e64a..7ec5876cc7d 100644 --- a/sharding/chainParametersHolder_test.go +++ b/sharding/chainParametersHolder_test.go @@ -7,9 +7,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/testscommon/commonmocks" - "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + mock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + "github.com/stretchr/testify/require" ) @@ -18,7 +20,7 @@ func TestNewChainParametersHolder(t *testing.T) { getDummyArgs := func() ArgsChainParametersHolder { return ArgsChainParametersHolder{ - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParameters: []config.ChainParametersByEpochConfig{ { EnableEpoch: 0, @@ -177,7 +179,7 @@ func TestChainParametersHolder_EpochStartActionShouldCallTheNotifier(t *testing. MetachainMinNumNodes: 7, }, }, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: notifier, }) @@ -203,7 +205,7 @@ func TestChainParametersHolder_ChainParametersForEpoch(t *testing.T) { paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ ChainParameters: params, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, }) @@ -249,7 +251,7 @@ func TestChainParametersHolder_ChainParametersForEpoch(t *testing.T) { paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ ChainParameters: params, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, }) @@ -291,7 +293,7 @@ func TestChainParametersHolder_CurrentChainParameters(t *testing.T) { paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ ChainParameters: params, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, }) @@ -330,7 +332,7 @@ func TestChainParametersHolder_AllChainParameters(t *testing.T) { paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ ChainParameters: params, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, }) @@ -356,7 +358,7 @@ func TestChainParametersHolder_ConcurrentOperations(t *testing.T) { paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ ChainParameters: chainParams, - EpochStartEventNotifier: &epochstartmock.EpochStartNotifierStub{}, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, }) diff --git a/sharding/mock/enableEpochsHandlerMock.go b/sharding/mock/enableEpochsHandlerMock.go index 32c6b4fa14c..48dfcedfa52 100644 --- a/sharding/mock/enableEpochsHandlerMock.go +++ b/sharding/mock/enableEpochsHandlerMock.go @@ -43,11 +43,6 @@ func (mock *EnableEpochsHandlerMock) GetCurrentEpoch() uint32 { return mock.CurrentEpoch } -// FixGasRemainingForSaveKeyValueBuiltinFunctionEnabled - -func (mock *EnableEpochsHandlerMock) FixGasRemainingForSaveKeyValueBuiltinFunctionEnabled() bool { - return false -} - // IsInterfaceNil returns true if there is no value under the interface func (mock *EnableEpochsHandlerMock) IsInterfaceNil() bool { return mock == nil diff --git a/sharding/networksharding/peerShardMapper_test.go b/sharding/networksharding/peerShardMapper_test.go index fef620ed90d..6b03abe6805 100644 --- a/sharding/networksharding/peerShardMapper_test.go +++ b/sharding/networksharding/peerShardMapper_test.go @@ -9,23 +9,24 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/sharding/networksharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) // ------- NewPeerShardMapper func createMockArgumentForPeerShardMapper() networksharding.ArgPeerShardMapper { return networksharding.ArgPeerShardMapper{ - PeerIdPkCache: testscommon.NewCacherMock(), - FallbackPkShardCache: testscommon.NewCacherMock(), - FallbackPidShardCache: testscommon.NewCacherMock(), + PeerIdPkCache: cache.NewCacherMock(), + FallbackPkShardCache: cache.NewCacherMock(), + FallbackPidShardCache: cache.NewCacherMock(), NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, } diff --git a/sharding/nodesCoordinator/errors.go b/sharding/nodesCoordinator/errors.go index 19c1bda084f..901559116ab 100644 --- a/sharding/nodesCoordinator/errors.go +++ b/sharding/nodesCoordinator/errors.go @@ -123,3 +123,6 @@ var ErrNilEpochNotifier = errors.New("nil epoch notifier provided") // ErrNilChainParametersHandler signals that a nil chain parameters handler has been provided var ErrNilChainParametersHandler = errors.New("nil chain parameters handler") + +// ErrEmptyValidatorsList signals that the validators list is empty +var ErrEmptyValidatorsList = errors.New("empty validators list") diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinator.go b/sharding/nodesCoordinator/indexHashedNodesCoordinator.go index ca95be3a522..f0029a8fd9b 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinator.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinator.go @@ -30,6 +30,7 @@ const ( keyFormat = "%s_%v_%v_%v" defaultSelectionChances = uint32(1) minEpochsToWait = uint32(1) + leaderSelectionSize = 1 ) // TODO: move this to config parameters @@ -40,6 +41,12 @@ type validatorWithShardID struct { shardID uint32 } +// savedConsensusGroup holds the leader and consensus group for a specific selection +type savedConsensusGroup struct { + leader Validator + consensusGroup []Validator +} + type validatorList []Validator // Len will return the length of the validatorList @@ -346,7 +353,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( round uint64, shardID uint32, epoch uint32, -) (validatorsGroup []Validator, err error) { +) (leader Validator, validatorsGroup []Validator, err error) { var selector RandomSelector var eligibleList []Validator @@ -357,7 +364,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( "round", round) if len(randomness) == 0 { - return nil, ErrNilRandomness + return nil, nil, ErrNilRandomness } ihnc.mutNodesConfig.RLock() @@ -366,7 +373,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( if shardID >= nodesConfig.nbShards && shardID != core.MetachainShardId { log.Warn("shardID is not ok", "shardID", shardID, "nbShards", nodesConfig.nbShards) ihnc.mutNodesConfig.RUnlock() - return nil, ErrInvalidShardId + return nil, nil, ErrInvalidShardId } selector = nodesConfig.selectors[shardID] eligibleList = nodesConfig.eligibleMap[shardID] @@ -374,13 +381,13 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( ihnc.mutNodesConfig.RUnlock() if !ok { - return nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) + return nil, nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) } key := []byte(fmt.Sprintf(keyFormat, string(randomness), round, shardID, epoch)) - validators := ihnc.searchConsensusForKey(key) - if validators != nil { - return validators, nil + savedCG := ihnc.searchConsensusForKey(key) + if savedCG != nil { + return savedCG.leader, savedCG.consensusGroup, nil } consensusSize := ihnc.ConsensusGroupSizeForShardAndEpoch(shardID, epoch) @@ -394,27 +401,59 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( "round", round, "shardID", shardID) - tempList, err := selectValidators(selector, randomness, uint32(consensusSize), eligibleList) + leader, validatorsGroup, err = ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) if err != nil { - return nil, err + return nil, nil, err } - size := 0 - for _, v := range tempList { - size += v.Size() + ihnc.cacheConsensusGroup(key, validatorsGroup, leader) + + return leader, validatorsGroup, nil +} + +func (ihnc *indexHashedNodesCoordinator) cacheConsensusGroup(key []byte, consensusGroup []Validator, leader Validator) { + size := leader.Size() * len(consensusGroup) + savedCG := &savedConsensusGroup{ + leader: leader, + consensusGroup: consensusGroup, } + ihnc.consensusGroupCacher.Put(key, savedCG, size) +} - ihnc.consensusGroupCacher.Put(key, tempList, size) +func (ihnc *indexHashedNodesCoordinator) selectLeaderAndConsensusGroup( + selector RandomSelector, + randomness []byte, + eligibleList []Validator, + consensusSize int, + epoch uint32, +) (Validator, []Validator, error) { + leaderPositionInSelection := 0 + if !ihnc.enableEpochsHandler.IsFlagEnabledInEpoch(common.FixedOrderInConsensusFlag, epoch) { + tempList, err := selectValidators(selector, randomness, uint32(consensusSize), eligibleList) + if err != nil { + return nil, nil, err + } + + if len(tempList) == 0 { + return nil, nil, ErrEmptyValidatorsList + } + + return tempList[leaderPositionInSelection], tempList, nil + } - return tempList, nil + selectedValidators, err := selectValidators(selector, randomness, leaderSelectionSize, eligibleList) + if err != nil { + return nil, nil, err + } + return selectedValidators[leaderPositionInSelection], eligibleList, nil } -func (ihnc *indexHashedNodesCoordinator) searchConsensusForKey(key []byte) []Validator { +func (ihnc *indexHashedNodesCoordinator) searchConsensusForKey(key []byte) *savedConsensusGroup { value, ok := ihnc.consensusGroupCacher.Get(key) if ok { - consensusGroup, typeOk := value.([]Validator) + savedCG, typeOk := value.(*savedConsensusGroup) if typeOk { - return consensusGroup + return savedCG } } return nil @@ -442,10 +481,10 @@ func (ihnc *indexHashedNodesCoordinator) GetConsensusValidatorsPublicKeys( round uint64, shardID uint32, epoch uint32, -) ([]string, error) { - consensusNodes, err := ihnc.ComputeConsensusGroup(randomness, round, shardID, epoch) +) (string, []string, error) { + leader, consensusNodes, err := ihnc.ComputeConsensusGroup(randomness, round, shardID, epoch) if err != nil { - return nil, err + return "", nil, err } pubKeys := make([]string, 0) @@ -454,7 +493,29 @@ func (ihnc *indexHashedNodesCoordinator) GetConsensusValidatorsPublicKeys( pubKeys = append(pubKeys, string(v.PubKey())) } - return pubKeys, nil + return string(leader.PubKey()), pubKeys, nil +} + +// GetAllEligibleValidatorsPublicKeysForShard will return all validators public keys for the provided shard +func (ihnc *indexHashedNodesCoordinator) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + ihnc.mutNodesConfig.RLock() + nodesConfig, ok := ihnc.nodesConfig[epoch] + ihnc.mutNodesConfig.RUnlock() + + if !ok { + return nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) + } + + nodesConfig.mutNodesMaps.RLock() + defer nodesConfig.mutNodesMaps.RUnlock() + + shardEligible := nodesConfig.eligibleMap[shardID] + validatorsPubKeys := make([]string, 0, len(shardEligible)) + for i := 0; i < len(shardEligible); i++ { + validatorsPubKeys = append(validatorsPubKeys, string(shardEligible[i].PubKey())) + } + + return validatorsPubKeys, nil } // GetAllEligibleValidatorsPublicKeys will return all validators public keys for all shards @@ -1262,7 +1323,7 @@ func computeActuallyLeaving( func selectValidators( selector RandomSelector, randomness []byte, - consensusSize uint32, + selectionSize uint32, eligibleList []Validator, ) ([]Validator, error) { if check.IfNil(selector) { @@ -1273,19 +1334,19 @@ func selectValidators( } // todo: checks for indexes - selectedIndexes, err := selector.Select(randomness, consensusSize) + selectedIndexes, err := selector.Select(randomness, selectionSize) if err != nil { return nil, err } - consensusGroup := make([]Validator, consensusSize) - for i := range consensusGroup { - consensusGroup[i] = eligibleList[selectedIndexes[i]] + selectedValidators := make([]Validator, selectionSize) + for i := range selectedValidators { + selectedValidators[i] = eligibleList[selectedIndexes[i]] } - displayValidatorsForRandomness(consensusGroup, randomness) + displayValidatorsForRandomness(selectedValidators, randomness) - return consensusGroup, nil + return selectedValidators, nil } // createValidatorInfoFromBody unmarshalls body data to create validator info diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go index 2a879d125d2..1154d93ae1a 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go @@ -21,7 +21,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" - "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" @@ -75,7 +75,7 @@ func TestIndexHashedGroupSelectorWithRater_OkValShouldWork(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -143,8 +143,9 @@ func TestIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup1ValidatorShoul assert.Equal(t, false, chancesCalled) ihnc, _ := NewIndexHashedNodesCoordinatorWithRater(nc, rater) assert.Equal(t, true, chancesCalled) - list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + assert.Equal(t, list[0], leader) assert.Nil(t, err) assert.Equal(t, 1, len(list2)) } @@ -176,7 +177,7 @@ func BenchmarkIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup63of400(b } nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -214,7 +215,7 @@ func BenchmarkIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup63of400(b for i := 0; i < b.N; i++ { randomness := strconv.Itoa(0) - list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), uint64(0), 0, 0) + _, list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), uint64(0), 0, 0) assert.Equal(b, consensusGroupSize, len(list2)) } @@ -255,7 +256,7 @@ func Test_ComputeValidatorsGroup63of400(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -288,8 +289,8 @@ func Test_ComputeValidatorsGroup63of400(t *testing.T) { hasher := sha256.NewSha256() for i := uint64(0); i < numRounds; i++ { randomness := hasher.Compute(fmt.Sprintf("%v%v", i, time.Millisecond)) - consensusGroup, _ := ihnc.ComputeConsensusGroup(randomness, uint64(0), 0, 0) - leaderAppearances[string(consensusGroup[0].PubKey())]++ + leader, consensusGroup, _ := ihnc.ComputeConsensusGroup(randomness, uint64(0), 0, 0) + leaderAppearances[string(leader.PubKey())]++ for _, v := range consensusGroup { consensusAppearances[string(v.PubKey())]++ } @@ -331,7 +332,7 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn nodeShuffler, err := NewHashValidatorsShuffler(sufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -389,7 +390,7 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -457,7 +458,7 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldWork(t nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap[core.MetachainShardId] = listMeta @@ -545,7 +546,7 @@ func TestIndexHashedGroupSelectorWithRater_GetAllEligibleValidatorsPublicKeys(t } nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap[core.MetachainShardId] = listMeta @@ -860,7 +861,7 @@ func BenchmarkIndexHashedWithRaterGroupSelector_ComputeValidatorsGroup21of400(b nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -899,7 +900,7 @@ func BenchmarkIndexHashedWithRaterGroupSelector_ComputeValidatorsGroup21of400(b for i := 0; i < b.N; i++ { randomness := strconv.Itoa(i) - list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) + _, list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) assert.Equal(b, consensusGroupSize, len(list2)) } diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go index 7f516e7cd6e..26a4340021d 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go @@ -32,10 +32,11 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" - "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks/nodesCoordinatorMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" ) @@ -106,7 +107,7 @@ func createArguments() ArgNodesCoordinator { } nodeShuffler, _ := NewHashValidatorsShuffler(shufflerArgs) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -297,7 +298,7 @@ func TestIndexHashedNodesCoordinator_OkValShouldWork(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -350,7 +351,7 @@ func TestIndexHashedNodesCoordinator_NewCoordinatorTooFewNodesShouldErr(t *testi nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -397,10 +398,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupNilRandomnessShouldEr arguments := createArguments() ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup(nil, 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup(nil, 0, 0, 0) require.Equal(t, ErrNilRandomness, err) require.Nil(t, list2) + require.Nil(t, leader) } func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupInvalidShardIdShouldErr(t *testing.T) { @@ -408,10 +410,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupInvalidShardIdShouldE arguments := createArguments() ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup([]byte("radomness"), 0, 5, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("radomness"), 0, 5, 0) require.Equal(t, ErrInvalidShardId, err) require.Nil(t, list2) + require.Nil(t, leader) } // ------- functionality tests @@ -434,7 +437,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup1ValidatorShouldRetur nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -471,10 +474,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup1ValidatorShouldRetur NodesCoordinatorRegistryFactory: createNodesCoordinatorRegistryFactory(), } ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) - require.Equal(t, list, list2) require.Nil(t, err) + require.Equal(t, list, list2) + require.Equal(t, list[0], leader) } func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoMemoization(t *testing.T) { @@ -490,7 +494,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() getCounter := int32(0) @@ -547,12 +551,14 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM miniBlocks := 10 var list2 []Validator + var leader Validator for i := 0; i < miniBlocks; i++ { for j := 0; j <= i; j++ { randomness := strconv.Itoa(j) - list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list2)) + require.NotNil(t, leader) } } @@ -575,7 +581,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() getCounter := 0 @@ -645,12 +651,14 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe miniBlocks := 10 var list2 []Validator + var leader Validator for i := 0; i < miniBlocks; i++ { for j := 0; j <= i; j++ { randomness := strconv.Itoa(j) - list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list2)) + require.NotNil(t, leader) } } @@ -684,7 +692,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -720,13 +728,15 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP repeatPerSampling := 100 list := make([][]Validator, repeatPerSampling) + var leader Validator for i := 0; i < nbDifferentSamplings; i++ { randomness := arguments.Hasher.Compute(strconv.Itoa(i)) fmt.Printf("starting selection with randomness: %s\n", hex.EncodeToString(randomness)) for j := 0; j < repeatPerSampling; j++ { - list[j], err = ihnc.ComputeConsensusGroup(randomness, 0, 0, 0) + leader, list[j], err = ihnc.ComputeConsensusGroup(randomness, 0, 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list[j])) + require.NotNil(t, leader) } for j := 1; j < repeatPerSampling; j++ { @@ -750,7 +760,7 @@ func BenchmarkIndexHashedGroupSelector_ComputeValidatorsGroup21of400(b *testing. nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -785,9 +795,10 @@ func BenchmarkIndexHashedGroupSelector_ComputeValidatorsGroup21of400(b *testing. for i := 0; i < b.N; i++ { randomness := strconv.Itoa(i) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } } @@ -826,7 +837,7 @@ func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap m nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -863,8 +874,9 @@ func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap m missedBlocks := 1000 for j := 0; j < missedBlocks; j++ { randomness := strconv.Itoa(j) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } } } @@ -879,7 +891,7 @@ func computeMemoryRequirements(consensusGroupCache Cacher, consensusGroupSize in nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -917,8 +929,9 @@ func computeMemoryRequirements(consensusGroupCache Cacher, consensusGroupSize in missedBlocks := 1000 for i := 0; i < missedBlocks; i++ { randomness := strconv.Itoa(i) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(i), 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(i), 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } m2 := runtime.MemStats{} @@ -1022,7 +1035,7 @@ func TestIndexHashedNodesCoordinator_GetValidatorWithPublicKeyShouldWork(t *test nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -1109,7 +1122,7 @@ func TestIndexHashedGroupSelector_GetAllEligibleValidatorsPublicKeys(t *testing. nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -1188,7 +1201,7 @@ func TestIndexHashedGroupSelector_GetAllWaitingValidatorsPublicKeys(t *testing.T nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -1586,7 +1599,7 @@ func TestIndexHashedNodesCoordinator_EpochStart_EligibleSortedAscendingByIndex(t nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ @@ -1657,10 +1670,12 @@ func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysNotExisting require.Nil(t, err) var pKeys []string + var leader string randomness := []byte("randomness") - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 1) + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 1) require.True(t, errors.Is(err, ErrEpochNodesConfigDoesNotExist)) require.Nil(t, pKeys) + require.Empty(t, leader) } func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysExistingEpoch(t *testing.T) { @@ -1673,11 +1688,13 @@ func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysExistingEpo shard0PubKeys := validatorsPubKeys(args.EligibleNodes[0]) var pKeys []string + var leader string randomness := []byte("randomness") - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) require.True(t, len(pKeys) > 0) require.True(t, isStringSubgroup(pKeys, shard0PubKeys)) + require.NotEmpty(t, leader) } func TestIndexHashedNodesCoordinator_GetValidatorsIndexes(t *testing.T) { @@ -1689,13 +1706,15 @@ func TestIndexHashedNodesCoordinator_GetValidatorsIndexes(t *testing.T) { randomness := []byte("randomness") var pKeys []string - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + var leader string + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) var indexes []uint64 indexes, err = ihnc.GetValidatorsIndexes(pKeys, 0) require.Nil(t, err) require.Equal(t, len(pKeys), len(indexes)) + require.NotEmpty(t, leader) } func TestIndexHashedNodesCoordinator_GetValidatorsIndexesInvalidPubKey(t *testing.T) { @@ -1707,8 +1726,10 @@ func TestIndexHashedNodesCoordinator_GetValidatorsIndexesInvalidPubKey(t *testin randomness := []byte("randomness") var pKeys []string - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + var leader string + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) + require.NotEmpty(t, leader) var indexes []uint64 pKeys[0] = "dummy" @@ -1843,6 +1864,39 @@ func TestIndexHashedNodesCoordinator_GetConsensusWhitelistedNodesEpoch1(t *testi } } +func TestIndexHashedNodesCoordinator_GetAllEligibleValidatorsPublicKeysForShard(t *testing.T) { + t.Parallel() + + t.Run("missing nodes config should error", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.ValidatorInfoCacher = dataPool.NewCurrentEpochValidatorInfoPool() + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + validators, err := ihnc.GetAllEligibleValidatorsPublicKeysForShard(100, 0) + require.True(t, errors.Is(err, ErrEpochNodesConfigDoesNotExist)) + require.Nil(t, validators) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.ValidatorInfoCacher = dataPool.NewCurrentEpochValidatorInfoPool() + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + expectedValidators := make([]string, 0, len(arguments.EligibleNodes[0])) + for _, val := range arguments.EligibleNodes[0] { + expectedValidators = append(expectedValidators, string(val.PubKey())) + } + validators, err := ihnc.GetAllEligibleValidatorsPublicKeysForShard(0, 0) + require.NoError(t, err) + require.Equal(t, expectedValidators, validators) + }) +} + func TestIndexHashedNodesCoordinator_GetConsensusWhitelistedNodesAfterRevertToEpoch(t *testing.T) { t.Parallel() @@ -2604,7 +2658,7 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) t.Run("missing nodes config for current epoch should error ", func(t *testing.T) { t.Parallel() - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() shufflerArgs := &NodesShufflerArgs{ @@ -2674,7 +2728,7 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) waitingMap[core.MetachainShardId] = listMeta waitingMap[shardZeroId] = listShard0 - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -2761,7 +2815,7 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) waitingMap[core.MetachainShardId] = listMeta waitingMap[shardZeroId] = listShard0 - epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -2952,6 +3006,183 @@ func TestNodesCoordinator_CustomConsensusGroupSize(t *testing.T) { require.Equal(t, numEpochsToCheck, uint32(checksCounter)) } +func TestIndexHashedNodesCoordinator_cacheConsensusGroup(t *testing.T) { + t.Parallel() + + maxNumValuesCache := 3 + key := []byte("key") + + leader := &validator{ + pubKey: []byte("leader"), + chances: 10, + index: 20, + } + validator1 := &validator{ + pubKey: []byte("validator1"), + chances: 10, + index: 20, + } + + t.Run("adding a key should work", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + consensusGroup := []Validator{leader, validator1} + expectedData := &savedConsensusGroup{ + leader: leader, + consensusGroup: consensusGroup, + } + + nodesCoordinator.cacheConsensusGroup(key, consensusGroup, leader) + value := nodesCoordinator.searchConsensusForKey(key) + + require.NotNil(t, value) + require.Equal(t, expectedData, value) + }) + + t.Run("adding a key twice should overwrite the value", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + cg1 := []Validator{leader, validator1} + cg2 := []Validator{leader} + expectedData := &savedConsensusGroup{ + leader: leader, + consensusGroup: cg2, + } + + nodesCoordinator.cacheConsensusGroup(key, cg1, leader) + nodesCoordinator.cacheConsensusGroup(key, cg2, leader) + value := nodesCoordinator.searchConsensusForKey(key) + require.NotNil(t, value) + require.Equal(t, expectedData, value) + }) + + t.Run("adding more keys than the cache size should remove the oldest key", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + key1 := []byte("key1") + key2 := []byte("key2") + key3 := []byte("key3") + key4 := []byte("key4") + + cg1 := []Validator{leader, validator1} + cg2 := []Validator{leader} + cg3 := []Validator{validator1} + cg4 := []Validator{leader, validator1, validator1} + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + nodesCoordinator.cacheConsensusGroup(key1, cg1, leader) + nodesCoordinator.cacheConsensusGroup(key2, cg2, leader) + nodesCoordinator.cacheConsensusGroup(key3, cg3, leader) + nodesCoordinator.cacheConsensusGroup(key4, cg4, leader) + + value := nodesCoordinator.searchConsensusForKey(key1) + require.Nil(t, value) + + value = nodesCoordinator.searchConsensusForKey(key2) + require.Equal(t, cg2, value.consensusGroup) + + value = nodesCoordinator.searchConsensusForKey(key3) + require.Equal(t, cg3, value.consensusGroup) + + value = nodesCoordinator.searchConsensusForKey(key4) + require.Equal(t, cg4, value.consensusGroup) + }) +} + +func TestIndexHashedNodesCoordinator_selectLeaderAndConsensusGroup(t *testing.T) { + t.Parallel() + + validator1 := &validator{pubKey: []byte("validator1")} + validator2 := &validator{pubKey: []byte("validator2")} + validator3 := &validator{pubKey: []byte("validator3")} + validator4 := &validator{pubKey: []byte("validator4")} + + randomness := []byte("randomness") + epoch := uint32(1) + + eligibleList := []Validator{validator1, validator2, validator3, validator4} + consensusSize := len(eligibleList) + expectedError := errors.New("expected error") + selectFunc := func(randSeed []byte, sampleSize uint32) ([]uint32, error) { + if len(eligibleList) < int(sampleSize) { + return nil, expectedError + } + + result := make([]uint32, sampleSize) + for i := 0; i < int(sampleSize); i++ { + // reverse order from eligible list + result[i] = uint32(len(eligibleList) - 1 - i) + } + + return result, nil + } + expectedConsensusFixedOrder := []Validator{validator1, validator2, validator3, validator4} + expectedConsensusNotFixedOrder := []Validator{validator4, validator3, validator2, validator1} + expectedLeader := validator4 + + t.Run("with fixed ordering enabled, data not cached", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + selector := &nodesCoordinatorMocks.RandomSelectorMock{ + SelectCalled: selectFunc, + } + + leader, cg, err := ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) + require.Nil(t, err) + require.Equal(t, validator4, leader) + require.Equal(t, expectedLeader, leader) + require.Equal(t, expectedConsensusFixedOrder, cg) + }) + t.Run("with fixed ordering disabled, data not cached", func(t *testing.T) { + t.Parallel() + arguments := createArguments() + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + selector := &nodesCoordinatorMocks.RandomSelectorMock{ + SelectCalled: selectFunc, + } + + leader, cg, err := ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) + require.Nil(t, err) + require.Equal(t, expectedLeader, leader) + require.Equal(t, expectedConsensusNotFixedOrder, cg) + }) +} + type consensusSizeChangeTestArgs struct { t *testing.T ihnc *indexHashedNodesCoordinator diff --git a/sharding/nodesCoordinator/interface.go b/sharding/nodesCoordinator/interface.go index aa1d386fc02..d9e3e4c7999 100644 --- a/sharding/nodesCoordinator/interface.go +++ b/sharding/nodesCoordinator/interface.go @@ -3,10 +3,11 @@ package nodesCoordinator import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/state" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // Validator defines a node that can be allocated to a shard for participation in a consensus group as validator @@ -22,7 +23,7 @@ type Validator interface { type NodesCoordinator interface { NodesCoordinatorHelper PublicKeysSelector - ComputeConsensusGroup(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []Validator, err error) + ComputeConsensusGroup(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader Validator, lidatorsGroup []Validator, err error) GetValidatorWithPublicKey(publicKey []byte) (validator Validator, shardId uint32, err error) LoadState(key []byte) error GetSavedStateKey() []byte @@ -46,11 +47,12 @@ type EpochStartEventNotifier interface { type PublicKeysSelector interface { GetValidatorsIndexes(publicKeys []string, epoch uint32) ([]uint64, error) GetAllEligibleValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) GetAllWaitingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetAllLeavingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetAllShuffledOutValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetShuffledOutToAuctionValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) - GetConsensusValidatorsPublicKeys(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + GetConsensusValidatorsPublicKeys(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) GetOwnPublicKey() []byte } diff --git a/sharding/nodesSetup.go b/sharding/nodesSetup.go index 26e8bee3351..32f9b1dbc92 100644 --- a/sharding/nodesSetup.go +++ b/sharding/nodesSetup.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) diff --git a/state/syncer/baseAccoutnsSyncer_test.go b/state/syncer/baseAccoutnsSyncer_test.go index da3819b05ce..e2fcf5336f0 100644 --- a/state/syncer/baseAccoutnsSyncer_test.go +++ b/state/syncer/baseAccoutnsSyncer_test.go @@ -4,15 +4,17 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" - "github.com/stretchr/testify/require" ) func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { @@ -22,7 +24,7 @@ func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { TrieStorageManager: &storageManager.StorageManagerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, Timeout: time.Second, - Cacher: testscommon.NewCacherMock(), + Cacher: cache.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, diff --git a/state/syncer/userAccountSyncer_test.go b/state/syncer/userAccountSyncer_test.go index eefdd96778f..3ecdf5cd178 100644 --- a/state/syncer/userAccountSyncer_test.go +++ b/state/syncer/userAccountSyncer_test.go @@ -4,15 +4,17 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/trie" - "github.com/stretchr/testify/assert" ) // TODO add more tests @@ -24,7 +26,7 @@ func getDefaultBaseAccSyncerArgs() ArgsNewBaseAccountsSyncer { TrieStorageManager: &storageManager.StorageManagerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, Timeout: time.Second, - Cacher: testscommon.NewCacherMock(), + Cacher: cache.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, MaxTrieLevelInMemory: 0, @@ -95,7 +97,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { rootHash, _ := tr.RootHash() _ = tr.Commit() - args.Cacher = &testscommon.CacherStub{ + args.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) return interceptedNode, true diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go index 176a4ec7497..5d7252d3b2e 100644 --- a/state/syncer/userAccountsSyncer_test.go +++ b/state/syncer/userAccountsSyncer_test.go @@ -10,6 +10,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/api/mock" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" @@ -20,14 +23,13 @@ import ( "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/storageMarker" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func getDefaultUserAccountsSyncerArgs() syncer.ArgsNewUserAccountsSyncer { @@ -148,7 +150,7 @@ func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -228,7 +230,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -285,7 +287,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -366,7 +368,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { rootHash, _ := tr.RootHash() _ = tr.Commit() - args.Cacher = &testscommon.CacherStub{ + args.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) return interceptedNode, true diff --git a/state/syncer/validatorAccountsSyncer_test.go b/state/syncer/validatorAccountsSyncer_test.go index b4a025883f1..1ba90712704 100644 --- a/state/syncer/validatorAccountsSyncer_test.go +++ b/state/syncer/validatorAccountsSyncer_test.go @@ -4,15 +4,16 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/storageMarker" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewValidatorAccountsSyncer(t *testing.T) { @@ -93,7 +94,7 @@ func TestValidatorAccountsSyncer_SyncAccounts(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher diff --git a/storage/pruning/triePruningStorer_test.go b/storage/pruning/triePruningStorer_test.go index 28dc5c93f8e..c9ea19e93a3 100644 --- a/storage/pruning/triePruningStorer_test.go +++ b/storage/pruning/triePruningStorer_test.go @@ -8,7 +8,8 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/mock" "github.com/multiversx/mx-chain-go/storage/pruning" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,7 +45,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheSearchesOnlyOldEpochsAndR args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -81,7 +82,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithCache(t *testing.T) { args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -185,7 +186,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheDoesNotSearchInCurrentSto args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherStub() + cacher := cache.NewCacherStub() cacher.PutCalled = func(_ []byte, _ interface{}, _ int) bool { require.Fail(t, "this should not be called") return false @@ -209,7 +210,7 @@ func TestTriePruningStorer_GetFromLastEpochSearchesOnlyLastEpoch(t *testing.T) { args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -258,7 +259,7 @@ func TestTriePruningStorer_GetFromCurrentEpochSearchesOnlyCurrentEpoch(t *testin args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") diff --git a/storage/storageunit/storageunit_test.go b/storage/storageunit/storageunit_test.go index da4aea63b33..f92d70a48f7 100644 --- a/storage/storageunit/storageunit_test.go +++ b/storage/storageunit/storageunit_test.go @@ -5,21 +5,22 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-storage-go/common" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/mock" "github.com/multiversx/mx-chain-go/storage/storageunit" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/multiversx/mx-chain-storage-go/common" - "github.com/stretchr/testify/assert" ) func TestNewStorageUnit(t *testing.T) { t.Parallel() - cacher := &testscommon.CacherStub{} + cacher := &cache.CacherStub{} persister := &mock.PersisterStub{} t.Run("nil cacher should error", func(t *testing.T) { diff --git a/consensus/mock/bootstrapperStub.go b/testscommon/bootstrapperStubs/bootstrapperStub.go similarity index 98% rename from consensus/mock/bootstrapperStub.go rename to testscommon/bootstrapperStubs/bootstrapperStub.go index bd4a1b98bf2..346656e1b8e 100644 --- a/consensus/mock/bootstrapperStub.go +++ b/testscommon/bootstrapperStubs/bootstrapperStub.go @@ -1,8 +1,9 @@ -package mock +package bootstrapperStubs import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" ) diff --git a/testscommon/cacherMock.go b/testscommon/cache/cacherMock.go similarity index 99% rename from testscommon/cacherMock.go rename to testscommon/cache/cacherMock.go index 0b1a9aa5edf..4b569d34375 100644 --- a/testscommon/cacherMock.go +++ b/testscommon/cache/cacherMock.go @@ -1,4 +1,4 @@ -package testscommon +package cache import ( "sync" diff --git a/testscommon/cacherStub.go b/testscommon/cache/cacherStub.go similarity index 99% rename from testscommon/cacherStub.go rename to testscommon/cache/cacherStub.go index e3e11dd811f..82e30610563 100644 --- a/testscommon/cacherStub.go +++ b/testscommon/cache/cacherStub.go @@ -1,4 +1,4 @@ -package testscommon +package cache // CacherStub - type CacherStub struct { diff --git a/node/mock/throttlerStub.go b/testscommon/common/throttlerStub.go similarity index 98% rename from node/mock/throttlerStub.go rename to testscommon/common/throttlerStub.go index 24ab94c45c3..f4f5e0a34d0 100644 --- a/node/mock/throttlerStub.go +++ b/testscommon/common/throttlerStub.go @@ -1,4 +1,4 @@ -package mock +package common // ThrottlerStub - type ThrottlerStub struct { diff --git a/testscommon/components/components.go b/testscommon/components/components.go index 6d33ad04fa0..6e630b9050d 100644 --- a/testscommon/components/components.go +++ b/testscommon/components/components.go @@ -8,6 +8,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/data/outport" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" commonFactory "github.com/multiversx/mx-chain-go/common/factory" "github.com/multiversx/mx-chain-go/config" @@ -41,9 +45,6 @@ import ( statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("componentsMock") diff --git a/testscommon/components/default.go b/testscommon/components/default.go index 514b8355407..aebd690c51a 100644 --- a/testscommon/components/default.go +++ b/testscommon/components/default.go @@ -133,7 +133,7 @@ func GetDefaultProcessComponents(shardCoordinator sharding.Coordinator) *mock.Pr BlockProcess: &testscommon.BlockProcessorStub{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BootstrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensus.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, diff --git a/consensus/mock/broadcastMessangerMock.go b/testscommon/consensus/broadcastMessangerMock.go similarity index 81% rename from consensus/mock/broadcastMessangerMock.go rename to testscommon/consensus/broadcastMessangerMock.go index 2d659490725..80b0298ada9 100644 --- a/consensus/mock/broadcastMessangerMock.go +++ b/testscommon/consensus/broadcastMessangerMock.go @@ -1,7 +1,9 @@ -package mock +package consensus import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -9,12 +11,14 @@ import ( type BroadcastMessengerMock struct { BroadcastBlockCalled func(data.BodyHandler, data.HeaderHandler) error BroadcastHeaderCalled func(data.HeaderHandler, []byte) error + BroadcastEquivalentProofCalled func(proof data.HeaderProofHandler, pkBytes []byte) error PrepareBroadcastBlockDataValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) error PrepareBroadcastHeaderValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) BroadcastMiniBlocksCalled func(map[uint32][]byte, []byte) error BroadcastTransactionsCalled func(map[string][][]byte, []byte) error BroadcastConsensusMessageCalled func(*consensus.Message) error BroadcastBlockDataLeaderCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, pkBytes []byte) error + PrepareBroadcastEquivalentProofCalled func(proof data.HeaderProofHandler, consensusIndex int, pkBytes []byte) } // BroadcastBlock - @@ -114,6 +118,25 @@ func (bmm *BroadcastMessengerMock) BroadcastHeader(headerhandler data.HeaderHand return nil } +// BroadcastEquivalentProof - +func (bmm *BroadcastMessengerMock) BroadcastEquivalentProof(proof *block.HeaderProof, pkBytes []byte) error { + if bmm.BroadcastEquivalentProofCalled != nil { + return bmm.BroadcastEquivalentProofCalled(proof, pkBytes) + } + return nil +} + +// PrepareBroadcastEquivalentProof - +func (bmm *BroadcastMessengerMock) PrepareBroadcastEquivalentProof( + proof *block.HeaderProof, + consensusIndex int, + pkBytes []byte, +) { + if bmm.PrepareBroadcastEquivalentProofCalled != nil { + bmm.PrepareBroadcastEquivalentProofCalled(proof, consensusIndex, pkBytes) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (bmm *BroadcastMessengerMock) IsInterfaceNil() bool { return bmm == nil diff --git a/consensus/mock/chronologyHandlerMock.go b/testscommon/consensus/chronologyHandlerMock.go similarity index 98% rename from consensus/mock/chronologyHandlerMock.go rename to testscommon/consensus/chronologyHandlerMock.go index 789387845de..0cfceca2eb9 100644 --- a/consensus/mock/chronologyHandlerMock.go +++ b/testscommon/consensus/chronologyHandlerMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "github.com/multiversx/mx-chain-go/consensus" diff --git a/consensus/mock/consensusDataContainerMock.go b/testscommon/consensus/consensusDataContainerMock.go similarity index 79% rename from consensus/mock/consensusDataContainerMock.go rename to testscommon/consensus/consensusDataContainerMock.go index 88f837b1da1..ad00574ca6b 100644 --- a/consensus/mock/consensusDataContainerMock.go +++ b/testscommon/consensus/consensusDataContainerMock.go @@ -1,9 +1,11 @@ -package mock +package consensus import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -13,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) +// TODO: remove this mock component; implement setters for main component in export_test.go // ConsensusCoreMock - type ConsensusCoreMock struct { blockChain data.ChainHandler @@ -38,6 +41,8 @@ type ConsensusCoreMock struct { messageSigningHandler consensus.P2PSigningHandler peerBlacklistHandler consensus.PeerBlacklistHandler signingHandler consensus.SigningHandler + enableEpochsHandler common.EnableEpochsHandler + equivalentProofsPool consensus.EquivalentProofsPool } // GetAntiFloodHandler - @@ -120,6 +125,11 @@ func (ccm *ConsensusCoreMock) SetBlockchain(blockChain data.ChainHandler) { ccm.blockChain = blockChain } +// SetHeaderSubscriber - +func (ccm *ConsensusCoreMock) SetHeaderSubscriber(headersSubscriber consensus.HeadersPoolSubscriber) { + ccm.headersSubscriber = headersSubscriber +} + // SetBlockProcessor - func (ccm *ConsensusCoreMock) SetBlockProcessor(blockProcessor process.BlockProcessor) { ccm.blockProcessor = blockProcessor @@ -175,6 +185,31 @@ func (ccm *ConsensusCoreMock) SetValidatorGroupSelector(validatorGroupSelector n ccm.validatorGroupSelector = validatorGroupSelector } +// SetEpochStartNotifier - +func (ccm *ConsensusCoreMock) SetEpochStartNotifier(epochStartNotifier epochStart.RegistrationHandler) { + ccm.epochStartNotifier = epochStartNotifier +} + +// SetAntifloodHandler - +func (ccm *ConsensusCoreMock) SetAntifloodHandler(antifloodHandler consensus.P2PAntifloodHandler) { + ccm.antifloodHandler = antifloodHandler +} + +// SetPeerHonestyHandler - +func (ccm *ConsensusCoreMock) SetPeerHonestyHandler(peerHonestyHandler consensus.PeerHonestyHandler) { + ccm.peerHonestyHandler = peerHonestyHandler +} + +// SetScheduledProcessor - +func (ccm *ConsensusCoreMock) SetScheduledProcessor(scheduledProcessor consensus.ScheduledProcessor) { + ccm.scheduledProcessor = scheduledProcessor +} + +// SetPeerBlacklistHandler - +func (ccm *ConsensusCoreMock) SetPeerBlacklistHandler(peerBlacklistHandler consensus.PeerBlacklistHandler) { + ccm.peerBlacklistHandler = peerBlacklistHandler +} + // PeerHonestyHandler - func (ccm *ConsensusCoreMock) PeerHonestyHandler() consensus.PeerHonestyHandler { return ccm.peerHonestyHandler @@ -240,6 +275,26 @@ func (ccm *ConsensusCoreMock) SetSigningHandler(signingHandler consensus.Signing ccm.signingHandler = signingHandler } +// EnableEpochsHandler - +func (ccm *ConsensusCoreMock) EnableEpochsHandler() common.EnableEpochsHandler { + return ccm.enableEpochsHandler +} + +// SetEnableEpochsHandler - +func (ccm *ConsensusCoreMock) SetEnableEpochsHandler(enableEpochsHandler common.EnableEpochsHandler) { + ccm.enableEpochsHandler = enableEpochsHandler +} + +// EquivalentProofsPool - +func (ccm *ConsensusCoreMock) EquivalentProofsPool() consensus.EquivalentProofsPool { + return ccm.equivalentProofsPool +} + +// SetEquivalentProofsPool - +func (ccm *ConsensusCoreMock) SetEquivalentProofsPool(proofPool consensus.EquivalentProofsPool) { + ccm.equivalentProofsPool = proofPool +} + // IsInterfaceNil returns true if there is no value under the interface func (ccm *ConsensusCoreMock) IsInterfaceNil() bool { return ccm == nil diff --git a/testscommon/consensus/consensusStateMock.go b/testscommon/consensus/consensusStateMock.go new file mode 100644 index 00000000000..dae02a0323c --- /dev/null +++ b/testscommon/consensus/consensusStateMock.go @@ -0,0 +1,652 @@ +package consensus + +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +) + +// ConsensusStateMock - +type ConsensusStateMock struct { + ResetConsensusStateCalled func() + IsNodeLeaderInCurrentRoundCalled func(node string) bool + IsSelfLeaderInCurrentRoundCalled func() bool + GetLeaderCalled func() (string, error) + GetNextConsensusGroupCalled func(randomSource []byte, round uint64, shardId uint32, nodesCoordinator nodesCoordinator.NodesCoordinator, epoch uint32) (string, []string, error) + IsConsensusDataSetCalled func() bool + IsConsensusDataEqualCalled func(data []byte) bool + IsJobDoneCalled func(node string, currentSubroundId int) bool + IsSelfJobDoneCalled func(currentSubroundId int) bool + IsCurrentSubroundFinishedCalled func(currentSubroundId int) bool + IsNodeSelfCalled func(node string) bool + IsBlockBodyAlreadyReceivedCalled func() bool + IsHeaderAlreadyReceivedCalled func() bool + CanDoSubroundJobCalled func(currentSubroundId int) bool + CanProcessReceivedMessageCalled func(cnsDta *consensus.Message, currentRoundIndex int64, currentSubroundId int) bool + GenerateBitmapCalled func(subroundId int) []byte + ProcessingBlockCalled func() bool + SetProcessingBlockCalled func(processingBlock bool) + ConsensusGroupSizeCalled func() int + SetThresholdCalled func(subroundId int, threshold int) + AddReceivedHeaderCalled func(headerHandler data.HeaderHandler) + GetReceivedHeadersCalled func() []data.HeaderHandler + AddMessageWithSignatureCalled func(key string, message p2p.MessageP2P) + GetMessageWithSignatureCalled func(key string) (p2p.MessageP2P, bool) + IsSubroundFinishedCalled func(subroundID int) bool + GetDataCalled func() []byte + SetDataCalled func(data []byte) + IsMultiKeyLeaderInCurrentRoundCalled func() bool + IsLeaderJobDoneCalled func(currentSubroundId int) bool + IsMultiKeyJobDoneCalled func(currentSubroundId int) bool + GetMultikeyRedundancyStepInReasonCalled func() string + ResetRoundsWithoutReceivedMessagesCalled func(pkBytes []byte, pid core.PeerID) + GetRoundCanceledCalled func() bool + SetRoundCanceledCalled func(state bool) + GetRoundIndexCalled func() int64 + SetRoundIndexCalled func(roundIndex int64) + GetRoundTimeStampCalled func() time.Time + SetRoundTimeStampCalled func(roundTimeStamp time.Time) + GetExtendedCalledCalled func() bool + GetBodyCalled func() data.BodyHandler + SetBodyCalled func(body data.BodyHandler) + GetHeaderCalled func() data.HeaderHandler + SetHeaderCalled func(header data.HeaderHandler) + GetWaitingAllSignaturesTimeOutCalled func() bool + SetWaitingAllSignaturesTimeOutCalled func(b bool) + ConsensusGroupIndexCalled func(pubKey string) (int, error) + SelfConsensusGroupIndexCalled func() (int, error) + SetEligibleListCalled func(eligibleList map[string]struct{}) + ConsensusGroupCalled func() []string + SetConsensusGroupCalled func(consensusGroup []string) + SetLeaderCalled func(leader string) + SetConsensusGroupSizeCalled func(consensusGroupSize int) + SelfPubKeyCalled func() string + SetSelfPubKeyCalled func(selfPubKey string) + JobDoneCalled func(key string, subroundId int) (bool, error) + SetJobDoneCalled func(key string, subroundId int, value bool) error + SelfJobDoneCalled func(subroundId int) (bool, error) + IsNodeInConsensusGroupCalled func(node string) bool + IsNodeInEligibleListCalled func(node string) bool + ComputeSizeCalled func(subroundId int) int + ResetRoundStateCalled func() + IsMultiKeyInConsensusGroupCalled func() bool + IsKeyManagedBySelfCalled func(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessagesCalled func(pkBytes []byte) + GetKeysHandlerCalled func() consensus.KeysHandler + LeaderCalled func() string + StatusCalled func(subroundId int) int + SetStatusCalled func(subroundId int, subroundStatus int) + ResetRoundStatusCalled func() + ThresholdCalled func(subroundId int) int + FallbackThresholdCalled func(subroundId int) int + SetFallbackThresholdCalled func(subroundId int, threshold int) +} + +// AddReceivedHeader - +func (cnsm *ConsensusStateMock) AddReceivedHeader(headerHandler data.HeaderHandler) { + if cnsm.AddReceivedHeaderCalled != nil { + cnsm.AddReceivedHeaderCalled(headerHandler) + } +} + +// GetReceivedHeaders - +func (cnsm *ConsensusStateMock) GetReceivedHeaders() []data.HeaderHandler { + if cnsm.GetReceivedHeadersCalled != nil { + return cnsm.GetReceivedHeadersCalled() + } + return nil +} + +// AddMessageWithSignature - +func (cnsm *ConsensusStateMock) AddMessageWithSignature(key string, message p2p.MessageP2P) { + if cnsm.AddMessageWithSignatureCalled != nil { + cnsm.AddMessageWithSignatureCalled(key, message) + } +} + +// GetMessageWithSignature - +func (cnsm *ConsensusStateMock) GetMessageWithSignature(key string) (p2p.MessageP2P, bool) { + if cnsm.GetMessageWithSignatureCalled != nil { + return cnsm.GetMessageWithSignatureCalled(key) + } + return nil, false +} + +// IsSubroundFinished - +func (cnsm *ConsensusStateMock) IsSubroundFinished(subroundID int) bool { + if cnsm.IsSubroundFinishedCalled != nil { + return cnsm.IsSubroundFinishedCalled(subroundID) + } + return false +} + +// GetData - +func (cnsm *ConsensusStateMock) GetData() []byte { + if cnsm.GetDataCalled != nil { + return cnsm.GetDataCalled() + } + return nil +} + +// SetData - +func (cnsm *ConsensusStateMock) SetData(data []byte) { + if cnsm.SetDataCalled != nil { + cnsm.SetDataCalled(data) + } +} + +// IsMultiKeyLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsMultiKeyLeaderInCurrentRound() bool { + if cnsm.IsMultiKeyLeaderInCurrentRoundCalled != nil { + return cnsm.IsMultiKeyLeaderInCurrentRoundCalled() + } + return false +} + +// IsLeaderJobDone - +func (cnsm *ConsensusStateMock) IsLeaderJobDone(currentSubroundId int) bool { + if cnsm.IsLeaderJobDoneCalled != nil { + return cnsm.IsLeaderJobDoneCalled(currentSubroundId) + } + return false +} + +// IsMultiKeyJobDone - +func (cnsm *ConsensusStateMock) IsMultiKeyJobDone(currentSubroundId int) bool { + if cnsm.IsMultiKeyJobDoneCalled != nil { + return cnsm.IsMultiKeyJobDoneCalled(currentSubroundId) + } + return false +} + +// GetMultikeyRedundancyStepInReason - +func (cnsm *ConsensusStateMock) GetMultikeyRedundancyStepInReason() string { + if cnsm.GetMultikeyRedundancyStepInReasonCalled != nil { + return cnsm.GetMultikeyRedundancyStepInReasonCalled() + } + return "" +} + +// ResetRoundsWithoutReceivedMessages - +func (cnsm *ConsensusStateMock) ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) { + if cnsm.ResetRoundsWithoutReceivedMessagesCalled != nil { + cnsm.ResetRoundsWithoutReceivedMessagesCalled(pkBytes, pid) + } +} + +// GetRoundCanceled - +func (cnsm *ConsensusStateMock) GetRoundCanceled() bool { + if cnsm.GetRoundCanceledCalled != nil { + return cnsm.GetRoundCanceledCalled() + } + return false +} + +// SetRoundCanceled - +func (cnsm *ConsensusStateMock) SetRoundCanceled(state bool) { + if cnsm.SetRoundCanceledCalled != nil { + cnsm.SetRoundCanceledCalled(state) + } +} + +// GetRoundIndex - +func (cnsm *ConsensusStateMock) GetRoundIndex() int64 { + if cnsm.GetRoundIndexCalled != nil { + return cnsm.GetRoundIndexCalled() + } + return 0 +} + +// SetRoundIndex - +func (cnsm *ConsensusStateMock) SetRoundIndex(roundIndex int64) { + if cnsm.SetRoundIndexCalled != nil { + cnsm.SetRoundIndexCalled(roundIndex) + } +} + +// GetRoundTimeStamp - +func (cnsm *ConsensusStateMock) GetRoundTimeStamp() time.Time { + if cnsm.GetRoundTimeStampCalled != nil { + return cnsm.GetRoundTimeStampCalled() + } + return time.Time{} +} + +// SetRoundTimeStamp - +func (cnsm *ConsensusStateMock) SetRoundTimeStamp(roundTimeStamp time.Time) { + if cnsm.SetRoundTimeStampCalled != nil { + cnsm.SetRoundTimeStampCalled(roundTimeStamp) + } +} + +// GetExtendedCalled - +func (cnsm *ConsensusStateMock) GetExtendedCalled() bool { + if cnsm.GetExtendedCalledCalled != nil { + return cnsm.GetExtendedCalledCalled() + } + return false +} + +// GetBody - +func (cnsm *ConsensusStateMock) GetBody() data.BodyHandler { + if cnsm.GetBodyCalled != nil { + return cnsm.GetBodyCalled() + } + return nil +} + +// SetBody - +func (cnsm *ConsensusStateMock) SetBody(body data.BodyHandler) { + if cnsm.SetBodyCalled != nil { + cnsm.SetBodyCalled(body) + } +} + +// GetHeader - +func (cnsm *ConsensusStateMock) GetHeader() data.HeaderHandler { + if cnsm.GetHeaderCalled != nil { + return cnsm.GetHeaderCalled() + } + return nil +} + +// SetHeader - +func (cnsm *ConsensusStateMock) SetHeader(header data.HeaderHandler) { + if cnsm.SetHeaderCalled != nil { + cnsm.SetHeaderCalled(header) + } +} + +// GetWaitingAllSignaturesTimeOut - +func (cnsm *ConsensusStateMock) GetWaitingAllSignaturesTimeOut() bool { + if cnsm.GetWaitingAllSignaturesTimeOutCalled != nil { + return cnsm.GetWaitingAllSignaturesTimeOutCalled() + } + return false +} + +// SetWaitingAllSignaturesTimeOut - +func (cnsm *ConsensusStateMock) SetWaitingAllSignaturesTimeOut(b bool) { + if cnsm.SetWaitingAllSignaturesTimeOutCalled != nil { + cnsm.SetWaitingAllSignaturesTimeOutCalled(b) + } +} + +// ConsensusGroupIndex - +func (cnsm *ConsensusStateMock) ConsensusGroupIndex(pubKey string) (int, error) { + if cnsm.ConsensusGroupIndexCalled != nil { + return cnsm.ConsensusGroupIndexCalled(pubKey) + } + return 0, nil +} + +// SelfConsensusGroupIndex - +func (cnsm *ConsensusStateMock) SelfConsensusGroupIndex() (int, error) { + if cnsm.SelfConsensusGroupIndexCalled != nil { + return cnsm.SelfConsensusGroupIndexCalled() + } + return 0, nil +} + +// SetEligibleList - +func (cnsm *ConsensusStateMock) SetEligibleList(eligibleList map[string]struct{}) { + if cnsm.SetEligibleListCalled != nil { + cnsm.SetEligibleListCalled(eligibleList) + } +} + +// ConsensusGroup - +func (cnsm *ConsensusStateMock) ConsensusGroup() []string { + if cnsm.ConsensusGroupCalled != nil { + return cnsm.ConsensusGroupCalled() + } + return nil +} + +// SetConsensusGroup - +func (cnsm *ConsensusStateMock) SetConsensusGroup(consensusGroup []string) { + if cnsm.SetConsensusGroupCalled != nil { + cnsm.SetConsensusGroupCalled(consensusGroup) + } +} + +// SetLeader - +func (cnsm *ConsensusStateMock) SetLeader(leader string) { + if cnsm.SetLeaderCalled != nil { + cnsm.SetLeaderCalled(leader) + } +} + +// SetConsensusGroupSize - +func (cnsm *ConsensusStateMock) SetConsensusGroupSize(consensusGroupSize int) { + if cnsm.SetConsensusGroupSizeCalled != nil { + cnsm.SetConsensusGroupSizeCalled(consensusGroupSize) + } +} + +// SelfPubKey - +func (cnsm *ConsensusStateMock) SelfPubKey() string { + if cnsm.SelfPubKeyCalled != nil { + return cnsm.SelfPubKeyCalled() + } + return "" +} + +// SetSelfPubKey - +func (cnsm *ConsensusStateMock) SetSelfPubKey(selfPubKey string) { + if cnsm.SetSelfPubKeyCalled != nil { + cnsm.SetSelfPubKeyCalled(selfPubKey) + } +} + +// JobDone - +func (cnsm *ConsensusStateMock) JobDone(key string, subroundId int) (bool, error) { + if cnsm.JobDoneCalled != nil { + return cnsm.JobDoneCalled(key, subroundId) + } + return false, nil +} + +// SetJobDone - +func (cnsm *ConsensusStateMock) SetJobDone(key string, subroundId int, value bool) error { + if cnsm.SetJobDoneCalled != nil { + return cnsm.SetJobDoneCalled(key, subroundId, value) + } + return nil +} + +// SelfJobDone - +func (cnsm *ConsensusStateMock) SelfJobDone(subroundId int) (bool, error) { + if cnsm.SelfJobDoneCalled != nil { + return cnsm.SelfJobDoneCalled(subroundId) + } + return false, nil +} + +// IsNodeInConsensusGroup - +func (cnsm *ConsensusStateMock) IsNodeInConsensusGroup(node string) bool { + if cnsm.IsNodeInConsensusGroupCalled != nil { + return cnsm.IsNodeInConsensusGroupCalled(node) + } + return false +} + +// IsNodeInEligibleList - +func (cnsm *ConsensusStateMock) IsNodeInEligibleList(node string) bool { + if cnsm.IsNodeInEligibleListCalled != nil { + return cnsm.IsNodeInEligibleListCalled(node) + } + return false +} + +// ComputeSize - +func (cnsm *ConsensusStateMock) ComputeSize(subroundId int) int { + if cnsm.ComputeSizeCalled != nil { + return cnsm.ComputeSizeCalled(subroundId) + } + return 0 +} + +// ResetRoundState - +func (cnsm *ConsensusStateMock) ResetRoundState() { + if cnsm.ResetRoundStateCalled != nil { + cnsm.ResetRoundStateCalled() + } +} + +// IsMultiKeyInConsensusGroup - +func (cnsm *ConsensusStateMock) IsMultiKeyInConsensusGroup() bool { + if cnsm.IsMultiKeyInConsensusGroupCalled != nil { + return cnsm.IsMultiKeyInConsensusGroupCalled() + } + return false +} + +// IsKeyManagedBySelf - +func (cnsm *ConsensusStateMock) IsKeyManagedBySelf(pkBytes []byte) bool { + if cnsm.IsKeyManagedBySelfCalled != nil { + return cnsm.IsKeyManagedBySelfCalled(pkBytes) + } + return false +} + +// IncrementRoundsWithoutReceivedMessages - +func (cnsm *ConsensusStateMock) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + if cnsm.IncrementRoundsWithoutReceivedMessagesCalled != nil { + cnsm.IncrementRoundsWithoutReceivedMessagesCalled(pkBytes) + } +} + +// GetKeysHandler - +func (cnsm *ConsensusStateMock) GetKeysHandler() consensus.KeysHandler { + if cnsm.GetKeysHandlerCalled != nil { + return cnsm.GetKeysHandlerCalled() + } + return nil +} + +// Leader - +func (cnsm *ConsensusStateMock) Leader() string { + if cnsm.LeaderCalled != nil { + return cnsm.LeaderCalled() + } + return "" +} + +// Status - +func (cnsm *ConsensusStateMock) Status(subroundId int) int { + if cnsm.StatusCalled != nil { + return cnsm.StatusCalled(subroundId) + } + return 0 +} + +// SetStatus - +func (cnsm *ConsensusStateMock) SetStatus(subroundId int, subroundStatus int) { + if cnsm.SetStatusCalled != nil { + cnsm.SetStatusCalled(subroundId, subroundStatus) + } +} + +// ResetRoundStatus - +func (cnsm *ConsensusStateMock) ResetRoundStatus() { + if cnsm.ResetRoundStatusCalled != nil { + cnsm.ResetRoundStatusCalled() + } +} + +// Threshold - +func (cnsm *ConsensusStateMock) Threshold(subroundId int) int { + if cnsm.ThresholdCalled != nil { + return cnsm.ThresholdCalled(subroundId) + } + return 0 +} + +// FallbackThreshold - +func (cnsm *ConsensusStateMock) FallbackThreshold(subroundId int) int { + if cnsm.FallbackThresholdCalled != nil { + return cnsm.FallbackThresholdCalled(subroundId) + } + return 0 +} + +func (cnsm *ConsensusStateMock) SetFallbackThreshold(subroundId int, threshold int) { + if cnsm.SetFallbackThresholdCalled != nil { + cnsm.SetFallbackThresholdCalled(subroundId, threshold) + } +} + +// ResetConsensusState - +func (cnsm *ConsensusStateMock) ResetConsensusState() { + if cnsm.ResetConsensusStateCalled != nil { + cnsm.ResetConsensusStateCalled() + } +} + +// IsNodeLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsNodeLeaderInCurrentRound(node string) bool { + if cnsm.IsNodeLeaderInCurrentRoundCalled != nil { + return cnsm.IsNodeLeaderInCurrentRoundCalled(node) + } + return false +} + +// IsSelfLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsSelfLeaderInCurrentRound() bool { + if cnsm.IsSelfLeaderInCurrentRoundCalled != nil { + return cnsm.IsSelfLeaderInCurrentRoundCalled() + } + return false +} + +// GetLeader - +func (cnsm *ConsensusStateMock) GetLeader() (string, error) { + if cnsm.GetLeaderCalled != nil { + return cnsm.GetLeaderCalled() + } + return "", nil +} + +// GetNextConsensusGroup - +func (cnsm *ConsensusStateMock) GetNextConsensusGroup( + randomSource []byte, + round uint64, + shardId uint32, + nodesCoordinator nodesCoordinator.NodesCoordinator, + epoch uint32, +) (string, []string, error) { + if cnsm.GetNextConsensusGroupCalled != nil { + return cnsm.GetNextConsensusGroupCalled(randomSource, round, shardId, nodesCoordinator, epoch) + } + return "", nil, nil +} + +// IsConsensusDataSet - +func (cnsm *ConsensusStateMock) IsConsensusDataSet() bool { + if cnsm.IsConsensusDataSetCalled != nil { + return cnsm.IsConsensusDataSetCalled() + } + return false +} + +// IsConsensusDataEqual - +func (cnsm *ConsensusStateMock) IsConsensusDataEqual(data []byte) bool { + if cnsm.IsConsensusDataEqualCalled != nil { + return cnsm.IsConsensusDataEqualCalled(data) + } + return false +} + +// IsJobDone - +func (cnsm *ConsensusStateMock) IsJobDone(node string, currentSubroundId int) bool { + if cnsm.IsJobDoneCalled != nil { + return cnsm.IsJobDoneCalled(node, currentSubroundId) + } + return false +} + +// IsSelfJobDone - +func (cnsm *ConsensusStateMock) IsSelfJobDone(currentSubroundId int) bool { + if cnsm.IsSelfJobDoneCalled != nil { + return cnsm.IsSelfJobDoneCalled(currentSubroundId) + } + return false +} + +// IsCurrentSubroundFinished - +func (cnsm *ConsensusStateMock) IsCurrentSubroundFinished(currentSubroundId int) bool { + if cnsm.IsCurrentSubroundFinishedCalled != nil { + return cnsm.IsCurrentSubroundFinishedCalled(currentSubroundId) + } + return false +} + +// IsNodeSelf - +func (cnsm *ConsensusStateMock) IsNodeSelf(node string) bool { + if cnsm.IsNodeSelfCalled != nil { + return cnsm.IsNodeSelfCalled(node) + } + return false +} + +// IsBlockBodyAlreadyReceived - +func (cnsm *ConsensusStateMock) IsBlockBodyAlreadyReceived() bool { + if cnsm.IsBlockBodyAlreadyReceivedCalled != nil { + return cnsm.IsBlockBodyAlreadyReceivedCalled() + } + return false +} + +// IsHeaderAlreadyReceived - +func (cnsm *ConsensusStateMock) IsHeaderAlreadyReceived() bool { + if cnsm.IsHeaderAlreadyReceivedCalled != nil { + return cnsm.IsHeaderAlreadyReceivedCalled() + } + return false +} + +// CanDoSubroundJob - +func (cnsm *ConsensusStateMock) CanDoSubroundJob(currentSubroundId int) bool { + if cnsm.CanDoSubroundJobCalled != nil { + return cnsm.CanDoSubroundJobCalled(currentSubroundId) + } + return false +} + +// CanProcessReceivedMessage - +func (cnsm *ConsensusStateMock) CanProcessReceivedMessage( + cnsDta *consensus.Message, + currentRoundIndex int64, + currentSubroundId int, +) bool { + return cnsm.CanProcessReceivedMessageCalled(cnsDta, currentRoundIndex, currentSubroundId) +} + +// GenerateBitmap - +func (cnsm *ConsensusStateMock) GenerateBitmap(subroundId int) []byte { + if cnsm.GenerateBitmapCalled != nil { + return cnsm.GenerateBitmapCalled(subroundId) + } + return nil +} + +// ProcessingBlock - +func (cnsm *ConsensusStateMock) ProcessingBlock() bool { + if cnsm.ProcessingBlockCalled != nil { + return cnsm.ProcessingBlockCalled() + } + return false +} + +// SetProcessingBlock - +func (cnsm *ConsensusStateMock) SetProcessingBlock(processingBlock bool) { + if cnsm.SetProcessingBlockCalled != nil { + cnsm.SetProcessingBlockCalled(processingBlock) + } +} + +// ConsensusGroupSize - +func (cnsm *ConsensusStateMock) ConsensusGroupSize() int { + if cnsm.ConsensusGroupSizeCalled != nil { + return cnsm.ConsensusGroupSizeCalled() + } + return 0 +} + +// SetThreshold - +func (cnsm *ConsensusStateMock) SetThreshold(subroundId int, threshold int) { + if cnsm.SetThresholdCalled != nil { + cnsm.SetThresholdCalled(subroundId, threshold) + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cnsm *ConsensusStateMock) IsInterfaceNil() bool { + return cnsm == nil +} diff --git a/testscommon/consensus/delayedBroadcasterMock.go b/testscommon/consensus/delayedBroadcasterMock.go new file mode 100644 index 00000000000..1c0aba7aee0 --- /dev/null +++ b/testscommon/consensus/delayedBroadcasterMock.go @@ -0,0 +1,88 @@ +package consensus + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + + "github.com/multiversx/mx-chain-go/consensus" + + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" +) + +// DelayedBroadcasterMock - +type DelayedBroadcasterMock struct { + SetLeaderDataCalled func(data *shared.DelayedBroadcastData) error + SetValidatorDataCalled func(data *shared.DelayedBroadcastData) error + SetHeaderForValidatorCalled func(vData *shared.ValidatorHeaderBroadcastData) error + SetBroadcastHandlersCalled func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofsBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error) error + CloseCalled func() + SetFinalProofForValidatorCalled func(proof *block.HeaderProof, consensusIndex int, pkBytes []byte) error +} + +// SetFinalProofForValidator - +func (mock *DelayedBroadcasterMock) SetFinalProofForValidator(proof *block.HeaderProof, consensusIndex int, pkBytes []byte) error { + if mock.SetFinalProofForValidatorCalled != nil { + return mock.SetFinalProofForValidatorCalled(proof, consensusIndex, pkBytes) + } + return nil +} + +// SetLeaderData - +func (mock *DelayedBroadcasterMock) SetLeaderData(data *shared.DelayedBroadcastData) error { + if mock.SetLeaderDataCalled != nil { + return mock.SetLeaderDataCalled(data) + } + return nil +} + +// SetValidatorData - +func (mock *DelayedBroadcasterMock) SetValidatorData(data *shared.DelayedBroadcastData) error { + if mock.SetValidatorDataCalled != nil { + return mock.SetValidatorDataCalled(data) + } + return nil +} + +// SetHeaderForValidator - +func (mock *DelayedBroadcasterMock) SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error { + if mock.SetHeaderForValidatorCalled != nil { + return mock.SetHeaderForValidatorCalled(vData) + } + return nil +} + +// SetBroadcastHandlers - +func (mock *DelayedBroadcasterMock) SetBroadcastHandlers( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + equivalentProofBroadcast func(proof *block.HeaderProof, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, +) error { + if mock.SetBroadcastHandlersCalled != nil { + return mock.SetBroadcastHandlersCalled( + mbBroadcast, + txBroadcast, + headerBroadcast, + equivalentProofBroadcast, + consensusMessageBroadcast) + } + return nil +} + +// Close - +func (mock *DelayedBroadcasterMock) Close() { + if mock.CloseCalled != nil { + mock.CloseCalled() + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (mock *DelayedBroadcasterMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/consensus/mock/hasherStub.go b/testscommon/consensus/hasherStub.go similarity index 97% rename from consensus/mock/hasherStub.go rename to testscommon/consensus/hasherStub.go index f05c2fd2cc8..05bea1aaa6d 100644 --- a/consensus/mock/hasherStub.go +++ b/testscommon/consensus/hasherStub.go @@ -1,4 +1,4 @@ -package mock +package consensus // HasherStub - type HasherStub struct { diff --git a/testscommon/consensus/headerSigVerifierStub.go b/testscommon/consensus/headerSigVerifierStub.go new file mode 100644 index 00000000000..d6f1004e9fd --- /dev/null +++ b/testscommon/consensus/headerSigVerifierStub.go @@ -0,0 +1,82 @@ +package consensus + +import "github.com/multiversx/mx-chain-core-go/data" + +// HeaderSigVerifierMock - +type HeaderSigVerifierMock struct { + VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error + VerifySignatureCalled func(header data.HeaderHandler) error + VerifyRandSeedCalled func(header data.HeaderHandler) error + VerifyLeaderSignatureCalled func(header data.HeaderHandler) error + VerifySignatureForHashCalled func(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderWithProofCalled func(header data.HeaderHandler) error + VerifyHeaderProofCalled func(proofHandler data.HeaderProofHandler) error +} + +// VerifyRandSeed - +func (mock *HeaderSigVerifierMock) VerifyRandSeed(header data.HeaderHandler) error { + if mock.VerifyRandSeedCalled != nil { + return mock.VerifyRandSeedCalled(header) + } + + return nil +} + +// VerifyRandSeedAndLeaderSignature - +func (mock *HeaderSigVerifierMock) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { + if mock.VerifyRandSeedAndLeaderSignatureCalled != nil { + return mock.VerifyRandSeedAndLeaderSignatureCalled(header) + } + + return nil +} + +// VerifySignature - +func (mock *HeaderSigVerifierMock) VerifySignature(header data.HeaderHandler) error { + if mock.VerifySignatureCalled != nil { + return mock.VerifySignatureCalled(header) + } + + return nil +} + +// VerifyLeaderSignature - +func (mock *HeaderSigVerifierMock) VerifyLeaderSignature(header data.HeaderHandler) error { + if mock.VerifyLeaderSignatureCalled != nil { + return mock.VerifyLeaderSignatureCalled(header) + } + + return nil +} + +// VerifySignatureForHash - +func (mock *HeaderSigVerifierMock) VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error { + if mock.VerifySignatureForHashCalled != nil { + return mock.VerifySignatureForHashCalled(header, hash, pubkeysBitmap, signature) + } + + return nil +} + +// VerifyHeaderWithProof - +func (mock *HeaderSigVerifierMock) VerifyHeaderWithProof(header data.HeaderHandler) error { + if mock.VerifyHeaderWithProofCalled != nil { + return mock.VerifyHeaderWithProofCalled(header) + } + + return nil +} + +// VerifyHeaderProof - +func (mock *HeaderSigVerifierMock) VerifyHeaderProof(proofHandler data.HeaderProofHandler) error { + if mock.VerifyHeaderProofCalled != nil { + return mock.VerifyHeaderProofCalled(proofHandler) + } + + return nil +} + +// IsInterfaceNil - +func (mock *HeaderSigVerifierMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/testscommon/consensus/initializers/initializers.go b/testscommon/consensus/initializers/initializers.go new file mode 100644 index 00000000000..187c8f02892 --- /dev/null +++ b/testscommon/consensus/initializers/initializers.go @@ -0,0 +1,156 @@ +package initializers + +import ( + crypto "github.com/multiversx/mx-chain-crypto-go" + "golang.org/x/exp/slices" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" +) + +func createEligibleList(size int) []string { + eligibleList := make([]string, 0) + for i := 0; i < size; i++ { + eligibleList = append(eligibleList, string([]byte{byte(i + 65)})) + } + return eligibleList +} + +// CreateEligibleListFromMap creates a list of eligible nodes from a map of private keys +func CreateEligibleListFromMap(mapKeys map[string]crypto.PrivateKey) []string { + eligibleList := make([]string, 0, len(mapKeys)) + for key := range mapKeys { + eligibleList = append(eligibleList, key) + } + slices.Sort(eligibleList) + return eligibleList +} + +// InitConsensusStateWithNodesCoordinator creates a consensus state with a nodes coordinator +func InitConsensusStateWithNodesCoordinator(validatorsGroupSelector nodesCoordinator.NodesCoordinator) *spos.ConsensusState { + return initConsensusStateWithKeysHandlerAndNodesCoordinator(&testscommon.KeysHandlerStub{}, validatorsGroupSelector) +} + +// InitConsensusState creates a consensus state +func InitConsensusState() *spos.ConsensusState { + return InitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) +} + +// InitConsensusStateWithArgs creates a consensus state the given arguments +func InitConsensusStateWithArgs(keysHandler consensus.KeysHandler, mapKeys map[string]crypto.PrivateKey) *spos.ConsensusState { + return initConsensusStateWithKeysHandlerWithGroupSizeWithRealKeys(keysHandler, mapKeys) +} + +// InitConsensusStateWithKeysHandler creates a consensus state with a keys handler +func InitConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { + consensusGroupSize := 9 + return initConsensusStateWithKeysHandlerWithGroupSize(keysHandler, consensusGroupSize) +} + +func initConsensusStateWithKeysHandlerAndNodesCoordinator(keysHandler consensus.KeysHandler, validatorsGroupSelector nodesCoordinator.NodesCoordinator) *spos.ConsensusState { + leader, consensusValidators, _ := validatorsGroupSelector.GetConsensusValidatorsPublicKeys([]byte("randomness"), 0, 0, 0) + eligibleNodesPubKeys := make(map[string]struct{}) + for _, key := range consensusValidators { + eligibleNodesPubKeys[key] = struct{}{} + } + return createConsensusStateWithNodes(eligibleNodesPubKeys, consensusValidators, leader, keysHandler) +} + +// InitConsensusStateWithArgsVerifySignature creates a consensus state with the given arguments for signature verification +func InitConsensusStateWithArgsVerifySignature(keysHandler consensus.KeysHandler, keys []string) *spos.ConsensusState { + numberOfKeys := len(keys) + eligibleNodesPubKeys := make(map[string]struct{}, numberOfKeys) + for _, key := range keys { + eligibleNodesPubKeys[key] = struct{}{} + } + + indexLeader := 1 + rcns, _ := spos.NewRoundConsensus( + eligibleNodesPubKeys, + numberOfKeys, + keys[indexLeader], + keysHandler, + ) + rcns.SetConsensusGroup(keys) + rcns.ResetRoundState() + + pBFTThreshold := numberOfKeys*2/3 + 1 + pBFTFallbackThreshold := numberOfKeys*1/2 + 1 + rthr := spos.NewRoundThreshold() + rthr.SetThreshold(1, 1) + rthr.SetThreshold(2, pBFTThreshold) + rthr.SetFallbackThreshold(1, 1) + rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) + + rstatus := spos.NewRoundStatus() + rstatus.ResetRoundStatus() + cns := spos.NewConsensusState( + rcns, + rthr, + rstatus, + ) + cns.Data = []byte("X") + cns.SetRoundIndex(0) + + return cns +} + +func initConsensusStateWithKeysHandlerWithGroupSize(keysHandler consensus.KeysHandler, consensusGroupSize int) *spos.ConsensusState { + eligibleList := createEligibleList(consensusGroupSize) + + eligibleNodesPubKeys := make(map[string]struct{}) + for _, key := range eligibleList { + eligibleNodesPubKeys[key] = struct{}{} + } + + return createConsensusStateWithNodes(eligibleNodesPubKeys, eligibleList, eligibleList[0], keysHandler) +} + +func initConsensusStateWithKeysHandlerWithGroupSizeWithRealKeys(keysHandler consensus.KeysHandler, mapKeys map[string]crypto.PrivateKey) *spos.ConsensusState { + eligibleList := CreateEligibleListFromMap(mapKeys) + + eligibleNodesPubKeys := make(map[string]struct{}, len(eligibleList)) + for _, key := range eligibleList { + eligibleNodesPubKeys[key] = struct{}{} + } + + return createConsensusStateWithNodes(eligibleNodesPubKeys, eligibleList, eligibleList[0], keysHandler) +} + +func createConsensusStateWithNodes(eligibleNodesPubKeys map[string]struct{}, consensusValidators []string, leader string, keysHandler consensus.KeysHandler) *spos.ConsensusState { + consensusGroupSize := len(consensusValidators) + rcns, _ := spos.NewRoundConsensus( + eligibleNodesPubKeys, + consensusGroupSize, + consensusValidators[1], + keysHandler, + ) + + rcns.SetConsensusGroup(consensusValidators) + rcns.SetLeader(leader) + rcns.ResetRoundState() + + pBFTThreshold := consensusGroupSize*2/3 + 1 + pBFTFallbackThreshold := consensusGroupSize*1/2 + 1 + + rthr := spos.NewRoundThreshold() + rthr.SetThreshold(1, 1) + rthr.SetThreshold(2, pBFTThreshold) + rthr.SetFallbackThreshold(1, 1) + rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) + + rstatus := spos.NewRoundStatus() + rstatus.ResetRoundStatus() + + cns := spos.NewConsensusState( + rcns, + rthr, + rstatus, + ) + + cns.Data = []byte("X") + cns.SetRoundIndex(0) + return cns +} diff --git a/consensus/mock/mockTestInitializer.go b/testscommon/consensus/mockTestInitializer.go similarity index 81% rename from consensus/mock/mockTestInitializer.go rename to testscommon/consensus/mockTestInitializer.go index 6fa62a5a49d..4cdd7174618 100644 --- a/consensus/mock/mockTestInitializer.go +++ b/testscommon/consensus/mockTestInitializer.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" @@ -7,12 +7,18 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" - consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochstartmock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/pool" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" ) @@ -119,14 +125,14 @@ func InitMultiSignerMock() *cryptoMocks.MultisignerMock { } // InitKeys - -func InitKeys() (*KeyGenMock, *PrivateKeyMock, *PublicKeyMock) { +func InitKeys() (*mock.KeyGenMock, *mock.PrivateKeyMock, *mock.PublicKeyMock) { toByteArrayMock := func() ([]byte, error) { return []byte("byteArray"), nil } - privKeyMock := &PrivateKeyMock{ + privKeyMock := &mock.PrivateKeyMock{ ToByteArrayMock: toByteArrayMock, } - pubKeyMock := &PublicKeyMock{ + pubKeyMock := &mock.PublicKeyMock{ ToByteArrayMock: toByteArrayMock, } privKeyFromByteArr := func(b []byte) (crypto.PrivateKey, error) { @@ -135,7 +141,7 @@ func InitKeys() (*KeyGenMock, *PrivateKeyMock, *PublicKeyMock) { pubKeyFromByteArr := func(b []byte) (crypto.PublicKey, error) { return pubKeyMock, nil } - keyGenMock := &KeyGenMock{ + keyGenMock := &mock.KeyGenMock{ PrivateKeyFromByteArrayMock: privKeyFromByteArr, PublicKeyFromByteArrayMock: pubKeyFromByteArr, } @@ -161,12 +167,14 @@ func InitConsensusCore() *ConsensusCoreMock { func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *ConsensusCoreMock { blockChain := &testscommon.ChainHandlerStub{ GetGenesisHeaderCalled: func() data.HeaderHandler { - return &block.Header{} + return &block.Header{ + RandSeed: []byte("randSeed"), + } }, } - marshalizerMock := MarshalizerMock{} + marshalizerMock := mock.MarshalizerMock{} blockProcessorMock := InitBlockProcessorMock(marshalizerMock) - bootstrapperMock := &BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} broadcastMessengerMock := &BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { return nil @@ -176,13 +184,14 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus chronologyHandlerMock := InitChronologyHandlerMock() hasherMock := &hashingMocks.HasherMock{} roundHandlerMock := &RoundHandlerMock{} - shardCoordinatorMock := ShardCoordinatorMock{} + shardCoordinatorMock := mock.ShardCoordinatorMock{} syncTimerMock := &SyncTimerMock{} validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { defaultSelectionChances := uint32(1) - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances), + leader := shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances) + return leader, []nodesCoordinator.Validator{ + leader, shardingMocks.NewValidatorMock([]byte("B"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("C"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("D"), 1, defaultSelectionChances), @@ -194,18 +203,20 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus }, nil }, } - epochStartSubscriber := &EpochStartNotifierStub{} - antifloodHandler := &P2PAntifloodHandlerStub{} - headerPoolSubscriber := &HeadersCacherStub{} + epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + antifloodHandler := &mock.P2PAntifloodHandlerStub{} + headerPoolSubscriber := &pool.HeadersPoolStub{} peerHonestyHandler := &testscommon.PeerHonestyHandlerStub{} - headerSigVerifier := &HeaderSigVerifierStub{} + headerSigVerifier := &HeaderSigVerifierMock{} fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{} - nodeRedundancyHandler := &NodeRedundancyHandlerStub{} - scheduledProcessor := &consensusMocks.ScheduledProcessorStub{} - messageSigningHandler := &MessageSigningHandlerStub{} - peerBlacklistHandler := &PeerBlacklistHandlerStub{} + nodeRedundancyHandler := &mock.NodeRedundancyHandlerStub{} + scheduledProcessor := &ScheduledProcessorStub{} + messageSigningHandler := &mock.MessageSigningHandlerStub{} + peerBlacklistHandler := &mock.PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signingHandler := &consensusMocks.SigningHandlerStub{} + signingHandler := &SigningHandlerStub{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + equivalentProofsPool := &dataRetriever.ProofsPoolMock{} container := &ConsensusCoreMock{ blockChain: blockChain, @@ -231,6 +242,8 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus messageSigningHandler: messageSigningHandler, peerBlacklistHandler: peerBlacklistHandler, signingHandler: signingHandler, + enableEpochsHandler: enableEpochsHandler, + equivalentProofsPool: equivalentProofsPool, } return container diff --git a/consensus/mock/rounderMock.go b/testscommon/consensus/rounderMock.go similarity index 98% rename from consensus/mock/rounderMock.go rename to testscommon/consensus/rounderMock.go index 6a0625932a1..bb463f38c33 100644 --- a/consensus/mock/rounderMock.go +++ b/testscommon/consensus/rounderMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" diff --git a/consensus/mock/sposWorkerMock.go b/testscommon/consensus/sposWorkerMock.go similarity index 68% rename from consensus/mock/sposWorkerMock.go rename to testscommon/consensus/sposWorkerMock.go index 0454370bedf..3a7e1ef384b 100644 --- a/consensus/mock/sposWorkerMock.go +++ b/testscommon/consensus/sposWorkerMock.go @@ -1,10 +1,11 @@ -package mock +package consensus import ( "context" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" ) @@ -16,6 +17,7 @@ type SposWorkerMock struct { receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool, ) AddReceivedHeaderHandlerCalled func(handler func(data.HeaderHandler)) + AddReceivedProofHandlerCalled func(handler func(proofHandler consensus.ProofHandler)) RemoveAllReceivedMessagesCallsCalled func() ProcessReceivedMessageCalled func(message p2p.MessageP2P) error SendConsensusMessageCalled func(cnsDta *consensus.Message) bool @@ -28,12 +30,15 @@ type SposWorkerMock struct { ReceivedHeaderCalled func(headerHandler data.HeaderHandler, headerHash []byte) SetAppStatusHandlerCalled func(ash core.AppStatusHandler) error ResetConsensusMessagesCalled func() + ReceivedProofCalled func(proofHandler consensus.ProofHandler) } // AddReceivedMessageCall - func (sposWorkerMock *SposWorkerMock) AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { - sposWorkerMock.AddReceivedMessageCallCalled(messageType, receivedMessageCall) + if sposWorkerMock.AddReceivedMessageCallCalled != nil { + sposWorkerMock.AddReceivedMessageCallCalled(messageType, receivedMessageCall) + } } // AddReceivedHeaderHandler - @@ -43,39 +48,64 @@ func (sposWorkerMock *SposWorkerMock) AddReceivedHeaderHandler(handler func(data } } +func (sposWorkerMock *SposWorkerMock) AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) { + if sposWorkerMock.AddReceivedProofHandlerCalled != nil { + sposWorkerMock.AddReceivedProofHandlerCalled(handler) + } +} + // RemoveAllReceivedMessagesCalls - func (sposWorkerMock *SposWorkerMock) RemoveAllReceivedMessagesCalls() { - sposWorkerMock.RemoveAllReceivedMessagesCallsCalled() + if sposWorkerMock.RemoveAllReceivedMessagesCallsCalled != nil { + sposWorkerMock.RemoveAllReceivedMessagesCallsCalled() + } } // ProcessReceivedMessage - func (sposWorkerMock *SposWorkerMock) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { - return sposWorkerMock.ProcessReceivedMessageCalled(message) + if sposWorkerMock.ProcessReceivedMessageCalled == nil { + return sposWorkerMock.ProcessReceivedMessageCalled(message) + } + return nil } // SendConsensusMessage - func (sposWorkerMock *SposWorkerMock) SendConsensusMessage(cnsDta *consensus.Message) bool { - return sposWorkerMock.SendConsensusMessageCalled(cnsDta) + if sposWorkerMock.SendConsensusMessageCalled != nil { + return sposWorkerMock.SendConsensusMessageCalled(cnsDta) + } + return false } // Extend - func (sposWorkerMock *SposWorkerMock) Extend(subroundId int) { - sposWorkerMock.ExtendCalled(subroundId) + if sposWorkerMock.ExtendCalled != nil { + sposWorkerMock.ExtendCalled(subroundId) + } } // GetConsensusStateChangedChannel - func (sposWorkerMock *SposWorkerMock) GetConsensusStateChangedChannel() chan bool { - return sposWorkerMock.GetConsensusStateChangedChannelsCalled() + if sposWorkerMock.GetConsensusStateChangedChannelsCalled != nil { + return sposWorkerMock.GetConsensusStateChangedChannelsCalled() + } + + return nil } // BroadcastBlock - func (sposWorkerMock *SposWorkerMock) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler) error { - return sposWorkerMock.GetBroadcastBlockCalled(body, header) + if sposWorkerMock.GetBroadcastBlockCalled != nil { + return sposWorkerMock.GetBroadcastBlockCalled(body, header) + } + return nil } // ExecuteStoredMessages - func (sposWorkerMock *SposWorkerMock) ExecuteStoredMessages() { - sposWorkerMock.ExecuteStoredMessagesCalled() + if sposWorkerMock.ExecuteStoredMessagesCalled != nil { + sposWorkerMock.ExecuteStoredMessagesCalled() + } } // DisplayStatistics - @@ -108,6 +138,13 @@ func (sposWorkerMock *SposWorkerMock) ResetConsensusMessages() { } } +// ReceivedProof - +func (sposWorkerMock *SposWorkerMock) ReceivedProof(proofHandler consensus.ProofHandler) { + if sposWorkerMock.ReceivedProofCalled != nil { + sposWorkerMock.ReceivedProofCalled(proofHandler) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (sposWorkerMock *SposWorkerMock) IsInterfaceNil() bool { return sposWorkerMock == nil diff --git a/consensus/mock/syncTimerMock.go b/testscommon/consensus/syncTimerMock.go similarity index 98% rename from consensus/mock/syncTimerMock.go rename to testscommon/consensus/syncTimerMock.go index 2fa41d42341..32b92bbe33b 100644 --- a/consensus/mock/syncTimerMock.go +++ b/testscommon/consensus/syncTimerMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" diff --git a/testscommon/dataRetriever/poolFactory.go b/testscommon/dataRetriever/poolFactory.go index a8f4374e800..43aaeb3e78f 100644 --- a/testscommon/dataRetriever/poolFactory.go +++ b/testscommon/dataRetriever/poolFactory.go @@ -6,10 +6,12 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/storage/cache" @@ -49,8 +51,7 @@ func CreateTxPool(numShards uint32, selfShard uint32) (dataRetriever.ShardedData ) } -// CreatePoolsHolder - -func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHolder { +func createPoolHolderArgs(numShards uint32, selfShard uint32) dataPool.DataPoolArgs { var err error txPool, err := CreateTxPool(numShards, selfShard) @@ -137,6 +138,8 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo }) panicIfError("CreatePoolsHolder", err) + proofsPool := proofscache.NewProofsPool() + currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currentEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() dataPoolArgs := dataPool.DataPoolArgs{ @@ -154,13 +157,37 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } + + return dataPoolArgs +} + +// CreatePoolsHolder - +func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHolder { + + dataPoolArgs := createPoolHolderArgs(numShards, selfShard) + holder, err := dataPool.NewDataPool(dataPoolArgs) panicIfError("CreatePoolsHolder", err) return holder } +// CreatePoolsHolderWithProofsPool - +func CreatePoolsHolderWithProofsPool( + numShards uint32, selfShard uint32, + proofsPool dataRetriever.ProofsPool, +) dataRetriever.PoolsHolder { + dataPoolArgs := createPoolHolderArgs(numShards, selfShard) + dataPoolArgs.Proofs = proofsPool + + holder, err := dataPool.NewDataPool(dataPoolArgs) + panicIfError("CreatePoolsHolderWithProofsPool", err) + + return holder +} + // CreatePoolsHolderWithTxPool - func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) dataRetriever.PoolsHolder { var err error @@ -221,6 +248,8 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) heartbeatPool, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) + proofsPool := proofscache.NewProofsPool() + currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currentEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() dataPoolArgs := dataPool.DataPoolArgs{ @@ -238,6 +267,7 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } holder, err := dataPool.NewDataPool(dataPoolArgs) panicIfError("CreatePoolsHolderWithTxPool", err) diff --git a/testscommon/dataRetriever/poolsHolderMock.go b/testscommon/dataRetriever/poolsHolderMock.go index d3d30562954..7e5cd64f5a4 100644 --- a/testscommon/dataRetriever/poolsHolderMock.go +++ b/testscommon/dataRetriever/poolsHolderMock.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/storage" @@ -33,6 +34,7 @@ type PoolsHolderMock struct { peerAuthentications storage.Cacher heartbeats storage.Cacher validatorsInfo dataRetriever.ShardedDataCacherNotifier + proofs dataRetriever.ProofsPool } // NewPoolsHolderMock - @@ -110,6 +112,8 @@ func NewPoolsHolderMock() *PoolsHolderMock { }) panicIfError("NewPoolsHolderMock", err) + holder.proofs = proofscache.NewProofsPool() + return holder } @@ -198,6 +202,11 @@ func (holder *PoolsHolderMock) ValidatorsInfo() dataRetriever.ShardedDataCacherN return holder.validatorsInfo } +// Proofs - +func (holder *PoolsHolderMock) Proofs() dataRetriever.ProofsPool { + return holder.proofs +} + // Close - func (holder *PoolsHolderMock) Close() error { var lastError error diff --git a/testscommon/dataRetriever/poolsHolderStub.go b/testscommon/dataRetriever/poolsHolderStub.go index 106c8b96bb5..7d9051d6f10 100644 --- a/testscommon/dataRetriever/poolsHolderStub.go +++ b/testscommon/dataRetriever/poolsHolderStub.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" ) // PoolsHolderStub - @@ -23,6 +24,7 @@ type PoolsHolderStub struct { PeerAuthenticationsCalled func() storage.Cacher HeartbeatsCalled func() storage.Cacher ValidatorsInfoCalled func() dataRetriever.ShardedDataCacherNotifier + ProofsCalled func() dataRetriever.ProofsPool CloseCalled func() error } @@ -73,7 +75,7 @@ func (holder *PoolsHolderStub) MiniBlocks() storage.Cacher { return holder.MiniBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // MetaBlocks - @@ -82,7 +84,7 @@ func (holder *PoolsHolderStub) MetaBlocks() storage.Cacher { return holder.MetaBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // CurrentBlockTxs - @@ -109,7 +111,7 @@ func (holder *PoolsHolderStub) TrieNodes() storage.Cacher { return holder.TrieNodesCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // TrieNodesChunks - @@ -118,7 +120,7 @@ func (holder *PoolsHolderStub) TrieNodesChunks() storage.Cacher { return holder.TrieNodesChunksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // PeerChangesBlocks - @@ -127,7 +129,7 @@ func (holder *PoolsHolderStub) PeerChangesBlocks() storage.Cacher { return holder.PeerChangesBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // SmartContracts - @@ -136,7 +138,7 @@ func (holder *PoolsHolderStub) SmartContracts() storage.Cacher { return holder.SmartContractsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // PeerAuthentications - @@ -145,7 +147,7 @@ func (holder *PoolsHolderStub) PeerAuthentications() storage.Cacher { return holder.PeerAuthenticationsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // Heartbeats - @@ -154,7 +156,7 @@ func (holder *PoolsHolderStub) Heartbeats() storage.Cacher { return holder.HeartbeatsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // ValidatorsInfo - @@ -166,6 +168,15 @@ func (holder *PoolsHolderStub) ValidatorsInfo() dataRetriever.ShardedDataCacherN return testscommon.NewShardedDataStub() } +// Proofs - +func (holder *PoolsHolderStub) Proofs() dataRetriever.ProofsPool { + if holder.ProofsCalled != nil { + return holder.ProofsCalled() + } + + return nil +} + // Close - func (holder *PoolsHolderStub) Close() error { if holder.CloseCalled != nil { diff --git a/testscommon/dataRetriever/proofsPoolMock.go b/testscommon/dataRetriever/proofsPoolMock.go new file mode 100644 index 00000000000..8154659a134 --- /dev/null +++ b/testscommon/dataRetriever/proofsPoolMock.go @@ -0,0 +1,55 @@ +package dataRetriever + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" +) + +// ProofsPoolMock - +type ProofsPoolMock struct { + AddProofCalled func(headerProof data.HeaderProofHandler) error + CleanupProofsBehindNonceCalled func(shardID uint32, nonce uint64) error + GetProofCalled func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + HasProofCalled func(shardID uint32, headerHash []byte) bool +} + +// AddProof - +func (p *ProofsPoolMock) AddProof(headerProof data.HeaderProofHandler) error { + if p.AddProofCalled != nil { + return p.AddProofCalled(headerProof) + } + + return nil +} + +// CleanupProofsBehindNonce - +func (p *ProofsPoolMock) CleanupProofsBehindNonce(shardID uint32, nonce uint64) error { + if p.CleanupProofsBehindNonceCalled != nil { + return p.CleanupProofsBehindNonceCalled(shardID, nonce) + } + + return nil +} + +// GetProof - +func (p *ProofsPoolMock) GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + if p.GetProofCalled != nil { + return p.GetProofCalled(shardID, headerHash) + } + + return &block.HeaderProof{}, nil +} + +// HasProof - +func (p *ProofsPoolMock) HasProof(shardID uint32, headerHash []byte) bool { + if p.HasProofCalled != nil { + return p.HasProofCalled(shardID, headerHash) + } + + return false +} + +// IsInterfaceNil - +func (p *ProofsPoolMock) IsInterfaceNil() bool { + return p == nil +} diff --git a/testscommon/epochstartmock/epochStartNotifierStub.go b/testscommon/epochstartmock/epochStartNotifierStub.go index d8a7bdceea3..2072ad30b5a 100644 --- a/testscommon/epochstartmock/epochStartNotifierStub.go +++ b/testscommon/epochstartmock/epochStartNotifierStub.go @@ -1,7 +1,8 @@ -package epochstartmock +package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) @@ -9,8 +10,9 @@ import ( type EpochStartNotifierStub struct { RegisterHandlerCalled func(handler epochStart.ActionHandler) UnregisterHandlerCalled func(handler epochStart.ActionHandler) - NotifyAllPrepareCalled func(hdr data.HeaderHandler, body data.BodyHandler, validatorInfoCacher epochStart.ValidatorInfoCacher) NotifyAllCalled func(hdr data.HeaderHandler) + NotifyAllPrepareCalled func(hdr data.HeaderHandler, body data.BodyHandler) + epochStartHdls []epochStart.ActionHandler } // RegisterHandler - @@ -18,6 +20,8 @@ func (esnm *EpochStartNotifierStub) RegisterHandler(handler epochStart.ActionHan if esnm.RegisterHandlerCalled != nil { esnm.RegisterHandlerCalled(handler) } + + esnm.epochStartHdls = append(esnm.epochStartHdls, handler) } // UnregisterHandler - @@ -25,12 +29,23 @@ func (esnm *EpochStartNotifierStub) UnregisterHandler(handler epochStart.ActionH if esnm.UnregisterHandlerCalled != nil { esnm.UnregisterHandlerCalled(handler) } + + for i, hdl := range esnm.epochStartHdls { + if hdl == handler { + esnm.epochStartHdls = append(esnm.epochStartHdls[:i], esnm.epochStartHdls[i+1:]...) + break + } + } } // NotifyAllPrepare - -func (esnm *EpochStartNotifierStub) NotifyAllPrepare(metaHdr data.HeaderHandler, body data.BodyHandler, validatorInfoCacher epochStart.ValidatorInfoCacher) { +func (esnm *EpochStartNotifierStub) NotifyAllPrepare(metaHdr data.HeaderHandler, body data.BodyHandler) { if esnm.NotifyAllPrepareCalled != nil { - esnm.NotifyAllPrepareCalled(metaHdr, body, validatorInfoCacher) + esnm.NotifyAllPrepareCalled(metaHdr, body) + } + + for _, hdl := range esnm.epochStartHdls { + hdl.EpochStartPrepare(metaHdr, body) } } @@ -39,6 +54,10 @@ func (esnm *EpochStartNotifierStub) NotifyAll(hdr data.HeaderHandler) { if esnm.NotifyAllCalled != nil { esnm.NotifyAllCalled(hdr) } + + for _, hdl := range esnm.epochStartHdls { + hdl.EpochStartAction(hdr) + } } // IsInterfaceNil - diff --git a/testscommon/fallbackHeaderValidatorStub.go b/testscommon/fallbackHeaderValidatorStub.go index b769aa94976..2ba582c7118 100644 --- a/testscommon/fallbackHeaderValidatorStub.go +++ b/testscommon/fallbackHeaderValidatorStub.go @@ -6,7 +6,16 @@ import ( // FallBackHeaderValidatorStub - type FallBackHeaderValidatorStub struct { - ShouldApplyFallbackValidationCalled func(headerHandler data.HeaderHandler) bool + ShouldApplyFallbackValidationCalled func(headerHandler data.HeaderHandler) bool + ShouldApplyFallbackValidationForHeaderWithCalled func(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool +} + +// ShouldApplyFallbackValidationForHeaderWith - +func (fhvs *FallBackHeaderValidatorStub) ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool { + if fhvs.ShouldApplyFallbackValidationForHeaderWithCalled != nil { + return fhvs.ShouldApplyFallbackValidationForHeaderWithCalled(shardID, startOfEpochBlock, round, prevHeaderHash) + } + return false } // ShouldApplyFallbackValidation - diff --git a/testscommon/generalConfig.go b/testscommon/generalConfig.go index 515c64518b4..f5777cfae6b 100644 --- a/testscommon/generalConfig.go +++ b/testscommon/generalConfig.go @@ -441,6 +441,10 @@ func GetGeneralConfig() config.Config { ResourceStats: config.ResourceStatsConfig{ RefreshIntervalInSec: 1, }, + InterceptedDataVerifier: config.InterceptedDataVerifierConfig{ + CacheSpanInSec: 1, + CacheExpiryInSec: 1, + }, } } diff --git a/testscommon/headerHandlerStub.go b/testscommon/headerHandlerStub.go index ab1d354ec60..00613c26d4d 100644 --- a/testscommon/headerHandlerStub.go +++ b/testscommon/headerHandlerStub.go @@ -38,6 +38,8 @@ type HeaderHandlerStub struct { SetRandSeedCalled func(seed []byte) error SetSignatureCalled func(signature []byte) error SetLeaderSignatureCalled func(signature []byte) error + GetPreviousProofCalled func() data.HeaderProofHandler + SetPreviousProofCalled func(proof data.HeaderProofHandler) } // GetAccumulatedFees - @@ -427,3 +429,19 @@ func (hhs *HeaderHandlerStub) SetBlockBodyTypeInt32(blockBodyType int32) error { return nil } + +// GetPreviousProof - +func (hhs *HeaderHandlerStub) GetPreviousProof() data.HeaderProofHandler { + if hhs.GetPreviousProofCalled != nil { + return hhs.GetPreviousProofCalled() + } + + return nil +} + +// SetPreviousProof - +func (hhs *HeaderHandlerStub) SetPreviousProof(proof data.HeaderProofHandler) { + if hhs.SetPreviousProofCalled != nil { + hhs.SetPreviousProofCalled(proof) + } +} diff --git a/testscommon/outport/outportStub.go b/testscommon/outport/outportStub.go index e9cd2649d3e..c6a2996036b 100644 --- a/testscommon/outport/outportStub.go +++ b/testscommon/outport/outportStub.go @@ -11,6 +11,7 @@ type OutportStub struct { SaveValidatorsRatingCalled func(validatorsRating *outportcore.ValidatorsRating) SaveValidatorsPubKeysCalled func(validatorsPubKeys *outportcore.ValidatorsPubKeys) HasDriversCalled func() bool + SaveRoundsInfoCalled func(roundsInfo *outportcore.RoundsInfo) } // SaveBlock - @@ -65,7 +66,10 @@ func (as *OutportStub) Close() error { } // SaveRoundsInfo - -func (as *OutportStub) SaveRoundsInfo(_ *outportcore.RoundsInfo) { +func (as *OutportStub) SaveRoundsInfo(roundsInfo *outportcore.RoundsInfo) { + if as.SaveRoundsInfoCalled != nil { + as.SaveRoundsInfoCalled(roundsInfo) + } } diff --git a/testscommon/shardedDataCacheNotifierMock.go b/testscommon/shardedDataCacheNotifierMock.go index d5af2000ab3..f6043415b08 100644 --- a/testscommon/shardedDataCacheNotifierMock.go +++ b/testscommon/shardedDataCacheNotifierMock.go @@ -4,7 +4,9 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/core/counting" + "github.com/multiversx/mx-chain-go/storage" + cacheMocks "github.com/multiversx/mx-chain-go/testscommon/cache" ) // ShardedDataCacheNotifierMock - @@ -31,7 +33,7 @@ func (mock *ShardedDataCacheNotifierMock) ShardDataStore(cacheId string) (c stor cache, found := mock.caches[cacheId] if !found { - cache = NewCacherMock() + cache = cacheMocks.NewCacherMock() mock.caches[cacheId] = cache } diff --git a/testscommon/shardingMocks/nodesCoordinatorMock.go b/testscommon/shardingMocks/nodesCoordinatorMock.go index 0343546364f..c7de88a268e 100644 --- a/testscommon/shardingMocks/nodesCoordinatorMock.go +++ b/testscommon/shardingMocks/nodesCoordinatorMock.go @@ -12,25 +12,26 @@ import ( // NodesCoordinatorMock defines the behaviour of a struct able to do validator group selection type NodesCoordinatorMock struct { - Validators map[uint32][]nodesCoordinator.Validator - ShardConsensusSize uint32 - MetaConsensusSize uint32 - ShardId uint32 - NbShards uint32 - GetSelectedPublicKeysCalled func(selection []byte, shardId uint32, epoch uint32) (publicKeys []string, err error) - GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - SetNodesPerShardsCalled func(nodes map[uint32][]nodesCoordinator.Validator, epoch uint32) error - ComputeValidatorsGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) - GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) - GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) - GetAllWaitingValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) - ConsensusGroupSizeCalled func(uint32, uint32) int - GetValidatorsIndexesCalled func(publicKeys []string, epoch uint32) ([]uint64, error) - GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) - GetAllShuffledOutValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) + Validators map[uint32][]nodesCoordinator.Validator + ShardConsensusSize uint32 + MetaConsensusSize uint32 + ShardId uint32 + NbShards uint32 + GetSelectedPublicKeysCalled func(selection []byte, shardId uint32, epoch uint32) (publicKeys []string, err error) + GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) + GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + SetNodesPerShardsCalled func(nodes map[uint32][]nodesCoordinator.Validator, epoch uint32) error + ComputeValidatorsGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) + GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) + GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShardCalled func(epoch uint32, shardID uint32) ([]string, error) + GetAllWaitingValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) + ConsensusGroupSizeCalled func(uint32, uint32) int + GetValidatorsIndexesCalled func(publicKeys []string, epoch uint32) ([]uint64, error) + GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) + GetAllShuffledOutValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) GetShuffledOutToAuctionValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) - GetNumTotalEligibleCalled func() uint64 + GetNumTotalEligibleCalled func() uint64 } // NewNodesCoordinatorMock - @@ -97,6 +98,14 @@ func (ncm *NodesCoordinatorMock) GetAllEligibleValidatorsPublicKeys(epoch uint32 return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (ncm *NodesCoordinatorMock) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + if ncm.GetAllEligibleValidatorsPublicKeysForShardCalled != nil { + return ncm.GetAllEligibleValidatorsPublicKeysForShardCalled(epoch, shardID) + } + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (ncm *NodesCoordinatorMock) GetAllWaitingValidatorsPublicKeys(_ uint32) (map[uint32][][]byte, error) { if ncm.GetAllWaitingValidatorsPublicKeysCalled != nil { @@ -156,14 +165,14 @@ func (ncm *NodesCoordinatorMock) GetConsensusValidatorsPublicKeys( round uint64, shardId uint32, epoch uint32, -) ([]string, error) { +) (string, []string, error) { if ncm.GetValidatorsPublicKeysCalled != nil { return ncm.GetValidatorsPublicKeysCalled(randomness, round, shardId, epoch) } - validators, err := ncm.ComputeConsensusGroup(randomness, round, shardId, epoch) + leader, validators, err := ncm.ComputeConsensusGroup(randomness, round, shardId, epoch) if err != nil { - return nil, err + return "", nil, err } valGrStr := make([]string, 0) @@ -172,7 +181,7 @@ func (ncm *NodesCoordinatorMock) GetConsensusValidatorsPublicKeys( valGrStr = append(valGrStr, string(v.PubKey())) } - return valGrStr, nil + return string(leader.PubKey()), valGrStr, nil } // SetNodesPerShards - @@ -205,7 +214,7 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( round uint64, shardId uint32, epoch uint32, -) ([]nodesCoordinator.Validator, error) { +) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { var consensusSize uint32 if ncm.ComputeValidatorsGroupCalled != nil { @@ -219,7 +228,7 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( } if randomess == nil { - return nil, nodesCoordinator.ErrNilRandomness + return nil, nil, nodesCoordinator.ErrNilRandomness } validatorsGroup := make([]nodesCoordinator.Validator, 0) @@ -228,7 +237,7 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( validatorsGroup = append(validatorsGroup, ncm.Validators[shardId][i]) } - return validatorsGroup, nil + return validatorsGroup[0], validatorsGroup, nil } // ConsensusGroupSizeForShardAndEpoch - diff --git a/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go b/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go new file mode 100644 index 00000000000..13c74dad98d --- /dev/null +++ b/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go @@ -0,0 +1,19 @@ +package nodesCoordinatorMocks + +// RandomSelectorMock is a mock for the RandomSelector interface +type RandomSelectorMock struct { + SelectCalled func(randSeed []byte, sampleSize uint32) ([]uint32, error) +} + +// Select calls the mocked method +func (rsm *RandomSelectorMock) Select(randSeed []byte, sampleSize uint32) ([]uint32, error) { + if rsm.SelectCalled != nil { + return rsm.SelectCalled(randSeed, sampleSize) + } + return nil, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (rsm *RandomSelectorMock) IsInterfaceNil() bool { + return rsm == nil +} diff --git a/testscommon/shardingMocks/nodesCoordinatorStub.go b/testscommon/shardingMocks/nodesCoordinatorStub.go index 4694676a9b0..9da3f317064 100644 --- a/testscommon/shardingMocks/nodesCoordinatorStub.go +++ b/testscommon/shardingMocks/nodesCoordinatorStub.go @@ -9,19 +9,22 @@ import ( // NodesCoordinatorStub - type NodesCoordinatorStub struct { - GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) - GetAllValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) - GetAllWaitingValidatorsPublicKeysCalled func(_ uint32) (map[uint32][][]byte, error) - GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) - ConsensusGroupSizeCalled func(shardID uint32, epoch uint32) int - ComputeConsensusGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) - EpochStartPrepareCalled func(metaHdr data.HeaderHandler, body data.BodyHandler) - GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) - GetOwnPublicKeyCalled func() []byte - GetWaitingEpochsLeftForPublicKeyCalled func(publicKey []byte) (uint32, error) - GetNumTotalEligibleCalled func() uint64 + GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) + GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) + GetAllValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) + GetAllWaitingValidatorsPublicKeysCalled func(_ uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShardCalled func(epoch uint32, shardID uint32) ([]string, error) + GetValidatorsIndexesCalled func(pubKeys []string, epoch uint32) ([]uint64, error) + ConsensusGroupSizeCalled func(shardID uint32, epoch uint32) int + ComputeConsensusGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) + EpochStartPrepareCalled func(metaHdr data.HeaderHandler, body data.BodyHandler) + GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) + GetOwnPublicKeyCalled func() []byte + GetWaitingEpochsLeftForPublicKeyCalled func(publicKey []byte) (uint32, error) + GetNumTotalEligibleCalled func() uint64 + ShardIdForEpochCalled func(epoch uint32) (uint32, error) } // NodesCoordinatorToRegistry - @@ -69,6 +72,14 @@ func (ncm *NodesCoordinatorStub) GetAllEligibleValidatorsPublicKeys(epoch uint32 return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (ncm *NodesCoordinatorStub) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + if ncm.GetAllEligibleValidatorsPublicKeysForShardCalled != nil { + return ncm.GetAllEligibleValidatorsPublicKeysForShardCalled(epoch, shardID) + } + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (ncm *NodesCoordinatorStub) GetAllWaitingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) { if ncm.GetAllWaitingValidatorsPublicKeysCalled != nil { @@ -106,7 +117,10 @@ func (ncm *NodesCoordinatorStub) GetAllValidatorsPublicKeys(_ uint32) (map[uint3 } // GetValidatorsIndexes - -func (ncm *NodesCoordinatorStub) GetValidatorsIndexes(_ []string, _ uint32) ([]uint64, error) { +func (ncm *NodesCoordinatorStub) GetValidatorsIndexes(pubkeys []string, epoch uint32) ([]uint64, error) { + if ncm.GetValidatorsIndexesCalled != nil { + return ncm.GetValidatorsIndexesCalled(pubkeys, epoch) + } return nil, nil } @@ -116,14 +130,12 @@ func (ncm *NodesCoordinatorStub) ComputeConsensusGroup( round uint64, shardId uint32, epoch uint32, -) (validatorsGroup []nodesCoordinator.Validator, err error) { +) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { if ncm.ComputeConsensusGroupCalled != nil { return ncm.ComputeConsensusGroupCalled(randomness, round, shardId, epoch) } - var list []nodesCoordinator.Validator - - return list, nil + return nil, nil, nil } // ConsensusGroupSizeForShardAndEpoch - @@ -140,12 +152,12 @@ func (ncm *NodesCoordinatorStub) GetConsensusValidatorsPublicKeys( round uint64, shardId uint32, epoch uint32, -) ([]string, error) { +) (string, []string, error) { if ncm.GetValidatorsPublicKeysCalled != nil { return ncm.GetValidatorsPublicKeysCalled(randomness, round, shardId, epoch) } - return nil, nil + return "", nil, nil } // SetNodesPerShards - @@ -165,8 +177,12 @@ func (ncm *NodesCoordinatorStub) GetSavedStateKey() []byte { // ShardIdForEpoch returns the nodesCoordinator configured ShardId for specified epoch if epoch configuration exists, // otherwise error -func (ncm *NodesCoordinatorStub) ShardIdForEpoch(_ uint32) (uint32, error) { - panic("not implemented") +func (ncm *NodesCoordinatorStub) ShardIdForEpoch(epoch uint32) (uint32, error) { + + if ncm.ShardIdForEpochCalled != nil { + return ncm.ShardIdForEpochCalled(epoch) + } + return 0, nil } // ShuffleOutForEpoch verifies if the shards changed in the new epoch and calls the shuffleOutHandler diff --git a/trie/sync_test.go b/trie/sync_test.go index ab5083eb85a..7d6c26b3ba5 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -10,14 +10,16 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/statistics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgument(timeout time.Duration) ArgTrieSyncer { @@ -32,7 +34,7 @@ func createMockArgument(timeout time.Duration) ArgTrieSyncer { return ArgTrieSyncer{ RequestHandler: &testscommon.RequestHandlerStub{}, - InterceptedNodes: testscommon.NewCacherMock(), + InterceptedNodes: cache.NewCacherMock(), DB: trieStorage, Hasher: &hashingMocks.HasherMock{}, Marshalizer: &marshallerMock.MarshalizerMock{}, diff --git a/update/factory/exportHandlerFactory.go b/update/factory/exportHandlerFactory.go index c13f25f3f5a..0cda7a5d2e0 100644 --- a/update/factory/exportHandlerFactory.go +++ b/update/factory/exportHandlerFactory.go @@ -8,6 +8,8 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -30,7 +32,6 @@ import ( "github.com/multiversx/mx-chain-go/update/genesis" "github.com/multiversx/mx-chain-go/update/storing" "github.com/multiversx/mx-chain-go/update/sync" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("update/factory") @@ -69,6 +70,7 @@ type ArgsExporter struct { TrieSyncerVersion int CheckNodesOnDisk bool NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type exportHandlerFactory struct { @@ -108,6 +110,7 @@ type exportHandlerFactory struct { trieSyncerVersion int checkNodesOnDisk bool nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewExportHandlerFactory creates an exporter factory @@ -266,6 +269,7 @@ func NewExportHandlerFactory(args ArgsExporter) (*exportHandlerFactory, error) { checkNodesOnDisk: args.CheckNodesOnDisk, statusCoreComponents: args.StatusCoreComponents, nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } return e, nil @@ -588,6 +592,7 @@ func (e *exportHandlerFactory) createInterceptors() error { FullArchiveInterceptorsContainer: e.fullArchiveInterceptorsContainer, AntifloodHandler: e.networkComponents.InputAntiFloodHandler(), NodeOperationMode: e.nodeOperationMode, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, } fullSyncInterceptors, err := NewFullSyncInterceptorsContainerFactory(argsInterceptors) if err != nil { diff --git a/update/factory/fullSyncInterceptors.go b/update/factory/fullSyncInterceptors.go index 0fe0298c4d6..ad8602e1e5b 100644 --- a/update/factory/fullSyncInterceptors.go +++ b/update/factory/fullSyncInterceptors.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -29,25 +30,26 @@ const numGoRoutines = 2000 // fullSyncInterceptorsContainerFactory will handle the creation the interceptors container for shards type fullSyncInterceptorsContainerFactory struct { - mainContainer process.InterceptorsContainer - fullArchiveContainer process.InterceptorsContainer - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - store dataRetriever.StorageService - dataPool dataRetriever.PoolsHolder - mainMessenger process.TopicHandler - fullArchiveMessenger process.TopicHandler - nodesCoordinator nodesCoordinator.NodesCoordinator - blockBlackList process.TimeCacher - argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory - globalThrottler process.InterceptorThrottler - maxTxNonceDeltaAllowed int - addressPubkeyConv core.PubkeyConverter - whiteListHandler update.WhiteListHandler - whiteListerVerifiedTxs update.WhiteListHandler - antifloodHandler process.P2PAntifloodHandler - preferredPeersHolder update.PreferredPeersHolderHandler - nodeOperationMode common.NodeOperation + mainContainer process.InterceptorsContainer + fullArchiveContainer process.InterceptorsContainer + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + store dataRetriever.StorageService + dataPool dataRetriever.PoolsHolder + mainMessenger process.TopicHandler + fullArchiveMessenger process.TopicHandler + nodesCoordinator nodesCoordinator.NodesCoordinator + blockBlackList process.TimeCacher + argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory + globalThrottler process.InterceptorThrottler + maxTxNonceDeltaAllowed int + addressPubkeyConv core.PubkeyConverter + whiteListHandler update.WhiteListHandler + whiteListerVerifiedTxs update.WhiteListHandler + antifloodHandler process.P2PAntifloodHandler + preferredPeersHolder update.PreferredPeersHolderHandler + nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // ArgsNewFullSyncInterceptorsContainerFactory holds the arguments needed for fullSyncInterceptorsContainerFactory @@ -75,6 +77,7 @@ type ArgsNewFullSyncInterceptorsContainerFactory struct { FullArchiveInterceptorsContainer process.InterceptorsContainer AntifloodHandler process.P2PAntifloodHandler NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewFullSyncInterceptorsContainerFactory is responsible for creating a new interceptors factory object @@ -132,6 +135,9 @@ func NewFullSyncInterceptorsContainerFactory( if check.IfNil(args.AntifloodHandler) { return nil, process.ErrNilAntifloodHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } argInterceptorFactory := &interceptorFactory.ArgInterceptedDataFactory{ CoreComponents: args.CoreComponents, @@ -164,8 +170,9 @@ func NewFullSyncInterceptorsContainerFactory( whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, antifloodHandler: args.AntifloodHandler, //TODO: inject the real peers holder once we have the peers mapping before epoch bootstrap finishes - preferredPeersHolder: disabled.NewPreferredPeersHolder(), - nodeOperationMode: args.NodeOperationMode, + preferredPeersHolder: disabled.NewPreferredPeersHolder(), + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } icf.globalThrottler, err = throttler.NewNumGoRoutinesThrottler(numGoRoutines) @@ -343,21 +350,28 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneShardHeaderIntercepto argProcessor := &processor.ArgHdrInterceptorProcessor{ Headers: ficf.dataPool.Headers(), BlockBlackList: ficf.blockBlackList, + Proofs: ficf.dataPool.Proofs(), } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), + Topic: topic, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -551,17 +565,23 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneTxInterceptor(topic s return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -586,17 +606,23 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneUnsignedTxInterceptor return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -621,17 +647,23 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneRewardTxInterceptor(t return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -694,16 +726,22 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneMiniBlocksInterceptor return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: txFactory, - Processor: txBlockBodyProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + DataFactory: txFactory, + Processor: txBlockBodyProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -727,23 +765,30 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateMetachainHeaderInterce argProcessor := &processor.ArgHdrInterceptorProcessor{ Headers: ficf.dataPool.Headers(), BlockBlackList: ficf.blockBlackList, + Proofs: ficf.dataPool.Proofs(), } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + //only one metachain header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -769,17 +814,23 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneTrieNodesInterceptor( return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: trieNodesFactory, - Processor: trieNodesProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + DataFactory: trieNodesFactory, + Processor: trieNodesProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -811,7 +862,6 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateRewardTxInterceptors() if err != nil { return err } - keys[int(idx)] = identifierScr interceptorSlice[int(idx)] = interceptor } diff --git a/update/sync/coordinator_test.go b/update/sync/coordinator_test.go index b56b2d8f99a..e5f3067dd33 100644 --- a/update/sync/coordinator_test.go +++ b/update/sync/coordinator_test.go @@ -11,18 +11,20 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/syncer" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) func createHeaderSyncHandler(retErr bool) update.HeaderSyncHandler { @@ -71,7 +73,7 @@ func createPendingMiniBlocksSyncHandler() update.EpochStartPendingMiniBlocksSync mb := &block.MiniBlock{TxHashes: [][]byte{txHash}} args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { return mb, true diff --git a/update/sync/syncMiniBlocks_test.go b/update/sync/syncMiniBlocks_test.go index 9fc8f96db1f..3f1c00a4773 100644 --- a/update/sync/syncMiniBlocks_test.go +++ b/update/sync/syncMiniBlocks_test.go @@ -10,19 +10,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) func createMockArgsPendingMiniBlock() ArgsNewPendingMiniBlocksSyncer { return ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, }, Marshalizer: &mock.MarshalizerFake{}, @@ -93,7 +95,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPool(t *testing.T) { mb := &block.MiniBlock{} args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { miniBlockInPool = true @@ -147,7 +149,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolWithRewards(t *testing.T) } args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { miniBlockInPool = true @@ -223,7 +225,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolMissingTimeout(t *testing return nil, localErr }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false @@ -274,7 +276,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolReceive(t *testing.T) { return nil, localErr }, }, - Cache: testscommon.NewCacherMock(), + Cache: cache.NewCacherMock(), Marshalizer: &mock.MarshalizerFake{}, RequestHandler: &testscommon.RequestHandlerStub{}, } @@ -322,7 +324,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInStorageReceive(t *testing.T) return mbBytes, nil }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(_ func(_ []byte, _ interface{})) {}, PeekCalled: func(key []byte) (interface{}, bool) { return nil, false @@ -376,7 +378,7 @@ func TestSyncPendingMiniBlocksFromMeta_GetMiniBlocksShouldWork(t *testing.T) { return nil, localErr }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(_ func(_ []byte, _ interface{})) {}, PeekCalled: func(key []byte) (interface{}, bool) { return nil, false diff --git a/update/sync/syncTransactions_test.go b/update/sync/syncTransactions_test.go index aa087bcbbe2..95ead49717f 100644 --- a/update/sync/syncTransactions_test.go +++ b/update/sync/syncTransactions_test.go @@ -16,17 +16,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgs() ArgsNewTransactionsSyncer { @@ -529,7 +531,7 @@ func TestTransactionsSync_GetValidatorInfoFromPoolShouldWork(t *testing.T) { ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, txHash) { return nil, true @@ -690,7 +692,7 @@ func TestTransactionsSync_GetValidatorInfoFromPoolOrStorage(t *testing.T) { ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -852,7 +854,7 @@ func getDataPoolsWithShardValidatorInfoAndTxHash(svi *state.ShardValidatorInfo, ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, txHash) { return svi, true