diff --git a/precompiles/types/address.go b/precompiles/types/address.go index d587b550fa..5e6654cc5b 100644 --- a/precompiles/types/address.go +++ b/precompiles/types/address.go @@ -1,6 +1,8 @@ package types import ( + "errors" + sdk "github.com/cosmos/cosmos-sdk/types" bank "github.com/cosmos/cosmos-sdk/x/bank/keeper" "github.com/ethereum/go-ethereum/common" @@ -12,6 +14,10 @@ import ( // If contract.CallerAddress != evm.Origin is true, it means the call was made through a contract, // on which case there is a need to set the caller to the evm.Origin. func GetEVMCallerAddress(evm *vm.EVM, contract *vm.Contract) (common.Address, error) { + if evm == nil || contract == nil { + return common.Address{}, errors.New("invalid input: evm or contract is nil") + } + caller := contract.CallerAddress if contract.CallerAddress != evm.Origin { caller = evm.Origin diff --git a/precompiles/types/address_test.go b/precompiles/types/address_test.go index c329fcd6ab..801df4f8d4 100644 --- a/precompiles/types/address_test.go +++ b/precompiles/types/address_test.go @@ -11,9 +11,41 @@ import ( ) func Test_GetEVMCallerAddress(t *testing.T) { + t.Run("should raise error when evm is nil", func(t *testing.T) { + _, mockVMContract := setupMockEVMAndContract(common.Address{}) + caller, err := GetEVMCallerAddress(nil, &mockVMContract) + require.Error(t, err) + require.Equal(t, common.Address{}, caller, "address should be zeroed") + }) + + t.Run("should raise error when contract is nil", func(t *testing.T) { + mockEVM, _ := setupMockEVMAndContract(common.Address{}) + caller, err := GetEVMCallerAddress(&mockEVM, nil) + require.Error(t, err) + require.Equal(t, common.Address{}, caller, "address should be zeroed") + }) + + // When contract.CallerAddress == evm.Origin, caller is set to contract.CallerAddress. + t.Run("when caller address equals origin", func(t *testing.T) { + mockEVM, mockVMContract := setupMockEVMAndContract(common.Address{}) + caller, err := GetEVMCallerAddress(&mockEVM, &mockVMContract) + require.NoError(t, err) + require.Equal(t, common.Address{}, caller, "address should be the same") + }) + + // When contract.CallerAddress != evm.Origin, caller should be set to evm.Origin. + t.Run("when caller address equals origin", func(t *testing.T) { + mockEVM, mockVMContract := setupMockEVMAndContract(sample.EthAddress()) + caller, err := GetEVMCallerAddress(&mockEVM, &mockVMContract) + require.NoError(t, err) + require.Equal(t, mockEVM.Origin, caller, "address should be evm.Origin") + }) +} + +func setupMockEVMAndContract(address common.Address) (vm.EVM, vm.Contract) { mockEVM := vm.EVM{ TxContext: vm.TxContext{ - Origin: common.Address{}, + Origin: address, }, } @@ -24,16 +56,7 @@ func Test_GetEVMCallerAddress(t *testing.T) { 0, ) - // When contract.CallerAddress == evm.Origin, caller is set to contract.CallerAddress. - caller, err := GetEVMCallerAddress(&mockEVM, mockVMContract) - require.NoError(t, err) - require.Equal(t, common.Address{}, caller, "address shouldn be the same") - - // When contract.CallerAddress != evm.Origin, caller should be set to evm.Origin. - mockEVM.Origin = sample.EthAddress() - caller, err = GetEVMCallerAddress(&mockEVM, mockVMContract) - require.NoError(t, err) - require.Equal(t, mockEVM.Origin, caller, "address should be evm.Origin") + return mockEVM, *mockVMContract } type contractRef struct { diff --git a/precompiles/types/coin.go b/precompiles/types/coin.go index 16037e0a94..0c95eb108b 100644 --- a/precompiles/types/coin.go +++ b/precompiles/types/coin.go @@ -11,12 +11,18 @@ import ( const ZRC20DenomPrefix = "zrc20/" // ZRC20ToCosmosDenom returns the cosmos coin address for a given ZRC20 address. -// This is converted to "zevm/{ZRC20Address}". +// This is converted to "zrc20/{ZRC20Address}". func ZRC20ToCosmosDenom(ZRC20Address common.Address) string { return ZRC20DenomPrefix + ZRC20Address.String() } func CreateCoinSet(tokenDenom string, amount *big.Int) (sdk.Coins, error) { + defer func() { + if r := recover(); r != nil { + return + } + }() + coin := sdk.NewCoin(tokenDenom, math.NewIntFromBigInt(amount)) if !coin.IsValid() { return nil, &ErrInvalidCoin{