diff --git a/README.md b/README.md index e65e6bd..4193b30 100644 --- a/README.md +++ b/README.md @@ -36,9 +36,9 @@ For example, if you are using a solc version newer than `0.8.19` and are plannin It is strongly recommended that you run the forge test suite of this SDK with your own compiler version to catch potential errors that stem from differences in compiler versions early. Yes, strictly speaking the Solidity version pragma should prevent these issues, but better to be safe than sorry, especially given that some components make extensive use of inline assembly. -**IERC20 Remapping** +**IERC20 and SafeERC20 Remapping** -This SDK comes with its own IERC20 interface. Given that projects tend to combine different SDKs, there's often this annoying issue of clashes of IERC20 interfaces, even though the are effectively the same. We handle this issue by importing `IERC20/IERC20.sol` which allows remapping the `IERC20/` prefix to whatever directory contains `IERC20.sol` in your project, thus providing an override mechanism that should allow dealing with this problem seamlessly until forge allows remapping of individual files. +This SDK comes with its own IERC20 interface and SafeERC20 implementation. Given that projects tend to combine different SDKs, there's often this annoying issue of clashes of IERC20 interfaces, even though they are effectively the same. We handle this issue by importing `IERC20/IERC20.sol` which allows remapping the `IERC20/` prefix to whatever directory contains `IERC20.sol` in your project, thus providing an override mechanism that should allow dealing with this problem seamlessly until forge allows remapping of individual files. The same approach is used for SafeERC20. ## Components diff --git a/foundry.toml b/foundry.toml index 2ab3aca..f553a94 100644 --- a/foundry.toml +++ b/foundry.toml @@ -16,6 +16,7 @@ remappings = [ "forge-std/=lib/forge-std/src/", "wormhole-sdk/=src/", "IERC20/=src/interfaces/token/", + "SafeERC20/=src/libraries/", ] verbosity = 3 diff --git a/src/Utils.sol b/src/Utils.sol index 63949f8..4b39b85 100644 --- a/src/Utils.sol +++ b/src/Utils.sol @@ -2,27 +2,6 @@ // SPDX-License-Identifier: Apache 2 pragma solidity ^0.8.19; -error NotAnEvmAddress(bytes32); - -function toUniversalAddress(address addr) pure returns (bytes32 universalAddr) { - universalAddr = bytes32(uint256(uint160(addr))); -} - -function fromUniversalAddress(bytes32 universalAddr) pure returns (address addr) { - if (bytes12(universalAddr) != 0) - revert NotAnEvmAddress(universalAddr); - - assembly ("memory-safe") { - addr := universalAddr - } -} - -/** - * Reverts with a given buffer data. - * Meant to be used to easily bubble up errors from low level calls when they fail. - */ -function reRevert(bytes memory err) pure { - assembly ("memory-safe") { - revert(add(err, 32), mload(err)) - } -} +import {tokenOrNativeTransfer} from "wormhole-sdk/utils/Transfer.sol"; +import {reRevert} from "wormhole-sdk/utils/Revert.sol"; +import {toUniversalAddress, fromUniversalAddress} from "wormhole-sdk/utils/UniversalAddress.sol"; diff --git a/src/components/dispatcher/AccessControl.sol b/src/components/dispatcher/AccessControl.sol new file mode 100644 index 0000000..2d617c5 --- /dev/null +++ b/src/components/dispatcher/AccessControl.sol @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import { + ACCESS_CONTROL_ID, + ACCESS_CONTROL_QUERIES_ID, + OWNER_ID, + PENDING_OWNER_ID, + IS_ADMIN_ID, + ADMINS_ID, + REVOKE_ADMIN_ID, + ADD_ADMIN_ID, + PROPOSE_OWNERSHIP_TRANSFER_ID, + ACQUIRE_OWNERSHIP_ID, + RELINQUISH_OWNERSHIP_ID +} from "wormhole-sdk/components/dispatcher/Ids.sol"; + +//rationale for different roles (owner, admin): +// * owner should be a mulit-sig / ultra cold wallet that is only activated in exceptional +// circumstances. +// * admin should also be either a cold wallet or Admin contract. In either case, +// the expectation is that multiple, slightly less trustworthy parties than the owner will +// have access to it, lowering trust assumptions and increasing attack surface. Admins +// perform rare but not exceptional operations. + +struct AccessControlState { + address owner; //puts owner address in eip1967 admin slot + address pendingOwner; + address[] admins; + mapping(address => uint256) isAdmin; +} + +// we use the designated eip1967 admin storage slot: +// keccak256("eip1967.proxy.admin") - 1 +bytes32 constant ACCESS_CONTROL_STORAGE_SLOT = + 0xb53127684a568b3173ae13b9f8a6016e243e63b6e8ee1178d6a717850b5d6103; + +function accessControlState() pure returns (AccessControlState storage state) { + assembly ("memory-safe") { state.slot := ACCESS_CONTROL_STORAGE_SLOT } +} + +error NotAuthorized(); +error InvalidAccessControlCommand(uint8 command); +error InvalidAccessControlQuery(uint8 query); + +event OwnerUpdated(address oldAddress, address newAddress, uint256 timestamp); +event AdminsUpdated(address addr, bool isAdmin, uint256 timestamp); + +enum Role { + None, + Owner, + Admin +} + +function failAuthIf(bool condition) pure { + if (condition) + revert NotAuthorized(); +} + +function senderAtLeastAdmin() view returns (Role) { + Role role = senderRole(); + failAuthIf(role == Role.None); + + return role; +} + +function senderRole() view returns (Role) { + AccessControlState storage state = accessControlState(); + if (msg.sender == state.owner) //check highest privilege level first + return Role.Owner; + + return state.isAdmin[msg.sender] != 0 ? Role.Admin : Role.None; +} + +abstract contract AccessControl { + using BytesParsing for bytes; + + // ---- construction ---- + + function _accessControlConstruction( + address owner, + address[] memory admins + ) internal { + accessControlState().owner = owner; + for (uint i = 0; i < admins.length; ++i) + _updateAdmins(admins[i], true); + } + + // ---- external ----- + + function transferOwnership(address newOwner) external { + AccessControlState storage state = accessControlState(); + failAuthIf(msg.sender != state.owner); + + state.pendingOwner = newOwner; + } + + function cancelOwnershipTransfer() external { + AccessControlState storage state = accessControlState(); + failAuthIf(msg.sender != state.owner); + + state.pendingOwner = address(0); + } + + function receiveOwnership() external { + _acquireOwnership(); + } + + // ---- internals ---- + + function dispatchExecAccessControl( + bytes calldata data, + uint offset, + uint8 command + ) internal returns (bool, uint) { + if (command == ACCESS_CONTROL_ID) + offset = _batchAccessControlCommands(data, offset); + else if (command == ACQUIRE_OWNERSHIP_ID) + _acquireOwnership(); + else + return (false, offset); + + return (true, offset); + } + + function dispatchQueryAccessControl( + bytes calldata data, + uint offset, + uint8 query + ) view internal returns (bool, bytes memory, uint) { + bytes memory result; + if (query == ACCESS_CONTROL_QUERIES_ID) + (result, offset) = _batchAccessControlQueries(data, offset); + else + return (false, new bytes(0), offset); + + return (true, result, offset); + } + + function _batchAccessControlCommands( + bytes calldata commands, + uint offset + ) internal returns (uint) { + AccessControlState storage state = accessControlState(); + bool isOwner = senderAtLeastAdmin() == Role.Owner; + + uint remainingCommands; + (remainingCommands, offset) = commands.asUint8CdUnchecked(offset); + for (uint i = 0; i < remainingCommands; ++i) { + uint8 command; + (command, offset) = commands.asUint8CdUnchecked(offset); + if (command == REVOKE_ADMIN_ID) { + address admin; + (admin, offset) = commands.asAddressCdUnchecked(offset); + _updateAdmins(admin, false); + } + else { + if (!isOwner) + revert NotAuthorized(); + + if (command == ADD_ADMIN_ID) { + address newAdmin; + (newAdmin, offset) = commands.asAddressCdUnchecked(offset); + + _updateAdmins(newAdmin, true); + } + else if (command == PROPOSE_OWNERSHIP_TRANSFER_ID) { + address newOwner; + (newOwner, offset) = commands.asAddressCdUnchecked(offset); + + state.pendingOwner = newOwner; + } + else if (command == RELINQUISH_OWNERSHIP_ID) { + _updateOwner(address(0)); + + //ownership relinquishment must be the last command in the batch + commands.checkLengthCd(offset); + } + else + revert InvalidAccessControlCommand(command); + } + } + return offset; + } + + function _batchAccessControlQueries( + bytes calldata queries, + uint offset + ) internal view returns (bytes memory, uint) { + AccessControlState storage state = accessControlState(); + bytes memory ret; + + uint remainingQueries; + (remainingQueries, offset) = queries.asUint8CdUnchecked(offset); + for (uint i = 0; i < remainingQueries; ++i) { + uint8 query; + (query, offset) = queries.asUint8CdUnchecked(offset); + + if (query == IS_ADMIN_ID) { + address admin; + (admin, offset) = queries.asAddressCdUnchecked(offset); + ret = abi.encodePacked(ret, state.isAdmin[admin] != 0); + } + else if (query == ADMINS_ID) { + ret = abi.encodePacked(ret, uint8(state.admins.length)); + for (uint j = 0; j < state.admins.length; ++j) + ret = abi.encodePacked(ret, state.admins[j]); + } + else { + address addr; + if (query == OWNER_ID) + addr = state.owner; + else if (query == PENDING_OWNER_ID) + addr = state.pendingOwner; + else + revert InvalidAccessControlQuery(query); + + ret = abi.encodePacked(ret, addr); + } + } + + return (ret, offset); + } + + function _acquireOwnership() internal { + AccessControlState storage state = accessControlState(); + if (msg.sender !=state.pendingOwner) + revert NotAuthorized(); + + state.pendingOwner = address(0); + _updateOwner(msg.sender); + } + + // ---- private ---- + + function _updateOwner(address newOwner) private { + address oldAddress; + accessControlState().owner = newOwner; + emit OwnerUpdated(oldAddress, newOwner, block.timestamp); + } + + function _updateAdmins(address admin, bool authorization) private { unchecked { + AccessControlState storage state = accessControlState(); + if ((state.isAdmin[admin] != 0) == authorization) + return; + + if (authorization) { + state.admins.push(admin); + state.isAdmin[admin] = state.admins.length; + } + else { + uint256 rawIndex = state.isAdmin[admin]; + if (rawIndex != state.admins.length) + state.admins[rawIndex - 1] = state.admins[state.admins.length - 1]; + + state.isAdmin[admin] = 0; + state.admins.pop(); + } + + emit AdminsUpdated(admin, authorization, block.timestamp); + }} +} diff --git a/src/components/dispatcher/Ids.sol b/src/components/dispatcher/Ids.sol new file mode 100644 index 0000000..2cfe31c --- /dev/null +++ b/src/components/dispatcher/Ids.sol @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +// ----------- Dispatcher Ids ----------- + +// Execute commands + +uint8 constant ACCESS_CONTROL_ID = 0x60; +uint8 constant ACQUIRE_OWNERSHIP_ID = 0x61; +uint8 constant UPGRADE_CONTRACT_ID = 0x62; +uint8 constant SWEEP_TOKENS_ID = 0x63; + +// Query commands + +uint8 constant ACCESS_CONTROL_QUERIES_ID = 0xe0; +uint8 constant IMPLEMENTATION_ID = 0xe1; + +// ----------- Access Control Ids ----------- + +// Execute commands + +//admin: +uint8 constant REVOKE_ADMIN_ID = 0x00; + +//owner only: +uint8 constant PROPOSE_OWNERSHIP_TRANSFER_ID = 0x10; +uint8 constant RELINQUISH_OWNERSHIP_ID = 0x11; +uint8 constant ADD_ADMIN_ID = 0x12; + +// Query commands + +uint8 constant OWNER_ID = 0x80; +uint8 constant PENDING_OWNER_ID = 0x81; +uint8 constant IS_ADMIN_ID = 0x82; +uint8 constant ADMINS_ID = 0x83; \ No newline at end of file diff --git a/src/components/dispatcher/SweepTokens.sol b/src/components/dispatcher/SweepTokens.sol new file mode 100644 index 0000000..780cae5 --- /dev/null +++ b/src/components/dispatcher/SweepTokens.sol @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {tokenOrNativeTransfer} from "wormhole-sdk/utils/Transfer.sol"; +import {senderAtLeastAdmin} from "wormhole-sdk/components/dispatcher/AccessControl.sol"; +import {SWEEP_TOKENS_ID} from "wormhole-sdk/components/dispatcher/Ids.sol"; + +abstract contract SweepTokens { + using BytesParsing for bytes; + + function dispatchExecSweepTokens( + bytes calldata data, + uint offset, + uint8 command + ) internal returns (bool, uint) { + return command == SWEEP_TOKENS_ID + ? (true, _sweepTokens(data, offset)) + : (false, offset); + } + + function _sweepTokens( + bytes calldata commands, + uint offset + ) internal returns (uint) { + senderAtLeastAdmin(); + + address token; + uint256 amount; + (token, offset) = commands.asAddressCdUnchecked(offset); + (amount, offset) = commands.asUint256CdUnchecked(offset); + + tokenOrNativeTransfer(token, msg.sender, amount); + return offset; + } +} \ No newline at end of file diff --git a/src/components/dispatcher/Upgrade.sol b/src/components/dispatcher/Upgrade.sol new file mode 100644 index 0000000..bba180c --- /dev/null +++ b/src/components/dispatcher/Upgrade.sol @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {ProxyBase} from "wormhole-sdk/proxy/ProxyBase.sol"; +import {Role, senderRole, failAuthIf} from "wormhole-sdk/components/dispatcher/AccessControl.sol"; +import {UPGRADE_CONTRACT_ID, IMPLEMENTATION_ID} from "wormhole-sdk/components/dispatcher/Ids.sol"; + +error InvalidGovernanceCommand(uint8 command); +error InvalidGovernanceQuery(uint8 query); + +abstract contract Upgrade is ProxyBase { + using BytesParsing for bytes; + + function dispatchExecUpgrade( + bytes calldata data, + uint offset, + uint8 command + ) internal returns (bool, uint) { + return (command == UPGRADE_CONTRACT_ID) + ? (true, _upgradeContract(data, offset)) + : (false, offset); + } + + function dispatchQueryUpgrade( + bytes calldata, + uint offset, + uint8 query + ) view internal returns (bool, bytes memory, uint) { + return query == IMPLEMENTATION_ID + ? (true, abi.encodePacked(_getImplementation()), offset) + : (false, new bytes(0), offset); + } + + function upgrade(address implementation, bytes calldata data) external { + failAuthIf(senderRole() != Role.Owner); + + _upgradeTo(implementation, data); + } + + function _upgradeContract( + bytes calldata commands, + uint offset + ) internal returns (uint) { + failAuthIf(senderRole() != Role.Owner); + + address newImplementation; + (newImplementation, offset) = commands.asAddressCdUnchecked(offset); + //contract upgrades must be the last command in the batch + commands.checkLengthCd(offset); + + _upgradeTo(newImplementation, new bytes(0)); + + return offset; + } +} diff --git a/src/constants/Common.sol b/src/constants/Common.sol index 27f19d9..f3ac5f5 100644 --- a/src/constants/Common.sol +++ b/src/constants/Common.sol @@ -1,11 +1,17 @@ // SPDX-License-Identifier: Apache 2 pragma solidity ^0.8.4; -uint256 constant FREE_MEMORY_PTR = 0x40; +// ┌──────────────────────────────────────────────────────────────────────────────┐ +// │ NOTE: We can't define e.g. WORD_SIZE_MINUS_ONE via WORD_SIZE - 1 because │ +// │ of solc restrictions on what constants can be used in inline assembly. │ +// └──────────────────────────────────────────────────────────────────────────────┘ + uint256 constant WORD_SIZE = 32; -//we can't define _WORD_SIZE_MINUS_ONE via _WORD_SIZE - 1 because of solc restrictions -// what constants can be used in inline assembly uint256 constant WORD_SIZE_MINUS_ONE = 31; //=0x1f=0b00011111 - //see section "prefer `< MAX + 1` over `<= MAX` for const comparison" in docs/Optimization.md uint256 constant WORD_SIZE_PLUS_ONE = 33; + +uint256 constant SCRATCH_SPACE_PTR = 0x00; +uint256 constant SCRATCH_SPACE_SIZE = 64; + +uint256 constant FREE_MEMORY_PTR = 0x40; \ No newline at end of file diff --git a/src/libraries/BytesParsing.sol b/src/libraries/BytesParsing.sol index 86a0aac..9f886af 100644 --- a/src/libraries/BytesParsing.sol +++ b/src/libraries/BytesParsing.sol @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache 2 pragma solidity ^0.8.4; -import "../constants/Common.sol"; +import "wormhole-sdk/constants/Common.sol"; //This file appears comically large, but all unused functions are removed by the compiler. library BytesParsing { diff --git a/src/libraries/SafeERC20.sol b/src/libraries/SafeERC20.sol new file mode 100644 index 0000000..dfa7e90 --- /dev/null +++ b/src/libraries/SafeERC20.sol @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.4; + +import {IERC20} from "IERC20/IERC20.sol"; +import {WORD_SIZE, SCRATCH_SPACE_PTR} from "wormhole-sdk/constants/Common.sol"; + +//Like OpenZeppelin's SafeERC20.sol, but slimmed down and more gas efficient. +// +//The main difference to OZ's implementation (besides the missing functions) is that we skip the +// EXTCODESIZE check that OZ does upon successful calls to ensure that an actual contract was +// called. The rationale for omitting this check is that ultimately the contract using the token +// has to verify that it "makes sense" for its use case regardless. Otherwise, a random token, or +// even just a contract that always returns true, could be passed, which makes this check +// superfluous in the final analysis. +// +//We also save on code size by not duplicating the assembly code in two separate functions. +// Otoh, we simply swallow revert reasons of failing token operations instead of bubbling them up. +// This is less clean and makes debugging harder, but is likely still a worthwhile trade-off +// given the cost in gas and code size. +library SafeERC20 { + error SafeERC20FailedOperation(address token); + + function safeTransfer(IERC20 token, address to, uint256 value) internal { + _revertOnFailure(token, abi.encodeCall(token.transfer, (to, value))); + } + + function safeTransferFrom(IERC20 token, address from, address to, uint256 value) internal { + _revertOnFailure(token, abi.encodeCall(token.transferFrom, (from, to, value))); + } + + function forceApprove(IERC20 token, address spender, uint256 value) internal { + bytes memory approveCall = abi.encodeCall(token.approve, (spender, value)); + + if (!_callWithOptionalReturnCheck(token, approveCall)) { + _revertOnFailure(token, abi.encodeCall(token.approve, (spender, 0))); + _revertOnFailure(token, approveCall); + } + } + + function _callWithOptionalReturnCheck( + IERC20 token, + bytes memory encodedCall + ) private returns (bool success) { + /// @solidity memory-safe-assembly + assembly { + mstore(SCRATCH_SPACE_PTR, 0) + success := call( //see https://www.evm.codes/?fork=cancun#f1 + gas(), //gas + token, //callee + 0, //value + add(encodedCall, WORD_SIZE), //input ptr + mload(encodedCall), //input size + SCRATCH_SPACE_PTR, //output ptr + WORD_SIZE //output size + ) + //calls to addresses without code are always successful + if success { + success := or(iszero(returndatasize()), mload(SCRATCH_SPACE_PTR)) + } + } + } + + function _revertOnFailure(IERC20 token, bytes memory encodedCall) private { + if (!_callWithOptionalReturnCheck(token, encodedCall)) + revert SafeERC20FailedOperation(address(token)); + } +} diff --git a/src/testing/UpgradeTester.sol b/src/testing/UpgradeTester.sol new file mode 100644 index 0000000..9fd318f --- /dev/null +++ b/src/testing/UpgradeTester.sol @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.24; + +import {ProxyBase} from "wormhole-sdk/proxy/ProxyBase.sol"; + +contract UpgradeTester is ProxyBase { + event Constructed(bytes data); + event Upgraded(bytes data); + + function upgradeTo(address newImplementation, bytes calldata data) external { + _upgradeTo(newImplementation, data); + } + + function getImplementation() external view returns (address) { + return _getImplementation(); + } + + function _proxyConstructor(bytes calldata data) internal override { + emit Constructed(data); + } + + function _contractUpgrade(bytes calldata data) internal override { + emit Upgraded(data); + } +} diff --git a/src/utils/Revert.sol b/src/utils/Revert.sol new file mode 100644 index 0000000..c5c2928 --- /dev/null +++ b/src/utils/Revert.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.19; + +import {WORD_SIZE} from "wormhole-sdk/constants/Common.sol"; + +//bubble up errors from low level calls +function reRevert(bytes memory err) pure { + assembly ("memory-safe") { + revert(add(err, WORD_SIZE), mload(err)) + } +} diff --git a/src/utils/Transfer.sol b/src/utils/Transfer.sol new file mode 100644 index 0000000..c1774b2 --- /dev/null +++ b/src/utils/Transfer.sol @@ -0,0 +1,23 @@ + +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.19; + +import {IERC20} from "IERC20/IERC20.sol"; +import {SafeERC20} from "SafeERC20/SafeERC20.sol"; + +error PaymentFailure(address target); + +//Note: Always forwards all gas, so consider gas griefing attack opportunities by the recipient. +//Note: Don't use this method if you need events for 0 amount transfers. +function tokenOrNativeTransfer(address tokenOrZeroForNative, address to, uint256 amount) { + if (amount == 0) + return; + + if (tokenOrZeroForNative == address(0)) { + (bool success, ) = to.call{value: amount}(new bytes(0)); + if (!success) + revert PaymentFailure(to); + } + else + SafeERC20.safeTransfer(IERC20(tokenOrZeroForNative), to, amount); +} diff --git a/src/utils/UniversalAddress.sol b/src/utils/UniversalAddress.sol new file mode 100644 index 0000000..eb607f1 --- /dev/null +++ b/src/utils/UniversalAddress.sol @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: Apache 2 +pragma solidity ^0.8.19; + +error NotAnEvmAddress(bytes32); + +function toUniversalAddress(address addr) pure returns (bytes32 universalAddr) { + universalAddr = bytes32(uint256(uint160(addr))); +} + +function fromUniversalAddress(bytes32 universalAddr) pure returns (address addr) { + if (bytes12(universalAddr) != 0) + revert NotAnEvmAddress(universalAddr); + + assembly ("memory-safe") { + addr := universalAddr + } +} diff --git a/test/components/dispatcher/AccessControl.t.sol b/test/components/dispatcher/AccessControl.t.sol new file mode 100644 index 0000000..1098a5a --- /dev/null +++ b/test/components/dispatcher/AccessControl.t.sol @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {AdminsUpdated, NotAuthorized} from "wormhole-sdk/components/dispatcher/AccessControl.sol"; +import { + ACCESS_CONTROL_ID, + ACCESS_CONTROL_QUERIES_ID, + OWNER_ID, + PENDING_OWNER_ID, + ACQUIRE_OWNERSHIP_ID, + IS_ADMIN_ID, + ADMINS_ID, + REVOKE_ADMIN_ID, + ADD_ADMIN_ID, + PROPOSE_OWNERSHIP_TRANSFER_ID, + RELINQUISH_OWNERSHIP_ID +} from "wormhole-sdk/components/dispatcher/Ids.sol"; +import {DispatcherTestBase} from "./utils/DispatcherTestBase.sol"; + +contract AcessControlTest is DispatcherTestBase { + using BytesParsing for bytes; + + function testCompleteOwnershipTransfer(address newOwner) public { + vm.assume(newOwner != address(this)); + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + PROPOSE_OWNERSHIP_TRANSFER_ID, + newOwner + ) + ); + + commandCount = 2; + bytes memory queries = abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ); + bytes memory getRes = invokeStaticDispatcher(queries); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, owner); + assertEq(pendingOwner_, newOwner); + + vm.prank(newOwner); + invokeDispatcher( + abi.encodePacked( + ACQUIRE_OWNERSHIP_ID + ) + ); + + getRes = invokeStaticDispatcher(queries); + (owner_, ) = getRes.asAddressUnchecked(0); + (pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, newOwner); + assertEq(pendingOwner_, address(0)); + } + + function testOwnershipTransfer_NotAuthorized() public { + uint8 commandCount = 1; + address newOwner = makeAddr("newOwner"); + + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + PROPOSE_OWNERSHIP_TRANSFER_ID, + newOwner + ) + ); + } + + function testOwnershipTransfer() public { + address newOwner = makeAddr("newOwner"); + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + PROPOSE_OWNERSHIP_TRANSFER_ID, + newOwner + ) + ); + + commandCount = 2; + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ) + ); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, owner); + assertEq(pendingOwner_, newOwner); + } + + + + + + function testAcquireOwnership_NotAuthorized() public { + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher( + abi.encodePacked( + ACQUIRE_OWNERSHIP_ID + ) + ); + } + + function testAcquireOwnership() public { + address newOwner = makeAddr("newOwner"); + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + PROPOSE_OWNERSHIP_TRANSFER_ID, + newOwner + ) + ); + + vm.prank(newOwner); + invokeDispatcher( + abi.encodePacked( + ACQUIRE_OWNERSHIP_ID + ) + ); + + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ) + ); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, newOwner); + assertEq(pendingOwner_, address(0)); + } + + function testExternalOwnershipTransfer_NotAuthorized() public { + address newOwner = makeAddr("newOwner"); + + vm.expectRevert(NotAuthorized.selector); + dispatcher.transferOwnership(newOwner); + } + + function testExternalOwnershipTransfer(address newOwner) public { + vm.prank(owner); + dispatcher.transferOwnership(newOwner); + + uint8 commandCount = 2; + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ) + ); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, owner); + assertEq(pendingOwner_, newOwner); + } + + function testExternalCancelOwnershipTransfer_NotAuthorized() public { + vm.expectRevert(NotAuthorized.selector); + dispatcher.cancelOwnershipTransfer(); + } + + function testExternalCancelOwnershipTransfer(address newOwner) public { + vm.prank(owner); + dispatcher.transferOwnership(newOwner); + + vm.prank(owner); + dispatcher.cancelOwnershipTransfer(); + + uint8 commandCount = 2; + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ) + ); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, owner); + assertEq(pendingOwner_, address(0)); + } + + function testExternalReceiveOwnership_NotAuthorized() public { + vm.expectRevert(NotAuthorized.selector); + dispatcher.receiveOwnership(); + } + + function testExternalReceiveOwnership(address newOwner) public { + vm.prank(owner); + dispatcher.transferOwnership(newOwner); + + vm.prank(newOwner); + dispatcher.receiveOwnership(); + + uint8 commandCount = 2; + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID + ) + ); + (address owner_, ) = getRes.asAddressUnchecked(0); + (address pendingOwner_, ) = getRes.asAddressUnchecked(20); + + assertEq(owner_, newOwner); + assertEq(pendingOwner_, address(0)); + } + + function testBatchAfterAcquire(address newOwner, address newAdmin) public { + vm.assume(newOwner != address(0)); + vm.assume(newOwner != owner); + vm.assume(newAdmin != address(0)); + vm.assume(newAdmin != admin); + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + PROPOSE_OWNERSHIP_TRANSFER_ID, + newOwner + ) + ); + + vm.prank(newOwner); + invokeDispatcher( + abi.encodePacked( + ACQUIRE_OWNERSHIP_ID, + ACCESS_CONTROL_ID, + commandCount, + ADD_ADMIN_ID, + newAdmin + ) + ); + + commandCount = 3; + bytes memory getRes = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID, + PENDING_OWNER_ID, + IS_ADMIN_ID, + newAdmin + ) + ); + uint offset = 0; + address owner_; + address pendingOwner_; + bool isAdmin; + (owner_, offset) = getRes.asAddressUnchecked(offset); + (pendingOwner_, offset) = getRes.asAddressUnchecked(offset); + (isAdmin, offset) = getRes.asBoolUnchecked(offset); + + assertEq(owner_, newOwner); + assertEq(pendingOwner_, address(0)); + assertEq(isAdmin, true); + } + + function testAddAdmin_NotAuthorized() public { + address newAdmin = makeAddr("newAdmin"); + uint8 commandCount = 1; + + bytes memory addAdminCommand = abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + ADD_ADMIN_ID, + newAdmin + ); + + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher(addAdminCommand); + + vm.expectRevert(NotAuthorized.selector); + vm.prank(admin); + invokeDispatcher(addAdminCommand); + } + + function testAddAdmin(address newAdmin) public { + vm.assume(newAdmin != admin); + uint8 commandCount = 1; + + vm.expectEmit(); + emit AdminsUpdated(newAdmin, true, block.timestamp); + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + ADD_ADMIN_ID, + newAdmin + ) + ); + + commandCount = 2; + bytes memory res = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + IS_ADMIN_ID, + newAdmin, + ADMINS_ID + ) + ); + + (bool isAdmin, ) = res.asBoolUnchecked(0); + (uint8 adminsCount, ) = res.asUint8Unchecked(1); + (address newAdmin_, ) = res.asAddressUnchecked(res.length - 20); + + assertEq(isAdmin, true); + assertEq(adminsCount, 2); + assertEq(newAdmin_, newAdmin); + } + + function testRevokeAdmin_NotAuthorized() public { + address newAdmin = makeAddr("newAdmin"); + uint8 commandCount = 1; + + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + REVOKE_ADMIN_ID, + newAdmin + ) + ); + } + + function testRevokeAdmin(address newAdmin) public { + vm.assume(newAdmin != admin); + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + ADD_ADMIN_ID, + newAdmin + ) + ); + + commandCount = 1; + + vm.expectEmit(); + emit AdminsUpdated(newAdmin, false, block.timestamp); + + vm.prank(admin); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + REVOKE_ADMIN_ID, + newAdmin + ) + ); + + commandCount = 2; + bytes memory res = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + IS_ADMIN_ID, + newAdmin, + ADMINS_ID + ) + ); + + (bool isAdmin, ) = res.asBoolUnchecked(0); + (uint8 adminsCount, ) = res.asUint8Unchecked(1); + + assertEq(isAdmin, false); + assertEq(adminsCount, 1); + } + + function testRelinquishAdministration() public { + uint8 commandCount = 1; + + vm.expectEmit(); + emit AdminsUpdated(admin, false, block.timestamp); + + vm.prank(admin); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + REVOKE_ADMIN_ID, + admin + ) + ); + + bool isAdmin; + (isAdmin, ) = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + IS_ADMIN_ID, + admin + ) + ).asBoolUnchecked(0); + + assertEq(isAdmin, false); + } + + function testRelinquishOwnership_NotAuthorized() public { + uint8 commandCount = 1; + bytes memory relinquishCommand = abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + RELINQUISH_OWNERSHIP_ID + ); + + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher(relinquishCommand); + + + vm.expectRevert(NotAuthorized.selector); + vm.prank(admin); + invokeDispatcher(relinquishCommand); + } + + function testRelinquishOwnership_LengthMismatch() public { + uint8 commandCount = 2; + + vm.prank(owner); + vm.expectRevert( + abi.encodeWithSelector(BytesParsing.LengthMismatch.selector, 4, 3) + ); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + RELINQUISH_OWNERSHIP_ID, + ADD_ADMIN_ID + ) + ); + } + + + function testRelinquishOwnership() public { + uint8 commandCount = 1; + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + ACCESS_CONTROL_ID, + commandCount, + RELINQUISH_OWNERSHIP_ID + ) + ); + + (address owner_, ) = invokeStaticDispatcher( + abi.encodePacked( + ACCESS_CONTROL_QUERIES_ID, + commandCount, + OWNER_ID + ) + ).asAddressUnchecked(0); + + assertEq(owner_, address(0)); + } +} diff --git a/test/components/dispatcher/SweepTokens.t.sol b/test/components/dispatcher/SweepTokens.t.sol new file mode 100644 index 0000000..e083306 --- /dev/null +++ b/test/components/dispatcher/SweepTokens.t.sol @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {SWEEP_TOKENS_ID} from "wormhole-sdk/components/dispatcher/Ids.sol"; +import {UpgradeTester} from "wormhole-sdk/testing/UpgradeTester.sol"; +import {ERC20Mock} from "wormhole-sdk/testing/ERC20Mock.sol"; +import {DispatcherTestBase} from "./utils/DispatcherTestBase.sol"; + +contract SweepTokensTest is DispatcherTestBase { + ERC20Mock token; + + function _setUp1() internal override { + token = new ERC20Mock("FakeToken", "FT"); + } + + function testSweepTokens_erc20() public { + uint tokenAmount = 1e6; + deal(address(token), address(dispatcher), tokenAmount); + + assertEq(token.balanceOf(owner), 0); + assertEq(token.balanceOf(address(dispatcher)), tokenAmount); + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + SWEEP_TOKENS_ID, address(token), tokenAmount + ) + ); + + assertEq(token.balanceOf(address(dispatcher)), 0); + assertEq(token.balanceOf(owner), tokenAmount); + } + + function testSweepTokens_eth() public { + uint ethAmount = 1 ether; + vm.deal(address(dispatcher), ethAmount); + uint ownerEthBalance = address(owner).balance; + assertEq(address(dispatcher).balance, ethAmount); + + vm.prank(owner); + invokeDispatcher( + abi.encodePacked( + SWEEP_TOKENS_ID, address(0), ethAmount + ) + ); + + assertEq(address(dispatcher).balance, 0); + assertEq(address(owner).balance, ownerEthBalance + ethAmount); + } +} diff --git a/test/components/dispatcher/Upgrade.t.sol b/test/components/dispatcher/Upgrade.t.sol new file mode 100644 index 0000000..6ffa59c --- /dev/null +++ b/test/components/dispatcher/Upgrade.t.sol @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {IdempotentUpgrade} from "wormhole-sdk/proxy/ProxyBase.sol"; +import {NotAuthorized} from "wormhole-sdk/components/dispatcher/AccessControl.sol"; +import { + UPGRADE_CONTRACT_ID, + IMPLEMENTATION_ID +} from "wormhole-sdk/components/dispatcher/Ids.sol"; +import {UpgradeTester} from "wormhole-sdk/testing/UpgradeTester.sol"; +import {DispatcherTestBase} from "./utils/DispatcherTestBase.sol"; + +contract UpgradeTest is DispatcherTestBase { + using BytesParsing for bytes; + + function testContractUpgrade_NotAuthorized() public { + address fakeAddress = makeAddr("fakeAddress"); + + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher( + abi.encodePacked( + UPGRADE_CONTRACT_ID, + address(fakeAddress) + ) + ); + + vm.prank(admin); + vm.expectRevert(NotAuthorized.selector); + invokeDispatcher( + abi.encodePacked( + UPGRADE_CONTRACT_ID, + address(fakeAddress) + ) + ); + } + + function testContractUpgrade_IdempotentUpgrade() public { + UpgradeTester upgradeTester = new UpgradeTester(); + + vm.startPrank(owner); + invokeDispatcher( + abi.encodePacked( + UPGRADE_CONTRACT_ID, + address(upgradeTester) + ) + ); + + vm.expectRevert(IdempotentUpgrade.selector); + UpgradeTester(address(dispatcher)).upgradeTo(address(upgradeTester), new bytes(0)); + } + + function testContractUpgrade() public { + UpgradeTester upgradeTester = new UpgradeTester(); + + bytes memory response = invokeStaticDispatcher( + abi.encodePacked( + IMPLEMENTATION_ID + ) + ); + assertEq(response.length, 20); + (address implementation,) = response.asAddressUnchecked(0); + + vm.startPrank(owner); + invokeDispatcher( + abi.encodePacked( + UPGRADE_CONTRACT_ID, + address(upgradeTester) + ) + ); + + UpgradeTester(address(dispatcher)).upgradeTo(implementation, new bytes(0)); + + response = invokeStaticDispatcher( + abi.encodePacked( + IMPLEMENTATION_ID + ) + ); + assertEq(response.length, 20); + (address restoredImplementation,) = response.asAddressUnchecked(0); + assertEq(restoredImplementation, implementation); + } + + function testExternalContractUpgrade_NotAuthorized() public { + address fakeAddress = makeAddr("fakeAddress"); + + vm.expectRevert(NotAuthorized.selector); + dispatcher.upgrade(address(fakeAddress), new bytes(0)); + + vm.prank(admin); + vm.expectRevert(NotAuthorized.selector); + dispatcher.upgrade(address(fakeAddress), new bytes(0)); + } + + function testExternalContractUpgrade_IdempotentUpgrade() public { + UpgradeTester upgradeTester = new UpgradeTester(); + + vm.startPrank(owner); + dispatcher.upgrade(address(upgradeTester), new bytes(0)); + + vm.expectRevert(IdempotentUpgrade.selector); + UpgradeTester(address(dispatcher)).upgradeTo(address(upgradeTester), new bytes(0)); + } + + function testExternalContractUpgrade() public { + UpgradeTester upgradeTester = new UpgradeTester(); + + bytes memory response = invokeStaticDispatcher( + abi.encodePacked( + IMPLEMENTATION_ID + ) + ); + assertEq(response.length, 20); + (address implementation,) = response.asAddressUnchecked(0); + + vm.startPrank(owner); + dispatcher.upgrade(address(upgradeTester), new bytes(0)); + + UpgradeTester(address(dispatcher)).upgradeTo(implementation, new bytes(0)); + + response = invokeStaticDispatcher( + abi.encodePacked( + IMPLEMENTATION_ID + ) + ); + assertEq(response.length, 20); + (address restoredImplementation,) = response.asAddressUnchecked(0); + assertEq(restoredImplementation, implementation); + } +} diff --git a/test/components/dispatcher/utils/Dispatcher.sol b/test/components/dispatcher/utils/Dispatcher.sol new file mode 100644 index 0000000..973f80a --- /dev/null +++ b/test/components/dispatcher/utils/Dispatcher.sol @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {RawDispatcher} from "wormhole-sdk/RawDispatcher.sol"; +import {AccessControl} from "wormhole-sdk/components/dispatcher/AccessControl.sol"; +import {SweepTokens} from "wormhole-sdk/components/dispatcher/SweepTokens.sol"; +import {Upgrade} from "wormhole-sdk/components/dispatcher/Upgrade.sol"; + +contract Dispatcher is RawDispatcher, AccessControl, SweepTokens, Upgrade { + using BytesParsing for bytes; + + function _proxyConstructor(bytes calldata args) internal override { + uint offset = 0; + + address owner; + (owner, offset) = args.asAddressCdUnchecked(offset); + + uint8 adminCount; + (adminCount, offset) = args.asUint8CdUnchecked(offset); + address[] memory admins = new address[](adminCount); + for (uint i = 0; i < adminCount; ++i) { + (admins[i], offset) = args.asAddressCdUnchecked(offset); + } + args.checkLengthCd(offset); + + _accessControlConstruction(owner, admins); + } + + function _exec(bytes calldata data) internal override returns (bytes memory) { unchecked { + uint offset = 0; + + while (offset < data.length) { + uint8 command; + (command, offset) = data.asUint8CdUnchecked(offset); + + bool dispatched; + (dispatched, offset) = dispatchExecAccessControl(data, offset, command); + if (!dispatched) + (dispatched, offset) = dispatchExecUpgrade(data, offset, command); + if (!dispatched) + (dispatched, offset) = dispatchExecSweepTokens(data, offset, command); + } + + data.checkLengthCd(offset); + return new bytes(0); + }} + + function _get(bytes calldata data) internal view override returns (bytes memory) { unchecked { + bytes memory ret; + uint offset = 0; + + while (offset < data.length) { + uint8 query; + (query, offset) = data.asUint8CdUnchecked(offset); + + bytes memory result; + bool dispatched; + (dispatched, result, offset) = dispatchQueryAccessControl(data, offset, query); + if (!dispatched) + (dispatched, result, offset) = dispatchQueryUpgrade(data, offset, query); + + ret = abi.encodePacked(ret, result); + } + data.checkLengthCd(offset); + return ret; + }} +} \ No newline at end of file diff --git a/test/components/dispatcher/utils/DispatcherTestBase.sol b/test/components/dispatcher/utils/DispatcherTestBase.sol new file mode 100644 index 0000000..38dfe4e --- /dev/null +++ b/test/components/dispatcher/utils/DispatcherTestBase.sol @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache 2 + +pragma solidity ^0.8.4; + +import "forge-std/Test.sol"; + +import {BytesParsing} from "wormhole-sdk/libraries/BytesParsing.sol"; +import {Proxy} from "wormhole-sdk/proxy/Proxy.sol"; +import {reRevert} from "wormhole-sdk/Utils.sol"; +import {Dispatcher} from "./Dispatcher.sol"; + +contract DispatcherTestBase is Test { + using BytesParsing for bytes; + + address immutable owner; + address immutable admin; + + address dispatcherImplementation; + Dispatcher dispatcher; + + constructor() { + owner = makeAddr("owner"); + admin = makeAddr("admin"); + } + + function _setUp1() internal virtual { } + + function setUp() public { + uint8 adminCount = 1; + + dispatcherImplementation = address(new Dispatcher()); + dispatcher = Dispatcher(address(new Proxy( + dispatcherImplementation, + abi.encodePacked( + owner, + adminCount, + admin + ) + ))); + + _setUp1(); + } + + function invokeStaticDispatcher(bytes memory messages) view internal returns (bytes memory data) { + bytes memory getCall = abi.encodePacked(dispatcher.get1959.selector, messages); + (bool success, bytes memory result) = address(dispatcher).staticcall(getCall); + return decodeBytesResult(success, result); + } + + function invokeDispatcher(bytes memory messages) internal returns (bytes memory data) { + return invokeDispatcher(messages, 0); + } + + function invokeDispatcher(bytes memory messages, uint value) internal returns (bytes memory data) { + bytes memory execCall = abi.encodePacked(dispatcher.exec768.selector, messages); + (bool success, bytes memory result) = address(dispatcher).call{value: value}(execCall); + return decodeBytesResult(success, result); + } + + function decodeBytesResult(bool success, bytes memory result) pure private returns (bytes memory data) { + if (!success) { + reRevert(result); + } + data = abi.decode(result, (bytes)); + } +} \ No newline at end of file