From 3adcfabb413329359f730c54d192e1d2328fe33c Mon Sep 17 00:00:00 2001 From: easyfold <137396765+easyfold@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:06:07 +0800 Subject: [PATCH] core/systemcontracts: use vm.StateDB in UpgradeBuildInSystemContract (#2578) --- core/systemcontracts/upgrade.go | 12 ++++++----- core/systemcontracts/upgrade_test.go | 32 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/core/systemcontracts/upgrade.go b/core/systemcontracts/upgrade.go index 714aa16115..83a2aa491a 100644 --- a/core/systemcontracts/upgrade.go +++ b/core/systemcontracts/upgrade.go @@ -4,10 +4,10 @@ import ( "encoding/hex" "fmt" "math/big" + "reflect" "strings" "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/systemcontracts/bohr" "github.com/ethereum/go-ethereum/core/systemcontracts/bruno" "github.com/ethereum/go-ethereum/core/systemcontracts/euler" @@ -23,6 +23,7 @@ import ( "github.com/ethereum/go-ethereum/core/systemcontracts/planck" "github.com/ethereum/go-ethereum/core/systemcontracts/plato" "github.com/ethereum/go-ethereum/core/systemcontracts/ramanujan" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" ) @@ -40,7 +41,7 @@ type Upgrade struct { Configs []*UpgradeConfig } -type upgradeHook func(blockNumber *big.Int, contractAddr common.Address, statedb *state.StateDB) error +type upgradeHook func(blockNumber *big.Int, contractAddr common.Address, statedb vm.StateDB) error const ( mainNet = "Mainnet" @@ -789,10 +790,11 @@ func init() { } } -func UpgradeBuildInSystemContract(config *params.ChainConfig, blockNumber *big.Int, lastBlockTime uint64, blockTime uint64, statedb *state.StateDB) { - if config == nil || blockNumber == nil || statedb == nil { +func UpgradeBuildInSystemContract(config *params.ChainConfig, blockNumber *big.Int, lastBlockTime uint64, blockTime uint64, statedb vm.StateDB) { + if config == nil || blockNumber == nil || statedb == nil || reflect.ValueOf(statedb).IsNil() { return } + var network string switch GenesisHash { /* Add mainnet genesis hash */ @@ -876,7 +878,7 @@ func UpgradeBuildInSystemContract(config *params.ChainConfig, blockNumber *big.I */ } -func applySystemContractUpgrade(upgrade *Upgrade, blockNumber *big.Int, statedb *state.StateDB, logger log.Logger) { +func applySystemContractUpgrade(upgrade *Upgrade, blockNumber *big.Int, statedb vm.StateDB, logger log.Logger) { if upgrade == nil { logger.Info("Empty upgrade config", "height", blockNumber.String()) return diff --git a/core/systemcontracts/upgrade_test.go b/core/systemcontracts/upgrade_test.go index 1d8270fd94..3f88d7687b 100644 --- a/core/systemcontracts/upgrade_test.go +++ b/core/systemcontracts/upgrade_test.go @@ -2,9 +2,13 @@ package systemcontracts import ( "crypto/sha256" + "math/big" "testing" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/params" "github.com/stretchr/testify/require" ) @@ -39,3 +43,31 @@ func TestAllCodesHash(t *testing.T) { allCodeHash := sha256.Sum256(allCodes) require.Equal(t, allCodeHash[:], common.Hex2Bytes("833cc0fc87c46ad8a223e44ccfdc16a51a7e7383525136441bd0c730f06023df")) } + +func TestUpgradeBuildInSystemContractNilInterface(t *testing.T) { + var ( + config = params.BSCChainConfig + blockNumber = big.NewInt(37959559) + lastBlockTime uint64 = 1713419337 + blockTime uint64 = 1713419340 + statedb vm.StateDB + ) + + GenesisHash = params.BSCGenesisHash + + UpgradeBuildInSystemContract(config, blockNumber, lastBlockTime, blockTime, statedb) +} + +func TestUpgradeBuildInSystemContractNilValue(t *testing.T) { + var ( + config = params.BSCChainConfig + blockNumber = big.NewInt(37959559) + lastBlockTime uint64 = 1713419337 + blockTime uint64 = 1713419340 + statedb vm.StateDB = (*state.StateDB)(nil) + ) + + GenesisHash = params.BSCGenesisHash + + UpgradeBuildInSystemContract(config, blockNumber, lastBlockTime, blockTime, statedb) +}