From 0bd4f7626d07504f129f2ae24a58902bc89f61bc Mon Sep 17 00:00:00 2001 From: Joaquin Gonzalez Date: Mon, 20 May 2024 11:59:03 -0300 Subject: [PATCH] feat: Add SimpleGuardianModule to SimplePlusAccount --- src/SimpleGuardianModule.sol | 87 +++++++++++++++++++++++ src/SimplePlusAccount.sol | 131 ++++++++++++++++++++--------------- test/SimplePlusAccount.t.sol | 25 +++++++ 3 files changed, 186 insertions(+), 57 deletions(-) create mode 100644 src/SimpleGuardianModule.sol diff --git a/src/SimpleGuardianModule.sol b/src/SimpleGuardianModule.sol new file mode 100644 index 0000000..3c840e5 --- /dev/null +++ b/src/SimpleGuardianModule.sol @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: GPL-3.0 +pragma solidity ^0.8.25; + +import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; +// import { console2 } from "forge-std/src/console2.sol"; + +abstract contract SimpleGuardianModule { + using ECDSA for bytes32; + + bytes32 public constant _RECOVER_TYPEHASH = + keccak256("Recover(address currentOwner, address newOwner, uint256 nonce)"); + + event NonceConsumed(address indexed owner, uint256 idx); + event GuardianUpdated(address indexed previousGuardian, address indexed newGuardian); + + address public guardian; + mapping(address => uint256) private _nonces; + + modifier onlyGuardian() { + require(msg.sender == guardian, "Not the guardian"); + _; + } + + /** + * @notice Retuns a nonce for a given address. + * @param from Address. + * @return uint256 Nonce Value. + */ + function getNonce(address from) external view virtual returns (uint256) { + return _nonces[from]; + } + + function _verifyAndConsumeNonce(address owner, uint256 nonde) internal virtual { + require(nonde == _nonces[owner]++, "invalid nonce"); + emit NonceConsumed(owner, nonde); + } + + function initGuardian(address newGuardian) external { + require(guardian == address(0)); + _updateGuardian(newGuardian); + } + + function updateGuardian(address newGuardian) external { + require(_onlyAuthorized(), "Not authorized"); + _updateGuardian(newGuardian); + } + + function _updateGuardian(address newGuardian) internal { + require( + newGuardian != address(0) && guardian != newGuardian && newGuardian != address(this), + "Invalid guardian address" + ); + address oldGuardian = guardian; + guardian = newGuardian; + emit GuardianUpdated(oldGuardian, newGuardian); + } + + function recoverAccount(address newOwner, uint256 nonce, bytes calldata signature) external { + require( + newOwner != address(0) && _owner() != newOwner && newOwner != address(this), "Invalid new owner address" + ); + + _verifyAndConsumeNonce(newOwner, nonce); + bytes32 structHash = keccak256(abi.encode(_RECOVER_TYPEHASH, _owner(), newOwner, nonce)); + bytes32 digest = _hashTypedDataV4(structHash); + + address recoveredAddress = digest.recover(signature); + + require(recoveredAddress == guardian, "Invalid guardian signature"); + + _transferOwnership(newOwner); + } + + function _transferOwnership(address newOwner) internal virtual; + + function _hashTypedDataV4(bytes32 structHash) internal view virtual returns (bytes32); + + function _onlyAuthorized() internal view virtual returns (bool); + + function _owner() internal view virtual returns (address); + + /** + * @dev This empty reserved space is put in place to allow future versions to add new + * variables without shifting down storage in the inheritance chain. + */ + uint256[49] private __gap; +} diff --git a/src/SimplePlusAccount.sol b/src/SimplePlusAccount.sol index 8169955..e2ea4b1 100644 --- a/src/SimplePlusAccount.sol +++ b/src/SimplePlusAccount.sol @@ -9,17 +9,14 @@ import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/Mes import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol"; import { EIP712 } from "@openzeppelin/contracts/utils/cryptography/EIP712.sol"; // TODO: use upgradable version +import { SimpleGuardianModule } from "./SimpleGuardianModule.sol"; +// import { console2 } from "forge-std/src/console2.sol"; -contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 { +contract SimplePlusAccount is SimpleAccount, SimpleGuardianModule, IERC1271, EIP712 { using ECDSA for bytes32; using MessageHashUtils for bytes32; - bytes32 internal constant _MESSAGE_TYPEHASH = keccak256("SimplePlusAccount(bytes message)"); - - modifier onlyAuthorized() { - _onlyAuthorized(); - _; - } + bytes32 public constant _MESSAGE_TYPEHASH = keccak256("SimplePlusAccount(bytes message)"); // @notice Signature types used for user operation validation and ERC-1271 signature validation. enum SignatureType { @@ -48,16 +45,18 @@ contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 { /// 1. The entry point /// 2. The account itself (when redirected through `execute`, etc.) /// 3. An owner - function _onlyAuthorized() internal view { + function _onlyAuthorized() internal view virtual override returns (bool) { if (msg.sender != address(entryPoint()) && msg.sender != address(this) && msg.sender != owner) { revert NotAuthorized(); } + return true; } /// @notice Transfers ownership of the contract to a new account (`newOwner`). Can only be called by the current /// owner or from the entry point via a user operation signed by the current owner. /// @param newOwner The new owner. - function transferOwnership(address newOwner) external onlyAuthorized { + function transferOwnership(address newOwner) public { + require(_onlyAuthorized()); if (newOwner == address(0) || newOwner == address(this) || owner == newOwner) { revert InvalidOwner(newOwner); } @@ -80,62 +79,80 @@ contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 { * "Ethereum Signed Message" envelope before checking the signature for the EOA-owner case. */ function isValidSignature(bytes32 hash, bytes calldata _signature) public view virtual returns (bytes4) { - if (_signature.length == 0) { - revert InvalidSignatureType(); + if (_signature.length == 0) { + revert InvalidSignatureType(); + } + + bytes32 structHash = keccak256(abi.encode(_MESSAGE_TYPEHASH, keccak256(abi.encode(hash)))); + bytes32 replaySafeHash = MessageHashUtils.toTypedDataHash(_domainSeparatorV4(), structHash); + + return _validateSignatureWithType(uint8(_signature[0]), replaySafeHash, _signature[1:]) + ? this.isValidSignature.selector + : bytes4(0xffffffff); } - bytes32 structHash = keccak256(abi.encode(_MESSAGE_TYPEHASH, keccak256(abi.encode(hash)))); - bytes32 replaySafeHash = MessageHashUtils.toTypedDataHash(_domainSeparatorV4(), structHash); + function _validateSignature( + PackedUserOperation calldata userOp, + bytes32 userOpHash + ) + internal + virtual + override + returns (uint256 validationData) + { + if (userOp.signature.length == 0) { + revert InvalidSignatureType(); + } - return _validateSignatureWithType(uint8(_signature[0]), replaySafeHash, _signature[1:]) - ? this.isValidSignature.selector - : bytes4(0xffffffff); -} + return _validateSignatureWithType( + uint8(userOp.signature[0]), userOpHash.toEthSignedMessageHash(), userOp.signature[1:] + ) ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED; + } + + function _validateSignatureWithType( + uint8 signatureType, + bytes32 hash, + bytes memory signature + ) + private + view + returns (bool) + { + if (signatureType == uint8(SignatureType.EOA)) { + return _validateEOASignature(hash, signature) == SIG_VALIDATION_SUCCESS; + } else if (signatureType == uint8(SignatureType.CONTRACT)) { + return _validateContractSignature(hash, signature) == SIG_VALIDATION_SUCCESS; + } else { + revert InvalidSignatureType(); + } + } -function _validateSignature( - PackedUserOperation calldata userOp, - bytes32 userOpHash -) - internal - virtual - override - returns (uint256 validationData) -{ - if (userOp.signature.length == 0) { - revert InvalidSignatureType(); + function _validateEOASignature(bytes32 hash, bytes memory signature) private view returns (uint256) { + address recovered = hash.recover(signature); + return recovered == owner ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED; } - return _validateSignatureWithType(uint8(userOp.signature[0]), userOpHash.toEthSignedMessageHash(), userOp.signature[1:]) - ? SIG_VALIDATION_SUCCESS - : SIG_VALIDATION_FAILED; -} + function _validateContractSignature(bytes32 userOpHash, bytes memory signature) private view returns (uint256) { + return SignatureChecker.isValidERC1271SignatureNow(owner, userOpHash, signature) + ? SIG_VALIDATION_SUCCESS + : SIG_VALIDATION_FAILED; + } -function _validateSignatureWithType( - uint8 signatureType, - bytes32 hash, - bytes memory signature -) - private - view - returns (bool) -{ - if (signatureType == uint8(SignatureType.EOA)) { - return _validateEOASignature(hash, signature) == SIG_VALIDATION_SUCCESS; - } else if (signatureType == uint8(SignatureType.CONTRACT)) { - return _validateContractSignature(hash, signature) == SIG_VALIDATION_SUCCESS; - } else { - revert InvalidSignatureType(); + function _transferOwnership(address newOwner) internal virtual override { + this.transferOwnership(newOwner); } -} -function _validateEOASignature(bytes32 hash, bytes memory signature) private view returns (uint256) { - address recovered = hash.recover(signature); - return recovered == owner ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED; -} + function _hashTypedDataV4(bytes32 structHash) + internal + view + virtual + override(EIP712, SimpleGuardianModule) + returns (bytes32) + { + return super._hashTypedDataV4(structHash); + } -function _validateContractSignature(bytes32 userOpHash, bytes memory signature) private view returns (uint256) { - return SignatureChecker.isValidERC1271SignatureNow(owner, userOpHash, signature) - ? SIG_VALIDATION_SUCCESS - : SIG_VALIDATION_FAILED; -} + function _owner() internal view virtual override returns (address) { + return owner; + } } diff --git a/test/SimplePlusAccount.t.sol b/test/SimplePlusAccount.t.sol index 65aba44..ae2e377 100644 --- a/test/SimplePlusAccount.t.sol +++ b/test/SimplePlusAccount.t.sol @@ -8,18 +8,25 @@ import { SimplePlusAccountFactory } from "../src/SimplePlusAccountFactory.sol"; import { EntryPoint } from "@account-abstraction/contracts/core/EntryPoint.sol"; import { SimpleAccount } from "@account-abstraction/contracts/samples/SimpleAccount.sol"; import { AccountTest } from "./AccountTest.sol"; +import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol"; +// import { console2 } from "forge-std/src/console2.sol"; contract SimplePlusAccountTest is AccountTest { uint256 public constant EOA_PRIVATE_KEY = 1; + uint256 public constant GUARDIAN_PRIVATE_KEY = 2; address payable public constant BENEFICIARY = payable(address(0xbe9ef1c1a2ee)); address public eoaAddress; + address public guardianAddress; + SimplePlusAccount public account; EntryPoint public entryPoint; event OwnershipTransferred(address indexed previousOwner, address indexed newOwner); + event GuardianUpdated(address indexed previousOwner, address indexed newOwner); function setUp() public { eoaAddress = vm.addr(EOA_PRIVATE_KEY); + guardianAddress = vm.addr(GUARDIAN_PRIVATE_KEY); entryPoint = new EntryPoint(); SimplePlusAccountFactory factory = new SimplePlusAccountFactory(entryPoint); account = factory.createAccount(eoaAddress, 1); @@ -97,6 +104,24 @@ contract SimplePlusAccountTest is AccountTest { assertEq(account.isValidSignature(message, signature), bytes4(keccak256("isValidSignature(bytes32,bytes)"))); } + function testGuardianCanTransferOwnership() public { + vm.prank(guardianAddress); + emit GuardianUpdated(address(0), guardianAddress); + account.initGuardian(guardianAddress); + uint256 nonce = account.getNonce(eoaAddress); + + address newOwner = address(0x100); + bytes32 structHash = keccak256(abi.encode(account._RECOVER_TYPEHASH(), eoaAddress, newOwner, nonce)); + bytes32 digest = MessageHashUtils.toTypedDataHash(domainSeparator(address(account)), structHash); + + bytes memory signature = sign(GUARDIAN_PRIVATE_KEY, digest); + + vm.expectEmit(true, true, false, false); + emit OwnershipTransferred(eoaAddress, newOwner); + account.recoverAccount(newOwner, nonce, signature); + assertEq(account.owner(), newOwner); + } + function _transferOwnership(address currentOwner, address newOwner) internal { vm.prank(currentOwner); vm.expectEmit(true, true, false, false);