From 7319423190de968d14398191c12765d48ba5722c Mon Sep 17 00:00:00 2001 From: skosito Date: Mon, 16 Sep 2024 23:25:34 +0200 Subject: [PATCH] refactor execute funcs --- v2/contracts/evm/GatewayEVM.sol | 68 +++++++++--------- v2/contracts/evm/interfaces/IGatewayEVM.sol | 4 +- v2/test/GatewayEVM.t.sol | 40 ++--------- v2/test/GatewayEVMZEVM.t.sol | 2 +- v2/test/utils/GatewayEVMUpgradeTest.sol | 76 +++++++++++++-------- v2/test/utils/ReceiverEVM.sol | 3 +- 6 files changed, 89 insertions(+), 104 deletions(-) diff --git a/v2/contracts/evm/GatewayEVM.sol b/v2/contracts/evm/GatewayEVM.sol index 53559ed2..e1a1a9a3 100644 --- a/v2/contracts/evm/GatewayEVM.sol +++ b/v2/contracts/evm/GatewayEVM.sol @@ -71,34 +71,6 @@ contract GatewayEVM is /// @param newImplementation Address of the new implementation. function _authorizeUpgrade(address newImplementation) internal override onlyRole(DEFAULT_ADMIN_ROLE) { } - /// @dev Internal function to execute an arbitrary call to a destination address. - /// @param destination Address to call. - /// @param data Calldata to pass to the call. - /// @return The result of the call. - function _executeArbitraryCall(address destination, bytes calldata data) internal returns (bytes memory) { - revertIfAuthenticatedCall(data); - (bool success, bytes memory result) = destination.call{ value: msg.value }(data); - if (!success) revert ExecutionFailed(); - - return result; - } - - /// @dev Internal function to execute an authenticated call to a destination address. - /// @param messageContext Message context containing sender and arbitrary call flag. - /// @param destination Address to call. - /// @param data Calldata to pass to the call. - /// @return The result of the call. - function _executeAuthenticatedCall( - MessageContext calldata messageContext, - address destination, - bytes calldata data - ) - internal - returns (bytes memory) - { - return Callable(destination).onCall(messageContext.sender, data); - } - /// @notice Pause contract. function pause() external onlyRole(PAUSER_ROLE) { _pause(); @@ -132,9 +104,9 @@ contract GatewayEVM is emit Reverted(destination, address(0), msg.value, data, revertContext); } - /// @notice Executes a call to a destination address without ERC20 tokens. + /// @notice Executes an authenticated call to a destination address without ERC20 tokens. /// @dev This function can only be called by the TSS address and it is payable. - /// @param messageContext Message context containing sender and arbitrary call flag. + /// @param messageContext Message context containing sender. /// @param destination Address to call. /// @param data Calldata to pass to the call. /// @return The result of the call. @@ -152,18 +124,14 @@ contract GatewayEVM is { if (destination == address(0)) revert ZeroAddress(); bytes memory result; - if (messageContext.isArbitraryCall) { - result = _executeArbitraryCall(destination, data); - } else { - result = _executeAuthenticatedCall(messageContext, destination, data); - } + result = _executeAuthenticatedCall(messageContext, destination, data); emit Executed(destination, msg.value, data); return result; } - /// @notice Executes a call to a destination address without ERC20 tokens. + /// @notice Executes an arbitrary call to a destination address without ERC20 tokens. /// @dev This function can only be called by the TSS address and it is payable. /// @param destination Address to call. /// @param data Calldata to pass to the call. @@ -433,6 +401,34 @@ contract GatewayEVM is } } + /// @dev Private function to execute an arbitrary call to a destination address. + /// @param destination Address to call. + /// @param data Calldata to pass to the call. + /// @return The result of the call. + function _executeArbitraryCall(address destination, bytes calldata data) private returns (bytes memory) { + revertIfAuthenticatedCall(data); + (bool success, bytes memory result) = destination.call{ value: msg.value }(data); + if (!success) revert ExecutionFailed(); + + return result; + } + + /// @dev Private function to execute an authenticated call to a destination address. + /// @param messageContext Message context containing sender and arbitrary call flag. + /// @param destination Address to call. + /// @param data Calldata to pass to the call. + /// @return The result of the call. + function _executeAuthenticatedCall( + MessageContext calldata messageContext, + address destination, + bytes calldata data + ) + private + returns (bytes memory) + { + return Callable(destination).onCall(messageContext, data); + } + // @dev prevent calling onCall function reserved for authenticated calls function revertIfAuthenticatedCall(bytes calldata data) private pure { if (data.length >= 4) { diff --git a/v2/contracts/evm/interfaces/IGatewayEVM.sol b/v2/contracts/evm/interfaces/IGatewayEVM.sol index 241e2951..d1f4dcb9 100644 --- a/v2/contracts/evm/interfaces/IGatewayEVM.sol +++ b/v2/contracts/evm/interfaces/IGatewayEVM.sol @@ -192,13 +192,11 @@ interface IGatewayEVM is IGatewayEVMErrors, IGatewayEVMEvents { /// @notice Message context passed to execute function. /// @param sender Sender from omnichain contract. -/// @param isArbitraryCall Indicates if call should be arbitrary or authenticated. struct MessageContext { address sender; - bool isArbitraryCall; } /// @notice Interface implemented by contracts receiving authenticated calls. interface Callable { - function onCall(address sender, bytes calldata message) external returns (bytes memory); + function onCall(MessageContext calldata context, bytes calldata message) external returns (bytes memory); } diff --git a/v2/test/GatewayEVM.t.sol b/v2/test/GatewayEVM.t.sol index 12640382..77163435 100644 --- a/v2/test/GatewayEVM.t.sol +++ b/v2/test/GatewayEVM.t.sol @@ -137,33 +137,13 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver gateway.execute(address(receiver), data); } - function testForwardCallToReceiveNonPayableUsingArbCall() public { - string[] memory str = new string[](1); - str[0] = "Hello, Foundry!"; - uint256[] memory num = new uint256[](1); - num[0] = 42; - bool flag = true; - - bytes memory data = abi.encodeWithSignature("receiveNonPayable(string[],uint256[],bool)", str, num, flag); - - vm.expectCall(address(receiver), 0, data); - vm.expectEmit(true, true, true, true, address(receiver)); - emit ReceivedNonPayable(address(gateway), str, num, flag); - vm.expectEmit(true, true, true, true, address(gateway)); - emit Executed(address(receiver), 0, data); - vm.prank(tssAddress); - gateway.execute(MessageContext({ sender: address(0x123), isArbitraryCall: true }), address(receiver), data); - } - function testForwardCallToReceiveOnCallUsingAuthCall() public { vm.expectEmit(true, true, true, true, address(receiver)); emit ReceivedOnCall(); vm.expectEmit(true, true, true, true, address(gateway)); emit Executed(address(receiver), 0, bytes("1")); vm.prank(tssAddress); - gateway.execute( - MessageContext({ sender: address(0x123), isArbitraryCall: false }), address(receiver), bytes("1") - ); + gateway.execute(MessageContext({ sender: address(0x123) }), address(receiver), bytes("1")); } function testForwardCallToReceiveNonPayableFailsIfSenderIsNotTSS() public { @@ -189,7 +169,7 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver vm.prank(owner); vm.expectRevert(abi.encodeWithSelector(AccessControlUnauthorizedAccount.selector, owner, TSS_ROLE)); - gateway.execute(MessageContext({ sender: address(0x123), isArbitraryCall: false }), address(receiver), data); + gateway.execute(MessageContext({ sender: address(0x123) }), address(receiver), data); } function testForwardCallToReceivePayable() public { @@ -224,20 +204,8 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver gateway.execute(address(receiver), data); } - function testForwardCallToReceiveNoParamsWithMsgContext() public { - bytes memory data = abi.encodeWithSignature("receiveNoParams()"); - - vm.expectCall(address(receiver), 0, data); - vm.expectEmit(true, true, true, true, address(receiver)); - emit ReceivedNoParams(address(gateway)); - vm.expectEmit(true, true, true, true, address(gateway)); - emit Executed(address(receiver), 0, data); - vm.prank(tssAddress); - gateway.execute(MessageContext({ sender: address(0x123), isArbitraryCall: true }), address(receiver), data); - } - function testForwardCallToReceiveOnCallFails() public { - bytes memory data = abi.encodeWithSignature("onCall(address,bytes)", address(123), bytes("")); + bytes memory data = abi.encodeWithSignature("onCall((address),bytes)", address(123), bytes("")); vm.prank(tssAddress); vm.expectRevert(NotAllowedToCallOnCall.selector); @@ -257,7 +225,7 @@ contract GatewayEVMTest is Test, IGatewayEVMErrors, IGatewayEVMEvents, IReceiver vm.prank(tssAddress); vm.expectRevert(ZeroAddress.selector); - gateway.execute(MessageContext({ sender: address(0x123), isArbitraryCall: true }), address(0), data); + gateway.execute(MessageContext({ sender: address(0x123) }), address(0), data); } function testForwardCallToReceiveNoParamsTogglePause() public { diff --git a/v2/test/GatewayEVMZEVM.t.sol b/v2/test/GatewayEVMZEVM.t.sol index c5380dea..f4af846f 100644 --- a/v2/test/GatewayEVMZEVM.t.sol +++ b/v2/test/GatewayEVMZEVM.t.sol @@ -13,7 +13,7 @@ import "./utils/TestERC20.sol"; import "./utils/SenderZEVM.sol"; -import "./utils/SystemContractMock.sol"; +import { SystemContractMock } from "./utils/SystemContractMock.sol"; import { GatewayZEVM } from "../contracts/zevm/GatewayZEVM.sol"; import { IGatewayZEVM } from "../contracts/zevm/GatewayZEVM.sol"; diff --git a/v2/test/utils/GatewayEVMUpgradeTest.sol b/v2/test/utils/GatewayEVMUpgradeTest.sol index 84150cc3..6db211f6 100644 --- a/v2/test/utils/GatewayEVMUpgradeTest.sol +++ b/v2/test/utils/GatewayEVMUpgradeTest.sol @@ -75,28 +75,6 @@ contract GatewayEVMUpgradeTest is /// @param newImplementation Address of the new implementation. function _authorizeUpgrade(address newImplementation) internal override onlyRole(DEFAULT_ADMIN_ROLE) { } - /// @dev Internal function to execute a call to a destination address. - /// @param destination Address to call. - /// @param data Calldata to pass to the call. - /// @return The result of the call. - function _executeArbitraryCall(address destination, bytes calldata data) internal returns (bytes memory) { - (bool success, bytes memory result) = destination.call{ value: msg.value }(data); - if (!success) revert ExecutionFailed(); - - return result; - } - - function _executeAuthenticatedCall( - MessageContext calldata messageContext, - address destination, - bytes calldata data - ) - internal - returns (bytes memory) - { - return Callable(destination).onCall(messageContext.sender, data); - } - /// @notice Pause contract. function pause() external onlyRole(PAUSER_ROLE) { _pause(); @@ -132,6 +110,7 @@ contract GatewayEVMUpgradeTest is /// @notice Executes a call to a destination address without ERC20 tokens. /// @dev This function can only be called by the TSS address and it is payable. + /// @param messageContext Message context containing sender and arbitrary call flag. /// @param destination Address to call. /// @param data Calldata to pass to the call. /// @return The result of the call. @@ -149,17 +128,18 @@ contract GatewayEVMUpgradeTest is { if (destination == address(0)) revert ZeroAddress(); bytes memory result; - if (messageContext.isArbitraryCall) { - result = _executeArbitraryCall(destination, data); - } else { - result = _executeAuthenticatedCall(messageContext, destination, data); - } + result = _executeAuthenticatedCall(messageContext, destination, data); emit Executed(destination, msg.value, data); return result; } + /// @notice Executes a call to a destination address without ERC20 tokens. + /// @dev This function can only be called by the TSS address and it is payable. + /// @param destination Address to call. + /// @param data Calldata to pass to the call. + /// @return The result of the call. function execute( address destination, bytes calldata data @@ -424,4 +404,46 @@ contract GatewayEVMUpgradeTest is IERC20(token).safeTransfer(custody, amount); } } + + /// @dev Internal function to execute an arbitrary call to a destination address. + /// @param destination Address to call. + /// @param data Calldata to pass to the call. + /// @return The result of the call. + function _executeArbitraryCall(address destination, bytes calldata data) internal returns (bytes memory) { + revertIfAuthenticatedCall(data); + (bool success, bytes memory result) = destination.call{ value: msg.value }(data); + if (!success) revert ExecutionFailed(); + + return result; + } + + /// @dev Internal function to execute an authenticated call to a destination address. + /// @param messageContext Message context containing sender and arbitrary call flag. + /// @param destination Address to call. + /// @param data Calldata to pass to the call. + /// @return The result of the call. + function _executeAuthenticatedCall( + MessageContext calldata messageContext, + address destination, + bytes calldata data + ) + internal + returns (bytes memory) + { + return Callable(destination).onCall(messageContext, data); + } + + // @dev prevent calling onCall function reserved for authenticated calls + function revertIfAuthenticatedCall(bytes calldata data) internal pure { + if (data.length >= 4) { + bytes4 functionSelector; + assembly { + functionSelector := calldataload(data.offset) + } + + if (functionSelector == Callable.onCall.selector) { + revert NotAllowedToCallOnCall(); + } + } + } } diff --git a/v2/test/utils/ReceiverEVM.sol b/v2/test/utils/ReceiverEVM.sol index 61557527..e8063551 100644 --- a/v2/test/utils/ReceiverEVM.sol +++ b/v2/test/utils/ReceiverEVM.sol @@ -2,6 +2,7 @@ pragma solidity 0.8.26; import { RevertContext } from "../../contracts/Revert.sol"; +import { MessageContext } from "../../contracts/evm/interfaces/IGatewayEVM.sol"; import "./IReceiverEVM.sol"; import "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; @@ -72,7 +73,7 @@ contract ReceiverEVM is IReceiverEVMEvents, ReentrancyGuard { emit ReceivedRevert(msg.sender, revertContext); } - function onCall(address sender, bytes calldata message) external returns (bytes memory) { + function onCall(MessageContext calldata messageContext, bytes calldata message) external returns (bytes memory) { emit ReceivedOnCall(); }