Skip to content

Commit

Permalink
Merge pull request #155 from rhinestonewtf/feat/unistall-storage-check
Browse files Browse the repository at this point in the history
feat(ModuleKitHelpers): add verifyModuleStorageWasCleared
  • Loading branch information
kopy-kat authored Nov 4, 2024
2 parents cf8df88 + a43ebe1 commit e6b8aa5
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 5 deletions.
115 changes: 114 additions & 1 deletion src/test/ModuleKitHelpers.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,23 @@ import { PackedUserOperation } from "../external/ERC4337.sol";
import { ERC4337Helpers } from "./utils/ERC4337Helpers.sol";
import { HelperBase } from "./helpers/HelperBase.sol";
import { Execution, MODULE_TYPE_HOOK } from "../external/ERC7579.sol";
import { prank } from "src/test/utils/Vm.sol";
import {
prank,
VmSafe,
startStateDiffRecording as vmStartStateDiffRecording,
stopAndReturnStateDiff as vmStopAndReturnStateDiff,
getMappingKeyAndParentOf,
envOr
} from "src/test/utils/Vm.sol";
import {
getAccountType as getAccountTypeFromStorage,
writeAccountType,
writeExpectRevert,
writeGasIdentifier,
writeSimulateUserOp,
writeStorageCompliance,
getStorageCompliance,
getSimulateUserOp,
writeAccountEnv,
getFactory,
getHelper as getHelperFromStorage,
Expand Down Expand Up @@ -140,6 +150,27 @@ library ModuleKitHelpers {
return exec(instance, target, 0, callData);
}

/*//////////////////////////////////////////////////////////////
HOOKS
//////////////////////////////////////////////////////////////*/

function preEnvHook() internal {
if (envOr("COMPLIANCE", false) || getStorageCompliance()) {
// Start state diff recording
vmStartStateDiffRecording();
}
}

function postEnvHook(AccountInstance memory instance, bytes memory data) internal {
if (envOr("COMPLIANCE", false) || getStorageCompliance()) {
address module = abi.decode(data, (address));
// Stop state diff recording and return account accesses
VmSafe.AccountAccess[] memory accountAccesses = vmStopAndReturnStateDiff();
// Check if storage was cleared
verifyModuleStorageWasCleared(instance, accountAccesses, module);
}
}

/*//////////////////////////////////////////////////////////////////////////
MODULE CONFIG
//////////////////////////////////////////////////////////////////////////*/
Expand All @@ -153,6 +184,8 @@ library ModuleKitHelpers {
internal
returns (UserOpData memory userOpData)
{
// Run preEnvHook
preEnvHook();
userOpData = instance.getInstallModuleOps(
moduleTypeId, module, data, address(instance.defaultValidator)
);
Expand Down Expand Up @@ -181,6 +214,8 @@ library ModuleKitHelpers {

// send userOp to entrypoint
userOpData.execUserOps();
// Run postEnvHook
postEnvHook(instance, abi.encode(module));
}

function isModuleInstalled(
Expand Down Expand Up @@ -323,6 +358,80 @@ library ModuleKitHelpers {
}
}
}

/// Start recording the state diff
function startStateDiffRecording(AccountInstance memory) internal {
vmStartStateDiffRecording();
}

/// Stop recording the state diff and return the account accesses
function stopAndReturnStateDiff(AccountInstance memory)
internal
returns (VmSafe.AccountAccess[] memory)
{
return vmStopAndReturnStateDiff();
}

/// Verifies from an accountAccesses array that storage was correctly cleared after uninstalling
/// a module
function verifyModuleStorageWasCleared(
AccountInstance memory,
VmSafe.AccountAccess[] memory accountAccesses,
address module
)
internal
view
{
bytes32[] memory seenSlots = new bytes32[](1000);
bytes32[] memory finalValues = new bytes32[](1000);
uint256 numSlots;

// Loop through account accesses
for (uint256 i; i < accountAccesses.length; i++) {
// Skip tests
if (accountAccesses[i].accessor == address(this)) {
continue;
}

// If we are accessing the storage of the module check writes and clears
if (accountAccesses[i].account == module) {
// Process all storage accesses for this module
for (uint256 j; j < accountAccesses[i].storageAccesses.length; j++) {
VmSafe.StorageAccess memory access = accountAccesses[i].storageAccesses[j];

// Skip reads
if (!access.isWrite) {
continue;
}

// Find if we've seen this slot
bool found;
for (uint256 k; k < numSlots; k++) {
if (seenSlots[k] == access.slot) {
finalValues[k] = access.newValue;
found = true;
break;
}
}

// If not seen, add it
if (!found) {
seenSlots[numSlots] = access.slot;
finalValues[numSlots] = access.newValue;
numSlots++;
}
}
}
}

// Check if any slot's final value is non-zero
for (uint256 i; i < numSlots; i++) {
if (finalValues[i] != bytes32(0)) {
revert("Storage not cleared after uninstalling module");
}
}
}

/*//////////////////////////////////////////////////////////////////////////
CONTROL FLOW
//////////////////////////////////////////////////////////////////////////*/
Expand Down Expand Up @@ -355,6 +464,10 @@ library ModuleKitHelpers {
writeSimulateUserOp(value);
}

function storageCompliance(AccountInstance memory, bool value) internal {
writeStorageCompliance(value);
}

/*//////////////////////////////////////////////////////////////////////////
ACCOUNT UTILS
//////////////////////////////////////////////////////////////////////////*/
Expand Down
12 changes: 12 additions & 0 deletions src/test/RhinestoneModuleKit.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
writeHelper
} from "./utils/Storage.sol";
import { ModuleKitHelpers } from "./ModuleKitHelpers.sol";
import { VmSafe } from "./utils/Vm.sol";

enum AccountType {
DEFAULT,
Expand Down Expand Up @@ -327,4 +328,15 @@ contract RhinestoneModuleKit is AuxiliaryFactory {
defaultSessionValidator: ISessionValidator(sessionValidator)
});
}

/*//////////////////////////////////////////////////////////////
STORAGE CLEARING
//////////////////////////////////////////////////////////////*/

modifier withModuleStorageClearValidation(AccountInstance memory instance, address module) {
instance.startStateDiffRecording();
_;
VmSafe.AccountAccess[] memory accountAccess = instance.stopAndReturnStateDiff();
instance.verifyModuleStorageWasCleared(accountAccess, module);
}
}
18 changes: 18 additions & 0 deletions src/test/utils/Storage.sol
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,24 @@ function getSimulateUserOp() view returns (bool value) {
}
}

/*//////////////////////////////////////////////////////////////
STORAGE COMPLIANCE
//////////////////////////////////////////////////////////////*/

function writeStorageCompliance(bool value) {
bytes32 slot = keccak256("ModuleKit.StorageCompliance");
assembly {
sstore(slot, value)
}
}

function getStorageCompliance() view returns (bool value) {
bytes32 slot = keccak256("ModuleKit.StorageCompliance");
assembly {
value := sload(slot)
}
}

/*//////////////////////////////////////////////////////////////
ACCOUNT ENV
//////////////////////////////////////////////////////////////*/
Expand Down
101 changes: 101 additions & 0 deletions test/Diff.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ import {
import { getAccountType, InstalledModule } from "src/test/utils/Storage.sol";
import { toString } from "src/test/utils/Vm.sol";
import { MockValidatorFalse } from "test/mocks/MockValidatorFalse.sol";
import { MockK1Validator, VALIDATION_SUCCESS } from "test/mocks/MockK1Validator.sol";
import { MockK1ValidatorUncompliantUninstall } from
"test/mocks/MockK1ValidatorUncompliantUninstall.sol";
import { VALIDATION_SUCCESS, VALIDATION_FAILED } from "erc7579/interfaces/IERC7579Module.sol";
import { VmSafe } from "src/test/utils/Vm.sol";

contract ERC7579DifferentialModuleKitLibTest is BaseTest {
using ModuleKitHelpers for *;
Expand All @@ -27,6 +32,7 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest {
MockTarget internal mockTarget;

MockERC20 internal token;
address module;

function setUp() public override {
super.setUp();
Expand Down Expand Up @@ -626,6 +632,101 @@ contract ERC7579DifferentialModuleKitLibTest is BaseTest {
}
}

function test_verifyModuleStorageWasCleared() public {
// Set simulate mode to false
instance.simulateUserOp(false);
// Install a module
address module = address(new MockK1Validator());
// Start state diff recording
instance.startStateDiffRecording();
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: module,
data: abi.encode(instance.account)
});
// Uninstall the module
instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" });
// Stop state diff recording
VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff();
// Assert that the module storage was cleared
instance.verifyModuleStorageWasCleared(accountAccesses, module);
}

function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared_UsingComplianceFlag()
public
{
// Set simulate mode to false
instance.simulateUserOp(false);
// Set compliance flag
instance.storageCompliance(true);

// Install a module
module = address(new MockK1ValidatorUncompliantUninstall());
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: module,
data: abi.encode(0xffffffffffffffffffff)
});
// Assert module storage
assertEq(
address(0xffffffffffffffffffff),
MockK1Validator(module).smartAccountOwners(address(instance.account))
);
// Expect revert
vm.expectRevert();
this.__revertWhen_verifyModuleStorageWasCleared_NotCleared();
}

function test_verifyModuleStorageWasCleared_RevertsWhen_NotCleared() public {
// Set simulate mode to false
instance.simulateUserOp(false);
// Install a module
module = address(new MockK1ValidatorUncompliantUninstall());
// Start state diff recording
instance.startStateDiffRecording();
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: module,
data: abi.encode(0xffffffffffffffffffff)
});
// Assert module storage
assertEq(
address(0xffffffffffffffffffff),
MockK1Validator(module).smartAccountOwners(address(instance.account))
);
// Uninstall the module
instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" });
// Stop state diff recording
VmSafe.AccountAccess[] memory accountAccesses = instance.stopAndReturnStateDiff();
// Expect revert
vm.expectRevert();
// Assert that the module storage was cleared
instance.verifyModuleStorageWasCleared(accountAccesses, module);
}

function __revertWhen_verifyModuleStorageWasCleared_NotCleared() public {
// Uninstall
instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" });
}

function test_withModuleStorageClearValidation()
public
withModuleStorageClearValidation(instance, module)
{
// Set simulate mode to false
instance.simulateUserOp(false);
// Install a module
module = address(new MockK1Validator());
// Install the module
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: module,
data: abi.encode(VALIDATION_FAILED)
});
// Uninstall the module
instance.uninstallModule({ moduleTypeId: MODULE_TYPE_VALIDATOR, module: module, data: "" });
}

/*//////////////////////////////////////////////////////////////
EXPECT REVERT
//////////////////////////////////////////////////////////////*/
Expand Down
4 changes: 2 additions & 2 deletions test/integrations/SmartSession.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ contract SmartSessionTest is BaseTest {
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: address(mockK1Validator),
data: abi.encodePacked(owner.addr)
data: abi.encode(owner.addr)
});

// Install smart session
Expand Down Expand Up @@ -288,7 +288,7 @@ contract SmartSessionTest is BaseTest {
instance.installModule({
moduleTypeId: MODULE_TYPE_VALIDATOR,
module: address(mockK1Validator),
data: abi.encodePacked(owner.addr)
data: abi.encode(owner.addr)
});

// Install smart session
Expand Down
2 changes: 1 addition & 1 deletion test/SwapTest.t.sol → test/integrations/SwapTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import "test/BaseTest.t.sol";
import "src/ModuleKit.sol";
import { ERC7579ExecutorBase } from "src/Modules.sol";
import { IERC20 } from "forge-std/interfaces/IERC20.sol";
import { UniswapV3Integration } from "../src/integrations/uniswap/v3/Uniswap.sol";
import { UniswapV3Integration } from "src/integrations/uniswap/v3/Uniswap.sol";

contract TestUniswap is BaseTest {
using ModuleKitHelpers for AccountInstance;
Expand Down
3 changes: 2 additions & 1 deletion test/mocks/MockK1Validator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ contract MockK1Validator is IValidator {
}

function onInstall(bytes calldata data) external {
smartAccountOwners[msg.sender] = address(bytes20(data));
address owner = abi.decode(data, (address));
smartAccountOwners[msg.sender] = owner;
}

function onUninstall(bytes calldata data) external {
Expand Down
Loading

0 comments on commit e6b8aa5

Please sign in to comment.