diff --git a/omnichain/staking/contracts/Staking.sol b/omnichain/staking/contracts/Staking.sol index 14ab5057..ca5663db 100644 --- a/omnichain/staking/contracts/Staking.sol +++ b/omnichain/staking/contracts/Staking.sol @@ -15,18 +15,25 @@ contract Staking is ERC20, zContract { SystemContract public immutable systemContract; uint256 public immutable chainID; - mapping(address => uint256) public stakes; - mapping(address => address) public beneficiaries; - mapping(address => uint256) public lastStakeTime; + mapping(bytes => uint256) public stakes; + mapping(bytes => address) public beneficiaries; + mapping(bytes => uint256) public lastStakeTime; uint256 public rewardRate = 1; event Staked( - address indexed staker, + bytes indexed staker, address indexed beneficiary, uint256 amount ); - event RewardsClaimed(address indexed staker, uint256 rewardAmount); - event Unstaked(address indexed staker, uint256 amount); + event RewardsClaimed(bytes indexed staker, uint256 rewardAmount); + event Unstaked(bytes indexed staker, uint256 amount); + event OnCrossChainCallEvent( + bytes staker, + address beneficiary, + uint32 action + ); + + event Logging(string); constructor( string memory name_, @@ -38,6 +45,27 @@ contract Staking is ERC20, zContract { chainID = chainID_; } + function decode( + bytes memory message + ) public pure returns (bytes memory bech32, address hexAddr, uint32 value) { + require(message.length == 66, "Invalid message length"); // 42 (bech32) + 20 (address) + 4 (uint32) + + bytes memory bech32Bytes = new bytes(42); + + for (uint256 i = 0; i < bech32Bytes.length; i++) { + bech32Bytes[i] = message[i]; + } + + bech32 = bech32Bytes; + + assembly { + hexAddr := mload(add(message, 42)) + value := mload(add(message, 62)) + } + + return (bech32, hexAddr, value); + } + function onCrossChainCall( zContext calldata context, address zrc20, @@ -48,20 +76,26 @@ contract Staking is ERC20, zContract { revert SenderNotSystemContract(); } + emit Logging("onCrossChainCall"); + address acceptedZRC20 = systemContract.gasCoinZRC20ByChainId(chainID); if (zrc20 != acceptedZRC20) revert WrongChain(); - address staker = BytesHelperLib.bytesToAddress(context.origin, 0); + bytes memory staker; address beneficiary; uint32 action; if (context.chainID == 18332) { - beneficiary = BytesHelperLib.bytesToAddress(message, 0); - action = BytesHelperLib.bytesToUint32(message, 20); + (staker, beneficiary, action) = decode(message); } else { - (beneficiary, action) = abi.decode(message, (address, uint32)); + (staker, beneficiary, action) = abi.decode( + message, + (bytes, address, uint32) + ); } + emit OnCrossChainCallEvent(staker, beneficiary, action); + if (action == 1) { stakeZRC(staker, beneficiary, amount); } else if (action == 2) { @@ -72,10 +106,11 @@ contract Staking is ERC20, zContract { } function stakeZRC( - address staker, + bytes memory staker, address beneficiary, uint256 amount ) internal { + emit Logging("stakeZRC"); stakes[staker] += amount; require(stakes[staker] >= amount, "Overflow detected"); @@ -89,7 +124,7 @@ contract Staking is ERC20, zContract { emit Staked(staker, beneficiary, amount); } - function updateRewards(address staker) internal { + function updateRewards(bytes memory staker) internal { uint256 timeDifference = block.timestamp - lastStakeTime[staker]; uint256 rewardAmount = timeDifference * stakes[staker] * rewardRate; require(rewardAmount >= timeDifference, "Overflow detected"); @@ -98,7 +133,7 @@ contract Staking is ERC20, zContract { lastStakeTime[staker] = block.timestamp; } - function claimRewards(address staker) external { + function claimRewards(bytes memory staker) external { require( beneficiaries[staker] == msg.sender, "Not authorized to claim rewards" @@ -112,7 +147,7 @@ contract Staking is ERC20, zContract { emit RewardsClaimed(staker, rewardAmount); } - function unstakeZRC(address staker) internal { + function unstakeZRC(bytes memory staker) internal { uint256 amount = stakes[staker]; updateRewards(staker); @@ -126,7 +161,7 @@ contract Staking is ERC20, zContract { bytes memory recipient; if (chainID == 18332) { - recipient = abi.encodePacked(BytesHelperLib.addressToBytes(staker)); + recipient = abi.encodePacked(staker); } else { recipient = abi.encodePacked(staker); } @@ -140,9 +175,9 @@ contract Staking is ERC20, zContract { emit Unstaked(staker, amount); } - function queryRewards(address account) public view returns (uint256) { - uint256 timeDifference = block.timestamp - lastStakeTime[account]; - uint256 rewardAmount = timeDifference * stakes[account] * rewardRate; + function queryRewards(bytes memory staker) public view returns (uint256) { + uint256 timeDifference = block.timestamp - lastStakeTime[staker]; + uint256 rewardAmount = timeDifference * stakes[staker] * rewardRate; return rewardAmount; } } diff --git a/omnichain/staking/tasks/rewards.ts b/omnichain/staking/tasks/rewards.ts index 16d03fef..941405c4 100644 --- a/omnichain/staking/tasks/rewards.ts +++ b/omnichain/staking/tasks/rewards.ts @@ -1,7 +1,23 @@ import { task } from "hardhat/config"; import { HardhatRuntimeEnvironment } from "hardhat/types"; import { convertToHexAddress } from "../lib/convertToHexAddress"; +const { decode } = require("bech32"); +const { arrayify } = require("@ethersproject/bytes"); +function bech32ToHex(bech32Address: string) { + const { words } = decode(bech32Address); + + // Convert 5-bit words to bytes + let bytes = []; + for (let i = 0; i < words.length; i++) { + let word = words[i]; + for (let j = 4; j >= 0; j--) { + bytes.push((word >> j) & 1); + } + } + + return arrayify(bytes); +} const main = async (args: any, hre: HardhatRuntimeEnvironment) => { const [signer] = await hre.ethers.getSigners(); console.log(`🔑 Using account: ${signer.address}\n`); @@ -11,7 +27,7 @@ const main = async (args: any, hre: HardhatRuntimeEnvironment) => { const factory = await hre.ethers.getContractFactory("Staking"); const contract = factory.attach(args.contract); - console.log(await contract.queryRewards(staker)); + console.log(await contract.queryRewards(bech32ToHex(args.staker))); }; task("rewards", "Query staking rewards", main)