diff --git a/precompiles/assets/IAssets.sol b/precompiles/assets/IAssets.sol index 2c80fa5e..ff918a9b 100644 --- a/precompiles/assets/IAssets.sol +++ b/precompiles/assets/IAssets.sol @@ -8,6 +8,12 @@ address constant ASSETS_PRECOMPILE_ADDRESS = 0x000000000000000000000000000000000 IAssets constant ASSETS_CONTRACT = IAssets(ASSETS_PRECOMPILE_ADDRESS); /// @dev The TokenInfo struct. +/// @param name The name of the token +/// @param symbol The symbol of the token +/// @param clientChainID The client chain ID +/// @param tokenID The token ID, typically the token address encoded in bytes +/// @param decimals The number of decimals of the token +/// @param totalStaked The total staked amount of the token struct TokenInfo { string name; string symbol; @@ -17,6 +23,26 @@ struct TokenInfo { uint256 totalStaked; } +/// @dev The StakerBalance struct. +/// @param clientChainID The client chain ID +/// @param stakerAddress The staker address, typically the staker's address encoded in bytes +/// @param tokenID The token ID, typically the token address encoded in bytes +/// @param balance The balance of the staker, balance = withdrawable + delegated + pendingUndelegated +/// @param withdrawable The withdrawable balance +/// @param delegated The delegated balance +/// @param pendingUndelegated The pending undelegated balance, during the unboding period and would become withdrawable after the unboding period +/// @param totalDeposited The total deposited balance +struct StakerBalance { + uint32 clientChainID; + bytes stakerAddress; + bytes tokenID; + uint256 balance; + uint256 withdrawable; + uint256 delegated; + uint256 pendingUndelegated; + uint256 totalDeposited; +} + /// @author Exocore Team /// @title Assets Precompile Contract /// @dev The interface through which solidity contracts will interact with assets module @@ -156,6 +182,13 @@ interface IAssets { /// @return success true if the query is successful /// @return assetInfo the asset info function getTokenInfo(uint32 clientChainId, bytes calldata tokenId) external view returns (bool success, TokenInfo memory assetInfo); + + /// @dev Returns the staker's balance for a given token. + /// @param clientChainId is the ID of the client chain + /// @param tokenId is the ID of the token, typically the token address + /// @return success true if the query is successful + /// @return stakerBalance the staker's balance + function getStakerBalanceByToken(uint32 clientChainId, bytes calldata stakerAddress, bytes calldata tokenId) external view returns (bool success, StakerBalance memory stakerBalance); } diff --git a/precompiles/assets/abi.json b/precompiles/assets/abi.json index d40e14a0..5ceb9c5e 100644 --- a/precompiles/assets/abi.json +++ b/precompiles/assets/abi.json @@ -95,6 +95,82 @@ ], "stateMutability": "view" }, + { + "type": "function", + "name": "getStakerBalanceByToken", + "inputs": [ + { + "name": "clientChainId", + "type": "uint32", + "internalType": "uint32" + }, + { + "name": "stakerAddress", + "type": "bytes", + "internalType": "bytes" + }, + { + "name": "tokenId", + "type": "bytes", + "internalType": "bytes" + } + ], + "outputs": [ + { + "name": "success", + "type": "bool", + "internalType": "bool" + }, + { + "name": "stakerBalance", + "type": "tuple", + "internalType": "struct StakerBalance", + "components": [ + { + "name": "clientChainID", + "type": "uint32", + "internalType": "uint32" + }, + { + "name": "stakerAddress", + "type": "bytes", + "internalType": "bytes" + }, + { + "name": "tokenID", + "type": "bytes", + "internalType": "bytes" + }, + { + "name": "balance", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "withdrawable", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "delegated", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "pendingUndelegated", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "totalDeposited", + "type": "uint256", + "internalType": "uint256" + } + ] + } + ], + "stateMutability": "view" + }, { "type": "function", "name": "getTokenInfo", diff --git a/precompiles/assets/abi_types.go b/precompiles/assets/abi_types.go index a859df98..edc72ecc 100644 --- a/precompiles/assets/abi_types.go +++ b/precompiles/assets/abi_types.go @@ -21,3 +21,27 @@ func NewEmptyTokenInfo() TokenInfo { TotalStaked: big.NewInt(0), } } + +type StakerBalance struct { + ClientChainID uint32 + StakerAddress []byte + TokenID []byte + Balance *big.Int + Withdrawable *big.Int + Delegated *big.Int + PendingUndelegated *big.Int + TotalDeposited *big.Int +} + +func NewEmptyStakerBalance() StakerBalance { + return StakerBalance{ + ClientChainID: 0, + StakerAddress: []byte{}, + TokenID: []byte{}, + Balance: big.NewInt(0), + Withdrawable: big.NewInt(0), + Delegated: big.NewInt(0), + PendingUndelegated: big.NewInt(0), + TotalDeposited: big.NewInt(0), + } +} diff --git a/precompiles/assets/assets.go b/precompiles/assets/assets.go index 28dc1d69..eb37b162 100644 --- a/precompiles/assets/assets.go +++ b/precompiles/assets/assets.go @@ -160,6 +160,12 @@ func (p Precompile) Run(evm *vm.EVM, contract *vm.Contract, readOnly bool) (bz [ ctx.Logger().Error("internal error when calling assets precompile", "module", "assets precompile", "method", method.Name, "err", err) bz, err = method.Outputs.Pack(false, NewEmptyTokenInfo()) } + case MethodGetStakerBalanceByToken: + bz, err = p.GetStakerBalanceByToken(ctx, method, args) + if err != nil { + ctx.Logger().Error("internal error when calling assets precompile", "module", "assets precompile", "method", method.Name, "err", err) + bz, err = method.Outputs.Pack(false, NewEmptyStakerBalance()) + } default: return nil, fmt.Errorf(cmn.ErrUnknownMethod, method.Name) } @@ -186,7 +192,7 @@ func (Precompile) IsTransaction(methodID string) bool { MethodRegisterOrUpdateClientChain, MethodRegisterToken, MethodUpdateToken, MethodUpdateAuthorizedGateways: return true - case MethodGetClientChains, MethodIsRegisteredClientChain, MethodIsAuthorizedGateway, MethodGetTokenInfo: + case MethodGetClientChains, MethodIsRegisteredClientChain, MethodIsAuthorizedGateway, MethodGetTokenInfo, MethodGetStakerBalanceByToken: return false default: return false diff --git a/precompiles/assets/assets_test.go b/precompiles/assets/assets_test.go index b7a23ad9..797c6aea 100644 --- a/precompiles/assets/assets_test.go +++ b/precompiles/assets/assets_test.go @@ -4,6 +4,7 @@ import ( "math/big" "strings" + "cosmossdk.io/math" sdkmath "cosmossdk.io/math" assetsprecompile "github.com/ExocoreNetwork/exocore/precompiles/assets" assetskeeper "github.com/ExocoreNetwork/exocore/x/assets/keeper" @@ -14,6 +15,7 @@ import ( "github.com/ExocoreNetwork/exocore/app" assetstype "github.com/ExocoreNetwork/exocore/x/assets/types" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" evmtypes "github.com/evmos/evmos/v16/x/evm/types" @@ -513,3 +515,574 @@ func (s *AssetsPrecompileSuite) TestGetClientChains() { ) }) } + +func (s *AssetsPrecompileSuite) TestUpdateAuthorizedGateways() { + testcases := []struct { + name string + malleate func() (common.Address, []byte) + readOnly bool + expPass bool + errContains string + expResult bool + }{ + { + name: "fail - update gateways for mainnet, authority mismatch", + malleate: func() (common.Address, []byte) { + newGateways := []common.Address{ + common.HexToAddress("0x3fC91A3afd70395Cd496C647d5a6CC9D4B2b7FAD"), + common.HexToAddress("0x4fC91A3afd70395Cd496C647d5a6CC9D4B2b7FAE"), + } + input, err := s.precompile.Pack( + "updateAuthorizedGateways", + newGateways, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: false, + expPass: true, + expResult: false, + }, + } + + for _, tc := range testcases { + tc := tc + s.Run(tc.name, func() { + s.SetupTest() // Reset state for each test case + + // Get caller address and input data + caller, input := tc.malleate() + + // set up EVM environment + contract, evm, err := s.setupEVMEnvironment(caller, input, big.NewInt(0)) + s.Require().NoError(err) + + // Execute precompile + bz, err := s.precompile.Run(evm, contract, tc.readOnly) + + s.Require().NoError(err) + // we expect no error for both success and failure cases + success, err := s.precompile.Unpack("updateAuthorizedGateways", bz) + + if tc.expPass { + s.Require().NoError(err) + s.Require().Equal(tc.expResult, success[0].(bool)) + } else { + s.Require().Error(err) + } + }) + } +} + +func (s *AssetsPrecompileSuite) TestIsAuthorizedGateway() { + testcases := []struct { + name string + malleate func() (common.Address, []byte) + readOnly bool + expPass bool + expResult bool + errContains string + }{ + { + name: "pass - gateway is authorized", + malleate: func() (common.Address, []byte) { + // First set up an authorized gateway + params := &assetstypes.Params{ + Gateways: []string{s.Address.String()}, + } + err := s.App.AssetsKeeper.SetParams(s.Ctx, params) + s.Require().NoError(err) + + input, err := s.precompile.Pack( + "isAuthorizedGateway", + s.Address, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + expResult: true, + }, + { + name: "pass - gateway is not authorized", + malleate: func() (common.Address, []byte) { + // Set up different authorized gateway + params := &assetstypes.Params{ + Gateways: []string{ + "0x1234567890123456789012345678901234567890", + }, + } + err := s.App.AssetsKeeper.SetParams(s.Ctx, params) + s.Require().NoError(err) + + input, err := s.precompile.Pack( + "isAuthorizedGateway", + s.Address, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + expResult: false, + }, + { + name: "pass - check with empty gateway list", + malleate: func() (common.Address, []byte) { + // Set empty gateway list + params := &assetstypes.Params{ + Gateways: []string{}, + } + err := s.App.AssetsKeeper.SetParams(s.Ctx, params) + s.Require().NoError(err) + + input, err := s.precompile.Pack( + "isAuthorizedGateway", + s.Address, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + expResult: false, + }, + { + name: "pass - check zero address", + malleate: func() (common.Address, []byte) { + input, err := s.precompile.Pack( + "isAuthorizedGateway", + common.Address{}, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + expResult: false, + }, + } + + for _, tc := range testcases { + tc := tc + s.Run(tc.name, func() { + s.SetupTest() // Reset state for each test case + + // Get caller address and input data + caller, input := tc.malleate() + + // Setup EVM environment + contract, evm, err := s.setupEVMEnvironment(caller, input, big.NewInt(0)) + s.Require().NoError(err) + + // Execute precompile + bz, err := s.precompile.Run(evm, contract, tc.readOnly) + + if tc.expPass { + s.Require().NoError(err) + // Unpack and verify the result + result, err := s.precompile.Unpack("isAuthorizedGateway", bz) + s.Require().NoError(err) + s.Require().Equal(true, result[0].(bool)) + s.Require().Equal(tc.expResult, result[1].(bool)) + } else { + s.Require().Error(err) + s.Require().Contains(err.Error(), tc.errContains) + } + }) + } +} + +func (s *AssetsPrecompileSuite) TestGetTokenInfo() { + testcases := []struct { + name string + malleate func() (common.Address, []byte) + readOnly bool + expPass bool + returnCheck func([]byte) bool + errContains string + }{ + { + name: "pass - get existing token info (NST)", + malleate: func() (common.Address, []byte) { + // NST token is already set up in SetupTest() + tokenAddr := paddingClientChainAddress( + common.FromHex("0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"), + assetstype.GeneralClientChainAddrLength, + ) + + input, err := s.precompile.Pack( + "getTokenInfo", + uint32(s.ClientChains[0].LayerZeroChainID), + tokenAddr, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getTokenInfo", bz) + s.Require().NoError(err) + success := result[0].(bool) + s.Require().True(success) + + tokenInfo := result[1].(struct { + Name string `json:"name"` + Symbol string `json:"symbol"` + ClientChainID uint32 `json:"clientChainID"` + TokenID []byte `json:"tokenID"` + Decimals uint8 `json:"decimals"` + TotalStaked *big.Int `json:"totalStaked"` + }) + return tokenInfo.Name == "Native Restaking ETH" && + tokenInfo.Symbol == "NSTETH" && + tokenInfo.Decimals == 18 && + tokenInfo.TotalStaked.Cmp(s.nstStaked.BigInt()) == 0 + }, + }, + { + name: "pass - get existing token info (LST)", + malleate: func() (common.Address, []byte) { + // Setup LST token first + s.lstStaked = math.NewInt(100) + lstToken := &assetstypes.StakingAssetInfo{ + AssetBasicInfo: assetstypes.AssetInfo{ + Name: "Liquid Staking Token", + Symbol: "LST", + Address: "0x1234567890123456789012345678901234567890", + Decimals: 6, + LayerZeroChainID: uint64(101), + MetaInfo: "liquid staking token", + }, + StakingTotalAmount: s.lstStaked, + } + err := s.App.AssetsKeeper.SetStakingAssetInfo(s.Ctx, lstToken) + s.Require().NoError(err) + + tokenAddr := paddingClientChainAddress( + common.FromHex(lstToken.AssetBasicInfo.Address), + assetstype.GeneralClientChainAddrLength, + ) + + input, err := s.precompile.Pack( + "getTokenInfo", + uint32(lstToken.AssetBasicInfo.LayerZeroChainID), + tokenAddr, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getTokenInfo", bz) + s.Require().NoError(err) + success := result[0].(bool) + s.Require().True(success) + + tokenInfo := result[1].(struct { + Name string `json:"name"` + Symbol string `json:"symbol"` + ClientChainID uint32 `json:"clientChainID"` + TokenID []byte `json:"tokenID"` + Decimals uint8 `json:"decimals"` + TotalStaked *big.Int `json:"totalStaked"` + }) + return tokenInfo.Name == "Liquid Staking Token" && + tokenInfo.Symbol == "LST" && + tokenInfo.Decimals == 6 && + tokenInfo.TotalStaked.Cmp(s.lstStaked.BigInt()) == 0 + }, + }, + { + name: "fail - non-existent token", + malleate: func() (common.Address, []byte) { + input, err := s.precompile.Pack( + "getTokenInfo", + uint32(999), + paddingClientChainAddress( + common.FromHex("0x1234567890123456789012345678901234567890"), + assetstype.GeneralClientChainAddrLength, + ), + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, // The call succeeds but returns false + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getTokenInfo", bz) + s.Require().NoError(err) + success := result[0].(bool) + return !success // Expect false for non-existent token + }, + }, + { + name: "fail - invalid chain ID", + malleate: func() (common.Address, []byte) { + input, err := s.precompile.Pack( + "getTokenInfo", + uint32(0), // invalid chain ID + paddingClientChainAddress( + common.FromHex("0xeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee"), + assetstype.GeneralClientChainAddrLength, + ), + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, // The call succeeds but returns false + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getTokenInfo", bz) + s.Require().NoError(err) + success := result[0].(bool) + return !success // Expect false for invalid chain ID + }, + }, + } + + for _, tc := range testcases { + tc := tc + s.Run(tc.name, func() { + s.SetupTest() // Reset state for each test case + + // Get caller address and input data + caller, input := tc.malleate() + + // Setup EVM environment + contract, evm, err := s.setupEVMEnvironment(caller, input, big.NewInt(0)) + s.Require().NoError(err) + + // Execute precompile + bz, err := s.precompile.Run(evm, contract, tc.readOnly) + s.Require().NoError(err) // All calls should succeed + + if tc.returnCheck != nil { + s.Require().True(tc.returnCheck(bz)) + } + }) + } +} + +func (s *AssetsPrecompileSuite) TestGetStakerBalanceByToken() { + testcases := []struct { + name string + malleate func() (common.Address, []byte) + readOnly bool + expPass bool + returnCheck func([]byte) bool + }{ + { + name: "pass - get balance with only asset state", + malleate: func() (common.Address, []byte) { + clientChainID := uint32(101) + tokenAddr := paddingClientChainAddress( + common.FromHex("0x1234567890123456789012345678901234567890"), + 20, + ) + stakerAddr := paddingClientChainAddress( + s.Address.Bytes(), + 20, + ) + + // Setup token + token := &assetstypes.StakingAssetInfo{ + AssetBasicInfo: assetstypes.AssetInfo{ + Name: "Test Token", + Symbol: "TEST", + Address: hexutil.Encode(tokenAddr), + Decimals: 18, + LayerZeroChainID: uint64(clientChainID), + }, + StakingTotalAmount: sdkmath.NewInt(100), + } + err := s.App.AssetsKeeper.SetStakingAssetInfo(s.Ctx, token) + s.Require().NoError(err) + + // Setup staker asset state + stakerID, assetID := assetstypes.GetStakerIDAndAssetID(uint64(clientChainID), stakerAddr, tokenAddr) + assetDelta := assetstypes.DeltaStakerSingleAsset{ + TotalDepositAmount: sdkmath.NewInt(100), + WithdrawableAmount: sdkmath.NewInt(70), + PendingUndelegationAmount: sdkmath.NewInt(30), + } + err = s.App.AssetsKeeper.UpdateStakerAssetState(s.Ctx, stakerID, assetID, assetDelta) + s.Require().NoError(err) + + input, err := s.precompile.Pack( + "getStakerBalanceByToken", + clientChainID, + stakerAddr, + tokenAddr, + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getStakerBalanceByToken", bz) + s.Require().NoError(err) + success := result[0].(bool) + s.Require().True(success) + + balance := result[1].(struct { + ClientChainID uint32 `json:"clientChainID"` + StakerAddress []byte `json:"stakerAddress"` + TokenID []byte `json:"tokenID"` + Balance *big.Int `json:"balance"` + Withdrawable *big.Int `json:"withdrawable"` + Delegated *big.Int `json:"delegated"` + PendingUndelegated *big.Int `json:"pendingUndelegated"` + TotalDeposited *big.Int `json:"totalDeposited"` + }) + + return balance.Balance.Cmp(big.NewInt(100)) == 0 && // TotalDepositAmount + balance.Withdrawable.Cmp(big.NewInt(70)) == 0 && // WithdrawableAmount + balance.Delegated.Sign() == 0 && // No delegations + balance.PendingUndelegated.Cmp(big.NewInt(30)) == 0 && // PendingUndelegationAmount + balance.TotalDeposited.Cmp(big.NewInt(100)) == 0 // TotalDepositAmount + }, + }, + { + name: "pass - non-existent token", + malleate: func() (common.Address, []byte) { + input, err := s.precompile.Pack( + "getStakerBalanceByToken", + uint32(999), + paddingClientChainAddress( + common.FromHex("0x1234567890123456789012345678901234567890"), + assetstype.GeneralClientChainAddrLength, + ), + paddingClientChainAddress( + common.FromHex("0xdAC17F958D2ee523a2206206994597C13D831ec7"), + assetstype.GeneralClientChainAddrLength, + ), + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getStakerBalanceByToken", bz) + s.Require().NoError(err) + success := result[0].(bool) + return !success // Should return false for non-existent token + }, + }, + { + name: "pass - invalid chain ID", + malleate: func() (common.Address, []byte) { + input, err := s.precompile.Pack( + "getStakerBalanceByToken", + uint32(0), // invalid chain ID + paddingClientChainAddress( + s.Address.Bytes(), + assetstype.GeneralClientChainAddrLength, + ), + paddingClientChainAddress( + common.FromHex("0xdAC17F958D2ee523a2206206994597C13D831ec7"), + assetstype.GeneralClientChainAddrLength, + ), + ) + s.Require().NoError(err) + return s.Address, input + }, + readOnly: true, + expPass: true, + returnCheck: func(bz []byte) bool { + result, err := s.precompile.Unpack("getStakerBalanceByToken", bz) + s.Require().NoError(err) + success := result[0].(bool) + return !success // Should return false for invalid chain ID + }, + }, + } + + for _, tc := range testcases { + tc := tc + s.Run(tc.name, func() { + s.SetupTest() // Reset state for each test case + + // Get caller address and input data + caller, input := tc.malleate() + + // Setup EVM environment + contract, evm, err := s.setupEVMEnvironment(caller, input, big.NewInt(0)) + s.Require().NoError(err) + + // Execute precompile + bz, err := s.precompile.Run(evm, contract, tc.readOnly) + s.Require().NoError(err) // All calls should succeed + + if tc.returnCheck != nil { + s.Require().True(tc.returnCheck(bz)) + } + }) + } +} + +// setupEVMEnvironment creates an EVM environment and returns the contract and EVM instance +func (s *AssetsPrecompileSuite) setupEVMEnvironment( + caller common.Address, + input []byte, + value *big.Int, +) (*vm.Contract, *vm.EVM, error) { + // Create contract + contract := vm.NewPrecompile(vm.AccountRef(caller), s.precompile, value, uint64(1e6)) + contract.Input = input + + // Get base fee for tx execution + baseFee := s.App.FeeMarketKeeper.GetBaseFee(s.Ctx) + + // Build and sign Ethereum transaction + contractAddr := contract.Address() + txArgs := evmtypes.EvmTxArgs{ + ChainID: s.App.EvmKeeper.ChainID(), + Nonce: 0, + To: &contractAddr, + Amount: nil, + GasLimit: 100000, + GasPrice: app.MainnetMinGasPrices.BigInt(), + GasFeeCap: baseFee, + GasTipCap: big.NewInt(1), + Accesses: ðtypes.AccessList{}, + } + msgEthereumTx := evmtypes.NewTx(&txArgs) + msgEthereumTx.From = caller.String() + err := msgEthereumTx.Sign(s.EthSigner, s.Signer) + if err != nil { + return nil, nil, err + } + + // Prepare EVM execution + proposerAddress := s.Ctx.BlockHeader().ProposerAddress + cfg, err := s.App.EvmKeeper.EVMConfig(s.Ctx, proposerAddress, s.App.EvmKeeper.ChainID()) + if err != nil { + return nil, nil, err + } + + msg, err := msgEthereumTx.AsMessage(s.EthSigner, baseFee) + if err != nil { + return nil, nil, err + } + + // Create EVM instance + evm := s.App.EvmKeeper.NewEVM(s.Ctx, msg, cfg, nil, s.StateDB) + + // Setup precompiles + params := s.App.EvmKeeper.GetParams(s.Ctx) + activePrecompiles := params.GetActivePrecompilesAddrs() + precompileMap := s.App.EvmKeeper.Precompiles(activePrecompiles...) + err = vm.ValidatePrecompiles(precompileMap, activePrecompiles) + if err != nil { + return nil, nil, err + } + evm.WithPrecompiles(precompileMap, activePrecompiles) + + return contract, evm, nil +} diff --git a/precompiles/assets/decode_helper.go b/precompiles/assets/decode_helper.go index 33d67ddc..502636cd 100644 --- a/precompiles/assets/decode_helper.go +++ b/precompiles/assets/decode_helper.go @@ -1,6 +1,7 @@ package assets import ( + "encoding/hex" "fmt" "math/big" @@ -81,7 +82,7 @@ func (ta *TypedArgs) GetBigInt(index int) (*big.Int, error) { return val, nil } -func (ta *TypedArgs) GetAddress(index int) (common.Address, error) { +func (ta *TypedArgs) GetEVMAddress(index int) (common.Address, error) { if index >= len(ta.args) { return common.Address{}, fmt.Errorf(exocmn.ErrIndexOutOfRange, index, len(ta.args)) } @@ -92,7 +93,7 @@ func (ta *TypedArgs) GetAddress(index int) (common.Address, error) { return val, nil } -func (ta *TypedArgs) GetAddressSlice(index int) ([]common.Address, error) { +func (ta *TypedArgs) GetEVMAddressSlice(index int) ([]common.Address, error) { if index >= len(ta.args) { return nil, fmt.Errorf(exocmn.ErrIndexOutOfRange, index, len(ta.args)) } @@ -169,8 +170,8 @@ func (ta *TypedArgs) GetRequiredBytesPrefix(index int, length uint32) ([]byte, e return val[:length], nil } -func (ta *TypedArgs) GetRequiredAddressSlice(index int) ([]common.Address, error) { - val, err := ta.GetAddressSlice(index) +func (ta *TypedArgs) GetRequiredEVMAddressSlice(index int) ([]common.Address, error) { + val, err := ta.GetEVMAddressSlice(index) if err != nil { return nil, err } @@ -179,3 +180,11 @@ func (ta *TypedArgs) GetRequiredAddressSlice(index int) ([]common.Address, error } return val, nil } + +func (ta *TypedArgs) GetRequiredHexAddress(index int, addressLength uint32) (string, error) { + val, err := ta.GetRequiredBytesPrefix(index, addressLength) + if err != nil { + return "", err + } + return hex.EncodeToString(val), nil +} diff --git a/precompiles/assets/query.go b/precompiles/assets/query.go index 97a27d7b..5a3e4bd3 100644 --- a/precompiles/assets/query.go +++ b/precompiles/assets/query.go @@ -14,6 +14,7 @@ const ( MethodIsRegisteredClientChain = "isRegisteredClientChain" MethodIsAuthorizedGateway = "isAuthorizedGateway" MethodGetTokenInfo = "getTokenInfo" + MethodGetStakerBalanceByToken = "getStakerBalanceByToken" ) func (p Precompile) GetClientChains( @@ -69,7 +70,7 @@ func (p Precompile) IsAuthorizedGateway( if err := ta.RequireLen(len(p.ABI.Methods[MethodIsAuthorizedGateway].Inputs)); err != nil { return nil, err } - gateway, err := ta.GetAddress(0) + gateway, err := ta.GetEVMAddress(0) if err != nil { return nil, err } @@ -93,11 +94,17 @@ func (p Precompile) GetTokenInfo( if err != nil { return nil, err } - tokenID, err := ta.GetRequiredBytes(1) + + info, err := p.assetsKeeper.GetClientChainInfoByIndex(ctx, uint64(clientChainID)) + if err != nil { + return nil, err + } + + assetAddress, err := ta.GetRequiredBytesPrefix(1, info.AddressLength) if err != nil { return nil, err } - _, assetID := assetstype.GetStakerIDAndAssetIDFromStr(uint64(clientChainID), "", string(tokenID)) + _, assetID := assetstype.GetStakerIDAndAssetID(uint64(clientChainID), nil, assetAddress) tokenInfo, err := p.assetsKeeper.GetStakingAssetInfo(ctx, assetID) if err != nil { return nil, err @@ -111,10 +118,61 @@ func (p Precompile) GetTokenInfo( Name: tokenInfo.AssetBasicInfo.Name, Symbol: tokenInfo.AssetBasicInfo.Symbol, ClientChainID: clientChainID, - TokenID: tokenID, + TokenID: assetAddress, Decimals: uint8(tokenInfo.AssetBasicInfo.Decimals), TotalStaked: tokenInfo.StakingTotalAmount.BigInt(), } return method.Outputs.Pack(true, result) } + +func (p Precompile) GetStakerBalanceByToken( + ctx sdk.Context, + method *abi.Method, + args []interface{}, +) ([]byte, error) { + ta := NewTypedArgs(args) + if err := ta.RequireLen(len(p.ABI.Methods[MethodGetStakerBalanceByToken].Inputs)); err != nil { + return nil, err + } + + clientChainID, err := ta.GetUint32(0) + if err != nil { + return nil, err + } + + info, err := p.assetsKeeper.GetClientChainInfoByIndex(ctx, uint64(clientChainID)) + if err != nil { + return nil, err + } + + stakerAddress, err := ta.GetRequiredBytesPrefix(1, info.AddressLength) + if err != nil { + return nil, err + } + + assetAddress, err := ta.GetRequiredBytesPrefix(2, info.AddressLength) + if err != nil { + return nil, err + } + + stakerID, assetID := assetstype.GetStakerIDAndAssetID(uint64(clientChainID), stakerAddress, assetAddress) + + balance, err := p.assetsKeeper.GetStakerBalanceByAsset(ctx, stakerID, assetID) + if err != nil { + return nil, err + } + + result := StakerBalance{ + ClientChainID: clientChainID, + StakerAddress: stakerAddress, + TokenID: assetAddress, + Balance: balance.Balance, + Withdrawable: balance.Withdrawable, + Delegated: balance.Delegated, + PendingUndelegated: balance.PendingUndelegated, + TotalDeposited: balance.TotalDeposited, + } + + return method.Outputs.Pack(true, result) +} diff --git a/precompiles/assets/setup_test.go b/precompiles/assets/setup_test.go index 2bf804f0..60defaab 100644 --- a/precompiles/assets/setup_test.go +++ b/precompiles/assets/setup_test.go @@ -21,6 +21,8 @@ type AssetsPrecompileSuite struct { testutil.BaseTestSuite precompile *assets.Precompile + nstStaked math.Int + lstStaked math.Int } func TestPrecompileTestSuite(t *testing.T) { @@ -49,4 +51,5 @@ func (s *AssetsPrecompileSuite) SetupTest() { }, StakingTotalAmount: depositAmountNST, }) + s.nstStaked = depositAmountNST } diff --git a/precompiles/assets/tx.go b/precompiles/assets/tx.go index 05155092..0409c039 100644 --- a/precompiles/assets/tx.go +++ b/precompiles/assets/tx.go @@ -174,6 +174,9 @@ func (p Precompile) UpdateToken( return method.Outputs.Pack(true) } +// UpdateAuthorizedGateways updates the authorized gateways for the assets module. +// For mainnet, if the authority of the assets module is the governance module, this method would not work. +// So it is mainly used for testing purposes. func (p Precompile) UpdateAuthorizedGateways( ctx sdk.Context, contract *vm.Contract, @@ -181,7 +184,7 @@ func (p Precompile) UpdateAuthorizedGateways( args []interface{}, ) ([]byte, error) { ta := NewTypedArgs(args) - gateways, err := ta.GetRequiredAddressSlice(0) + gateways, err := ta.GetRequiredEVMAddressSlice(0) if err != nil { return nil, err } @@ -198,7 +201,6 @@ func (p Precompile) UpdateAuthorizedGateways( if err != nil { return nil, err } - fmt.Printf("precompile successfully updated gateways") return method.Outputs.Pack(true) } diff --git a/precompiles/assets/types.go b/precompiles/assets/types.go index c13dc3f9..58654dd5 100644 --- a/precompiles/assets/types.go +++ b/precompiles/assets/types.go @@ -7,9 +7,6 @@ import ( "strings" "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common" - - "github.com/ethereum/go-ethereum/common/hexutil" sdkmath "cosmossdk.io/math" exocmn "github.com/ExocoreNetwork/exocore/precompiles/common" @@ -156,7 +153,7 @@ func (p Precompile) TokenFromInputs(ctx sdk.Context, args []interface{}) (*asset return nil, nil, err } - assetAddr, err := ta.GetRequiredBytesPrefix(1, info.AddressLength) // Must not be empty and must match length + assetAddr, err := ta.GetRequiredHexAddress(1, info.AddressLength) // Must not be empty and must match length if err != nil { return nil, nil, err } @@ -190,7 +187,7 @@ func (p Precompile) TokenFromInputs(ctx sdk.Context, args []interface{}) (*asset // Assign values to asset asset := &assetstypes.AssetInfo{ LayerZeroChainID: uint64(clientChainID), - Address: hexutil.Encode(assetAddr), + Address: assetAddr, Decimals: uint32(decimal), Name: name, MetaInfo: metaInfo, @@ -235,8 +232,6 @@ func (p Precompile) UpdateTokenFromInputs(ctx sdk.Context, args []interface{}) ( return 0, "", "", err } - var assetAddr []byte - clientChainID, err = ta.GetPositiveUint32(0) if err != nil { return 0, "", "", err @@ -247,7 +242,7 @@ func (p Precompile) UpdateTokenFromInputs(ctx sdk.Context, args []interface{}) ( return 0, "", "", err } - assetAddr, err = ta.GetRequiredBytesPrefix(1, info.AddressLength) + hexAssetAddr, err = ta.GetRequiredHexAddress(1, info.AddressLength) if err != nil { return 0, "", "", err } @@ -260,48 +255,5 @@ func (p Precompile) UpdateTokenFromInputs(ctx sdk.Context, args []interface{}) ( return 0, "", "", fmt.Errorf(exocmn.ErrInvalidMetaInfoLength, metadata, len(metadata), assetstypes.MaxChainTokenMetaInfoLength) } - hexAssetAddr = hexutil.Encode(assetAddr) return clientChainID, hexAssetAddr, metadata, nil } - -func (p Precompile) ClientChainIDFromInputs(_ sdk.Context, args []interface{}) (uint32, error) { - inputsLen := len(p.ABI.Methods[MethodIsRegisteredClientChain].Inputs) - if len(args) != inputsLen { - return 0, fmt.Errorf(cmn.ErrInvalidNumberOfArgs, inputsLen, len(args)) - } - clientChainID, ok := args[0].(uint32) - if !ok { - return 0, fmt.Errorf(exocmn.ErrContractInputParaOrType, 0, "uint32", args[0]) - } - return clientChainID, nil -} - -func (p Precompile) GatewaysFromInputs(_ sdk.Context, args []interface{}) ([]string, error) { - ta := NewTypedArgs(args) - if err := ta.RequireLen(1); err != nil { - return nil, err - } - - gateways, ok := args[0].([]common.Address) - if !ok { - return nil, fmt.Errorf(exocmn.ErrContractInputParaOrType, 0, "[]common.Address", args[0]) - } - if len(gateways) == 0 { - return nil, fmt.Errorf(exocmn.ErrEmptyGateways) - } - - gatewaysStr := make([]string, len(gateways)) - for i, gateway := range gateways { - gatewaysStr[i] = gateway.Hex() - } - return gatewaysStr, nil -} - -func (p Precompile) GatewayFromInputs(_ sdk.Context, args []interface{}) (common.Address, error) { - ta := NewTypedArgs(args) - if err := ta.RequireLen(1); err != nil { - return common.Address{}, err - } - - return ta.GetAddress(0) -} diff --git a/x/assets/keeper/expected_keepers.go b/x/assets/keeper/expected_keepers.go index a2570704..2dc0ba7b 100644 --- a/x/assets/keeper/expected_keepers.go +++ b/x/assets/keeper/expected_keepers.go @@ -1,6 +1,7 @@ package keeper import ( + sdkmath "cosmossdk.io/math" delegationtype "github.com/ExocoreNetwork/exocore/x/delegation/types" sdk "github.com/cosmos/cosmos-sdk/types" ) @@ -8,4 +9,5 @@ import ( // this keeper interface is defined here to avoid a circular dependency type delegationKeeper interface { GetDelegationInfo(ctx sdk.Context, stakerID, assetID string) (*delegationtype.QueryDelegationInfoResponse, error) + TotalDelegatedAmountForStakerAsset(ctx sdk.Context, stakerID, assetID string) (amount sdkmath.Int, err error) } diff --git a/x/assets/keeper/staker_asset.go b/x/assets/keeper/staker_asset.go index 05d041f9..81648602 100644 --- a/x/assets/keeper/staker_asset.go +++ b/x/assets/keeper/staker_asset.go @@ -161,3 +161,29 @@ func (k Keeper) UpdateStakerAssetState(ctx sdk.Context, stakerID string, assetID return nil } + +func (k Keeper) GetStakerBalanceByAsset(ctx sdk.Context, stakerID string, assetID string) (balance assetstype.StakerBalance, err error) { + stakerAssetInfo, err := k.GetStakerSpecifiedAssetInfo(ctx, stakerID, assetID) + if err != nil { + return assetstype.StakerBalance{}, err + } + + delegatedAmount, err := k.dk.TotalDelegatedAmountForStakerAsset(ctx, stakerID, assetID) + if err != nil { + return assetstype.StakerBalance{}, err + } + + totalBalance := stakerAssetInfo.WithdrawableAmount.Add(stakerAssetInfo.PendingUndelegationAmount).Add(delegatedAmount) + + balance = assetstype.StakerBalance{ + StakerID: stakerID, + AssetID: assetID, + Balance: totalBalance.BigInt(), + Withdrawable: stakerAssetInfo.WithdrawableAmount.BigInt(), + Delegated: delegatedAmount.BigInt(), + PendingUndelegated: stakerAssetInfo.PendingUndelegationAmount.BigInt(), + TotalDeposited: stakerAssetInfo.TotalDepositAmount.BigInt(), + } + + return balance, nil +} diff --git a/x/assets/types/general.go b/x/assets/types/general.go index 0c9b8bb8..b671fdc3 100644 --- a/x/assets/types/general.go +++ b/x/assets/types/general.go @@ -2,6 +2,7 @@ package types import ( "fmt" + "math/big" "strings" "github.com/ExocoreNetwork/exocore/utils" @@ -68,6 +69,18 @@ type DeltaOperatorSingleAsset OperatorAssetInfo type CreateQueryContext func(height int64, prove bool) (sdk.Context, error) +// StakerBalance is a struct to describe the balance of a staker for a specific asset +// balance = withdrawable + delegated + pendingUndelegated +type StakerBalance struct { + StakerID string + AssetID string + Balance *big.Int + Withdrawable *big.Int + Delegated *big.Int + PendingUndelegated *big.Int + TotalDeposited *big.Int +} + // GetStakerIDAndAssetID stakerID = stakerAddress+'_'+clientChainLzID,assetID = // assetAddress+'_'+clientChainLzID func GetStakerIDAndAssetID( @@ -87,7 +100,7 @@ func GetStakerIDAndAssetID( } // GetStakerIDAndAssetIDFromStr stakerID = stakerAddress+'_'+clientChainLzID,assetID = -// assetAddress+'_'+clientChainLzID +// assetAddress+'_'+clientChainLzID, NOTE: the stakerAddress and assetsAddress should be in hex format func GetStakerIDAndAssetIDFromStr( clientChainLzID uint64, stakerAddress string,