From 2bb6f7be245fd7b91343fbf8cc8c2eaff1d9a0b7 Mon Sep 17 00:00:00 2001 From: Alex Gartner Date: Thu, 20 Jun 2024 11:52:45 -0700 Subject: [PATCH 1/6] refactor: use docker healthcheck for localnet e2e (#2353) * refactor: use docker healthcheck for localnet e2e * tune parameters * remove e2e sleep * tune zetaclient1 seed curl command * add zetaclientd gen-pre-params command * preparams caching and ssh config fix --- Dockerfile-localnet | 5 ++- cmd/zetaclientd/gen_pre_params.go | 40 +++++++++++++++++++ cmd/zetae2e/local/local.go | 9 +---- contrib/localnet/docker-compose.yml | 19 +++++++-- contrib/localnet/scripts/start-zetaclientd.sh | 13 ++++-- contrib/localnet/scripts/start-zetacored.sh | 2 +- contrib/localnet/ssh_config | 4 +- zetaclient/config/config.go | 3 +- 8 files changed, 74 insertions(+), 21 deletions(-) create mode 100644 cmd/zetaclientd/gen_pre_params.go diff --git a/Dockerfile-localnet b/Dockerfile-localnet index 40ab112af1..bacf8c19a4 100644 --- a/Dockerfile-localnet +++ b/Dockerfile-localnet @@ -44,10 +44,11 @@ RUN mkdir -p /root/.zetacored/cosmovisor/genesis/bin && \ ENV PATH /root/.zetacored/cosmovisor/current/bin/:/root/.zetaclientd/upgrades/current/:${PATH} COPY contrib/localnet/scripts /root -COPY contrib/localnet/ssh_config /root/.ssh/config +COPY contrib/localnet/ssh_config /etc/ssh/ssh_config.d/localnet.conf COPY contrib/localnet/zetacored /root/zetacored -RUN chmod 755 /root/*.sh +RUN chmod 755 /root/*.sh && \ + chmod 644 /etc/ssh/ssh_config.d/localnet.conf WORKDIR /usr/local/bin EXPOSE 22 diff --git a/cmd/zetaclientd/gen_pre_params.go b/cmd/zetaclientd/gen_pre_params.go new file mode 100644 index 0000000000..c797a9f206 --- /dev/null +++ b/cmd/zetaclientd/gen_pre_params.go @@ -0,0 +1,40 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "time" + + "github.com/binance-chain/tss-lib/ecdsa/keygen" + "github.com/spf13/cobra" +) + +func init() { + RootCmd.AddCommand(GenPrePramsCmd) +} + +var GenPrePramsCmd = &cobra.Command{ + Use: "gen-pre-params ", + Short: "Generate pre parameters for TSS", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + startTime := time.Now() + preParams, err := keygen.GeneratePreParams(time.Second * 300) + if err != nil { + return err + } + + file, err := os.OpenFile(args[0], os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return err + } + defer file.Close() + err = json.NewEncoder(file).Encode(preParams) + if err != nil { + return err + } + fmt.Printf("Generated new pre-parameters in %v\n", time.Since(startTime)) + return nil + }, +} diff --git a/cmd/zetae2e/local/local.go b/cmd/zetae2e/local/local.go index f8ea2d4fba..364bdf80a4 100644 --- a/cmd/zetae2e/local/local.go +++ b/cmd/zetae2e/local/local.go @@ -161,13 +161,6 @@ func localE2ETest(cmd *cobra.Command, _ []string) { // set account prefix to zeta setCosmosConfig() - // wait for Genesis - // if setup is skip, we assume that the genesis is already created - if !skipSetup { - logger.Print("⏳ wait 70s for genesis") - time.Sleep(70 * time.Second) - } - zetaTxServer, err := txserver.NewZetaTxServer( conf.RPCs.ZetaCoreRPC, []string{utils.FungibleAdminName}, @@ -399,7 +392,7 @@ func waitKeygenHeight( logger *runner.Logger, ) { // wait for keygen to be completed - keygenHeight := int64(60) + keygenHeight := int64(35) logger.Print("⏳ wait height %v for keygen to be completed", keygenHeight) for { time.Sleep(2 * time.Second) diff --git a/contrib/localnet/docker-compose.yml b/contrib/localnet/docker-compose.yml index 1c930d29d3..f176c30f00 100644 --- a/contrib/localnet/docker-compose.yml +++ b/contrib/localnet/docker-compose.yml @@ -40,6 +40,14 @@ services: - "26657:26657" - "6060:6060" - "9090:9090" + healthcheck: + # use the zevm endpoint for the healthcheck as it is the slowest to come up + test: ["CMD", "curl", "-f", "-X", "POST", "--data", '{"jsonrpc":"2.0","method":"web3_clientVersion","params":[],"id":67}', "-H", "Content-Type: application/json", "http://localhost:8545"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 30s + start_interval: 1s networks: mynetwork: ipv4_address: 172.20.0.11 @@ -78,6 +86,7 @@ services: - HOTKEY_PASSWORD=password # test purposes only volumes: - ssh:/root/.ssh + - preparams:/root/preparams zetaclient1: image: zetanode:latest @@ -93,6 +102,7 @@ services: - HOTKEY_PASSWORD=password # test purposes only volumes: - ssh:/root/.ssh + - preparams:/root/preparams eth: image: ethereum/client-go:v1.10.26 @@ -127,8 +137,10 @@ services: tty: true container_name: orchestrator depends_on: - - zetacore0 - - eth + zetacore0: + condition: service_healthy + eth: + condition: service_started hostname: orchestrator networks: mynetwork: @@ -137,4 +149,5 @@ services: volumes: - ssh:/root/.ssh volumes: - ssh: \ No newline at end of file + ssh: + preparams: \ No newline at end of file diff --git a/contrib/localnet/scripts/start-zetaclientd.sh b/contrib/localnet/scripts/start-zetaclientd.sh index ff2a3b9b99..01667df450 100755 --- a/contrib/localnet/scripts/start-zetaclientd.sh +++ b/contrib/localnet/scripts/start-zetaclientd.sh @@ -15,6 +15,13 @@ set_sepolia_endpoint() { jq '.EVMChainConfigs."11155111".Endpoint = "http://eth2:8545"' /root/.zetacored/config/zetaclient_config.json > tmp.json && mv tmp.json /root/.zetacored/config/zetaclient_config.json } +# generate pre-params as early as possible +# to reach keygen height on schedule +PREPARAMS_PATH="/root/preparams/${HOSTNAME}.json" +if [ ! -f "$PREPARAMS_PATH" ]; then + zetaclientd gen-pre-params "$PREPARAMS_PATH" +fi + # Wait for authorized_keys file to exist (generated by zetacore0) while [ ! -f ~/.ssh/authorized_keys ]; do echo "Waiting for authorized_keys file to exist..." @@ -42,7 +49,7 @@ echo "Start zetaclientd" if [[ $HOSTNAME == "zetaclient0" && ! -f ~/.zetacored/config/zetaclient_config.json ]] then MYIP=$(/sbin/ip -o -4 addr list eth0 | awk '{print $4}' | cut -d/ -f1) - zetaclientd init --zetacore-url zetacore0 --chain-id athens_101-1 --operator "$operatorAddress" --log-format=text --public-ip "$MYIP" --keyring-backend "$BACKEND" + zetaclientd init --zetacore-url zetacore0 --chain-id athens_101-1 --operator "$operatorAddress" --log-format=text --public-ip "$MYIP" --keyring-backend "$BACKEND" --pre-params "$PREPARAMS_PATH" # check if the option is additional-evm # in this case, the additional evm is represented with the sepolia chain, we set manually the eth2 endpoint to the sepolia chain (11155111 -> http://eth2:8545) @@ -59,9 +66,9 @@ then SEED="" while [ -z "$SEED" ] do - SEED=$(curl --retry 10 --retry-delay 5 --retry-connrefused -s zetaclient0:8123/p2p) + SEED=$(curl --retry 30 --retry-delay 1 --max-time 1 --retry-connrefused -s zetaclient0:8123/p2p) done - zetaclientd init --peer "/ip4/172.20.0.21/tcp/6668/p2p/${SEED}" --zetacore-url "$node" --chain-id athens_101-1 --operator "$operatorAddress" --log-format=text --public-ip "$MYIP" --log-level 1 --keyring-backend "$BACKEND" + zetaclientd init --peer "/ip4/172.20.0.21/tcp/6668/p2p/${SEED}" --zetacore-url "$node" --chain-id athens_101-1 --operator "$operatorAddress" --log-format=text --public-ip "$MYIP" --log-level 1 --keyring-backend "$BACKEND" --pre-params "$PREPARAMS_PATH" # check if the option is additional-evm # in this case, the additional evm is represented with the sepolia chain, we set manually the eth2 endpoint to the sepolia chain (11155111 -> http://eth2:8545) diff --git a/contrib/localnet/scripts/start-zetacored.sh b/contrib/localnet/scripts/start-zetacored.sh index 1791742e66..8e16ffbc4a 100755 --- a/contrib/localnet/scripts/start-zetacored.sh +++ b/contrib/localnet/scripts/start-zetacored.sh @@ -189,7 +189,7 @@ then # 2. Add the observers, authorizations, required params and accounts to the genesis.json zetacored collect-observer-info - zetacored add-observer-list --keygen-block 55 + zetacored add-observer-list --keygen-block 25 # Check for the existence of "AddToOutTxTracker" string in the genesis file # If this message is found in the genesis, it means add-observer-list has been run with the v16 binary for upgrade tests diff --git a/contrib/localnet/ssh_config b/contrib/localnet/ssh_config index dd05b90161..64779c2887 100644 --- a/contrib/localnet/ssh_config +++ b/contrib/localnet/ssh_config @@ -1,4 +1,2 @@ - Host * - StrictHostKeyChecking no - IdentityFile ~/.ssh/localtest.pem \ No newline at end of file + StrictHostKeyChecking no \ No newline at end of file diff --git a/zetaclient/config/config.go b/zetaclient/config/config.go index 8680071c3c..7dcf1b00f3 100644 --- a/zetaclient/config/config.go +++ b/zetaclient/config/config.go @@ -91,9 +91,10 @@ func GetPath(inputPath string) string { return "" } path[0] = home + return filepath.Join(path...) } } - return filepath.Join(path...) + return inputPath } // ContainRestrictedAddress returns true if any one of the addresses is restricted From ca9b90f8e953415c4babeea76ecb23848892ac2e Mon Sep 17 00:00:00 2001 From: Tanmay Date: Thu, 20 Jun 2024 17:15:19 -0400 Subject: [PATCH 2/6] refactor: use CheckAuthorization instead of IsAuthorized (#2319) --- changelog.md | 115 +++++----- cmd/zetacored/parsers_test.go | 32 +-- rpc/backend/mocks/client.go | 2 +- testutil/keeper/authority.go | 6 +- testutil/keeper/mocks/crosschain/authority.go | 17 +- testutil/keeper/mocks/fungible/authority.go | 17 +- .../keeper/mocks/lightclient/authority.go | 17 +- testutil/keeper/mocks/observer/authority.go | 16 +- x/authority/keeper/authorization_list.go | 17 +- .../keeper/msg_server_add_authorization.go | 11 +- .../msg_server_add_authorization_test.go | 71 ++++--- .../keeper/msg_server_remove_authorization.go | 11 +- .../msg_server_remove_authorization_test.go | 11 +- .../keeper/msg_server_update_chain_info.go | 9 +- .../msg_server_update_chain_info_test.go | 17 +- x/authority/keeper/policies_test.go | 26 --- x/authority/types/authorization_list.go | 1 + x/authority/types/authorization_list_test.go | 1 + x/authority/types/genesis.go | 5 +- .../keeper/msg_server_abort_stuck_cctx.go | 6 +- .../msg_server_abort_stuck_cctx_test.go | 43 ++-- .../keeper/msg_server_add_inbound_tracker.go | 12 +- .../msg_server_add_inbound_tracker_test.go | 130 ++++++------ .../keeper/msg_server_add_outbound_tracker.go | 10 +- .../msg_server_add_outbound_tracker_test.go | 154 +++++++------- .../keeper/msg_server_migrate_tss_funds.go | 11 +- .../msg_server_migrate_tss_funds_test.go | 200 +++++++++++------- .../keeper/msg_server_refund_aborted_tx.go | 5 +- .../msg_server_refund_aborted_tx_test.go | 155 ++++++++------ .../msg_server_remove_outbound_tracker.go | 6 +- ...msg_server_remove_outbound_tracker_test.go | 17 +- .../msg_server_update_rate_limiter_flags.go | 7 +- ...g_server_update_rate_limiter_flags_test.go | 22 +- x/crosschain/keeper/msg_server_update_tss.go | 10 +- .../keeper/msg_server_update_tss_test.go | 94 ++++---- .../keeper/msg_server_vote_inbound_tx_test.go | 6 +- .../keeper/msg_server_whitelist_erc20.go | 8 +- .../keeper/msg_server_whitelist_erc20_test.go | 66 +++--- x/crosschain/types/expected_keepers.go | 3 +- .../types/message_vote_inbound_test.go | 84 ++++---- .../types/message_vote_outbound_test.go | 72 +++---- .../msg_server_deploy_fungible_coin_zrc20.go | 8 +- ..._server_deploy_fungible_coin_zrc20_test.go | 57 ++--- .../msg_server_deploy_system_contract.go | 22 +- .../msg_server_deploy_system_contract_test.go | 45 ++-- x/fungible/keeper/msg_server_pause_zrc20.go | 10 +- .../keeper/msg_server_pause_zrc20_test.go | 39 ++-- .../keeper/msg_server_remove_foreign_coin.go | 8 +- .../msg_server_remove_foreign_coin_test.go | 17 +- .../msg_server_udpate_zrc20_liquidity_cap.go | 8 +- ..._server_udpate_zrc20_liquidity_cap_test.go | 47 ++-- x/fungible/keeper/msg_server_unpause_zrc20.go | 10 +- .../keeper/msg_server_unpause_zrc20_test.go | 39 ++-- .../msg_server_update_contract_bytecode.go | 26 +-- ...sg_server_update_contract_bytecode_test.go | 76 +++---- .../msg_server_update_system_contract.go | 8 +- .../msg_server_update_system_contract_test.go | 46 ++-- .../msg_server_update_zrc20_withdraw_fee.go | 8 +- ...g_server_update_zrc20_withdraw_fee_test.go | 68 +++--- x/fungible/types/expected_keepers.go | 3 +- ...disable_block_header_verification._test.go | 22 +- ...erver_disable_block_header_verification.go | 7 +- ...server_enable_block_header_verification.go | 7 +- ...r_enable_block_header_verification_test.go | 21 +- x/lightclient/types/expected_keepers.go | 4 +- x/observer/keeper/msg_server_add_observer.go | 6 +- .../keeper/msg_server_add_observer_test.go | 52 +++-- .../keeper/msg_server_disable_cctx_flags.go | 11 +- .../msg_server_disable_cctx_flags_test.go | 50 ++--- .../keeper/msg_server_enable_cctx_flags.go | 10 +- .../msg_server_enable_cctx_flags_test.go | 29 +-- .../keeper/msg_server_remove_chain_params.go | 7 +- .../msg_server_remove_chain_params_test.go | 47 ++-- .../keeper/msg_server_reset_chain_nonces.go | 6 +- .../msg_server_reset_chain_nonces_test.go | 48 +++-- .../keeper/msg_server_update_chain_params.go | 6 +- .../msg_server_update_chain_params_test.go | 47 ++-- ..._server_update_gas_price_increase_flags.go | 10 +- ...er_update_gas_price_increase_flags_test.go | 53 +++-- x/observer/keeper/msg_server_update_keygen.go | 6 +- .../keeper/msg_server_update_keygen_test.go | 38 ++-- .../keeper/msg_server_update_observer.go | 5 +- .../keeper/msg_server_update_observer_test.go | 10 +- x/observer/types/expected_keepers.go | 2 +- 84 files changed, 1358 insertions(+), 1214 deletions(-) diff --git a/changelog.md b/changelog.md index 2f15b653ba..fb156220c4 100644 --- a/changelog.md +++ b/changelog.md @@ -23,8 +23,9 @@ * [2291](https://github.com/zeta-chain/node/pull/2291) - initialize cctx gateway interface * [2289](https://github.com/zeta-chain/node/pull/2289) - add an authorization list to keep track of all authorizations on the chain * [2305](https://github.com/zeta-chain/node/pull/2305) - add new messages `MsgAddAuthorization` and `MsgRemoveAuthorization` that can be used to update the authorization list -* [2313](https://github.com/zeta-chain/node/pull/2313) - add `CheckAuthorization` function to replace the `IsAuthorized` function. The new function uses the authorization list to verify the signer's authorization. +* [2313](https://github.com/zeta-chain/node/pull/2313) - add `CheckAuthorization` function to replace the `IsAuthorized` function. The new function uses the authorization list to verify the signer's authorization * [2312](https://github.com/zeta-chain/node/pull/2312) - add queries `ShowAuthorization` and `ListAuthorizations` +* [2319](https://github.com/zeta-chain/node/pull/2319) - use `CheckAuthorization` function in all messages * [2325](https://github.com/zeta-chain/node/pull/2325) - revert telemetry server changes * [2339](https://github.com/zeta-chain/node/pull/2339) - add binaries related question to syncing issue form @@ -75,14 +76,14 @@ ### CI +* [2285](https://github.com/zeta-chain/node/pull/2285) - added nightly EVM performance testing pipeline, modified localnet testing docker image to utilitze debian:bookworm, removed build-jet runners where applicable, removed deprecated/removed upgrade path testing pipeline +* [2268](https://github.com/zeta-chain/node/pull/2268) - updated the publish-release pipeline to utilize the Github Actions Ubuntu 20.04 Runners +* [2070](https://github.com/zeta-chain/node/pull/2070) - Added commands to build binaries from the working branch as a live full node rpc to test non-governance changes +* [2119](https://github.com/zeta-chain/node/pull/2119) - Updated the release pipeline to only run on hotfix/ and release/ branches. Added option to only run pre-checks and not cut release as well. Switched approval steps to use environments +* [2189](https://github.com/zeta-chain/node/pull/2189) - Updated the docker tag when a release trigger runs to be the github event for the release name which should be the version. Removed mac specific build as the arm build should handle that +* [2191](https://github.com/zeta-chain/node/pull/2191) - Fixed conditional logic for the docker build step for non release builds to not overwrite the github tag +* [2192](https://github.com/zeta-chain/node/pull/2192) - Added release status checker and updater pipeline that will update release statuses when they go live on network * [2335](https://github.com/zeta-chain/node/pull/2335) - ci: updated the artillery report to publish to artillery cloud -* [2285](https://github.com/zeta-chain/node/pull/2285) - added nightly EVM performance testing pipeline, modified localnet testing docker image to utilitze debian:bookworm, removed build-jet runners where applicable, removed deprecated/removed upgrade path testing pipeline. -* [2268](https://github.com/zeta-chain/node/pull/2268) - updated the publish-release pipeline to utilize the Github Actions Ubuntu 20.04 Runners. -* [2070](https://github.com/zeta-chain/node/pull/2070) - Added commands to build binaries from the working branch as a live full node rpc to test non-governance changes. -* [2119](https://github.com/zeta-chain/node/pull/2119) - Updated the release pipeline to only run on hotfix/ and release/ branches. Added option to only run pre-checks and not cut release as well. Switched approval steps to use environments. -* [2189](https://github.com/zeta-chain/node/pull/2189) - Updated the docker tag when a release trigger runs to be the github event for the release name which should be the version. Removed mac specific build as the arm build should handle that. -* [2191](https://github.com/zeta-chain/node/pull/2191) - Fixed conditional logic for the docker build step for non release builds to not overwrite the github tag. -* [2192](https://github.com/zeta-chain/node/pull/2192) - Added release status checker and updater pipeline that will update release statuses when they go live on network. ## v17.0.0 @@ -95,25 +96,25 @@ ### Breaking Changes -* Admin policies have been moved from `observer` to a new module `authority`. - * Updating admin policies now requires to send a governance proposal executing the `UpdatePolicies` message in the `authority` module. - * The `Policies` query of the `authority` module must be used to get the current admin policies. - * `PolicyType_group1` has been renamed into `PolicyType_groupEmergency` and `PolicyType_group2` has been renamed into `PolicyType_groupAdmin`. +* Admin policies have been moved from `observer` to a new module `authority` + * Updating admin policies now requires to send a governance proposal executing the `UpdatePolicies` message in the `authority` module + * The `Policies` query of the `authority` module must be used to get the current admin policies + * `PolicyType_group1` has been renamed into `PolicyType_groupEmergency` and `PolicyType_group2` has been renamed into `PolicyType_groupAdmin` * A new module called `lightclient` has been created for the blocker header and proof functionality to add inbound and outbound trackers in a permissionless manner (currently deactivated on live networks) - * The list of block headers are now stored in the `lightclient` module instead of the `observer` module. - * The message to vote on new block headers is still in the `observer` module but has been renamed to `MsgVoteBlockHeader` instead of `MsgAddBlockHeader`. - * The `GetAllBlockHeaders` query has been moved to the `lightclient` module and renamed to `BlockHeaderAll`. - * The `GetBlockHeaderByHash` query has been moved to the `lightclient` module and renamed to `BlockHeader`. - * The `GetBlockHeaderStateByChain` query has been moved to the `lightclient` module and renamed to `ChainState`. - * The `Prove` query has been moved to the `lightclient` module. - * The `BlockHeaderVerificationFlags` has been deprecated in `CrosschainFlags`, `VerificationFlags` should be used instead. + * The list of block headers are now stored in the `lightclient` module instead of the `observer` module + * The message to vote on new block headers is still in the `observer` module but has been renamed to `MsgVoteBlockHeader` instead of `MsgAddBlockHeader` + * The `GetAllBlockHeaders` query has been moved to the `lightclient` module and renamed to `BlockHeaderAll` + * The `GetBlockHeaderByHash` query has been moved to the `lightclient` module and renamed to `BlockHeader` + * The `GetBlockHeaderStateByChain` query has been moved to the `lightclient` module and renamed to `ChainState` + * The `Prove` query has been moved to the `lightclient` module + * The `BlockHeaderVerificationFlags` has been deprecated in `CrosschainFlags`, `VerificationFlags` should be used instead -* `MsgGasPriceVoter` message in the `crosschain` module has been renamed to `MsgVoteGasPrice`. - * The structure of the message remains the same. +* `MsgGasPriceVoter` message in the `crosschain` module has been renamed to `MsgVoteGasPrice` + * The structure of the message remains the same -* `MsgCreateTSSVoter` message in the `crosschain` module has been moved to the `observer` module and renamed to `MsgVoteTSS`. - * The structure of the message remains the same. +* `MsgCreateTSSVoter` message in the `crosschain` module has been moved to the `observer` module and renamed to `MsgVoteTSS` + * The structure of the message remains the same ### Refactor @@ -137,13 +138,13 @@ * [2013](https://github.com/zeta-chain/node/pull/2013) - rename `GasPriceVoter` message to `VoteGasPrice` * [2059](https://github.com/zeta-chain/node/pull/2059) - Remove unused params from all functions in zetanode * [2071](https://github.com/zeta-chain/node/pull/2071) - Modify chains struct to add all chain related information -* [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM. +* [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM * [2169](https://github.com/zeta-chain/node/pull/2169) - Limit zEVM revert transactions to coin type ZETA ### Features * [1789](https://github.com/zeta-chain/node/issues/1789) - block cross-chain transactions that involve restricted addresses -* [1755](https://github.com/zeta-chain/node/issues/1755) - use evm JSON RPC for inbound tx (including blob tx) observation. +* [1755](https://github.com/zeta-chain/node/issues/1755) - use evm JSON RPC for inbound tx (including blob tx) observation * [1884](https://github.com/zeta-chain/node/pull/1884) - added zetatool cmd, added subcommand to filter deposits * [1942](https://github.com/zeta-chain/node/pull/1982) - support Bitcoin P2TR, P2WSH, P2SH, P2PKH addresses * [1935](https://github.com/zeta-chain/node/pull/1935) - add an operational authority group @@ -179,14 +180,14 @@ * [1992](https://github.com/zeta-chain/node/pull/1992) - remove setupKeeper from crosschain module * [2008](https://github.com/zeta-chain/node/pull/2008) - add test for connector bytecode update * [2047](https://github.com/zeta-chain/node/pull/2047) - fix liquidity cap advanced test -* [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM. +* [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM ### Fixes * [1861](https://github.com/zeta-chain/node/pull/1861) - fix `ObserverSlashAmount` invalid read -* [1880](https://github.com/zeta-chain/node/issues/1880) - lower the gas price multiplier for EVM chains. +* [1880](https://github.com/zeta-chain/node/issues/1880) - lower the gas price multiplier for EVM chains * [1883](https://github.com/zeta-chain/node/issues/1883) - zetaclient should check 'IsSupported' flag to pause/unpause a specific chain -* * [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM. +* * [2076](https://github.com/zeta-chain/node/pull/2076) - automatically deposit native zeta to an address if it doesn't exist on ZEVM * [1633](https://github.com/zeta-chain/node/issues/1633) - zetaclient should be able to pick up new connector and erc20Custody addresses * [1944](https://github.com/zeta-chain/node/pull/1944) - fix evm signer unit tests * [1888](https://github.com/zeta-chain/node/issues/1888) - zetaclient should stop inbound/outbound txs according to cross-chain flags @@ -198,11 +199,11 @@ ### CI -* [1958](https://github.com/zeta-chain/node/pull/1958) - Fix e2e advanced test debug checkbox. -* [1945](https://github.com/zeta-chain/node/pull/1945) - update advanced testing pipeline to not execute tests that weren't selected so they show skipped instead of skipping steps. +* [1958](https://github.com/zeta-chain/node/pull/1958) - Fix e2e advanced test debug checkbox +* [1945](https://github.com/zeta-chain/node/pull/1945) - update advanced testing pipeline to not execute tests that weren't selected so they show skipped instead of skipping steps * [1940](https://github.com/zeta-chain/node/pull/1940) - adjust release pipeline to be created as pre-release instead of latest -* [1867](https://github.com/zeta-chain/node/pull/1867) - default restore_type for full node docker-compose to snapshot instead of statesync for reliability. -* [1891](https://github.com/zeta-chain/node/pull/1891) - fix typo that was introduced to docker-compose and a typo in start.sh for the docker start script for full nodes. +* [1867](https://github.com/zeta-chain/node/pull/1867) - default restore_type for full node docker-compose to snapshot instead of statesync for reliability +* [1891](https://github.com/zeta-chain/node/pull/1891) - fix typo that was introduced to docker-compose and a typo in start.sh for the docker start script for full nodes * [1894](https://github.com/zeta-chain/node/pull/1894) - added download binaries and configs to the start sequence so it will download binaries that don't exist * [1953](https://github.com/zeta-chain/node/pull/1953) - run E2E tests for all PRs @@ -220,7 +221,7 @@ ### Breaking Changes * `zetaclientd start`: now requires 2 inputs from stdin: hotkey password and tss keyshare password - Starting zetaclient now requires two passwords to be input; one for the hotkey and another for the tss key-share. + Starting zetaclient now requires two passwords to be input; one for the hotkey and another for the tss key-share ### Features @@ -228,11 +229,11 @@ ### Docs -* [1731](https://github.com/zeta-chain/node/pull/1731) added doc for hotkey and tss key-share password prompts. +* [1731](https://github.com/zeta-chain/node/pull/1731) added doc for hotkey and tss key-share password prompts ### Features -* [1728] (https://github.com/zeta-chain/node/pull/1728) - allow aborted transactions to be refunded by minting tokens to zEvm. +* [1728] (https://github.com/zeta-chain/node/pull/1728) - allow aborted transactions to be refunded by minting tokens to zEvm ### Refactor @@ -252,7 +253,7 @@ * [1712](https://github.com/zeta-chain/node/issues/1712) - increase EVM outtx inclusion timeout to 20 minutes * [1733](https://github.com/zeta-chain/node/pull/1733) - remove the unnecessary 2x multiplier in the convertGasToZeta RPC * [1721](https://github.com/zeta-chain/node/issues/1721) - zetaclient should provide bitcoin_chain_id when querying TSS address -* [1744](https://github.com/zeta-chain/node/pull/1744) - added cmd to encrypt tss keyshare file, allowing empty tss password for backward compatibility. +* [1744](https://github.com/zeta-chain/node/pull/1744) - added cmd to encrypt tss keyshare file, allowing empty tss password for backward compatibility ### Tests @@ -264,18 +265,18 @@ ### CI -* Adjusted the release pipeline to be a manually executed pipeline with an approver step. The pipeline now executes all the required tests run before the approval step unless skipped. -* Added pipeline to build and push docker images into dockerhub on release for ubuntu and macos. -* Adjusted the pipeline for building and pushing docker images for MacOS to install and run docker. +* Adjusted the release pipeline to be a manually executed pipeline with an approver step. The pipeline now executes all the required tests run before the approval step unless skipped +* Added pipeline to build and push docker images into dockerhub on release for ubuntu and macos +* Adjusted the pipeline for building and pushing docker images for MacOS to install and run docker * Added docker-compose and make commands for launching full nodes. `make mainnet-zetarpc-node` `make mainnet-bitcoind-node` -* Made adjustments to the docker-compose for launching mainnet full nodes to include examples of using the docker images build from the docker image build pipeline. +* Made adjustments to the docker-compose for launching mainnet full nodes to include examples of using the docker images build from the docker image build pipeline * [1736](https://github.com/zeta-chain/node/pull/1736) - chore: add Ethermint endpoints to OpenAPI -* Re-wrote Dockerfile for building Zetacored docker images. -* Adjusted the docker-compose files for Zetacored nodes to utilize the new docker image. -* Added scripts for the new docker image that facilitate the start up automation. -* Adjusted the docker pipeline slightly to pull the version on PR from the app.go file. +* Re-wrote Dockerfile for building Zetacored docker images +* Adjusted the docker-compose files for Zetacored nodes to utilize the new docker image +* Added scripts for the new docker image that facilitate the start up automation +* Adjusted the docker pipeline slightly to pull the version on PR from the app.go file * [1781](https://github.com/zeta-chain/node/pull/1781) - add codecov coverage report in CI -* fixed the download binary script to use relative pathing from binary_list file. +* fixed the download binary script to use relative pathing from binary_list file ### Features @@ -320,13 +321,13 @@ * [1535](https://github.com/zeta-chain/node/issues/1535) - Avoid voting on wrong ballots due to false blockNumber in EVM tx receipt * [1588](https://github.com/zeta-chain/node/pull/1588) - fix chain params comparison logic * [1650](https://github.com/zeta-chain/node/pull/1605) - exempt (discounted) *system txs* from min gas price check and gas fee deduction -* [1632](https://github.com/zeta-chain/node/pull/1632) - set keygen to `KeygenStatus_KeyGenSuccess` if its in `KeygenStatus_PendingKeygen`. -* [1576](https://github.com/zeta-chain/node/pull/1576) - Fix zetaclient crash due to out of bound integer conversion and log prints. +* [1632](https://github.com/zeta-chain/node/pull/1632) - set keygen to `KeygenStatus_KeyGenSuccess` if its in `KeygenStatus_PendingKeygen` +* [1576](https://github.com/zeta-chain/node/pull/1576) - Fix zetaclient crash due to out of bound integer conversion and log prints * [1575](https://github.com/zeta-chain/node/issues/1575) - Skip unsupported chain parameters by IsSupported flag ### CI -* [1580](https://github.com/zeta-chain/node/pull/1580) - Fix release pipelines cleanup step. +* [1580](https://github.com/zeta-chain/node/pull/1580) - Fix release pipelines cleanup step ### Chores @@ -349,21 +350,21 @@ ### Breaking Changes TSS and chain validation related queries have been moved from `crosschain` module to `observer` module: -* `PendingNonces` :Changed from `/zeta-chain/crosschain/pendingNonces/{chain_id}/{address}` to `/zeta-chain/observer/pendingNonces/{chain_id}/{address}` . It returns all the pending nonces for a chain id and address. This returns the current pending nonces for the chain. -* `ChainNonces` : Changed from `/zeta-chain/crosschain/chainNonces/{chain_id}` to`/zeta-chain/observer/chainNonces/{chain_id}` . It returns all the chain nonces for a chain id. This returns the current nonce of the TSS address for the chain. -* `ChainNoncesAll` :Changed from `/zeta-chain/crosschain/chainNonces` to `/zeta-chain/observer/chainNonces` . It returns all the chain nonces for all chains. This returns the current nonce of the TSS address for all chains. +* `PendingNonces` :Changed from `/zeta-chain/crosschain/pendingNonces/{chain_id}/{address}` to `/zeta-chain/observer/pendingNonces/{chain_id}/{address}` . It returns all the pending nonces for a chain id and address. This returns the current pending nonces for the chain +* `ChainNonces` : Changed from `/zeta-chain/crosschain/chainNonces/{chain_id}` to`/zeta-chain/observer/chainNonces/{chain_id}` . It returns all the chain nonces for a chain id. This returns the current nonce of the TSS address for the chain +* `ChainNoncesAll` :Changed from `/zeta-chain/crosschain/chainNonces` to `/zeta-chain/observer/chainNonces` . It returns all the chain nonces for all chains. This returns the current nonce of the TSS address for all chains All chains now have the same observer set: -* `ObserversByChain`: `/zeta-chain/observer/observers_by_chain/{observation_chain}` has been removed and replaced with `/zeta-chain/observer/observer_set`. All chains have the same observer set. +* `ObserversByChain`: `/zeta-chain/observer/observers_by_chain/{observation_chain}` has been removed and replaced with `/zeta-chain/observer/observer_set`. All chains have the same observer set * `AllObserverMappers`: `/zeta-chain/observer/all_observer_mappers` has been removed. `/zeta-chain/observer/observer_set` should be used to get observers. Observer params and core params have been merged into chain params: * `Params`: `/zeta-chain/observer/params` no longer returns observer params. Observer params data have been moved to chain params described below. -* `GetCoreParams`: Renamed into `GetChainParams`. `/zeta-chain/observer/get_core_params` moved to `/zeta-chain/observer/get_chain_params`. -* `GetCoreParamsByChain`: Renamed into `GetChainParamsForChain`. `/zeta-chain/observer/get_core_params_by_chain` moved to `/zeta-chain/observer/get_chain_params_by_chain`. +* `GetCoreParams`: Renamed into `GetChainParams`. `/zeta-chain/observer/get_core_params` moved to `/zeta-chain/observer/get_chain_params` +* `GetCoreParamsByChain`: Renamed into `GetChainParamsForChain`. `/zeta-chain/observer/get_core_params_by_chain` moved to `/zeta-chain/observer/get_chain_params_by_chain` Getting the correct TSS address for Bitcoin now requires proviidng the Bitcoin chain id: -* `GetTssAddress` : Changed from `/zeta-chain/observer/get_tss_address/` to `/zeta-chain/observer/getTssAddress/{bitcoin_chain_id}` . Optional bitcoin chain id can now be passed as a parameter to fetch the correct tss for required BTC chain. This parameter only affects the BTC tss address in the response. +* `GetTssAddress` : Changed from `/zeta-chain/observer/get_tss_address/` to `/zeta-chain/observer/getTssAddress/{bitcoin_chain_id}` . Optional bitcoin chain id can now be passed as a parameter to fetch the correct tss for required BTC chain. This parameter only affects the BTC tss address in the response ### Features @@ -423,7 +424,7 @@ Getting the correct TSS address for Bitcoin now requires proviidng the Bitcoin c ### Chores -* [1446](https://github.com/zeta-chain/node/pull/1446) - renamed file `zetaclientd/aux.go` to `zetaclientd/utils.go` to avoid complaints from go package resolver. +* [1446](https://github.com/zeta-chain/node/pull/1446) - renamed file `zetaclientd/aux.go` to `zetaclientd/utils.go` to avoid complaints from go package resolver * [1499](https://github.com/zeta-chain/node/pull/1499) - Add scripts to localnet to help test gov proposals * [1442](https://github.com/zeta-chain/node/pull/1442) - remove build types in `.goreleaser.yaml` * [1504](https://github.com/zeta-chain/node/pull/1504) - remove `-race` in the `make install` commmand @@ -443,7 +444,7 @@ Getting the correct TSS address for Bitcoin now requires proviidng the Bitcoin c * [1387](https://github.com/zeta-chain/node/pull/1387) - Add HSM capability for zetaclient hot key * add a new thread to zetaclient which checks zeta supply in all connected chains in every block -* add a new tx to update an observer, this can be either be run a tombstoned observer/validator or via admin_policy_group_2. +* add a new tx to update an observer, this can be either be run a tombstoned observer/validator or via admin_policy_group_2 ### Fixes diff --git a/cmd/zetacored/parsers_test.go b/cmd/zetacored/parsers_test.go index 3b8e0a05d8..8750f0f61b 100644 --- a/cmd/zetacored/parsers_test.go +++ b/cmd/zetacored/parsers_test.go @@ -6,11 +6,10 @@ import ( "os" "testing" - "github.com/cometbft/cometbft/crypto" - sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" "github.com/zeta-chain/zetacore/app" + "github.com/zeta-chain/zetacore/testutil/sample" ) func TestParsefileToObserverMapper(t *testing.T) { @@ -20,31 +19,38 @@ func TestParsefileToObserverMapper(t *testing.T) { require.NoError(t, err) }(t, file) app.SetConfig() - createObserverList(file) + + observerAddress := sample.AccAddress() + commonGrantAddress := sample.AccAddress() + validatorAddress := sample.AccAddress() + + createObserverList(file, observerAddress, commonGrantAddress, validatorAddress) obsListReadFromFile, err := ParsefileToObserverDetails(file) require.NoError(t, err) for _, obs := range obsListReadFromFile { + require.Equal( + t, + obs.ObserverAddress, + observerAddress, + ) require.Equal( t, obs.ZetaClientGranteeAddress, - sdk.AccAddress(crypto.AddressHash([]byte("ObserverGranteeAddress"))).String(), + commonGrantAddress, ) } } -func createObserverList(fp string) { +func createObserverList(fp string, observerAddress, commonGrantAddress, validatorAddress string) { var listReader []ObserverInfoReader - commonGrantAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverGranteeAddress"))) - observerAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverAddress"))) - validatorAddress := sdk.ValAddress(crypto.AddressHash([]byte("ValidatorAddress"))) info := ObserverInfoReader{ - ObserverAddress: observerAddress.String(), - ZetaClientGranteeAddress: commonGrantAddress.String(), - StakingGranteeAddress: commonGrantAddress.String(), + ObserverAddress: observerAddress, + ZetaClientGranteeAddress: commonGrantAddress, + StakingGranteeAddress: commonGrantAddress, StakingMaxTokens: "100000000", - StakingValidatorAllowList: []string{validatorAddress.String()}, + StakingValidatorAllowList: []string{validatorAddress}, SpendMaxTokens: "100000000", - GovGranteeAddress: commonGrantAddress.String(), + GovGranteeAddress: commonGrantAddress, ZetaClientGranteePubKey: "zetapub1addwnpepqggtjvkmj6apcqr6ynyc5edxf2mpf5fxp2d3kwupemxtfwvg6gm7qv79fw0", } listReader = append(listReader, info) diff --git a/rpc/backend/mocks/client.go b/rpc/backend/mocks/client.go index 16090fcc69..d93d9c5a59 100644 --- a/rpc/backend/mocks/client.go +++ b/rpc/backend/mocks/client.go @@ -884,4 +884,4 @@ func NewClient(t mockConstructorTestingTNewClient) *Client { t.Cleanup(func() { mock.AssertExpectations(t) }) return mock -} \ No newline at end of file +} diff --git a/testutil/keeper/authority.go b/testutil/keeper/authority.go index be9101d1df..1e61bf9881 100644 --- a/testutil/keeper/authority.go +++ b/testutil/keeper/authority.go @@ -74,9 +74,9 @@ func AuthorityKeeper(t testing.TB) (*keeper.Keeper, sdk.Context) { return &k, ctx } -// MockIsAuthorized mocks the IsAuthorized method of an authority keeper mock -func MockIsAuthorized(m *mock.Mock, address string, policyType types.PolicyType, isAuthorized bool) { - m.On("IsAuthorized", mock.Anything, address, policyType).Return(isAuthorized).Once() +// MockCheckAuthorization mocks the CheckAuthorization method of the authority keeper. +func MockCheckAuthorization(m *mock.Mock, msg sdk.Msg, authorizationResult error) { + m.On("CheckAuthorization", mock.Anything, msg).Return(authorizationResult).Once() } func SetAdminPolicies(ctx sdk.Context, ak *keeper.Keeper) string { diff --git a/testutil/keeper/mocks/crosschain/authority.go b/testutil/keeper/mocks/crosschain/authority.go index 9f08c9d673..59edbd6d2d 100644 --- a/testutil/keeper/mocks/crosschain/authority.go +++ b/testutil/keeper/mocks/crosschain/authority.go @@ -4,7 +4,6 @@ package mocks import ( mock "github.com/stretchr/testify/mock" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" types "github.com/cosmos/cosmos-sdk/types" ) @@ -14,19 +13,19 @@ type CrosschainAuthorityKeeper struct { mock.Mock } -// IsAuthorized provides a mock function with given fields: ctx, address, policyType -func (_m *CrosschainAuthorityKeeper) IsAuthorized(ctx types.Context, address string, policyType authoritytypes.PolicyType) bool { - ret := _m.Called(ctx, address, policyType) +// CheckAuthorization provides a mock function with given fields: ctx, msg +func (_m *CrosschainAuthorityKeeper) CheckAuthorization(ctx types.Context, msg types.Msg) error { + ret := _m.Called(ctx, msg) if len(ret) == 0 { - panic("no return value specified for IsAuthorized") + panic("no return value specified for CheckAuthorization") } - var r0 bool - if rf, ok := ret.Get(0).(func(types.Context, string, authoritytypes.PolicyType) bool); ok { - r0 = rf(ctx, address, policyType) + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, types.Msg) error); ok { + r0 = rf(ctx, msg) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Error(0) } return r0 diff --git a/testutil/keeper/mocks/fungible/authority.go b/testutil/keeper/mocks/fungible/authority.go index 929a99021c..b3ed6cae5a 100644 --- a/testutil/keeper/mocks/fungible/authority.go +++ b/testutil/keeper/mocks/fungible/authority.go @@ -4,7 +4,6 @@ package mocks import ( mock "github.com/stretchr/testify/mock" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" types "github.com/cosmos/cosmos-sdk/types" ) @@ -14,19 +13,19 @@ type FungibleAuthorityKeeper struct { mock.Mock } -// IsAuthorized provides a mock function with given fields: ctx, address, policyType -func (_m *FungibleAuthorityKeeper) IsAuthorized(ctx types.Context, address string, policyType authoritytypes.PolicyType) bool { - ret := _m.Called(ctx, address, policyType) +// CheckAuthorization provides a mock function with given fields: ctx, msg +func (_m *FungibleAuthorityKeeper) CheckAuthorization(ctx types.Context, msg types.Msg) error { + ret := _m.Called(ctx, msg) if len(ret) == 0 { - panic("no return value specified for IsAuthorized") + panic("no return value specified for CheckAuthorization") } - var r0 bool - if rf, ok := ret.Get(0).(func(types.Context, string, authoritytypes.PolicyType) bool); ok { - r0 = rf(ctx, address, policyType) + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, types.Msg) error); ok { + r0 = rf(ctx, msg) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Error(0) } return r0 diff --git a/testutil/keeper/mocks/lightclient/authority.go b/testutil/keeper/mocks/lightclient/authority.go index 03dd6972de..8592036260 100644 --- a/testutil/keeper/mocks/lightclient/authority.go +++ b/testutil/keeper/mocks/lightclient/authority.go @@ -4,7 +4,6 @@ package mocks import ( mock "github.com/stretchr/testify/mock" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" types "github.com/cosmos/cosmos-sdk/types" ) @@ -14,19 +13,19 @@ type LightclientAuthorityKeeper struct { mock.Mock } -// IsAuthorized provides a mock function with given fields: ctx, address, policyType -func (_m *LightclientAuthorityKeeper) IsAuthorized(ctx types.Context, address string, policyType authoritytypes.PolicyType) bool { - ret := _m.Called(ctx, address, policyType) +// CheckAuthorization provides a mock function with given fields: ctx, msg +func (_m *LightclientAuthorityKeeper) CheckAuthorization(ctx types.Context, msg types.Msg) error { + ret := _m.Called(ctx, msg) if len(ret) == 0 { - panic("no return value specified for IsAuthorized") + panic("no return value specified for CheckAuthorization") } - var r0 bool - if rf, ok := ret.Get(0).(func(types.Context, string, authoritytypes.PolicyType) bool); ok { - r0 = rf(ctx, address, policyType) + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, types.Msg) error); ok { + r0 = rf(ctx, msg) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Error(0) } return r0 diff --git a/testutil/keeper/mocks/observer/authority.go b/testutil/keeper/mocks/observer/authority.go index 76e5e0566c..30a0d21b6d 100644 --- a/testutil/keeper/mocks/observer/authority.go +++ b/testutil/keeper/mocks/observer/authority.go @@ -14,19 +14,19 @@ type ObserverAuthorityKeeper struct { mock.Mock } -// IsAuthorized provides a mock function with given fields: ctx, address, policyType -func (_m *ObserverAuthorityKeeper) IsAuthorized(ctx types.Context, address string, policyType authoritytypes.PolicyType) bool { - ret := _m.Called(ctx, address, policyType) +// CheckAuthorization provides a mock function with given fields: ctx, msg +func (_m *ObserverAuthorityKeeper) CheckAuthorization(ctx types.Context, msg types.Msg) error { + ret := _m.Called(ctx, msg) if len(ret) == 0 { - panic("no return value specified for IsAuthorized") + panic("no return value specified for CheckAuthorization") } - var r0 bool - if rf, ok := ret.Get(0).(func(types.Context, string, authoritytypes.PolicyType) bool); ok { - r0 = rf(ctx, address, policyType) + var r0 error + if rf, ok := ret.Get(0).(func(types.Context, types.Msg) error); ok { + r0 = rf(ctx, msg) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Error(0) } return r0 diff --git a/x/authority/keeper/authorization_list.go b/x/authority/keeper/authorization_list.go index 24255bd089..4464b7b640 100644 --- a/x/authority/keeper/authorization_list.go +++ b/x/authority/keeper/authorization_list.go @@ -28,22 +28,7 @@ func (k Keeper) GetAuthorizationList(ctx sdk.Context) (val types.AuthorizationLi return val, true } -// IsAuthorized checks if the address is authorized for the given policy type -func (k Keeper) IsAuthorized(ctx sdk.Context, address string, policyType types.PolicyType) bool { - policies, found := k.GetPolicies(ctx) - if !found { - return false - } - for _, policy := range policies.Items { - if policy.Address == address && policy.PolicyType == policyType { - return true - } - } - return false -} - -// CheckAuthorization checks if the signer is authorized to sign the message -// It uses both the authorization list and the policies to check if the signer is authorized +// CheckAuthorization uses both the authorization list and the policies to check if the signer is authorized func (k Keeper) CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error { // Policy transactions must have only one signer if len(msg.GetSigners()) != 1 { diff --git a/x/authority/keeper/msg_server_add_authorization.go b/x/authority/keeper/msg_server_add_authorization.go index 252884cd3f..3c3247fc4e 100644 --- a/x/authority/keeper/msg_server_add_authorization.go +++ b/x/authority/keeper/msg_server_add_authorization.go @@ -17,11 +17,10 @@ func (k msgServer) AddAuthorization( ) (*types.MsgAddAuthorizationResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.IsAuthorized(ctx, msg.Creator, types.PolicyType_groupAdmin) { - return nil, errorsmod.Wrap( - types.ErrUnauthorized, - "AddAuthorization can only be executed by the admin policy account", - ) + // check if the caller is authorized to add an authorization + err := k.CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(types.ErrUnauthorized, err.Error()) } authorizationList, found := k.GetAuthorizationList(ctx) @@ -31,7 +30,7 @@ func (k msgServer) AddAuthorization( authorizationList.SetAuthorization(types.Authorization{MsgUrl: msg.MsgUrl, AuthorizedPolicy: msg.AuthorizedPolicy}) // validate the authorization list after adding the authorization as a precautionary measure. - err := authorizationList.Validate() + err = authorizationList.Validate() if err != nil { return nil, errorsmod.Wrap(err, "authorization list is invalid") } diff --git a/x/authority/keeper/msg_server_add_authorization_test.go b/x/authority/keeper/msg_server_add_authorization_test.go index 95c962951e..f761935a0c 100644 --- a/x/authority/keeper/msg_server_add_authorization_test.go +++ b/x/authority/keeper/msg_server_add_authorization_test.go @@ -14,6 +14,10 @@ import ( func TestMsgServer_AddAuthorization(t *testing.T) { const url = "/zetachain.zetacore.sample.ABC" + var AddAuthorization = types.Authorization{ + MsgUrl: "/zetachain.zetacore.authority.MsgAddAuthorization", + AuthorizedPolicy: types.PolicyType_groupAdmin, + } t.Run("successfully add authorization of type admin to existing authorization list", func(t *testing.T) { k, ctx := keepertest.AuthorityKeeper(t) admin := keepertest.SetAdminPolicies(ctx, k) @@ -86,35 +90,41 @@ func TestMsgServer_AddAuthorization(t *testing.T) { require.Equal(t, prevLen+1, len(authorizationList.Authorizations)) }) - t.Run("successfully add authorization to empty authorization list", func(t *testing.T) { - k, ctx := keepertest.AuthorityKeeper(t) - admin := keepertest.SetAdminPolicies(ctx, k) - k.SetAuthorizationList(ctx, types.AuthorizationList{}) - msgServer := keeper.NewMsgServerImpl(*k) - - msg := &types.MsgAddAuthorization{ - Creator: admin, - MsgUrl: url, - AuthorizedPolicy: types.PolicyType_groupAdmin, - } - - _, err := msgServer.AddAuthorization(sdk.WrapSDKContext(ctx), msg) - require.NoError(t, err) - - authorizationList, found := k.GetAuthorizationList(ctx) - require.True(t, found) - policy, err := authorizationList.GetAuthorizedPolicy(url) - require.NoError(t, err) - require.Equal(t, types.PolicyType_groupAdmin, policy) - require.Equal(t, 1, len(authorizationList.Authorizations)) - }) + t.Run( + "successfully add authorization to list containing only authorization for AddAuthorization", + func(t *testing.T) { + k, ctx := keepertest.AuthorityKeeper(t) + admin := keepertest.SetAdminPolicies(ctx, k) + k.SetAuthorizationList(ctx, types.AuthorizationList{ + Authorizations: []types.Authorization{ + AddAuthorization, + }, + }) + msgServer := keeper.NewMsgServerImpl(*k) + + msg := &types.MsgAddAuthorization{ + Creator: admin, + MsgUrl: url, + AuthorizedPolicy: types.PolicyType_groupAdmin, + } + + _, err := msgServer.AddAuthorization(sdk.WrapSDKContext(ctx), msg) + require.NoError(t, err) + + authorizationList, found := k.GetAuthorizationList(ctx) + require.True(t, found) + policy, err := authorizationList.GetAuthorizedPolicy(url) + require.NoError(t, err) + require.Equal(t, types.PolicyType_groupAdmin, policy) + require.Equal(t, 2, len(authorizationList.Authorizations)) + }, + ) - t.Run("successfully set authorization when list is not found ", func(t *testing.T) { + t.Run("unable to add authorization to empty authorization list", func(t *testing.T) { k, ctx := keepertest.AuthorityKeeper(t) admin := keepertest.SetAdminPolicies(ctx, k) + k.SetAuthorizationList(ctx, types.AuthorizationList{}) msgServer := keeper.NewMsgServerImpl(*k) - authorizationList, found := k.GetAuthorizationList(ctx) - require.False(t, found) msg := &types.MsgAddAuthorization{ Creator: admin, @@ -123,14 +133,7 @@ func TestMsgServer_AddAuthorization(t *testing.T) { } _, err := msgServer.AddAuthorization(sdk.WrapSDKContext(ctx), msg) - require.NoError(t, err) - - authorizationList, found = k.GetAuthorizationList(ctx) - require.True(t, found) - policy, err := authorizationList.GetAuthorizedPolicy(url) - require.NoError(t, err) - require.Equal(t, types.PolicyType_groupAdmin, policy) - require.Equal(t, 1, len(authorizationList.Authorizations)) + require.ErrorIs(t, err, types.ErrUnauthorized) }) t.Run("update existing authorization", func(t *testing.T) { @@ -141,6 +144,7 @@ func TestMsgServer_AddAuthorization(t *testing.T) { MsgUrl: "/zetachain.zetacore.sample.ABC", AuthorizedPolicy: types.PolicyType_groupOperational, }, + AddAuthorization, }, } k.SetAuthorizationList(ctx, authorizationList) @@ -198,6 +202,7 @@ func TestMsgServer_AddAuthorization(t *testing.T) { MsgUrl: url, AuthorizedPolicy: types.PolicyType_groupOperational, }, + AddAuthorization, }} k.SetAuthorizationList(ctx, authorizationList) prevLen := len(authorizationList.Authorizations) diff --git a/x/authority/keeper/msg_server_remove_authorization.go b/x/authority/keeper/msg_server_remove_authorization.go index f84e2fdffe..671406429c 100644 --- a/x/authority/keeper/msg_server_remove_authorization.go +++ b/x/authority/keeper/msg_server_remove_authorization.go @@ -18,11 +18,10 @@ func (k msgServer) RemoveAuthorization( ) (*types.MsgRemoveAuthorizationResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.IsAuthorized(ctx, msg.Creator, types.PolicyType_groupAdmin) { - return nil, errorsmod.Wrap( - types.ErrUnauthorized, - "RemoveAuthorization can only be executed by the admin policy account", - ) + // check if the caller is authorized to remove an authorization + err := k.CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(types.ErrUnauthorized, err.Error()) } // check if the authorization list exists, we can return early if there is no list. @@ -32,7 +31,7 @@ func (k msgServer) RemoveAuthorization( } // check if the authorization exists, we can return early if the authorization does not exist. - _, err := authorizationList.GetAuthorizedPolicy(msg.MsgUrl) + _, err = authorizationList.GetAuthorizedPolicy(msg.MsgUrl) if err != nil { return nil, errorsmod.Wrap(err, fmt.Sprintf("msg url %s", msg.MsgUrl)) } diff --git a/x/authority/keeper/msg_server_remove_authorization_test.go b/x/authority/keeper/msg_server_remove_authorization_test.go index 566369839a..6e39ab4701 100644 --- a/x/authority/keeper/msg_server_remove_authorization_test.go +++ b/x/authority/keeper/msg_server_remove_authorization_test.go @@ -13,6 +13,10 @@ import ( ) func TestMsgServer_RemoveAuthorization(t *testing.T) { + var removeAuthorization = types.Authorization{ + MsgUrl: "/zetachain.zetacore.authority.MsgRemoveAuthorization", + AuthorizedPolicy: types.PolicyType_groupAdmin, + } t.Run("successfully remove operational policy authorization", func(t *testing.T) { k, ctx := keepertest.AuthorityKeeper(t) admin := keepertest.SetAdminPolicies(ctx, k) @@ -135,7 +139,7 @@ func TestMsgServer_RemoveAuthorization(t *testing.T) { } _, err := msgServer.RemoveAuthorization(sdk.WrapSDKContext(ctx), msg) - require.ErrorIs(t, err, types.ErrAuthorizationListNotFound) + require.ErrorContains(t, err, types.ErrAuthorizationListNotFound.Error()) }) t.Run("unable to remove authorization if authorization does not exist", func(t *testing.T) { @@ -162,7 +166,7 @@ func TestMsgServer_RemoveAuthorization(t *testing.T) { require.ErrorIs(t, err, types.ErrAuthorizationNotFound) _, err = msgServer.RemoveAuthorization(sdk.WrapSDKContext(ctx), msg) - require.ErrorIs(t, err, types.ErrAuthorizationNotFound) + require.ErrorContains(t, err, types.ErrAuthorizationNotFound.Error()) authorizationListNew, found := k.GetAuthorizationList(ctx) require.True(t, found) @@ -185,6 +189,7 @@ func TestMsgServer_RemoveAuthorization(t *testing.T) { MsgUrl: "ABC", AuthorizedPolicy: types.PolicyType_groupOperational, }, + removeAuthorization, }} k.SetAuthorizationList(ctx, authorizationList) msgServer := keeper.NewMsgServerImpl(*k) @@ -195,7 +200,7 @@ func TestMsgServer_RemoveAuthorization(t *testing.T) { } _, err := msgServer.RemoveAuthorization(sdk.WrapSDKContext(ctx), msg) - require.ErrorIs(t, err, types.ErrInvalidAuthorizationList) + require.ErrorContains(t, err, types.ErrInvalidAuthorizationList.Error()) authorizationListNew, found := k.GetAuthorizationList(ctx) require.True(t, found) diff --git a/x/authority/keeper/msg_server_update_chain_info.go b/x/authority/keeper/msg_server_update_chain_info.go index 11cd064c2a..5200e073e3 100644 --- a/x/authority/keeper/msg_server_update_chain_info.go +++ b/x/authority/keeper/msg_server_update_chain_info.go @@ -2,9 +2,8 @@ package keeper import ( "context" - "fmt" - cosmoserror "cosmossdk.io/errors" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/x/authority/types" @@ -21,10 +20,10 @@ func (k msgServer) UpdateChainInfo( // This message is only allowed to be called by group admin // Group admin because this functionality would rarely be called // and overwriting false chain info can have undesired effects - if !k.IsAuthorized(ctx, msg.Creator, types.PolicyType_groupAdmin) { - return nil, cosmoserror.Wrap(types.ErrUnauthorized, fmt.Sprintf("creator %s", msg.Creator)) + err := k.CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(types.ErrUnauthorized, err.Error()) } - // set chain info k.SetChainInfo(ctx, msg.ChainInfo) diff --git a/x/authority/keeper/msg_server_update_chain_info_test.go b/x/authority/keeper/msg_server_update_chain_info_test.go index e9c5762940..47073bf70b 100644 --- a/x/authority/keeper/msg_server_update_chain_info_test.go +++ b/x/authority/keeper/msg_server_update_chain_info_test.go @@ -16,10 +16,17 @@ func TestMsgServer_UpdateChainInfo(t *testing.T) { t.Run("can't update chain info if not authorized", func(t *testing.T) { k, ctx := keepertest.AuthorityKeeper(t) msgServer := keeper.NewMsgServerImpl(*k) - - _, err := msgServer.UpdateChainInfo(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainInfo{ + msg := types.MsgUpdateChainInfo{ Creator: sample.AccAddress(), - }) + } + k.SetAuthorizationList(ctx, types.AuthorizationList{Authorizations: []types.Authorization{ + { + MsgUrl: sdk.MsgTypeURL(&msg), + AuthorizedPolicy: types.PolicyType_groupAdmin, + }, + }}) + + _, err := msgServer.UpdateChainInfo(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, types.ErrUnauthorized) }) @@ -42,6 +49,8 @@ func TestMsgServer_UpdateChainInfo(t *testing.T) { }) chainInfo := sample.ChainInfo(42) + k.SetAuthorizationList(ctx, types.DefaultAuthorizationsList()) + _, err := msgServer.UpdateChainInfo(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainInfo{ Creator: admin, ChainInfo: chainInfo, @@ -71,6 +80,7 @@ func TestMsgServer_UpdateChainInfo(t *testing.T) { }, }) chainInfo := sample.ChainInfo(84) + k.SetAuthorizationList(ctx, types.DefaultAuthorizationsList()) _, err := msgServer.UpdateChainInfo(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainInfo{ Creator: admin, @@ -101,6 +111,7 @@ func TestMsgServer_UpdateChainInfo(t *testing.T) { }, }) chainInfo := types.ChainInfo{} + k.SetAuthorizationList(ctx, types.DefaultAuthorizationsList()) _, err := msgServer.UpdateChainInfo(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainInfo{ Creator: admin, diff --git a/x/authority/keeper/policies_test.go b/x/authority/keeper/policies_test.go index 5e2347bd05..3cb7f486a6 100644 --- a/x/authority/keeper/policies_test.go +++ b/x/authority/keeper/policies_test.go @@ -7,7 +7,6 @@ import ( keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" - "github.com/zeta-chain/zetacore/x/authority/types" ) func TestKeeper_SetPolicies(t *testing.T) { @@ -32,28 +31,3 @@ func TestKeeper_SetPolicies(t *testing.T) { require.True(t, found) require.Equal(t, newPolicies, got) } - -func TestKeeper_IsAuthorized(t *testing.T) { - k, ctx := keepertest.AuthorityKeeper(t) - - // Not authorized if no policies - require.False(t, k.IsAuthorized(ctx, sample.AccAddress(), types.PolicyType_groupAdmin)) - require.False(t, k.IsAuthorized(ctx, sample.AccAddress(), types.PolicyType_groupEmergency)) - - policies := sample.Policies() - k.SetPolicies(ctx, policies) - - // Check policy is set - got, found := k.GetPolicies(ctx) - require.True(t, found) - require.Equal(t, policies, got) - - // Check policy is authorized - for _, policy := range policies.Items { - require.True(t, k.IsAuthorized(ctx, policy.Address, policy.PolicyType)) - } - - // Check policy is not authorized - require.False(t, k.IsAuthorized(ctx, sample.AccAddress(), types.PolicyType_groupAdmin)) - require.False(t, k.IsAuthorized(ctx, sample.AccAddress(), types.PolicyType_groupEmergency)) -} diff --git a/x/authority/types/authorization_list.go b/x/authority/types/authorization_list.go index 848b1da1d2..5f72408833 100644 --- a/x/authority/types/authorization_list.go +++ b/x/authority/types/authorization_list.go @@ -36,6 +36,7 @@ var ( "/zetachain.zetacore.observer.MsgUpdateObserver", "/zetachain.zetacore.authority.MsgAddAuthorization", "/zetachain.zetacore.authority.MsgRemoveAuthorization", + "/zetachain.zetacore.authority.MsgUpdateChainInfo", } // EmergencyPolicyMessages keeps track of the message URLs that can, by default, only be executed by emergency policy address EmergencyPolicyMessages = []string{ diff --git a/x/authority/types/authorization_list_test.go b/x/authority/types/authorization_list_test.go index 009e958c15..ee301c92d6 100644 --- a/x/authority/types/authorization_list_test.go +++ b/x/authority/types/authorization_list_test.go @@ -422,6 +422,7 @@ func TestDefaultAuthorizationsList(t *testing.T) { sdk.MsgTypeURL(&observertypes.MsgUpdateObserver{}), sdk.MsgTypeURL(&types.MsgAddAuthorization{}), sdk.MsgTypeURL(&types.MsgRemoveAuthorization{}), + sdk.MsgTypeURL(&types.MsgUpdateChainInfo{}), } defaultList := types.DefaultAuthorizationsList() for _, msgUrl := range OperationalPolicyMessageList { diff --git a/x/authority/types/genesis.go b/x/authority/types/genesis.go index 13a01030bf..8a35955104 100644 --- a/x/authority/types/genesis.go +++ b/x/authority/types/genesis.go @@ -3,8 +3,9 @@ package types // DefaultGenesis returns the default authority genesis state func DefaultGenesis() *GenesisState { return &GenesisState{ - Policies: DefaultPolicies(), - ChainInfo: DefaultChainInfo(), + Policies: DefaultPolicies(), + ChainInfo: DefaultChainInfo(), + AuthorizationList: DefaultAuthorizationsList(), } } diff --git a/x/crosschain/keeper/msg_server_abort_stuck_cctx.go b/x/crosschain/keeper/msg_server_abort_stuck_cctx.go index ab6038d30a..5ea695568f 100644 --- a/x/crosschain/keeper/msg_server_abort_stuck_cctx.go +++ b/x/crosschain/keeper/msg_server_abort_stuck_cctx.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -23,8 +24,9 @@ func (k msgServer) AbortStuckCCTX( ctx := sdk.UnwrapSDKContext(goCtx) // check if authorized - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // check if the cctx exists diff --git a/x/crosschain/keeper/msg_server_abort_stuck_cctx_test.go b/x/crosschain/keeper/msg_server_abort_stuck_cctx_test.go index aa1fc26786..cdc38b36f2 100644 --- a/x/crosschain/keeper/msg_server_abort_stuck_cctx_test.go +++ b/x/crosschain/keeper/msg_server_abort_stuck_cctx_test.go @@ -21,7 +21,6 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { msgServer := crosschainkeeper.NewMsgServerImpl(*k) admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // create a cctx cctx := sample.CrossChainTx(t, "cctx_index") @@ -32,10 +31,12 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.NoError(t, err) cctxFound, found := k.GetCrossChainTx(ctx, sample.GetCctxIndexFromString("cctx_index")) @@ -52,8 +53,6 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { msgServer := crosschainkeeper.NewMsgServerImpl(*k) admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // create a cctx cctx := sample.CrossChainTx(t, "cctx_index") cctx.CctxStatus = &crosschaintypes.Status{ @@ -63,10 +62,12 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.NoError(t, err) cctxFound, found := k.GetCrossChainTx(ctx, sample.GetCctxIndexFromString("cctx_index")) @@ -83,7 +84,6 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { msgServer := crosschainkeeper.NewMsgServerImpl(*k) admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // create a cctx cctx := sample.CrossChainTx(t, "cctx_index") @@ -94,10 +94,12 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.NoError(t, err) cctxFound, found := k.GetCrossChainTx(ctx, sample.GetCctxIndexFromString("cctx_index")) @@ -114,7 +116,6 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) // create a cctx cctx := sample.CrossChainTx(t, "cctx_index") @@ -125,10 +126,12 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -140,13 +143,14 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.ErrorIs(t, err, crosschaintypes.ErrCannotFindCctx) }) @@ -158,7 +162,6 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // create a cctx cctx := sample.CrossChainTx(t, "cctx_index") @@ -169,10 +172,12 @@ func TestMsgServer_AbortStuckCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) // abort the cctx - _, err := msgServer.AbortStuckCCTX(ctx, &crosschaintypes.MsgAbortStuckCCTX{ + msg := crosschaintypes.MsgAbortStuckCCTX{ Creator: admin, CctxIndex: sample.GetCctxIndexFromString("cctx_index"), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AbortStuckCCTX(ctx, &msg) require.ErrorIs(t, err, crosschaintypes.ErrStatusNotPending) }) } diff --git a/x/crosschain/keeper/msg_server_add_inbound_tracker.go b/x/crosschain/keeper/msg_server_add_inbound_tracker.go index 5f2c85e96a..fd9b8fa2bd 100644 --- a/x/crosschain/keeper/msg_server_add_inbound_tracker.go +++ b/x/crosschain/keeper/msg_server_add_inbound_tracker.go @@ -23,13 +23,19 @@ func (k msgServer) AddInboundTracker( return nil, observertypes.ErrSupportedChains } - // emergency or observer group can submit tracker without proof - isEmergencyGroup := k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) + // check if the msg signer is from the emergency group policy address.It is okay to ignore the error as the sender can also be an observer + isAuthorizedPolicy := false + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err == nil { + isAuthorizedPolicy = true + } + + // check if the msg signer is an observer isObserver := k.GetObserverKeeper().IsNonTombstonedObserver(ctx, msg.Creator) // only emergency group and observer can submit tracker without proof // if the sender is not from the emergency group or observer, the inbound proof must be provided - if !(isEmergencyGroup || isObserver) { + if !(isAuthorizedPolicy || isObserver) { if msg.Proof == nil { return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, fmt.Sprintf("Creator %s", msg.Creator)) } diff --git a/x/crosschain/keeper/msg_server_add_inbound_tracker_test.go b/x/crosschain/keeper/msg_server_add_inbound_tracker_test.go index 394ea622f4..9d135d513d 100644 --- a/x/crosschain/keeper/msg_server_add_inbound_tracker_test.go +++ b/x/crosschain/keeper/msg_server_add_inbound_tracker_test.go @@ -30,15 +30,13 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - - keepertest.MockIsAuthorized(&authorityMock.Mock, nonAdmin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) - observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) - txHash := "string" chainID := getValidEthChainID() - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) + observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) + + msg := types.MsgAddInboundTracker{ Creator: nonAdmin, ChainId: chainID, TxHash: txHash, @@ -46,7 +44,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: nil, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) _, found := k.GetInboundTracker(ctx, chainID, txHash) require.False(t, found) @@ -58,14 +58,13 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { UseObserverMock: true, }) msgServer := keeper.NewMsgServerImpl(*k) - - observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockFailedGetSupportedChainFromChainID(observerMock, nil) - txHash := "string" chainID := getValidEthChainID() - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + observerMock := keepertest.GetCrosschainObserverMock(t, k) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(nil) + + msg := types.MsgAddInboundTracker{ Creator: sample.AccAddress(), ChainId: chainID + 1, TxHash: txHash, @@ -73,7 +72,8 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: nil, BlockHash: "", TxIndex: 0, - }) + } + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, observertypes.ErrSupportedChains) _, found := k.GetInboundTracker(ctx, chainID, txHash) require.False(t, found) @@ -89,16 +89,15 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) + txHash := "string" + chainID := getValidEthChainID() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) - txHash := "string" - chainID := getValidEthChainID() setSupportedChain(ctx, zk, chainID) - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -106,7 +105,10 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: nil, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AddInboundTracker(ctx, &msg) + require.NoError(t, err) _, found := k.GetInboundTracker(ctx, chainID, txHash) require.True(t, found) @@ -122,15 +124,13 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) - observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(true) - txHash := "string" chainID := getValidEthChainID() - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) + observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(true) + + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -138,7 +138,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: nil, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.NoError(t, err) _, found := k.GetInboundTracker(ctx, chainID, txHash) require.True(t, found) @@ -151,22 +153,20 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { UseObserverMock: true, }) msgServer := keeper.NewMsgServerImpl(*k) - - admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + admin := sample.AccAddress() + txHash := "string" + chainID := getValidEthChainID() + + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil, errors.New("error")) - txHash := "string" - chainID := getValidEthChainID() - - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -174,7 +174,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: &proofs.Proof{}, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrProofVerificationFail) }) @@ -190,18 +192,16 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) + txHash := "string" + chainID := getValidEthChainID() - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(sample.Bytes(), nil) observerMock.On("GetChainParamsByChainID", mock.Anything, mock.Anything).Return(nil, false) - txHash := "string" - chainID := getValidEthChainID() - - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -209,7 +209,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: &proofs.Proof{}, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrUnsupportedChain) }) @@ -221,13 +223,15 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + admin := sample.AccAddress() + txHash := "string" + chainID := getValidEthChainID() + + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(sample.Bytes(), nil) @@ -235,11 +239,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Return(sample.ChainParams(chains.Ethereum.ChainId), true) observerMock.On("GetTssAddress", mock.Anything, mock.Anything).Return(nil, errors.New("error")) - txHash := "string" - chainID := getValidEthChainID() setSupportedChain(ctx, zk, chainID) - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -247,7 +249,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: &proofs.Proof{}, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, observertypes.ErrTssNotFound) }) @@ -259,13 +263,15 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + admin := sample.AccAddress() + txHash := "string" + chainID := getValidEthChainID() + + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) observerMock.On("GetChainParamsByChainID", mock.Anything, mock.Anything). Return(sample.ChainParams(chains.Ethereum.ChainId), true) @@ -277,11 +283,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return([]byte("invalid"), nil) - txHash := "string" - chainID := getValidEthChainID() setSupportedChain(ctx, zk, chainID) - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -289,7 +293,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: &proofs.Proof{}, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrTxBodyVerificationFail) }) @@ -301,19 +307,17 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() - chainID := chains.Ethereum.ChainId tssAddress := sample.EthAddress() ethTx, ethTxBytes := sample.EthTx(t, chainID, tssAddress, 42) + admin := sample.AccAddress() txHash := ethTx.Hash().Hex() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, mock.Anything, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) observerMock.On("GetChainParamsByChainID", mock.Anything, mock.Anything). Return(sample.ChainParams(chains.Ethereum.ChainId), true) @@ -323,7 +327,7 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(ethTxBytes, nil) - _, err := msgServer.AddInboundTracker(ctx, &types.MsgAddInboundTracker{ + msg := types.MsgAddInboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -331,7 +335,9 @@ func TestMsgServer_AddToInboundTracker(t *testing.T) { Proof: &proofs.Proof{}, BlockHash: "", TxIndex: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddInboundTracker(ctx, &msg) require.NoError(t, err) _, found := k.GetInboundTracker(ctx, chainID, txHash) require.True(t, found) diff --git a/x/crosschain/keeper/msg_server_add_outbound_tracker.go b/x/crosschain/keeper/msg_server_add_outbound_tracker.go index 977343265e..ce7b2f0c03 100644 --- a/x/crosschain/keeper/msg_server_add_outbound_tracker.go +++ b/x/crosschain/keeper/msg_server_add_outbound_tracker.go @@ -57,13 +57,19 @@ func (k msgServer) AddOutboundTracker( return &types.MsgAddOutboundTrackerResponse{IsRemoved: true}, nil } - isEmergencyGroup := k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) + // check if the msg signer is from the emergency group policy address.It is okay to ignore the error as the sender can also be an observer + isAuthorizedPolicy := false + if k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) == nil { + isAuthorizedPolicy = true + } + + // check if the msg signer is an observer isObserver := k.GetObserverKeeper().IsNonTombstonedObserver(ctx, msg.Creator) isProven := false // only emergency group and observer can submit tracker without proof // if the sender is not from the emergency group or observer, the outbound proof must be provided - if !(isEmergencyGroup || isObserver) { + if !(isAuthorizedPolicy || isObserver) { if msg.Proof == nil { return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, fmt.Sprintf("Creator %s", msg.Creator)) } diff --git a/x/crosschain/keeper/msg_server_add_outbound_tracker_test.go b/x/crosschain/keeper/msg_server_add_outbound_tracker_test.go index b8a213ee36..e43800b2ac 100644 --- a/x/crosschain/keeper/msg_server_add_outbound_tracker_test.go +++ b/x/crosschain/keeper/msg_server_add_outbound_tracker_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/proofs" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" @@ -32,19 +33,17 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + chainID := getEthereumChainID() + hash := sample.Hash().Hex() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) - chainID := getEthereumChainID() - hash := sample.Hash().Hex() - - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: hash, @@ -52,7 +51,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 0) require.True(t, found) @@ -66,20 +67,18 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) - observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(true) - keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) - + admin := sample.AccAddress() chainID := getEthereumChainID() hash := sample.Hash().Hex() - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) + observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(true) + keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) + + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: hash, @@ -87,7 +86,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 0) require.True(t, found) @@ -101,20 +102,18 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) - observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) - keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) - + admin := sample.AccAddress() chainID := getEthereumChainID() existinghHash := sample.Hash().Hex() newHash := sample.Hash().Hex() + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) + observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) + keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) + k.SetOutboundTracker(ctx, types.OutboundTracker{ ChainId: chainID, Nonce: 42, @@ -125,7 +124,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }, }) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: newHash, @@ -133,7 +132,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 42) require.True(t, found) @@ -150,16 +151,8 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - - observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) - - // set cctx status to outbound mined - keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_OutboundMined, false) - chainID := getEthereumChainID() - - res, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: sample.Hash().Hex(), @@ -167,7 +160,15 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 0, - }) + } + + observerMock := keepertest.GetCrosschainObserverMock(t, k) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) + + // set cctx status to outbound mined + keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_OutboundMined, false) + + res, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) require.Equal(t, &types.MsgAddOutboundTrackerResponse{IsRemoved: true}, res) @@ -186,7 +187,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { admin := sample.AccAddress() observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockFailedGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(nil) chainID := getEthereumChainID() @@ -213,7 +214,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, true) chainID := getEthereumChainID() @@ -237,13 +238,14 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + admin := sample.AccAddress() + chainID := getEthereumChainID() + newHash := sample.Hash().Hex() + + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) @@ -254,16 +256,13 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { } } - chainID := getEthereumChainID() - newHash := sample.Hash().Hex() - k.SetOutboundTracker(ctx, types.OutboundTracker{ ChainId: chainID, Nonce: 42, HashList: hashes, }) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: newHash, @@ -271,7 +270,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrMaxTxOutTrackerHashesReached) }) @@ -284,17 +285,16 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { admin := sample.AccAddress() + chainID := getEthereumChainID() + existinghHash := sample.Hash().Hex() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) - chainID := getEthereumChainID() - existinghHash := sample.Hash().Hex() - k.SetOutboundTracker(ctx, types.OutboundTracker{ ChainId: chainID, Nonce: 42, @@ -305,7 +305,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }, }) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: existinghHash, @@ -313,7 +313,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 42) require.True(t, found) @@ -338,8 +340,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) observerMock.On("GetTssAddress", mock.Anything, mock.Anything).Return(&observertypes.QueryGetTssAddressResponse{ @@ -348,7 +349,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(ethTxBytes, nil) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -356,7 +357,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 42) require.True(t, found) @@ -381,8 +384,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) observerMock.On("GetTssAddress", mock.Anything, mock.Anything).Return(&observertypes.QueryGetTssAddressResponse{ @@ -406,7 +408,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { }, }) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -414,7 +416,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.NoError(t, err) tracker, found := k.GetOutboundTracker(ctx, chainID, 42) require.True(t, found) @@ -440,14 +444,13 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(ethTxBytes, errors.New("error")) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -455,7 +458,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrProofVerificationFail) }) @@ -476,8 +481,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). @@ -486,7 +490,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { Eth: tssAddress.Hex(), }, errors.New("error")) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -494,7 +498,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.ErrorIs(t, err, observertypes.ErrTssNotFound) }) @@ -510,13 +516,11 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { chainID := getEthereumChainID() ethTx, _, tssAddress := sample.EthTxSigned(t, chainID, sample.EthAddress(), 42) txHash := ethTx.Hash().Hex() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) observerMock := keepertest.GetCrosschainObserverMock(t, k) lightclientMock := keepertest.GetCrosschainLightclientMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - keepertest.MockGetSupportedChainFromChainID(observerMock, nil) + observerMock.On("GetSupportedChainFromChainID", mock.Anything, mock.Anything).Return(&chains.Chain{}) observerMock.On("IsNonTombstonedObserver", mock.Anything, mock.Anything).Return(false) keepertest.MockCctxByNonce(t, ctx, *k, observerMock, types.CctxStatus_PendingOutbound, false) observerMock.On("GetTssAddress", mock.Anything, mock.Anything).Return(&observertypes.QueryGetTssAddressResponse{ @@ -527,7 +531,7 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { lightclientMock.On("VerifyProof", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(sample.Bytes(), nil) - _, err := msgServer.AddOutboundTracker(ctx, &types.MsgAddOutboundTracker{ + msg := types.MsgAddOutboundTracker{ Creator: admin, ChainId: chainID, TxHash: txHash, @@ -535,7 +539,9 @@ func TestMsgServer_AddToOutboundTracker(t *testing.T) { BlockHash: "", TxIndex: 0, Nonce: 42, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.AddOutboundTracker(ctx, &msg) require.ErrorIs(t, err, types.ErrTxBodyVerificationFail) }) } diff --git a/x/crosschain/keeper/msg_server_migrate_tss_funds.go b/x/crosschain/keeper/msg_server_migrate_tss_funds.go index b651178312..1bff3b8443 100644 --- a/x/crosschain/keeper/msg_server_migrate_tss_funds.go +++ b/x/crosschain/keeper/msg_server_migrate_tss_funds.go @@ -10,6 +10,7 @@ import ( tmbytes "github.com/cometbft/cometbft/libs/bytes" tmtypes "github.com/cometbft/cometbft/types" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/types/errors" "github.com/ethereum/go-ethereum/crypto" "github.com/zeta-chain/zetacore/pkg/chains" @@ -30,11 +31,9 @@ func (k msgServer) MigrateTssFunds( ctx := sdk.UnwrapSDKContext(goCtx) // check if authorized - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupAdmin) { - return nil, errorsmod.Wrap( - authoritytypes.ErrUnauthorized, - "Update can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } if k.zetaObserverKeeper.IsInboundEnabled(ctx) { @@ -73,7 +72,7 @@ func (k msgServer) MigrateTssFunds( return nil, errorsmod.Wrap(types.ErrCannotMigrateTssFunds, "cannot migrate funds when there are pending nonces") } - err := k.MigrateTSSFundsForChain(ctx, msg.ChainId, msg.Amount, tss, tssHistory) + err = k.MigrateTSSFundsForChain(ctx, msg.ChainId, msg.Amount, tss, tssHistory) if err != nil { return nil, errorsmod.Wrap(types.ErrCannotMigrateTssFunds, err.Error()) } diff --git a/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go b/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go index 0d1207058e..fbd3c870bf 100644 --- a/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go +++ b/x/crosschain/keeper/msg_server_migrate_tss_funds_test.go @@ -92,20 +92,23 @@ func TestKeeper_MigrateTSSFundsForChain(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") + indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) gp, found := k.GetMedianGasPriceInUint(ctx, chain.ChainId) require.True(t, found) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.NoError(t, err) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() @@ -122,20 +125,23 @@ func TestKeeper_MigrateTSSFundsForChain(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidBTCChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidBTCChain() - amount := sdkmath.NewUintFromString("10000000000000000000") indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) gp, found := k.GetMedianGasPriceInUint(ctx, chain.ChainId) require.True(t, found) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.NoError(t, err) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() @@ -152,17 +158,20 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, false) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -173,19 +182,22 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(true) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -196,21 +208,24 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(false) observerMock.On("GetTSS", mock.Anything).Return(observertypes.TSS{}, false) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -221,8 +236,10 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(false) @@ -230,13 +247,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { observerMock.On("GetAllTSS", mock.Anything).Return([]observertypes.TSS{}) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -247,8 +265,10 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(false) @@ -257,13 +277,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { observerMock.On("GetAllTSS", mock.Anything).Return([]observertypes.TSS{tss}) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -274,8 +295,10 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(false) @@ -287,13 +310,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { observerMock.On("GetAllTSS", mock.Anything).Return([]observertypes.TSS{tss2}) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -304,8 +328,10 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) observerMock := keepertest.GetCrosschainObserverMock(t, k) observerMock.On("IsInboundEnabled", mock.Anything).Return(false) @@ -319,13 +345,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { Return(observertypes.PendingNonces{}, false) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.Error(t, err) }) @@ -335,18 +362,21 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") + indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.NoError(t, err) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() @@ -363,18 +393,20 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) chain := getValidEthChain() amount := sdkmath.NewUintFromString("100") + + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.ErrorContains(t, err, crosschaintypes.ErrCannotMigrateTssFunds.Error()) hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() @@ -388,18 +420,20 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) chain := getValidEthChain() amount := sdkmath.NewUintFromString("10000000000000000000") + + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, false, true) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.ErrorContains(t, err, "no new tss address has been generated") hash := crypto.Keccak256Hash([]byte(indexString)) index := hash.Hex() @@ -413,12 +447,11 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) chain := getValidEthChain() amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) + indexString, tssPubkey := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) k.GetObserverKeeper().SetPendingNonces(ctx, observertypes.PendingNonces{ NonceLow: 1, @@ -426,11 +459,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, Tss: tssPubkey, }) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) require.ErrorContains(t, err, "cannot migrate funds when there are pending nonces") hash := crypto.Keccak256Hash([]byte(indexString)) @@ -445,12 +481,11 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) chain := getValidEthChain() amount := sdkmath.NewUintFromString("10000000000000000000") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) + indexString, tssPubkey := setupTssMigrationParams(zk, k, ctx, *chain, amount, true, true) k.GetObserverKeeper().SetPendingNonces(ctx, observertypes.PendingNonces{ NonceLow: 1, @@ -465,11 +500,14 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { ChainId: chain.ChainId, MigrationCctxIndex: existingCctx.Index, }) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) require.ErrorContains(t, err, "cannot migrate funds while there are pending migrations") hash := crypto.Keccak256Hash([]byte(indexString)) @@ -488,12 +526,12 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { }) admin := sample.AccAddress() + chain := getValidEthChain() + amount := sdkmath.NewUintFromString("10000000000000000000") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - chain := getValidEthChain() - amount := sdkmath.NewUintFromString("10000000000000000000") + indexString, _ := setupTssMigrationParams(zk, k, ctx, *chain, amount, false, false) currentTss, found := k.GetObserverKeeper().GetTSS(ctx) require.True(t, found) @@ -501,11 +539,15 @@ func TestMsgServer_MigrateTssFunds(t *testing.T) { newTss.FinalizedZetaHeight = currentTss.FinalizedZetaHeight - 10 newTss.KeyGenZetaHeight = currentTss.KeyGenZetaHeight - 10 k.GetObserverKeeper().SetTSSHistory(ctx, newTss) - _, err := msgServer.MigrateTssFunds(ctx, &crosschaintypes.MsgMigrateTssFunds{ + + msg := crosschaintypes.MsgMigrateTssFunds{ Creator: admin, ChainId: chain.ChainId, Amount: amount, - }) + } + + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.MigrateTssFunds(ctx, &msg) require.ErrorIs(t, err, crosschaintypes.ErrCannotMigrateTssFunds) require.ErrorContains(t, err, "current tss is the latest") hash := crypto.Keccak256Hash([]byte(indexString)) diff --git a/x/crosschain/keeper/msg_server_refund_aborted_tx.go b/x/crosschain/keeper/msg_server_refund_aborted_tx.go index f33650ed96..1efd3227ec 100644 --- a/x/crosschain/keeper/msg_server_refund_aborted_tx.go +++ b/x/crosschain/keeper/msg_server_refund_aborted_tx.go @@ -26,8 +26,9 @@ func (k msgServer) RefundAbortedCCTX( ctx := sdk.UnwrapSDKContext(goCtx) // check if authorized - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // check if the cctx exists diff --git a/x/crosschain/keeper/msg_server_refund_aborted_tx_test.go b/x/crosschain/keeper/msg_server_refund_aborted_tx_test.go index 2ac6b8c363..bf3678ecaf 100644 --- a/x/crosschain/keeper/msg_server_refund_aborted_tx_test.go +++ b/x/crosschain/keeper/msg_server_refund_aborted_tx_test.go @@ -44,13 +44,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -68,11 +67,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { "foobar", ) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddress := ethcommon.HexToAddress(cctx.InboundParams.TxOrigin) @@ -91,13 +92,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -110,11 +110,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { ) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddress := ethcommon.HexToAddress(cctx.InboundParams.TxOrigin) @@ -133,13 +135,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -153,11 +154,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { ) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddress := ethcommon.HexToAddress(cctx.InboundParams.TxOrigin) @@ -176,13 +179,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = true cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -196,11 +198,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { ) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.Error(t, err) }) @@ -211,13 +215,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -231,11 +234,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { ) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.Error(t, err) }) @@ -246,13 +251,14 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") + refundAddress := sample.EthAddress() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -262,12 +268,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.InboundParams.Amount}) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - refundAddress := sample.EthAddress() - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: refundAddress.String(), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddressCosmos := sdk.AccAddress(refundAddress.Bytes()) @@ -286,13 +293,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() asset := sample.EthAddress().String() + cctx := sample.CrossChainTx(t, "sample-index") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.SenderChainId = chainID @@ -312,11 +319,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { "bar", ) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddress := ethcommon.HexToAddress(cctx.InboundParams.Sender) @@ -335,18 +344,18 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidBtcChainID() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - - msgServer := keeper.NewMsgServerImpl(*k) - k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") + cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false - cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender cctx.InboundParams.SenderChainId = chainID cctx.InboundParams.CoinType = coin.CoinType_Gas + + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + + msgServer := keeper.NewMsgServerImpl(*k) + k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) + k.SetCrossChainTx(ctx, *cctx) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) zrc20 := setupGasCoin( @@ -359,11 +368,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { "foobar", ) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.TxOrigin, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.NoError(t, err) refundAddress := ethcommon.HexToAddress(cctx.InboundParams.TxOrigin) @@ -382,13 +393,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -398,11 +409,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.InboundParams.Amount}) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: "invalid-address", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "invalid refund address") }) @@ -413,13 +426,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -429,11 +442,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { k.SetZetaAccounting(ctx, crosschaintypes.ZetaAccounting{AbortedZetaAmount: cctx.InboundParams.Amount}) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: "0x0000000000000000000000000000000000000000", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "invalid refund address") }) @@ -445,12 +460,11 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + cctx := sample.CrossChainTx(t, "sample-index") msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_PendingOutbound cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -459,11 +473,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: "", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "CCTX is not aborted") c, found := k.GetCrossChainTx(ctx, cctx.Index) require.True(t, found) @@ -477,13 +493,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_PendingOutbound cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -491,11 +506,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { cctx.InboundParams.CoinType = coin.CoinType_Gas deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: "", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "cannot find cctx") }) @@ -506,13 +523,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidBtcChainID() + cctx := sample.CrossChainTx(t, "sample-index") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -530,11 +547,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { "foobar", ) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: "", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "refund address is required") }) @@ -545,13 +564,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + cctx := sample.CrossChainTx(t, "sample-index") + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -560,11 +578,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { k.SetCrossChainTx(ctx, *cctx) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorContains(t, err, "unable to find zeta accounting") }) @@ -575,13 +595,12 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { admin := sample.AccAddress() chainID := getValidEthChainID() + cctx := sample.CrossChainTx(t, "sample-index") authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - cctx := sample.CrossChainTx(t, "sample-index") cctx.CctxStatus.Status = crosschaintypes.CctxStatus_Aborted cctx.CctxStatus.IsAbortRefunded = false cctx.InboundParams.TxOrigin = cctx.InboundParams.Sender @@ -599,11 +618,13 @@ func TestMsgServer_RefundAbortedCCTX(t *testing.T) { "foobar", ) - _, err := msgServer.RefundAbortedCCTX(ctx, &crosschaintypes.MsgRefundAbortedCCTX{ + msg := crosschaintypes.MsgRefundAbortedCCTX{ Creator: admin, CctxIndex: cctx.Index, RefundAddress: cctx.InboundParams.Sender, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.RefundAbortedCCTX(ctx, &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) } diff --git a/x/crosschain/keeper/msg_server_remove_outbound_tracker.go b/x/crosschain/keeper/msg_server_remove_outbound_tracker.go index bd0b8d0465..ef20dfc6f7 100644 --- a/x/crosschain/keeper/msg_server_remove_outbound_tracker.go +++ b/x/crosschain/keeper/msg_server_remove_outbound_tracker.go @@ -3,6 +3,7 @@ package keeper import ( "context" + errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -17,8 +18,9 @@ func (k msgServer) RemoveOutboundTracker( msg *types.MsgRemoveOutboundTracker, ) (*types.MsgRemoveOutboundTrackerResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) { - return &types.MsgRemoveOutboundTrackerResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } k.RemoveOutboundTrackerFromStore(ctx, msg.ChainId, msg.Nonce) diff --git a/x/crosschain/keeper/msg_server_remove_outbound_tracker_test.go b/x/crosschain/keeper/msg_server_remove_outbound_tracker_test.go index 6044a276c4..34c40044b5 100644 --- a/x/crosschain/keeper/msg_server_remove_outbound_tracker_test.go +++ b/x/crosschain/keeper/msg_server_remove_outbound_tracker_test.go @@ -23,14 +23,15 @@ func TestMsgServer_RemoveFromOutboundTracker(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) msgServer := keeper.NewMsgServerImpl(*k) - res, err := msgServer.RemoveOutboundTracker(ctx, &types.MsgRemoveOutboundTracker{ + msg := types.MsgRemoveOutboundTracker{ Creator: admin, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + res, err := msgServer.RemoveOutboundTracker(ctx, &msg) require.Error(t, err) require.Empty(t, res) @@ -49,15 +50,15 @@ func TestMsgServer_RemoveFromOutboundTracker(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - msgServer := keeper.NewMsgServerImpl(*k) - res, err := msgServer.RemoveOutboundTracker(ctx, &types.MsgRemoveOutboundTracker{ + msg := types.MsgRemoveOutboundTracker{ Creator: admin, ChainId: 1, Nonce: 1, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := msgServer.RemoveOutboundTracker(ctx, &msg) require.NoError(t, err) require.Empty(t, res) diff --git a/x/crosschain/keeper/msg_server_update_rate_limiter_flags.go b/x/crosschain/keeper/msg_server_update_rate_limiter_flags.go index d312fd22f0..874338f205 100644 --- a/x/crosschain/keeper/msg_server_update_rate_limiter_flags.go +++ b/x/crosschain/keeper/msg_server_update_rate_limiter_flags.go @@ -2,7 +2,6 @@ package keeper import ( "context" - "fmt" errorsmod "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" @@ -19,10 +18,10 @@ func (k msgServer) UpdateRateLimiterFlags( ) (*types.MsgUpdateRateLimiterFlagsResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, fmt.Sprintf("Creator %s", msg.Creator)) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } - k.SetRateLimiterFlags(ctx, msg.RateLimiterFlags) return &types.MsgUpdateRateLimiterFlagsResponse{}, nil diff --git a/x/crosschain/keeper/msg_server_update_rate_limiter_flags_test.go b/x/crosschain/keeper/msg_server_update_rate_limiter_flags_test.go index 8a2ee87900..6872479973 100644 --- a/x/crosschain/keeper/msg_server_update_rate_limiter_flags_test.go +++ b/x/crosschain/keeper/msg_server_update_rate_limiter_flags_test.go @@ -19,19 +19,18 @@ func TestMsgServer_UpdateRateLimiterFlags(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + flags := sample.RateLimiterFlags() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, found := k.GetRateLimiterFlags(ctx) require.False(t, found) - flags := sample.RateLimiterFlags() - - _, err := msgServer.UpdateRateLimiterFlags(ctx, types.NewMsgUpdateRateLimiterFlags( + msg := types.NewMsgUpdateRateLimiterFlags( admin, flags, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateRateLimiterFlags(ctx, msg) require.NoError(t, err) storedFlags, found := k.GetRateLimiterFlags(ctx) @@ -45,14 +44,15 @@ func TestMsgServer_UpdateRateLimiterFlags(t *testing.T) { }) msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - + flags := sample.RateLimiterFlags() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := msgServer.UpdateRateLimiterFlags(ctx, types.NewMsgUpdateRateLimiterFlags( + msg := types.NewMsgUpdateRateLimiterFlags( admin, - sample.RateLimiterFlags(), - )) + flags, + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UpdateRateLimiterFlags(ctx, msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) } diff --git a/x/crosschain/keeper/msg_server_update_tss.go b/x/crosschain/keeper/msg_server_update_tss.go index c2a1a490e0..672d64e462 100644 --- a/x/crosschain/keeper/msg_server_update_tss.go +++ b/x/crosschain/keeper/msg_server_update_tss.go @@ -18,13 +18,9 @@ func (k msgServer) UpdateTssAddress( ctx := sdk.UnwrapSDKContext(goCtx) // check if authorized - // TODO : Add a new policy type for updating the TSS address - // https://github.com/zeta-chain/node/issues/1715 - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupAdmin) { - return nil, errorsmod.Wrap( - authoritytypes.ErrUnauthorized, - "Update can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } currentTss, found := k.zetaObserverKeeper.GetTSS(ctx) diff --git a/x/crosschain/keeper/msg_server_update_tss_test.go b/x/crosschain/keeper/msg_server_update_tss_test.go index f644982d3a..4e894e1586 100644 --- a/x/crosschain/keeper/msg_server_update_tss_test.go +++ b/x/crosschain/keeper/msg_server_update_tss_test.go @@ -18,17 +18,18 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { k, ctx, _, _ := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{ UseAuthorityMock: true, }) - admin := sample.AccAddress() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, false) msgServer := keeper.NewMsgServerImpl(*k) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: "", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.Error(t, err) }) @@ -39,14 +40,14 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - msgServer := keeper.NewMsgServerImpl(*k) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: "", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.Error(t, err) }) @@ -56,12 +57,11 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) tssOld := sample.Tss() tssNew := sample.Tss() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) + k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) k.GetObserverKeeper().SetTSS(ctx, tssOld) @@ -80,10 +80,13 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), len(k.GetObserverKeeper().GetSupportedChains(ctx)), ) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssNew.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.NoError(t, err) tss, found := k.GetObserverKeeper().GetTSS(ctx) require.True(t, found) @@ -98,12 +101,12 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() + tssOld := sample.Tss() + tssNew := sample.Tss() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - tssOld := sample.Tss() - tssNew := sample.Tss() + k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { @@ -121,10 +124,13 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), len(k.GetObserverKeeper().GetSupportedChains(ctx)), ) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssNew.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.ErrorContains(t, err, "tss pubkey has not been generated") require.ErrorIs(t, err, crosschaintypes.ErrUnableToUpdateTss) tss, found := k.GetObserverKeeper().GetTSS(ctx) @@ -143,11 +149,11 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() + tssOld := sample.Tss() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - tssOld := sample.Tss() + k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSS(ctx, tssOld) for _, chain := range k.GetObserverKeeper().GetSupportedChains(ctx) { @@ -165,10 +171,13 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), len(k.GetObserverKeeper().GetSupportedChains(ctx)), ) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssOld.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.ErrorContains(t, err, "no new tss address has been generated") require.ErrorIs(t, err, crosschaintypes.ErrUnableToUpdateTss) tss, found := k.GetObserverKeeper().GetTSS(ctx) @@ -187,12 +196,10 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) tssOld := sample.Tss() tssNew := sample.Tss() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) @@ -209,12 +216,14 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { cctx := sample.CrossChainTx(t, index) cctx.CctxStatus.Status = crosschaintypes.CctxStatus_OutboundMined k.SetCrossChainTx(ctx, *cctx) - require.Equal(t, len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), 1) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssNew.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.ErrorContains(t, err, "cannot update tss address not enough migrations have been created and completed") require.ErrorIs(t, err, crosschaintypes.ErrUnableToUpdateTss) tss, found := k.GetObserverKeeper().GetTSS(ctx) @@ -230,12 +239,11 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() + tssOld := sample.Tss() + tssNew := sample.Tss() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) msgServer := keeper.NewMsgServerImpl(*k) - tssOld := sample.Tss() - tssNew := sample.Tss() k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) @@ -257,10 +265,13 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), len(k.GetObserverKeeper().GetSupportedChains(ctx)), ) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssNew.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.ErrorContains(t, err, "cannot update tss address while there are pending migrations") require.ErrorIs(t, err, crosschaintypes.ErrUnableToUpdateTss) tss, found := k.GetObserverKeeper().GetTSS(ctx) @@ -276,12 +287,10 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { }) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - msgServer := keeper.NewMsgServerImpl(*k) tssOld := sample.Tss() tssNew := sample.Tss() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + msgServer := keeper.NewMsgServerImpl(*k) k.GetObserverKeeper().SetTSSHistory(ctx, tssOld) k.GetObserverKeeper().SetTSSHistory(ctx, tssNew) @@ -300,10 +309,13 @@ func TestMsgServer_UpdateTssAddress(t *testing.T) { len(k.GetObserverKeeper().GetAllTssFundMigrators(ctx)), len(k.GetObserverKeeper().GetSupportedChains(ctx)), ) - _, err := msgServer.UpdateTssAddress(ctx, &crosschaintypes.MsgUpdateTssAddress{ + + msg := crosschaintypes.MsgUpdateTssAddress{ Creator: admin, TssPubkey: tssNew.TssPubkey, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.UpdateTssAddress(ctx, &msg) require.ErrorContains(t, err, "migration cross chain tx not found") require.ErrorIs(t, err, crosschaintypes.ErrUnableToUpdateTss) tss, found := k.GetObserverKeeper().GetTSS(ctx) diff --git a/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go b/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go index 1a94c3f10d..332a47f270 100644 --- a/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go +++ b/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go @@ -140,7 +140,7 @@ func TestKeeper_VoteInbound(t *testing.T) { require.True(t, found) require.Equal(t, ballot.BallotStatus, observertypes.BallotStatus_BallotFinalized_SuccessObservation) //Perform the SAME event. Except, this time, we resubmit the event. - msg2 := &types.MsgVoteInbound{ + msg = &types.MsgVoteInbound{ Creator: validatorAddr, Sender: "0x954598965C2aCdA2885B037561526260764095B8", SenderChainId: 1337, @@ -159,11 +159,11 @@ func TestKeeper_VoteInbound(t *testing.T) { _, err = msgServer.VoteInbound( ctx, - msg2, + msg, ) require.Error(t, err) require.ErrorIs(t, err, types.ErrObservedTxAlreadyFinalized) - _, found = zk.ObserverKeeper.GetBallot(ctx, msg2.Digest()) + _, found = zk.ObserverKeeper.GetBallot(ctx, msg.Digest()) require.False(t, found) }) diff --git a/x/crosschain/keeper/msg_server_whitelist_erc20.go b/x/crosschain/keeper/msg_server_whitelist_erc20.go index 47f4007eb5..c8fbd1248b 100644 --- a/x/crosschain/keeper/msg_server_whitelist_erc20.go +++ b/x/crosschain/keeper/msg_server_whitelist_erc20.go @@ -30,11 +30,9 @@ func (k msgServer) WhitelistERC20( ctx := sdk.UnwrapSDKContext(goCtx) // check if authorized - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, errorsmod.Wrap( - authoritytypes.ErrUnauthorized, - "Deploy can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } erc20Addr := ethcommon.HexToAddress(msg.Erc20Address) diff --git a/x/crosschain/keeper/msg_server_whitelist_erc20_test.go b/x/crosschain/keeper/msg_server_whitelist_erc20_test.go index ba0cfa2609..397fd7aac1 100644 --- a/x/crosschain/keeper/msg_server_whitelist_erc20_test.go +++ b/x/crosschain/keeper/msg_server_whitelist_erc20_test.go @@ -30,8 +30,8 @@ func TestKeeper_WhitelistERC20(t *testing.T) { setSupportedChain(ctx, zk, chainID) admin := sample.AccAddress() + erc20Address := sample.EthAddress().Hex() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) deploySystemContracts(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper) setupGasCoin(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper, chainID, "foobar", "FOOBAR") @@ -42,8 +42,7 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Prices: []uint64{1}, }) - erc20Address := sample.EthAddress().Hex() - res, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: erc20Address, ChainId: chainID, @@ -51,7 +50,9 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := msgServer.WhitelistERC20(ctx, &msg) require.NoError(t, err) require.NotNil(t, res) zrc20 := res.Zrc20Address @@ -72,10 +73,7 @@ func TestKeeper_WhitelistERC20(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(100000), gasLimit.Uint64()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - - // Ensure that whitelist a new erc20 create a cctx with a different index - res, err = msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msgNew := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: sample.EthAddress().Hex(), ChainId: chainID, @@ -83,7 +81,11 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "BAR", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msgNew, nil) + + // Ensure that whitelist a new erc20 create a cctx with a different index + res, err = msgServer.WhitelistERC20(ctx, &msgNew) require.NoError(t, err) require.NotNil(t, res) require.NotEqual(t, cctxIndex, res.CctxIndex) @@ -99,9 +101,8 @@ func TestKeeper_WhitelistERC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: sample.EthAddress().Hex(), ChainId: getValidEthChainID(), @@ -109,7 +110,9 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.WhitelistERC20(ctx, &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -123,9 +126,8 @@ func TestKeeper_WhitelistERC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: "invalid", ChainId: getValidEthChainID(), @@ -133,7 +135,10 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + + _, err := msgServer.WhitelistERC20(ctx, &msg) require.ErrorIs(t, err, sdkerrors.ErrInvalidAddress) }) @@ -146,17 +151,16 @@ func TestKeeper_WhitelistERC20(t *testing.T) { k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - chainID := getValidEthChainID() asset := sample.EthAddress().Hex() + authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) + fc := sample.ForeignCoins(t, sample.EthAddress().Hex()) fc.Asset = asset fc.ForeignChainId = chainID zk.FungibleKeeper.SetForeignCoins(ctx, fc) - _, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: asset, ChainId: chainID, @@ -164,7 +168,9 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.WhitelistERC20(ctx, &msg) require.ErrorIs(t, err, fungibletypes.ErrForeignCoinAlreadyExist) }) @@ -176,13 +182,12 @@ func TestKeeper_WhitelistERC20(t *testing.T) { msgServer := crosschainkeeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) - chainID := getValidEthChainID() admin := sample.AccAddress() + chainID := getValidEthChainID() + erc20Address := sample.EthAddress().Hex() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - erc20Address := sample.EthAddress().Hex() - _, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: erc20Address, ChainId: chainID, @@ -190,7 +195,9 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.WhitelistERC20(ctx, &msg) require.ErrorIs(t, err, types.ErrCannotFindTSSKeys) }) @@ -203,13 +210,12 @@ func TestKeeper_WhitelistERC20(t *testing.T) { k.GetAuthKeeper().GetModuleAccount(ctx, fungibletypes.ModuleName) admin := sample.AccAddress() + erc20Address := sample.EthAddress().Hex() authorityMock := keepertest.GetCrosschainAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) k.GetObserverKeeper().SetTssAndUpdateNonce(ctx, sample.Tss()) - erc20Address := sample.EthAddress().Hex() - _, err := msgServer.WhitelistERC20(ctx, &types.MsgWhitelistERC20{ + msg := types.MsgWhitelistERC20{ Creator: admin, Erc20Address: erc20Address, ChainId: 10000, @@ -217,7 +223,9 @@ func TestKeeper_WhitelistERC20(t *testing.T) { Symbol: "FOO", Decimals: 18, GasLimit: 100000, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := msgServer.WhitelistERC20(ctx, &msg) require.ErrorIs(t, err, types.ErrInvalidChainID) }) } diff --git a/x/crosschain/types/expected_keepers.go b/x/crosschain/types/expected_keepers.go index 7b564fc78a..6f35afa6c5 100644 --- a/x/crosschain/types/expected_keepers.go +++ b/x/crosschain/types/expected_keepers.go @@ -13,7 +13,6 @@ import ( "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" "github.com/zeta-chain/zetacore/pkg/proofs" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" fungibletypes "github.com/zeta-chain/zetacore/x/fungible/types" observertypes "github.com/zeta-chain/zetacore/x/observer/types" ) @@ -209,7 +208,7 @@ type FungibleKeeper interface { } type AuthorityKeeper interface { - IsAuthorized(ctx sdk.Context, address string, policyType authoritytypes.PolicyType) bool + CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error } type LightclientKeeper interface { diff --git a/x/crosschain/types/message_vote_inbound_test.go b/x/crosschain/types/message_vote_inbound_test.go index 749f94f946..f96f640b9e 100644 --- a/x/crosschain/types/message_vote_inbound_test.go +++ b/x/crosschain/types/message_vote_inbound_test.go @@ -158,87 +158,87 @@ func TestMsgVoteInbound_Digest(t *testing.T) { require.NotEmpty(t, hash, "hash should not be empty") // creator not used - msg2 := msg - msg2.Creator = sample.AccAddress() - hash2 := msg2.Digest() + msg = msg + msg.Creator = sample.AccAddress() + hash2 := msg.Digest() require.Equal(t, hash, hash2, "creator should not change hash") // in block height not used - msg2 = msg - msg2.InboundBlockHeight = 43 - hash2 = msg2.Digest() + msg = msg + msg.InboundBlockHeight = 43 + hash2 = msg.Digest() require.Equal(t, hash, hash2, "in block height should not change hash") // sender used - msg2 = msg - msg2.Sender = sample.AccAddress() - hash2 = msg2.Digest() + msg = msg + msg.Sender = sample.AccAddress() + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "sender should change hash") // sender chain ID used - msg2 = msg - msg2.SenderChainId = 43 - hash2 = msg2.Digest() + msg = msg + msg.SenderChainId = 43 + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "sender chain ID should change hash") // tx origin used - msg2 = msg - msg2.TxOrigin = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msg = msg + msg.TxOrigin = sample.StringRandom(r, 32) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "tx origin should change hash") // receiver used - msg2 = msg - msg2.Receiver = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msg = msg + msg.Receiver = sample.StringRandom(r, 32) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "receiver should change hash") // receiver chain ID used - msg2 = msg - msg2.ReceiverChain = 43 - hash2 = msg2.Digest() + msg = msg + msg.ReceiverChain = 43 + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "receiver chain ID should change hash") // amount used - msg2 = msg - msg2.Amount = math.NewUint(43) - hash2 = msg2.Digest() + msg = msg + msg.Amount = math.NewUint(43) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "amount should change hash") // message used - msg2 = msg - msg2.Message = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msg = msg + msg.Message = sample.StringRandom(r, 32) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "message should change hash") // in tx hash used - msg2 = msg - msg2.InboundHash = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msg = msg + msg.InboundHash = sample.StringRandom(r, 32) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "in tx hash should change hash") // gas limit used - msg2 = msg - msg2.GasLimit = 43 - hash2 = msg2.Digest() + msg = msg + msg.GasLimit = 43 + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "gas limit should change hash") // coin type used - msg2 = msg - msg2.CoinType = coin.CoinType_ERC20 - hash2 = msg2.Digest() + msg = msg + msg.CoinType = coin.CoinType_ERC20 + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "coin type should change hash") // asset used - msg2 = msg - msg2.Asset = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msg = msg + msg.Asset = sample.StringRandom(r, 32) + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "asset should change hash") // event index used - msg2 = msg - msg2.EventIndex = 43 - hash2 = msg2.Digest() + msg = msg + msg.EventIndex = 43 + hash2 = msg.Digest() require.NotEqual(t, hash, hash2, "event index should change hash") } diff --git a/x/crosschain/types/message_vote_outbound_test.go b/x/crosschain/types/message_vote_outbound_test.go index 31d44e8f81..8e782f1e66 100644 --- a/x/crosschain/types/message_vote_outbound_test.go +++ b/x/crosschain/types/message_vote_outbound_test.go @@ -109,75 +109,75 @@ func TestMsgVoteOutbound_Digest(t *testing.T) { require.NotEmpty(t, hash, "hash should not be empty") // creator not used - msg2 := msg - msg2.Creator = sample.AccAddress() - hash2 := msg2.Digest() + msgNew := msg + msgNew.Creator = sample.AccAddress() + hash2 := msgNew.Digest() require.Equal(t, hash, hash2, "creator should not change hash") // status not used - msg2 = msg - msg2.Status = chains.ReceiveStatus_failed - hash2 = msg2.Digest() + msgNew = msg + msgNew.Status = chains.ReceiveStatus_failed + hash2 = msgNew.Digest() require.Equal(t, hash, hash2, "status should not change hash") // cctx hash used - msg2 = msg - msg2.CctxHash = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msgNew = msg + msgNew.CctxHash = sample.StringRandom(r, 32) + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "cctx hash should change hash") // observed outbound tx hash used - msg2 = msg - msg2.ObservedOutboundHash = sample.StringRandom(r, 32) - hash2 = msg2.Digest() + msgNew = msg + msgNew.ObservedOutboundHash = sample.StringRandom(r, 32) + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "observed outbound tx hash should change hash") // observed outbound tx block height used - msg2 = msg - msg2.ObservedOutboundBlockHeight = 43 - hash2 = msg2.Digest() + msgNew = msg + msgNew.ObservedOutboundBlockHeight = 43 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "observed outbound tx block height should change hash") // observed outbound tx gas used used - msg2 = msg - msg2.ObservedOutboundGasUsed = 43 - hash2 = msg2.Digest() + msgNew = msg + msgNew.ObservedOutboundGasUsed = 43 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "observed outbound tx gas used should change hash") // observed outbound tx effective gas price used - msg2 = msg - msg2.ObservedOutboundEffectiveGasPrice = math.NewInt(43) - hash2 = msg2.Digest() + msgNew = msg + msgNew.ObservedOutboundEffectiveGasPrice = math.NewInt(43) + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "observed outbound tx effective gas price should change hash") // observed outbound tx effective gas limit used - msg2 = msg - msg2.ObservedOutboundEffectiveGasLimit = 43 - hash2 = msg2.Digest() + msgNew = msg + msgNew.ObservedOutboundEffectiveGasLimit = 43 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "observed outbound tx effective gas limit should change hash") // zeta minted used - msg2 = msg - msg2.ValueReceived = math.NewUint(43) - hash2 = msg2.Digest() + msgNew = msg + msgNew.ValueReceived = math.NewUint(43) + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "zeta minted should change hash") // out tx chain used - msg2 = msg - msg2.OutboundChain = 43 - hash2 = msg2.Digest() + msgNew = msg + msgNew.OutboundChain = 43 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "out tx chain should change hash") // out tx tss nonce used - msg2 = msg - msg2.OutboundTssNonce = 43 - hash2 = msg2.Digest() + msgNew = msg + msgNew.OutboundTssNonce = 43 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "out tx tss nonce should change hash") // coin type used - msg2 = msg - msg2.CoinType = coin.CoinType_ERC20 - hash2 = msg2.Digest() + msgNew = msg + msgNew.CoinType = coin.CoinType_ERC20 + hash2 = msgNew.Digest() require.NotEqual(t, hash, hash2, "coin type should change hash") } diff --git a/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20.go b/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20.go index 8792817fb7..46ecb550d0 100644 --- a/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20.go +++ b/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20.go @@ -44,11 +44,9 @@ func (k msgServer) DeployFungibleCoinZRC20( return nil, err } - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "Deploy can only be executed by the correct policy account", - ) + err = k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } if msg.CoinType == coin.CoinType_Gas { diff --git a/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20_test.go b/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20_test.go index 2bc23b7c32..b37dd82cf7 100644 --- a/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20_test.go +++ b/x/fungible/keeper/msg_server_deploy_fungible_coin_zrc20_test.go @@ -25,15 +25,13 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() + chainID := getValidChainID(t) authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - - chainID := getValidChainID(t) deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) - res, err := msgServer.DeployFungibleCoinZRC20(ctx, types.NewMsgDeployFungibleCoinZRC20( + msg := types.NewMsgDeployFungibleCoinZRC20( admin, sample.EthAddress().Hex(), chainID, @@ -42,7 +40,9 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { "foo", coin.CoinType_Gas, 1000000, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + res, err := msgServer.DeployFungibleCoinZRC20(ctx, msg) require.NoError(t, err) gasAddress := res.Address assertContractDeployment(t, sdkk.EvmKeeper, ctx, ethcommon.HexToAddress(gasAddress)) @@ -62,10 +62,8 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { require.NoError(t, err) require.Equal(t, gasAddress, gas.Hex()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can deploy non-gas zrc20 - res, err = msgServer.DeployFungibleCoinZRC20(ctx, types.NewMsgDeployFungibleCoinZRC20( + msg = types.NewMsgDeployFungibleCoinZRC20( admin, sample.EthAddress().Hex(), chainID, @@ -74,8 +72,11 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { "bar", coin.CoinType_ERC20, 2000000, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + res, err = msgServer.DeployFungibleCoinZRC20(ctx, msg) require.NoError(t, err) + assertContractDeployment(t, sdkk.EvmKeeper, ctx, ethcommon.HexToAddress(res.Address)) foreignCoin, found = k.GetForeignCoins(ctx, res.Address) @@ -101,15 +102,13 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { }) k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) chainID := getValidChainID(t) - admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) // should not deploy a new zrc20 if not admin - _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, types.NewMsgDeployFungibleCoinZRC20( + msg := types.NewMsgDeployFungibleCoinZRC20( admin, sample.EthAddress().Hex(), chainID, @@ -118,7 +117,9 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { "foo", coin.CoinType_Gas, 1000000, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -130,13 +131,8 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - chainID := getValidChainID(t) - - deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) - - // should not deploy a new zrc20 if not admin - _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, types.NewMsgDeployFungibleCoinZRC20( + msg := types.NewMsgDeployFungibleCoinZRC20( admin, sample.EthAddress().Hex(), chainID, @@ -145,7 +141,12 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { "foo", coin.CoinType_Gas, 1000000, - )) + ) + + deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) + + // should not deploy a new zrc20 if not admin + _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, sdkerrors.ErrInvalidRequest) }) @@ -158,12 +159,11 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) // should not deploy a new zrc20 if not admin - _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, types.NewMsgDeployFungibleCoinZRC20( + msg := types.NewMsgDeployFungibleCoinZRC20( admin, sample.EthAddress().Hex(), 9999999, @@ -172,7 +172,9 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { "foo", coin.CoinType_Gas, 1000000, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, observertypes.ErrSupportedChains) }) @@ -200,14 +202,13 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { coin.CoinType_Gas, 1000000, ) - - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + keepertest.MockCheckAuthorization(&authorityMock.Mock, deployMsg, nil) // Attempt to deploy the same gas token twice should result in error _, err := keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, deployMsg) require.NoError(t, err) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + keepertest.MockCheckAuthorization(&authorityMock.Mock, deployMsg, nil) _, err = keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, deployMsg) require.Error(t, err) @@ -215,12 +216,12 @@ func TestMsgServer_DeployFungibleCoinZRC20(t *testing.T) { // Similar to above, redeploying existing erc20 should also fail deployMsg.CoinType = coin.CoinType_ERC20 - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + keepertest.MockCheckAuthorization(&authorityMock.Mock, deployMsg, nil) _, err = keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, deployMsg) require.NoError(t, err) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + keepertest.MockCheckAuthorization(&authorityMock.Mock, deployMsg, nil) _, err = keeper.NewMsgServerImpl(*k).DeployFungibleCoinZRC20(ctx, deployMsg) require.Error(t, err) diff --git a/x/fungible/keeper/msg_server_deploy_system_contract.go b/x/fungible/keeper/msg_server_deploy_system_contract.go index 68814a4c6b..b7926b75c8 100644 --- a/x/fungible/keeper/msg_server_deploy_system_contract.go +++ b/x/fungible/keeper/msg_server_deploy_system_contract.go @@ -3,7 +3,7 @@ package keeper import ( "context" - cosmoserror "cosmossdk.io/errors" + cosmoserrors "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -19,41 +19,39 @@ func (k msgServer) DeploySystemContracts( ) (*types.MsgDeploySystemContractsResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserror.Wrap( - authoritytypes.ErrUnauthorized, - "System contract deployment can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // uniswap v2 factory factory, err := k.DeployUniswapV2Factory(ctx) if err != nil { - return nil, cosmoserror.Wrapf(err, "failed to deploy UniswapV2Factory") + return nil, cosmoserrors.Wrapf(err, "failed to deploy UniswapV2Factory") } // wzeta contract wzeta, err := k.DeployWZETA(ctx) if err != nil { - return nil, cosmoserror.Wrapf(err, "failed to DeployWZetaContract") + return nil, cosmoserrors.Wrapf(err, "failed to DeployWZetaContract") } // uniswap v2 router router, err := k.DeployUniswapV2Router02(ctx, factory, wzeta) if err != nil { - return nil, cosmoserror.Wrapf(err, "failed to deploy UniswapV2Router02") + return nil, cosmoserrors.Wrapf(err, "failed to deploy UniswapV2Router02") } // connector zevm connector, err := k.DeployConnectorZEVM(ctx, wzeta) if err != nil { - return nil, cosmoserror.Wrapf(err, "failed to deploy ConnectorZEVM") + return nil, cosmoserrors.Wrapf(err, "failed to deploy ConnectorZEVM") } // system contract systemContract, err := k.DeploySystemContract(ctx, wzeta, factory, router) if err != nil { - return nil, cosmoserror.Wrapf(err, "failed to deploy SystemContract") + return nil, cosmoserrors.Wrapf(err, "failed to deploy SystemContract") } err = ctx.EventManager().EmitTypedEvent( @@ -72,7 +70,7 @@ func (k msgServer) DeploySystemContracts( "event", "EventSystemContractsDeployed", "error", err.Error(), ) - return nil, cosmoserror.Wrapf(types.ErrEmitEvent, "failed to emit event (%s)", err.Error()) + return nil, cosmoserrors.Wrapf(types.ErrEmitEvent, "failed to emit event (%s)", err.Error()) } return &types.MsgDeploySystemContractsResponse{ diff --git a/x/fungible/keeper/msg_server_deploy_system_contract_test.go b/x/fungible/keeper/msg_server_deploy_system_contract_test.go index 117ee0cc33..882b8fecf1 100644 --- a/x/fungible/keeper/msg_server_deploy_system_contract_test.go +++ b/x/fungible/keeper/msg_server_deploy_system_contract_test.go @@ -26,11 +26,12 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - res, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + res, err := msgServer.DeploySystemContracts(ctx, msg) + require.NoError(t, err) require.NotNil(t, res) assertContractDeployment(t, sdkk.EvmKeeper, ctx, ethcommon.HexToAddress(res.UniswapV2Factory)) @@ -48,11 +49,12 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) nonadmin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, nonadmin, authoritytypes.PolicyType_groupOperational, false) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(nonadmin)) + msg := types.NewMsgDeploySystemContracts(nonadmin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.DeploySystemContracts(ctx, msg) + require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -64,14 +66,14 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // mock failed uniswapv2factory deployment mockFailedContractDeployment(ctx, t, k) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.DeploySystemContracts(ctx, msg) require.ErrorContains(t, err, "failed to deploy") }) @@ -83,16 +85,16 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // mock successful uniswapv2factory deployment mockSuccessfulContractDeployment(ctx, t, k) // mock failed wzeta deployment deployment mockFailedContractDeployment(ctx, t, k) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.DeploySystemContracts(ctx, msg) require.Error(t, err) require.ErrorContains(t, err, "failed to deploy") }) @@ -107,15 +109,15 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // mock successful uniswapv2factory and wzeta deployments mockSuccessfulContractDeployment(ctx, t, k) mockSuccessfulContractDeployment(ctx, t, k) // mock failed uniswapv2router deployment mockFailedContractDeployment(ctx, t, k) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.DeploySystemContracts(ctx, msg) require.Error(t, err) require.ErrorContains(t, err, "failed to deploy") }) @@ -128,9 +130,7 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // mock successful uniswapv2factory, wzeta and uniswapv2router deployments mockSuccessfulContractDeployment(ctx, t, k) @@ -139,7 +139,10 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { // mock failed connectorzevm deployment mockFailedContractDeployment(ctx, t, k) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + + _, err := msgServer.DeploySystemContracts(ctx, msg) require.Error(t, err) require.ErrorContains(t, err, "failed to deploy") }) @@ -152,9 +155,7 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) _ = k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // mock successful uniswapv2factory, wzeta, uniswapv2router and connectorzevm deployments mockSuccessfulContractDeployment(ctx, t, k) @@ -164,7 +165,9 @@ func TestMsgServer_DeploySystemContracts(t *testing.T) { // mock failed system contract deployment mockFailedContractDeployment(ctx, t, k) - _, err := msgServer.DeploySystemContracts(ctx, types.NewMsgDeploySystemContracts(admin)) + msg := types.NewMsgDeploySystemContracts(admin) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.DeploySystemContracts(ctx, msg) require.Error(t, err) require.ErrorContains(t, err, "failed to deploy") }) diff --git a/x/fungible/keeper/msg_server_pause_zrc20.go b/x/fungible/keeper/msg_server_pause_zrc20.go index 536da472a1..0901cce964 100644 --- a/x/fungible/keeper/msg_server_pause_zrc20.go +++ b/x/fungible/keeper/msg_server_pause_zrc20.go @@ -18,11 +18,9 @@ func (k msgServer) PauseZRC20( ) (*types.MsgPauseZRC20Response, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "PauseZRC20 can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // iterate all foreign coins and set paused status @@ -36,7 +34,7 @@ func (k msgServer) PauseZRC20( k.SetForeignCoins(ctx, fc) } - err := ctx.EventManager().EmitTypedEvent( + err = ctx.EventManager().EmitTypedEvent( &types.EventZRC20Paused{ MsgTypeUrl: sdk.MsgTypeURL(&types.MsgPauseZRC20{}), Zrc20Addresses: msg.Zrc20Addresses, diff --git a/x/fungible/keeper/msg_server_pause_zrc20_test.go b/x/fungible/keeper/msg_server_pause_zrc20_test.go index 9f9edc68fe..4b672f3745 100644 --- a/x/fungible/keeper/msg_server_pause_zrc20_test.go +++ b/x/fungible/keeper/msg_server_pause_zrc20_test.go @@ -47,46 +47,46 @@ func TestKeeper_PauseZRC20(t *testing.T) { assertUnpaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - // can pause zrc20 - _, err := msgServer.PauseZRC20(ctx, types.NewMsgPauseZRC20( + msg := types.NewMsgPauseZRC20( admin, []string{ zrc20A, zrc20B, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.PauseZRC20(ctx, msg) require.NoError(t, err) assertPaused(zrc20A) assertPaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - // can pause already paused zrc20 - _, err = msgServer.PauseZRC20(ctx, types.NewMsgPauseZRC20( + msg = types.NewMsgPauseZRC20( admin, []string{ zrc20B, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.PauseZRC20(ctx, msg) require.NoError(t, err) assertPaused(zrc20A) assertPaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - // can pause all zrc20 - _, err = msgServer.PauseZRC20(ctx, types.NewMsgPauseZRC20( + msg = types.NewMsgPauseZRC20( admin, []string{ zrc20A, zrc20B, zrc20C, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.PauseZRC20(ctx, msg) require.NoError(t, err) assertPaused(zrc20A) assertPaused(zrc20B) @@ -102,12 +102,13 @@ func TestKeeper_PauseZRC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - _, err := msgServer.PauseZRC20(ctx, types.NewMsgPauseZRC20( + msg := types.NewMsgPauseZRC20( admin, []string{sample.EthAddress().String()}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.PauseZRC20(ctx, msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -119,21 +120,23 @@ func TestKeeper_PauseZRC20(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) zrc20A, zrc20B := sample.EthAddress().String(), sample.EthAddress().String() k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20A)) k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20B)) - _, err := msgServer.PauseZRC20(ctx, types.NewMsgPauseZRC20( + msg := types.NewMsgPauseZRC20( admin, []string{ zrc20A, sample.EthAddress().String(), zrc20B, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.PauseZRC20(ctx, msg) require.ErrorIs(t, err, types.ErrForeignCoinNotFound) }) } diff --git a/x/fungible/keeper/msg_server_remove_foreign_coin.go b/x/fungible/keeper/msg_server_remove_foreign_coin.go index 3f46997e3d..6903531cd3 100644 --- a/x/fungible/keeper/msg_server_remove_foreign_coin.go +++ b/x/fungible/keeper/msg_server_remove_foreign_coin.go @@ -20,11 +20,9 @@ func (k msgServer) RemoveForeignCoin( msg *types.MsgRemoveForeignCoin, ) (*types.MsgRemoveForeignCoinResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "Removal can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } index := msg.Name _, found := k.GetForeignCoins(ctx, index) diff --git a/x/fungible/keeper/msg_server_remove_foreign_coin_test.go b/x/fungible/keeper/msg_server_remove_foreign_coin_test.go index fafb01b597..01ac286e9e 100644 --- a/x/fungible/keeper/msg_server_remove_foreign_coin_test.go +++ b/x/fungible/keeper/msg_server_remove_foreign_coin_test.go @@ -23,8 +23,8 @@ func TestMsgServer_RemoveForeignCoin(t *testing.T) { k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() + authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) chainID := getValidChainID(t) @@ -34,7 +34,9 @@ func TestMsgServer_RemoveForeignCoin(t *testing.T) { _, found := k.GetForeignCoins(ctx, zrc20.Hex()) require.True(t, found) - _, err := msgServer.RemoveForeignCoin(ctx, types.NewMsgRemoveForeignCoin(admin, zrc20.Hex())) + msg := types.NewMsgRemoveForeignCoin(admin, zrc20.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.RemoveForeignCoin(ctx, msg) require.NoError(t, err) _, found = k.GetForeignCoins(ctx, zrc20.Hex()) require.False(t, found) @@ -51,12 +53,13 @@ func TestMsgServer_RemoveForeignCoin(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) zrc20 := setupGasCoin(t, ctx, k, sdkk.EvmKeeper, chainID, "foo", "foo") - _, err := msgServer.RemoveForeignCoin(ctx, types.NewMsgRemoveForeignCoin(admin, zrc20.Hex())) + msg := types.NewMsgRemoveForeignCoin(admin, zrc20.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.RemoveForeignCoin(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -69,11 +72,11 @@ func TestMsgServer_RemoveForeignCoin(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := msgServer.RemoveForeignCoin(ctx, types.NewMsgRemoveForeignCoin(admin, sample.EthAddress().Hex())) + msg := types.NewMsgRemoveForeignCoin(admin, sample.EthAddress().Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.RemoveForeignCoin(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, sdkerrors.ErrInvalidRequest) }) diff --git a/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap.go b/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap.go index e519ea509b..e978d2a907 100644 --- a/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap.go +++ b/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap.go @@ -20,11 +20,9 @@ func (k msgServer) UpdateZRC20LiquidityCap( ctx := sdk.UnwrapSDKContext(goCtx) // check authorization - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "update can only be executed by group 2 policy group", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // fetch the foreign coin diff --git a/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap_test.go b/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap_test.go index 416ba3115e..7282346d33 100644 --- a/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap_test.go +++ b/x/fungible/keeper/msg_server_udpate_zrc20_liquidity_cap_test.go @@ -28,28 +28,29 @@ func TestMsgServer_UpdateZRC20LiquidityCap(t *testing.T) { k.SetForeignCoins(ctx, foreignCoin) authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // can update liquidity cap - _, err := msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg := types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.NewUint(42), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.NoError(t, err) coin, found := k.GetForeignCoins(ctx, coinAddress) require.True(t, found) require.True(t, coin.LiquidityCap.Equal(math.NewUint(42)), "invalid liquidity cap", coin.LiquidityCap.String()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can update liquidity cap again - _, err = msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg = types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.NewUint(4200000), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.NoError(t, err) coin, found = k.GetForeignCoins(ctx, coinAddress) @@ -61,28 +62,28 @@ func TestMsgServer_UpdateZRC20LiquidityCap(t *testing.T) { coin.LiquidityCap.String(), ) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can set liquidity cap to 0 - _, err = msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg = types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.NewUint(0), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.NoError(t, err) coin, found = k.GetForeignCoins(ctx, coinAddress) require.True(t, found) require.True(t, coin.LiquidityCap.Equal(math.ZeroUint()), "invalid liquidity cap", coin.LiquidityCap.String()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can set liquidity cap to nil - _, err = msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg = types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.NoError(t, err) coin, found = k.GetForeignCoins(ctx, coinAddress) @@ -102,15 +103,15 @@ func TestMsgServer_UpdateZRC20LiquidityCap(t *testing.T) { foreignCoin := sample.ForeignCoins(t, coinAddress) foreignCoin.LiquidityCap = math.Uint{} k.SetForeignCoins(ctx, foreignCoin) - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg := types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.NewUint(42), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -123,15 +124,15 @@ func TestMsgServer_UpdateZRC20LiquidityCap(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() coinAddress := sample.EthAddress().String() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := msgServer.UpdateZRC20LiquidityCap(ctx, types.NewMsgUpdateZRC20LiquidityCap( + msg := types.NewMsgUpdateZRC20LiquidityCap( admin, coinAddress, math.NewUint(42), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateZRC20LiquidityCap(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, types.ErrForeignCoinNotFound) }) diff --git a/x/fungible/keeper/msg_server_unpause_zrc20.go b/x/fungible/keeper/msg_server_unpause_zrc20.go index f00afd6909..80a5541a73 100644 --- a/x/fungible/keeper/msg_server_unpause_zrc20.go +++ b/x/fungible/keeper/msg_server_unpause_zrc20.go @@ -18,11 +18,9 @@ func (k msgServer) UnpauseZRC20( ) (*types.MsgUnpauseZRC20Response, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "UnPauseZRC20 can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // iterate all foreign coins and set unpaused status @@ -36,7 +34,7 @@ func (k msgServer) UnpauseZRC20( k.SetForeignCoins(ctx, fc) } - err := ctx.EventManager().EmitTypedEvent( + err = ctx.EventManager().EmitTypedEvent( &types.EventZRC20Unpaused{ MsgTypeUrl: sdk.MsgTypeURL(&types.MsgUnpauseZRC20{}), Zrc20Addresses: msg.Zrc20Addresses, diff --git a/x/fungible/keeper/msg_server_unpause_zrc20_test.go b/x/fungible/keeper/msg_server_unpause_zrc20_test.go index ea90f01e85..e4e4a482f4 100644 --- a/x/fungible/keeper/msg_server_unpause_zrc20_test.go +++ b/x/fungible/keeper/msg_server_unpause_zrc20_test.go @@ -49,45 +49,45 @@ func TestKeeper_UnpauseZRC20(t *testing.T) { assertPaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can unpause zrc20 - _, err := msgServer.UnpauseZRC20(ctx, types.NewMsgUnpauseZRC20( + msg := types.NewMsgUnpauseZRC20( admin, []string{ zrc20A, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UnpauseZRC20(ctx, msg) require.NoError(t, err) assertUnpaused(zrc20A) assertPaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can unpause already unpaused zrc20 - _, err = msgServer.UnpauseZRC20(ctx, types.NewMsgUnpauseZRC20( + msg = types.NewMsgUnpauseZRC20( admin, []string{ zrc20C, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UnpauseZRC20(ctx, msg) require.NoError(t, err) assertUnpaused(zrc20A) assertPaused(zrc20B) assertUnpaused(zrc20C) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can unpause all zrc20 - _, err = msgServer.UnpauseZRC20(ctx, types.NewMsgUnpauseZRC20( + msg = types.NewMsgUnpauseZRC20( admin, []string{ zrc20A, zrc20B, zrc20C, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UnpauseZRC20(ctx, msg) require.NoError(t, err) assertUnpaused(zrc20A) assertUnpaused(zrc20B) @@ -104,12 +104,12 @@ func TestKeeper_UnpauseZRC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - - _, err := msgServer.UnpauseZRC20(ctx, types.NewMsgUnpauseZRC20( + msg := types.NewMsgUnpauseZRC20( admin, []string{sample.EthAddress().String()}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UnpauseZRC20(ctx, msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -123,20 +123,21 @@ func TestKeeper_UnpauseZRC20(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) zrc20A, zrc20B := sample.EthAddress().String(), sample.EthAddress().String() k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20A)) k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20B)) - _, err := msgServer.UnpauseZRC20(ctx, types.NewMsgUnpauseZRC20( + msg := types.NewMsgUnpauseZRC20( admin, []string{ zrc20A, sample.EthAddress().String(), zrc20B, }, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UnpauseZRC20(ctx, msg) require.ErrorIs(t, err, types.ErrForeignCoinNotFound) }) } diff --git a/x/fungible/keeper/msg_server_update_contract_bytecode.go b/x/fungible/keeper/msg_server_update_contract_bytecode.go index 2bbb24e7f3..8dfa6a0ada 100644 --- a/x/fungible/keeper/msg_server_update_contract_bytecode.go +++ b/x/fungible/keeper/msg_server_update_contract_bytecode.go @@ -3,7 +3,7 @@ package keeper import ( "context" - cosmoserror "cosmossdk.io/errors" + cosmoserrors "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ethcommon "github.com/ethereum/go-ethereum/common" @@ -26,21 +26,23 @@ func (k msgServer) UpdateContractBytecode( ctx := sdk.UnwrapSDKContext(goCtx) // check authorization - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupAdmin) { - return nil, cosmoserror.Wrap( - authoritytypes.ErrUnauthorized, - "Deploy can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // fetch account to update if !ethcommon.IsHexAddress(msg.ContractAddress) { - return nil, cosmoserror.Wrapf(sdkerrors.ErrInvalidAddress, "invalid contract address (%s)", msg.ContractAddress) + return nil, cosmoserrors.Wrapf( + sdkerrors.ErrInvalidAddress, + "invalid contract address (%s)", + msg.ContractAddress, + ) } contractAddress := ethcommon.HexToAddress(msg.ContractAddress) acct := k.evmKeeper.GetAccount(ctx, contractAddress) if acct == nil { - return nil, cosmoserror.Wrapf(types.ErrContractNotFound, "contract (%s) not found", contractAddress.Hex()) + return nil, cosmoserrors.Wrapf(types.ErrContractNotFound, "contract (%s) not found", contractAddress.Hex()) } // check the contract is a zrc20 @@ -53,7 +55,7 @@ func (k msgServer) UpdateContractBytecode( } if msg.ContractAddress != systemContract.ConnectorZevm { // not a zrc20 or wzeta connector contract, can't be updated - return nil, cosmoserror.Wrapf( + return nil, cosmoserrors.Wrapf( types.ErrInvalidContract, "contract (%s) is neither a zrc20 nor wzeta connector", msg.ContractAddress, @@ -64,9 +66,9 @@ func (k msgServer) UpdateContractBytecode( // set the new CodeHash to the account oldCodeHash := acct.CodeHash acct.CodeHash = ethcommon.HexToHash(msg.NewCodeHash).Bytes() - err := k.evmKeeper.SetAccount(ctx, contractAddress, *acct) + err = k.evmKeeper.SetAccount(ctx, contractAddress, *acct) if err != nil { - return nil, cosmoserror.Wrapf( + return nil, cosmoserrors.Wrapf( types.ErrSetBytecode, "failed to update contract (%s) bytecode (%s)", contractAddress.Hex(), @@ -85,7 +87,7 @@ func (k msgServer) UpdateContractBytecode( ) if err != nil { k.Logger(ctx).Error("failed to emit event", "error", err.Error()) - return nil, cosmoserror.Wrapf(types.ErrEmitEvent, "failed to emit event (%s)", err.Error()) + return nil, cosmoserrors.Wrapf(types.ErrEmitEvent, "failed to emit event (%s)", err.Error()) } return &types.MsgUpdateContractBytecodeResponse{}, nil diff --git a/x/fungible/keeper/msg_server_update_contract_bytecode_test.go b/x/fungible/keeper/msg_server_update_contract_bytecode_test.go index 18d095493a..287cc18046 100644 --- a/x/fungible/keeper/msg_server_update_contract_bytecode_test.go +++ b/x/fungible/keeper/msg_server_update_contract_bytecode_test.go @@ -95,14 +95,14 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { require.NoError(t, err) codeHash := codeHashFromAddress(t, ctx, k, newCodeAddress.Hex()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // update the bytecode - _, err = msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, zrc20.Hex(), codeHash, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateContractBytecode(ctx, msg) require.NoError(t, err) // check the returned new bytecode hash matches the one in the account @@ -141,13 +141,13 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { codeHash = codeHashFromAddress(t, ctx, k, newCodeAddress.Hex()) require.NoError(t, err) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - - _, err = msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg = types.NewMsgUpdateContractBytecode( admin, zrc20.Hex(), codeHash, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateContractBytecode(ctx, msg) require.NoError(t, err) balance, err = k.BalanceOfZRC4(ctx, zrc20, addr1) require.NoError(t, err) @@ -180,14 +180,14 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { require.NotEmpty(t, newConnector) assertContractDeployment(t, sdkk.EvmKeeper, ctx, newConnector) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can update the bytecode of the new connector with the old connector contract - _, err = msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, newConnector.Hex(), codeHash, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateContractBytecode(ctx, msg) require.NoError(t, err) }) @@ -199,13 +199,14 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, false) - _, err := msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, sample.EthAddress().Hex(), sample.Hash().Hex(), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -217,13 +218,15 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - _, err := msgServer.UpdateContractBytecode(ctx, &types.MsgUpdateContractBytecode{ - Creator: admin, - ContractAddress: "invalid", - NewCodeHash: sample.Hash().Hex(), - }) + msg := types.NewMsgUpdateContractBytecode( + admin, + "invalid", + sample.Hash().Hex(), + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, sdkerrors.ErrInvalidAddress) }) @@ -238,7 +241,6 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { contractAddr := sample.EthAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) mockEVMKeeper.On( "GetAccount", @@ -246,11 +248,13 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { contractAddr, ).Return(nil) - _, err := msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, contractAddr.Hex(), sample.Hash().Hex(), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, types.ErrContractNotFound) mockEVMKeeper.AssertExpectations(t) @@ -263,17 +267,17 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) k.GetAuthKeeper().GetModuleAccount(ctx, types.ModuleName) admin := sample.AccAddress() - authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) wzeta, _, _, _, _ := deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) // can't update the bytecode of the wzeta contract - _, err := msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, wzeta.Hex(), sample.Hash().Hex(), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, types.ErrInvalidContract) }) @@ -286,19 +290,19 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - _, _, _, connector, _ := deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) // remove system contract k.RemoveSystemContract(ctx) // can't update the bytecode of the wzeta contract - _, err := msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, connector.Hex(), sample.Hash().Hex(), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, types.ErrSystemContractNotFound) }) @@ -313,8 +317,6 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - contractAddr := sample.EthAddress() newCodeHash := sample.Hash().Hex() @@ -336,11 +338,13 @@ func TestKeeper_UpdateContractBytecode(t *testing.T) { mock.Anything, ).Return(errors.New("can't set account")) - _, err := msgServer.UpdateContractBytecode(ctx, types.NewMsgUpdateContractBytecode( + msg := types.NewMsgUpdateContractBytecode( admin, contractAddr.Hex(), newCodeHash, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateContractBytecode(ctx, msg) require.ErrorIs(t, err, types.ErrSetBytecode) mockEVMKeeper.AssertExpectations(t) diff --git a/x/fungible/keeper/msg_server_update_system_contract.go b/x/fungible/keeper/msg_server_update_system_contract.go index 411a2c07a0..5477396eeb 100644 --- a/x/fungible/keeper/msg_server_update_system_contract.go +++ b/x/fungible/keeper/msg_server_update_system_contract.go @@ -22,11 +22,9 @@ func (k msgServer) UpdateSystemContract( msg *types.MsgUpdateSystemContract, ) (*types.MsgUpdateSystemContractResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupAdmin) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "Deploy can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } newSystemContractAddr := ethcommon.HexToAddress(msg.NewSystemContractAddress) if newSystemContractAddr == (ethcommon.Address{}) { diff --git a/x/fungible/keeper/msg_server_update_system_contract_test.go b/x/fungible/keeper/msg_server_update_system_contract_test.go index 9849ae6523..bdea95980b 100644 --- a/x/fungible/keeper/msg_server_update_system_contract_test.go +++ b/x/fungible/keeper/msg_server_update_system_contract_test.go @@ -30,7 +30,6 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) queryZRC20SystemContract := func(contract common.Address) string { abi, err := zrc20.ZRC20MetaData.GetAbi() @@ -76,7 +75,9 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { require.NotEqual(t, oldSystemContract, newSystemContract) // can update the system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg := types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.NoError(t, err) // can retrieve the system contract @@ -120,10 +121,10 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { newSystemContract, err := k.DeployContract(ctx, systemcontract.SystemContractMetaData, wzeta, factory, router) require.NoError(t, err) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can update the system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg := types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.NoError(t, err) // can retrieve the system contract @@ -135,10 +136,10 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { newSystemContract, err = k.DeployContract(ctx, systemcontract.SystemContractMetaData, wzeta, factory, router) require.NoError(t, err) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can overwrite the previous system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg = types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.NoError(t, err) // can retrieve the system contract @@ -157,7 +158,6 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, false) // deploy a new system contracts wzeta, factory, router, _, oldSystemContract := deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) @@ -166,7 +166,9 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { require.NotEqual(t, oldSystemContract, newSystemContract) // should not update the system contract if not admin - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg := types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -181,8 +183,6 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // deploy a new system contracts wzeta, factory, router, _, oldSystemContract := deploySystemContracts(t, ctx, k, sdkk.EvmKeeper) newSystemContract, err := k.DeployContract(ctx, systemcontract.SystemContractMetaData, wzeta, factory, router) @@ -190,7 +190,9 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { require.NotEqual(t, oldSystemContract, newSystemContract) // should not update the system contract if invalid address - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, "invalid")) + msg := types.NewMsgUpdateSystemContract(admin, "invalid") + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.Error(t, err) require.ErrorIs(t, err, sdkerrors.ErrInvalidAddress) }) @@ -240,30 +242,30 @@ func TestKeeper_UpdateSystemContract(t *testing.T) { // fail on first evm call mockEVMKeeper.MockEVMFailCallOnce() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can't update the system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg := types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) // fail on second evm call mockEVMKeeper.MockEVMSuccessCallOnce() mockEVMKeeper.MockEVMFailCallOnce() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can't update the system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg = types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) // fail on third evm call mockEVMKeeper.MockEVMSuccessCallTimes(2) mockEVMKeeper.MockEVMFailCallOnce() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) - // can't update the system contract - _, err = msgServer.UpdateSystemContract(ctx, types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex())) + msg = types.NewMsgUpdateSystemContract(admin, newSystemContract.Hex()) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateSystemContract(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) }) } diff --git a/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee.go b/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee.go index 67d4821836..479877b84d 100644 --- a/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee.go +++ b/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee.go @@ -20,11 +20,9 @@ func (k msgServer) UpdateZRC20WithdrawFee( ctx := sdk.UnwrapSDKContext(goCtx) // check signer permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, cosmoserrors.Wrap( - authoritytypes.ErrUnauthorized, - "deploy can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // check the zrc20 exists diff --git a/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee_test.go b/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee_test.go index 2d6c9a95a3..fb156a1c36 100644 --- a/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee_test.go +++ b/x/fungible/keeper/msg_server_update_zrc20_withdraw_fee_test.go @@ -41,15 +41,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { require.NoError(t, err) require.Zero(t, protocolFee.Uint64()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can update the protocol fee and gas limit - _, err = msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20Addr.String(), math.NewUint(42), math.NewUint(42), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.NoError(t, err) // can query the updated fee @@ -60,15 +60,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(42), gasLimit.Uint64()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can update protocol fee only - _, err = msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg = types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20Addr.String(), math.NewUint(43), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.NoError(t, err) protocolFee, err = k.QueryProtocolFlatFee(ctx, zrc20Addr) require.NoError(t, err) @@ -77,15 +77,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(42), gasLimit.Uint64()) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // can update gas limit only - _, err = msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg = types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20Addr.String(), math.Uint{}, math.NewUint(44), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.NoError(t, err) protocolFee, err = k.QueryProtocolFlatFee(ctx, zrc20Addr) require.NoError(t, err) @@ -103,14 +103,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, sample.EthAddress().String(), math.NewUint(42), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) + _, err := msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -122,14 +123,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, "invalid_address", math.NewUint(42), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, sdkerrors.ErrInvalidAddress) }) @@ -141,14 +143,15 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { msgServer := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, sample.EthAddress().String(), math.NewUint(42), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, types.ErrForeignCoinNotFound) }) @@ -163,18 +166,19 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { // setup admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - zrc20 := sample.EthAddress() + k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20.String())) // the method shall fail since we only set the foreign coin manually in the store but didn't deploy the contract - _, err := msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20.String(), math.NewUint(42), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err := msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) }) @@ -190,7 +194,6 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { // setup admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) zrc20Addr := sample.EthAddress() k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20Addr.String())) @@ -217,12 +220,14 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { // this is the update call (commit == true) mockEVMKeeper.MockEVMFailCallOnce() - _, err = msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20Addr.String(), math.NewUint(42), math.Uint{}, - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) mockEVMKeeper.AssertExpectations(t) @@ -240,7 +245,6 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { // setup admin := sample.AccAddress() authorityMock := keepertest.GetFungibleAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) zrc20Addr := sample.EthAddress() k.SetForeignCoins(ctx, sample.ForeignCoins(t, zrc20Addr.String())) @@ -264,12 +268,14 @@ func TestKeeper_UpdateZRC20WithdrawFee(t *testing.T) { require.NoError(t, err) mockEVMKeeper.MockEVMFailCallOnce() - _, err = msgServer.UpdateZRC20WithdrawFee(ctx, types.NewMsgUpdateZRC20WithdrawFee( + msg := types.NewMsgUpdateZRC20WithdrawFee( admin, zrc20Addr.String(), math.Uint{}, math.NewUint(42), - )) + ) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) + _, err = msgServer.UpdateZRC20WithdrawFee(ctx, msg) require.ErrorIs(t, err, types.ErrContractCall) mockEVMKeeper.AssertExpectations(t) diff --git a/x/fungible/types/expected_keepers.go b/x/fungible/types/expected_keepers.go index 514e7867f3..b2997566ff 100644 --- a/x/fungible/types/expected_keepers.go +++ b/x/fungible/types/expected_keepers.go @@ -13,7 +13,6 @@ import ( evmtypes "github.com/evmos/ethermint/x/evm/types" "github.com/zeta-chain/zetacore/pkg/chains" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" ) // AccountKeeper defines the expected account keeper used for simulations (noalias) @@ -60,5 +59,5 @@ type EVMKeeper interface { } type AuthorityKeeper interface { - IsAuthorized(ctx sdk.Context, address string, policyType authoritytypes.PolicyType) bool + CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error } diff --git a/x/lightclient/keeper/msg_server_disable_block_header_verification._test.go b/x/lightclient/keeper/msg_server_disable_block_header_verification._test.go index bfb300a344..c3f9f9d52a 100644 --- a/x/lightclient/keeper/msg_server_disable_block_header_verification._test.go +++ b/x/lightclient/keeper/msg_server_disable_block_header_verification._test.go @@ -39,12 +39,14 @@ func TestMsgServer_DisableVerificationFlags(t *testing.T) { }) // enable eth type chain - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgDisableHeaderVerification{ + msg := types.MsgDisableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId, chains.BitcoinMainnet.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) + bhv, found := k.GetBlockHeaderVerification(ctx) require.True(t, found) require.False(t, bhv.IsChainEnabled(chains.Ethereum.ChainId)) @@ -75,11 +77,12 @@ func TestMsgServer_DisableVerificationFlags(t *testing.T) { }, }) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgDisableHeaderVerification{ + msg := types.MsgDisableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -94,11 +97,12 @@ func TestMsgServer_DisableVerificationFlags(t *testing.T) { authorityMock := keepertest.GetLightclientAuthorityMock(t, k) // enable eth type chain - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgDisableHeaderVerification{ + msg := types.MsgDisableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId, chains.BitcoinMainnet.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) bhv, found := k.GetBlockHeaderVerification(ctx) require.True(t, found) diff --git a/x/lightclient/keeper/msg_server_disable_block_header_verification.go b/x/lightclient/keeper/msg_server_disable_block_header_verification.go index 023bc17515..c38bac327b 100644 --- a/x/lightclient/keeper/msg_server_disable_block_header_verification.go +++ b/x/lightclient/keeper/msg_server_disable_block_header_verification.go @@ -3,6 +3,7 @@ package keeper import ( "context" + cosmoserrors "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -18,8 +19,10 @@ func (k msgServer) DisableHeaderVerification( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) { - return nil, authoritytypes.ErrUnauthorized + + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } bhv, found := k.GetBlockHeaderVerification(ctx) diff --git a/x/lightclient/keeper/msg_server_enable_block_header_verification.go b/x/lightclient/keeper/msg_server_enable_block_header_verification.go index 84b3ff3a69..0b0613c8e2 100644 --- a/x/lightclient/keeper/msg_server_enable_block_header_verification.go +++ b/x/lightclient/keeper/msg_server_enable_block_header_verification.go @@ -3,6 +3,7 @@ package keeper import ( "context" + cosmoserrors "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -18,8 +19,10 @@ func (k msgServer) EnableHeaderVerification(goCtx context.Context, msg *types.Ms ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return nil, authoritytypes.ErrUnauthorized + + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } bhv, found := k.GetBlockHeaderVerification(ctx) diff --git a/x/lightclient/keeper/msg_server_enable_block_header_verification_test.go b/x/lightclient/keeper/msg_server_enable_block_header_verification_test.go index 2025904b9a..96c5444123 100644 --- a/x/lightclient/keeper/msg_server_enable_block_header_verification_test.go +++ b/x/lightclient/keeper/msg_server_enable_block_header_verification_test.go @@ -39,11 +39,12 @@ func TestMsgServer_EnableVerificationFlags(t *testing.T) { }) // enable both eth and btc type chain together - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgEnableHeaderVerification{ + msg := types.MsgEnableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId, chains.BitcoinMainnet.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) bhv, found := k.GetBlockHeaderVerification(ctx) require.True(t, found) @@ -62,11 +63,12 @@ func TestMsgServer_EnableVerificationFlags(t *testing.T) { authorityMock := keepertest.GetLightclientAuthorityMock(t, k) // enable both eth and btc type chain together - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgEnableHeaderVerification{ + msg := types.MsgEnableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId, chains.BitcoinMainnet.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) bhv, found := k.GetBlockHeaderVerification(ctx) require.True(t, found) @@ -97,11 +99,12 @@ func TestMsgServer_EnableVerificationFlags(t *testing.T) { }, }) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &types.MsgEnableHeaderVerification{ + msg := types.MsgEnableHeaderVerification{ Creator: admin, ChainIdList: []int64{chains.Ethereum.ChainId}, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.EnableHeaderVerification(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) } diff --git a/x/lightclient/types/expected_keepers.go b/x/lightclient/types/expected_keepers.go index 57dec2bd11..bee335d628 100644 --- a/x/lightclient/types/expected_keepers.go +++ b/x/lightclient/types/expected_keepers.go @@ -2,10 +2,8 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" - - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" ) type AuthorityKeeper interface { - IsAuthorized(ctx sdk.Context, address string, policyType authoritytypes.PolicyType) bool + CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error } diff --git a/x/observer/keeper/msg_server_add_observer.go b/x/observer/keeper/msg_server_add_observer.go index 034cd5e22e..11f7701dde 100644 --- a/x/observer/keeper/msg_server_add_observer.go +++ b/x/observer/keeper/msg_server_add_observer.go @@ -21,10 +21,10 @@ func (k msgServer) AddObserver( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgAddObserverResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, cosmoserrors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } - pubkey, err := crypto.NewPubKey(msg.ZetaclientGranteePubkey) if err != nil { return &types.MsgAddObserverResponse{}, cosmoserrors.Wrap(sdkerrors.ErrInvalidPubKey, err.Error()) diff --git a/x/observer/keeper/msg_server_add_observer_test.go b/x/observer/keeper/msg_server_add_observer_test.go index e8006f4e42..b26a2fe5d0 100644 --- a/x/observer/keeper/msg_server_add_observer_test.go +++ b/x/observer/keeper/msg_server_add_observer_test.go @@ -4,7 +4,6 @@ import ( "math" "testing" - "github.com/cometbft/cometbft/crypto" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/require" @@ -22,15 +21,16 @@ func TestMsgServer_AddObserver(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) wctx := sdk.WrapSDKContext(ctx) - srv := keeper.NewMsgServerImpl(*k) - res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + + msg := types.MsgAddObserver{ Creator: admin, - }) - require.Error(t, err) - require.Equal(t, &types.MsgAddObserverResponse{}, res) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + res, err := srv.AddObserver(wctx, &msg) + require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) + require.Nil(t, res) }) t.Run("should error if pub key not valid", func(t *testing.T) { @@ -39,14 +39,15 @@ func TestMsgServer_AddObserver(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) wctx := sdk.WrapSDKContext(ctx) - srv := keeper.NewMsgServerImpl(*k) - res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + + msg := types.MsgAddObserver{ Creator: admin, ZetaclientGranteePubkey: "invalid", - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.AddObserver(wctx, &msg) require.Error(t, err) require.Equal(t, &types.MsgAddObserverResponse{}, res) }) @@ -57,19 +58,21 @@ func TestMsgServer_AddObserver(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + observerAddress := sample.AccAddress() wctx := sdk.WrapSDKContext(ctx) _, found := k.GetLastObserverCount(ctx) require.False(t, found) srv := keeper.NewMsgServerImpl(*k) - observerAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverAddress"))) - res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + + msg := types.MsgAddObserver{ Creator: admin, ZetaclientGranteePubkey: sample.PubKeyString(), AddNodeAccountOnly: false, - ObserverAddress: observerAddress.String(), - }) + ObserverAddress: observerAddress, + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.AddObserver(wctx, &msg) require.NoError(t, err) require.Equal(t, &types.MsgAddObserverResponse{}, res) @@ -84,24 +87,27 @@ func TestMsgServer_AddObserver(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + observerAddress := sample.AccAddress() + wctx := sdk.WrapSDKContext(ctx) _, found := k.GetLastObserverCount(ctx) require.False(t, found) srv := keeper.NewMsgServerImpl(*k) - observerAddress := sdk.AccAddress(crypto.AddressHash([]byte("ObserverAddress"))) + _, found = k.GetKeygen(ctx) require.False(t, found) - _, found = k.GetNodeAccount(ctx, observerAddress.String()) + _, found = k.GetNodeAccount(ctx, observerAddress) require.False(t, found) - res, err := srv.AddObserver(wctx, &types.MsgAddObserver{ + msg := types.MsgAddObserver{ Creator: admin, ZetaclientGranteePubkey: sample.PubKeyString(), AddNodeAccountOnly: true, - ObserverAddress: observerAddress.String(), - }) + ObserverAddress: observerAddress, + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.AddObserver(wctx, &msg) require.NoError(t, err) require.Equal(t, &types.MsgAddObserverResponse{}, res) @@ -112,7 +118,7 @@ func TestMsgServer_AddObserver(t *testing.T) { require.True(t, found) require.Equal(t, types.Keygen{BlockNumber: math.MaxInt64}, keygen) - _, found = k.GetNodeAccount(ctx, observerAddress.String()) + _, found = k.GetNodeAccount(ctx, observerAddress) require.True(t, found) }) } diff --git a/x/observer/keeper/msg_server_disable_cctx_flags.go b/x/observer/keeper/msg_server_disable_cctx_flags.go index 0645db2cc9..82b1438652 100644 --- a/x/observer/keeper/msg_server_disable_cctx_flags.go +++ b/x/observer/keeper/msg_server_disable_cctx_flags.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -18,12 +19,10 @@ func (k msgServer) DisableCCTX( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) { - return &types.MsgDisableCCTXResponse{}, authoritytypes.ErrUnauthorized.Wrap( - "DisableCCTX can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } - // check if the value exists, // if not, set the default value for the Inbound and Outbound flags only flags, isFound := k.GetCrosschainFlags(ctx) @@ -41,7 +40,7 @@ func (k msgServer) DisableCCTX( k.SetCrosschainFlags(ctx, flags) - err := ctx.EventManager().EmitTypedEvents(&types.EventCCTXDisabled{ + err = ctx.EventManager().EmitTypedEvents(&types.EventCCTXDisabled{ MsgTypeUrl: sdk.MsgTypeURL(&types.MsgDisableCCTX{}), IsInboundEnabled: flags.IsInboundEnabled, IsOutboundEnabled: flags.IsOutboundEnabled, diff --git a/x/observer/keeper/msg_server_disable_cctx_flags_test.go b/x/observer/keeper/msg_server_disable_cctx_flags_test.go index 2dcab71f51..74ec5cb14a 100644 --- a/x/observer/keeper/msg_server_disable_cctx_flags_test.go +++ b/x/observer/keeper/msg_server_disable_cctx_flags_test.go @@ -20,15 +20,15 @@ func TestMsgServer_MsgDisableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgDisableCCTX{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + msg := types.MsgDisableCCTX{ Creator: admin, DisableOutbound: true, DisableInbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - - _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -50,15 +50,15 @@ func TestMsgServer_MsgDisableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgDisableCCTX{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + msg := types.MsgDisableCCTX{ Creator: admin, DisableOutbound: true, DisableInbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - - _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -80,15 +80,15 @@ func TestMsgServer_MsgDisableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgDisableCCTX{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + msg := types.MsgDisableCCTX{ Creator: admin, DisableOutbound: true, DisableInbound: false, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - - _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -110,15 +110,15 @@ func TestMsgServer_MsgDisableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgDisableCCTX{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + msg := types.MsgDisableCCTX{ Creator: admin, DisableOutbound: false, DisableInbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - - _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -135,15 +135,15 @@ func TestMsgServer_MsgDisableCCTX(t *testing.T) { srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgDisableCCTX{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + msg := types.MsgDisableCCTX{ Creator: admin, DisableOutbound: true, DisableInbound: false, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) - - _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.DisableCCTX(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, authoritytypes.ErrUnauthorized, err) _, found := k.GetCrosschainFlags(ctx) diff --git a/x/observer/keeper/msg_server_enable_cctx_flags.go b/x/observer/keeper/msg_server_enable_cctx_flags.go index a361af6708..4df76a29a8 100644 --- a/x/observer/keeper/msg_server_enable_cctx_flags.go +++ b/x/observer/keeper/msg_server_enable_cctx_flags.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -18,10 +19,9 @@ func (k msgServer) EnableCCTX( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgEnableCCTXResponse{}, authoritytypes.ErrUnauthorized.Wrap( - "EnableCCTX can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // check if the value exists, @@ -41,7 +41,7 @@ func (k msgServer) EnableCCTX( k.SetCrosschainFlags(ctx, flags) - err := ctx.EventManager().EmitTypedEvents(&types.EventCCTXEnabled{ + err = ctx.EventManager().EmitTypedEvents(&types.EventCCTXEnabled{ MsgTypeUrl: sdk.MsgTypeURL(&types.MsgEnableCCTX{}), IsInboundEnabled: flags.IsInboundEnabled, IsOutboundEnabled: flags.IsOutboundEnabled, diff --git a/x/observer/keeper/msg_server_enable_cctx_flags_test.go b/x/observer/keeper/msg_server_enable_cctx_flags_test.go index 8bb3346f91..d09ea79262 100644 --- a/x/observer/keeper/msg_server_enable_cctx_flags_test.go +++ b/x/observer/keeper/msg_server_enable_cctx_flags_test.go @@ -20,14 +20,14 @@ func TestMsgServer_EnableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + msg := &types.MsgEnableCCTX{ Creator: admin, EnableInbound: true, EnableOutbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) _, err := srv.EnableCCTX(sdk.WrapSDKContext(ctx), msg) require.NoError(t, err) @@ -50,14 +50,14 @@ func TestMsgServer_EnableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + msg := &types.MsgEnableCCTX{ Creator: admin, EnableInbound: true, EnableOutbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) _, err := srv.EnableCCTX(sdk.WrapSDKContext(ctx), msg) require.NoError(t, err) @@ -80,14 +80,14 @@ func TestMsgServer_EnableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + msg := &types.MsgEnableCCTX{ Creator: admin, EnableInbound: true, EnableOutbound: false, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) _, err := srv.EnableCCTX(sdk.WrapSDKContext(ctx), msg) require.NoError(t, err) @@ -110,14 +110,14 @@ func TestMsgServer_EnableCCTX(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + msg := &types.MsgEnableCCTX{ Creator: admin, EnableInbound: false, EnableOutbound: true, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, nil) _, err := srv.EnableCCTX(sdk.WrapSDKContext(ctx), msg) require.NoError(t, err) @@ -135,13 +135,14 @@ func TestMsgServer_EnableCCTX(t *testing.T) { srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + msg := &types.MsgEnableCCTX{ Creator: admin, EnableInbound: true, EnableOutbound: false, } - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) + keepertest.MockCheckAuthorization(&authorityMock.Mock, msg, authoritytypes.ErrUnauthorized) _, err := srv.EnableCCTX(sdk.WrapSDKContext(ctx), msg) require.ErrorIs(t, authoritytypes.ErrUnauthorized, err) diff --git a/x/observer/keeper/msg_server_remove_chain_params.go b/x/observer/keeper/msg_server_remove_chain_params.go index 30f7a3e91b..b8e419b121 100644 --- a/x/observer/keeper/msg_server_remove_chain_params.go +++ b/x/observer/keeper/msg_server_remove_chain_params.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -17,10 +18,10 @@ func (k msgServer) RemoveChainParams( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgRemoveChainParamsResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } - // find current core params list or initialize a new one chainParamsList, found := k.GetChainParamsList(ctx) if !found { diff --git a/x/observer/keeper/msg_server_remove_chain_params_test.go b/x/observer/keeper/msg_server_remove_chain_params_test.go index e4d5b457ee..5139f84086 100644 --- a/x/observer/keeper/msg_server_remove_chain_params_test.go +++ b/x/observer/keeper/msg_server_remove_chain_params_test.go @@ -30,7 +30,6 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { // set admin admin := sample.AccAddress() - // add chain params k.SetChainParamsList(ctx, types.ChainParamsList{ ChainParams: []*types.ChainParams{ @@ -41,11 +40,12 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { }) // remove chain params - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg := types.MsgRemoveChainParams{ Creator: admin, ChainId: chain2, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has two chain params @@ -55,12 +55,13 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { require.Equal(t, chain1, chainParamsList.ChainParams[0].ChainId) require.Equal(t, chain3, chainParamsList.ChainParams[1].ChainId) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // remove chain params - _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg = types.MsgRemoveChainParams{ Creator: admin, ChainId: chain1, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has one chain params @@ -69,13 +70,13 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { require.Len(t, chainParamsList.ChainParams, 1) require.Equal(t, chain3, chainParamsList.ChainParams[0].ChainId) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // remove chain params - _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg = types.MsgRemoveChainParams{ Creator: admin, ChainId: chain3, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has no chain params @@ -92,12 +93,13 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg := types.MsgRemoveChainParams{ Creator: admin, ChainId: chains.ExternalChainList()[0].ChainId, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -106,20 +108,21 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { UseAuthorityMock: true, }) srv := keeper.NewMsgServerImpl(*k) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) // set admin admin := sample.AccAddress() - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) // not found if no chain params _, found := k.GetChainParamsList(ctx) require.False(t, found) - _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg := types.MsgRemoveChainParams{ Creator: admin, ChainId: chains.ExternalChainList()[0].ChainId, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, types.ErrChainParamsNotFound) // add chain params @@ -131,13 +134,13 @@ func TestMsgServer_RemoveChainParams(t *testing.T) { }, }) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // not found if chain ID not in list - _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &types.MsgRemoveChainParams{ + msg = types.MsgRemoveChainParams{ Creator: admin, ChainId: chains.ExternalChainList()[3].ChainId, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.RemoveChainParams(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, types.ErrChainParamsNotFound) }) } diff --git a/x/observer/keeper/msg_server_reset_chain_nonces.go b/x/observer/keeper/msg_server_reset_chain_nonces.go index 3c03ea6f82..a3ea3ecc81 100644 --- a/x/observer/keeper/msg_server_reset_chain_nonces.go +++ b/x/observer/keeper/msg_server_reset_chain_nonces.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/zeta-chain/zetacore/pkg/chains" @@ -16,8 +17,9 @@ func (k msgServer) ResetChainNonces( msg *types.MsgResetChainNonces, ) (*types.MsgResetChainNoncesResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgResetChainNoncesResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } tss, found := k.GetTSS(ctx) diff --git a/x/observer/keeper/msg_server_reset_chain_nonces_test.go b/x/observer/keeper/msg_server_reset_chain_nonces_test.go index 5306b98bcc..53009d9608 100644 --- a/x/observer/keeper/msg_server_reset_chain_nonces_test.go +++ b/x/observer/keeper/msg_server_reset_chain_nonces_test.go @@ -21,17 +21,17 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) chainId := chains.GoerliLocalnet.ChainId - admin := sample.AccAddress() authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &types.MsgResetChainNonces{ + msg := types.MsgResetChainNonces{ Creator: admin, ChainId: chainId, ChainNonceLow: 1, ChainNonceHigh: 5, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) @@ -43,15 +43,16 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - chainId := chains.GoerliLocalnet.ChainId - _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &types.MsgResetChainNonces{ + + msg := types.MsgResetChainNonces{ Creator: admin, ChainId: chainId, ChainNonceLow: 1, ChainNonceHigh: 5, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, types.ErrTssNotFound) }) @@ -65,14 +66,16 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { admin := sample.AccAddress() authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &types.MsgResetChainNonces{ + msg := types.MsgResetChainNonces{ Creator: admin, ChainId: 999, ChainNonceLow: 1, ChainNonceHigh: 5, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + + _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &msg) require.ErrorIs(t, err, types.ErrSupportedChains) }) @@ -85,11 +88,10 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { k.SetTSS(ctx, tss) admin := sample.AccAddress() - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - chainId := chains.GoerliLocalnet.ChainId + nonceLow := 1 + nonceHigh := 5 + authorityMock := keepertest.GetObserverAuthorityMock(t, k) index := chains.GoerliLocalnet.ChainName.String() // check existing chain nonces @@ -99,14 +101,15 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { require.False(t, found) // reset chain nonces - nonceLow := 1 - nonceHigh := 5 - _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &types.MsgResetChainNonces{ + // Reset nonces to nonceLow and nonceHigh + msg := types.MsgResetChainNonces{ Creator: admin, ChainId: chainId, ChainNonceLow: int64(nonceLow), ChainNonceHigh: int64(nonceHigh), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check updated chain nonces @@ -124,12 +127,15 @@ func TestMsgServer_ResetChainNonces(t *testing.T) { require.Equal(t, int64(nonceHigh), pendingNonces.NonceHigh) // reset nonces back to 0 - _, err = srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &types.MsgResetChainNonces{ + // Reset nonces back to 0 + msg = types.MsgResetChainNonces{ Creator: admin, ChainId: chainId, ChainNonceLow: 0, ChainNonceHigh: 0, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.ResetChainNonces(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check updated chain nonces diff --git a/x/observer/keeper/msg_server_update_chain_params.go b/x/observer/keeper/msg_server_update_chain_params.go index c7895016fc..3e4e3faf4f 100644 --- a/x/observer/keeper/msg_server_update_chain_params.go +++ b/x/observer/keeper/msg_server_update_chain_params.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -20,8 +21,9 @@ func (k msgServer) UpdateChainParams( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgUpdateChainParamsResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // find current chain params list or initialize a new one diff --git a/x/observer/keeper/msg_server_update_chain_params_test.go b/x/observer/keeper/msg_server_update_chain_params_test.go index d842ab9ac7..fc38200764 100644 --- a/x/observer/keeper/msg_server_update_chain_params_test.go +++ b/x/observer/keeper/msg_server_update_chain_params_test.go @@ -27,20 +27,20 @@ func TestMsgServer_UpdateChainParams(t *testing.T) { // set admin admin := sample.AccAddress() + chainParams1 := sample.ChainParams(chain1) authorityMock := keepertest.GetObserverAuthorityMock(t, k) // check list initially empty _, found := k.GetChainParamsList(ctx) require.False(t, found) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // a new chain params can be added - chainParams1 := sample.ChainParams(chain1) - _, err := srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainParams{ + msg := types.MsgUpdateChainParams{ Creator: admin, ChainParams: chainParams1, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has one chain params @@ -48,15 +48,15 @@ func TestMsgServer_UpdateChainParams(t *testing.T) { require.True(t, found) require.Len(t, chainParamsList.ChainParams, 1) require.Equal(t, chainParams1, chainParamsList.ChainParams[0]) - - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + chainParams2 := sample.ChainParams(chain2) // a new chian params can be added - chainParams2 := sample.ChainParams(chain2) - _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainParams{ + msg = types.MsgUpdateChainParams{ Creator: admin, ChainParams: chainParams2, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has two chain params @@ -65,15 +65,15 @@ func TestMsgServer_UpdateChainParams(t *testing.T) { require.Len(t, chainParamsList.ChainParams, 2) require.Equal(t, chainParams1, chainParamsList.ChainParams[0]) require.Equal(t, chainParams2, chainParamsList.ChainParams[1]) - - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) + chainParams3 := sample.ChainParams(chain3) // a new chain params can be added - chainParams3 := sample.ChainParams(chain3) - _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainParams{ + msg = types.MsgUpdateChainParams{ Creator: admin, ChainParams: chainParams3, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has three chain params @@ -84,14 +84,14 @@ func TestMsgServer_UpdateChainParams(t *testing.T) { require.Equal(t, chainParams2, chainParamsList.ChainParams[1]) require.Equal(t, chainParams3, chainParamsList.ChainParams[2]) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - // chain params can be updated chainParams2.ConfirmationCount = chainParams2.ConfirmationCount + 1 - _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainParams{ + msg = types.MsgUpdateChainParams{ Creator: admin, ChainParams: chainParams2, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) // check list has three chain params @@ -108,15 +108,16 @@ func TestMsgServer_UpdateChainParams(t *testing.T) { UseAuthorityMock: true, }) srv := keeper.NewMsgServerImpl(*k) - admin := sample.AccAddress() authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - _, err := srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &types.MsgUpdateChainParams{ + msg := types.MsgUpdateChainParams{ Creator: admin, ChainParams: sample.ChainParams(chains.ExternalChainList()[0].ChainId), - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.UpdateChainParams(sdk.WrapSDKContext(ctx), &msg) + require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) }) } diff --git a/x/observer/keeper/msg_server_update_gas_price_increase_flags.go b/x/observer/keeper/msg_server_update_gas_price_increase_flags.go index 7402ab0cc7..203e289c43 100644 --- a/x/observer/keeper/msg_server_update_gas_price_increase_flags.go +++ b/x/observer/keeper/msg_server_update_gas_price_increase_flags.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -18,10 +19,9 @@ func (k msgServer) UpdateGasPriceIncreaseFlags( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupOperational) { - return &types.MsgUpdateGasPriceIncreaseFlagsResponse{}, authoritytypes.ErrUnauthorized.Wrap( - "UpdateGasPriceIncreaseFlags can only be executed by the correct policy account", - ) + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } // check if the value exists, // if not, set the default value for the GasPriceIncreaseFlags only @@ -33,7 +33,7 @@ func (k msgServer) UpdateGasPriceIncreaseFlags( flags.IsOutboundEnabled = false } - err := msg.GasPriceIncreaseFlags.Validate() + err = msg.GasPriceIncreaseFlags.Validate() if err != nil { return &types.MsgUpdateGasPriceIncreaseFlagsResponse{}, err } diff --git a/x/observer/keeper/msg_server_update_gas_price_increase_flags_test.go b/x/observer/keeper/msg_server_update_gas_price_increase_flags_test.go index 63fcec7acf..7fe99dc47d 100644 --- a/x/observer/keeper/msg_server_update_gas_price_increase_flags_test.go +++ b/x/observer/keeper/msg_server_update_gas_price_increase_flags_test.go @@ -21,16 +21,16 @@ func TestKeeper_UpdateGasPriceIncreaseFlags(t *testing.T) { srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() updatedFlags := sample.GasPriceIncreaseFlags() - msg := &types.MsgUpdateGasPriceIncreaseFlags{ - Creator: admin, - GasPriceIncreaseFlags: updatedFlags, - } + authorityMock := keepertest.GetObserverAuthorityMock(t, k) // mock the authority keeper for authorization - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), msg) + msg := types.MsgUpdateGasPriceIncreaseFlags{ + Creator: admin, + GasPriceIncreaseFlags: updatedFlags, + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -49,16 +49,16 @@ func TestKeeper_UpdateGasPriceIncreaseFlags(t *testing.T) { defaultCrosschainFlags := types.DefaultCrosschainFlags() k.SetCrosschainFlags(ctx, *defaultCrosschainFlags) updatedFlags := sample.GasPriceIncreaseFlags() - msg := &types.MsgUpdateGasPriceIncreaseFlags{ - Creator: admin, - GasPriceIncreaseFlags: updatedFlags, - } // mock the authority keeper for authorization authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), msg) + msg := types.MsgUpdateGasPriceIncreaseFlags{ + Creator: admin, + GasPriceIncreaseFlags: updatedFlags, + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) flags, found := k.GetCrosschainFlags(ctx) @@ -74,7 +74,11 @@ func TestKeeper_UpdateGasPriceIncreaseFlags(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgUpdateGasPriceIncreaseFlags{ + authorityMock := keepertest.GetObserverAuthorityMock(t, k) + + // mock the authority keeper for authorization + + msg := types.MsgUpdateGasPriceIncreaseFlags{ Creator: admin, GasPriceIncreaseFlags: types.GasPriceIncreaseFlags{ EpochLength: -1, @@ -82,12 +86,8 @@ func TestKeeper_UpdateGasPriceIncreaseFlags(t *testing.T) { GasPriceIncreasePercent: 1, }, } - - // mock the authority keeper for authorization - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, true) - - _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), msg) + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), &msg) require.ErrorContains(t, err, "epoch length must be positive") _, found := k.GetCrosschainFlags(ctx) @@ -101,16 +101,15 @@ func TestKeeper_UpdateGasPriceIncreaseFlags(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - msg := &types.MsgUpdateGasPriceIncreaseFlags{ - Creator: admin, - GasPriceIncreaseFlags: sample.GasPriceIncreaseFlags(), - } // mock the authority keeper for authorization authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupOperational, false) - - _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), msg) + msg := types.MsgUpdateGasPriceIncreaseFlags{ + Creator: admin, + GasPriceIncreaseFlags: sample.GasPriceIncreaseFlags(), + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + _, err := srv.UpdateGasPriceIncreaseFlags(sdk.WrapSDKContext(ctx), &msg) require.ErrorContains(t, err, "sender not authorized") }) } diff --git a/x/observer/keeper/msg_server_update_keygen.go b/x/observer/keeper/msg_server_update_keygen.go index 6ee1faca62..343c8a8019 100644 --- a/x/observer/keeper/msg_server_update_keygen.go +++ b/x/observer/keeper/msg_server_update_keygen.go @@ -3,6 +3,7 @@ package keeper import ( "context" + "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" @@ -20,8 +21,9 @@ func (k msgServer) UpdateKeygen( ctx := sdk.UnwrapSDKContext(goCtx) // check permission - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupEmergency) { - return &types.MsgUpdateKeygenResponse{}, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return nil, errors.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } keygen, found := k.GetKeygen(ctx) diff --git a/x/observer/keeper/msg_server_update_keygen_test.go b/x/observer/keeper/msg_server_update_keygen_test.go index 4cf6237ce3..966f4a2588 100644 --- a/x/observer/keeper/msg_server_update_keygen_test.go +++ b/x/observer/keeper/msg_server_update_keygen_test.go @@ -20,15 +20,16 @@ func TestMsgServer_UpdateKeygen(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, false) wctx := sdk.WrapSDKContext(ctx) srv := keeper.NewMsgServerImpl(*k) - res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + msg := types.MsgUpdateKeygen{ Creator: admin, - }) - require.Error(t, err) - require.Equal(t, &types.MsgUpdateKeygenResponse{}, res) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, authoritytypes.ErrUnauthorized) + res, err := srv.UpdateKeygen(wctx, &msg) + require.ErrorIs(t, err, authoritytypes.ErrUnauthorized) + require.Nil(t, res) }) t.Run("should error if keygen not found", func(t *testing.T) { @@ -37,13 +38,15 @@ func TestMsgServer_UpdateKeygen(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) - wctx := sdk.WrapSDKContext(ctx) + wctx := sdk.WrapSDKContext(ctx) srv := keeper.NewMsgServerImpl(*k) - res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + + msg := types.MsgUpdateKeygen{ Creator: admin, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.UpdateKeygen(wctx, &msg) require.Error(t, err) require.Nil(t, res) }) @@ -54,17 +57,20 @@ func TestMsgServer_UpdateKeygen(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) + wctx := sdk.WrapSDKContext(ctx) item := types.Keygen{ BlockNumber: 10, } k.SetKeygen(ctx, item) srv := keeper.NewMsgServerImpl(*k) - res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + + msg := types.MsgUpdateKeygen{ Creator: admin, Block: 2, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.UpdateKeygen(wctx, &msg) require.Error(t, err) require.Nil(t, res) }) @@ -75,7 +81,7 @@ func TestMsgServer_UpdateKeygen(t *testing.T) { }) authorityMock := keepertest.GetObserverAuthorityMock(t, k) admin := sample.AccAddress() - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupEmergency, true) + wctx := sdk.WrapSDKContext(ctx) item := types.Keygen{ BlockNumber: 10, @@ -89,10 +95,12 @@ func TestMsgServer_UpdateKeygen(t *testing.T) { GranteePubkey: granteePubKey, }) - res, err := srv.UpdateKeygen(wctx, &types.MsgUpdateKeygen{ + msg := types.MsgUpdateKeygen{ Creator: admin, Block: ctx.BlockHeight() + 30, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + res, err := srv.UpdateKeygen(wctx, &msg) require.NoError(t, err) require.Equal(t, &types.MsgUpdateKeygenResponse{}, res) diff --git a/x/observer/keeper/msg_server_update_observer.go b/x/observer/keeper/msg_server_update_observer.go index 81e76ac2e2..a1e4c134d1 100644 --- a/x/observer/keeper/msg_server_update_observer.go +++ b/x/observer/keeper/msg_server_update_observer.go @@ -98,8 +98,9 @@ func (k Keeper) CheckUpdateReason(ctx sdk.Context, msg *types.MsgUpdateObserver) case types.ObserverUpdateReason_AdminUpdate: { // Operational policy is required to update an observer for admin update - if !k.GetAuthorityKeeper().IsAuthorized(ctx, msg.Creator, authoritytypes.PolicyType_groupAdmin) { - return false, authoritytypes.ErrUnauthorized + err := k.GetAuthorityKeeper().CheckAuthorization(ctx, msg) + if err != nil { + return false, errorsmod.Wrap(authoritytypes.ErrUnauthorized, err.Error()) } return true, nil } diff --git a/x/observer/keeper/msg_server_update_observer_test.go b/x/observer/keeper/msg_server_update_observer_test.go index afdb4d4907..2a4308590f 100644 --- a/x/observer/keeper/msg_server_update_observer_test.go +++ b/x/observer/keeper/msg_server_update_observer_test.go @@ -12,7 +12,6 @@ import ( keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" - authoritytypes "github.com/zeta-chain/zetacore/x/authority/types" "github.com/zeta-chain/zetacore/x/observer/keeper" "github.com/zeta-chain/zetacore/x/observer/types" ) @@ -325,8 +324,6 @@ func TestMsgServer_UpdateObserver(t *testing.T) { }) srv := keeper.NewMsgServerImpl(*k) admin := sample.AccAddress() - authorityMock := keepertest.GetObserverAuthorityMock(t, k) - keepertest.MockIsAuthorized(&authorityMock.Mock, admin, authoritytypes.PolicyType_groupAdmin, true) // #nosec G404 test purpose - weak randomness is not an issue here r := rand.New(rand.NewSource(9)) @@ -365,13 +362,16 @@ func TestMsgServer_UpdateObserver(t *testing.T) { k.SetLastObserverCount(ctx, &types.LastObserverCount{ Count: count, }) + authorityMock := keepertest.GetObserverAuthorityMock(t, k) - _, err = srv.UpdateObserver(sdk.WrapSDKContext(ctx), &types.MsgUpdateObserver{ + msg := types.MsgUpdateObserver{ Creator: admin, OldObserverAddress: accAddressOfValidator.String(), NewObserverAddress: newOperatorAddress.String(), UpdateReason: types.ObserverUpdateReason_AdminUpdate, - }) + } + keepertest.MockCheckAuthorization(&authorityMock.Mock, &msg, nil) + _, err = srv.UpdateObserver(sdk.WrapSDKContext(ctx), &msg) require.NoError(t, err) acc, found := k.GetNodeAccount(ctx, newOperatorAddress.String()) diff --git a/x/observer/types/expected_keepers.go b/x/observer/types/expected_keepers.go index 63abdd7887..5eb8b03468 100644 --- a/x/observer/types/expected_keepers.go +++ b/x/observer/types/expected_keepers.go @@ -32,7 +32,7 @@ type StakingHooks interface { } type AuthorityKeeper interface { - IsAuthorized(ctx sdk.Context, address string, policyType authoritytypes.PolicyType) bool + CheckAuthorization(ctx sdk.Context, msg sdk.Msg) error // SetPolicies is solely used for the migration of policies from observer to authority SetPolicies(ctx sdk.Context, policies authoritytypes.Policies) From e6287e26bdff17d2aa3895e4492d76fc3a6ca0b0 Mon Sep 17 00:00:00 2001 From: skosito Date: Fri, 21 Jun 2024 03:30:43 +0100 Subject: [PATCH 3/6] refactor: cctx validate inbound (#2340) --- app/app.go | 9 --- changelog.md | 3 +- testutil/keeper/crosschain.go | 8 -- x/crosschain/keeper/cctx_gateway_observers.go | 54 ++++++++----- x/crosschain/keeper/cctx_gateway_zevm.go | 15 ++-- x/crosschain/keeper/cctx_gateways.go | 29 +++++++ .../cctx_orchestrator_validate_inbound.go | 50 +++++++++++++ .../cctx_orchestrator_validate_outbound.go | 2 +- x/crosschain/keeper/cctx_utils.go | 4 +- x/crosschain/keeper/cctx_utils_test.go | 14 ++-- x/crosschain/keeper/evm_deposit.go | 4 +- x/crosschain/keeper/evm_hooks.go | 75 +++++-------------- x/crosschain/keeper/evm_hooks_test.go | 56 ++++++-------- x/crosschain/keeper/initiate_outbound.go | 25 +++++-- x/crosschain/keeper/initiate_outbound_test.go | 54 ++++--------- x/crosschain/keeper/keeper.go | 20 +---- .../keeper/msg_server_migrate_tss_funds.go | 2 +- .../keeper/msg_server_vote_inbound_tx.go | 23 ++---- .../keeper/msg_server_vote_inbound_tx_test.go | 6 +- .../keeper/msg_server_whitelist_erc20.go | 2 +- x/crosschain/types/cctx.go | 2 +- 21 files changed, 229 insertions(+), 228 deletions(-) create mode 100644 x/crosschain/keeper/cctx_gateways.go create mode 100644 x/crosschain/keeper/cctx_orchestrator_validate_inbound.go diff --git a/app/app.go b/app/app.go index 243a0e7665..0688ca75db 100644 --- a/app/app.go +++ b/app/app.go @@ -105,7 +105,6 @@ import ( "github.com/zeta-chain/zetacore/app/ante" "github.com/zeta-chain/zetacore/docs/openapi" - "github.com/zeta-chain/zetacore/pkg/chains" zetamempool "github.com/zeta-chain/zetacore/pkg/mempool" srvflags "github.com/zeta-chain/zetacore/server/flags" authoritymodule "github.com/zeta-chain/zetacore/x/authority" @@ -598,14 +597,6 @@ func New( app.LightclientKeeper, ) - // initializing map of cctx gateways so crosschain module can decide which one to use - // based on chain info of destination chain - cctxGateways := map[chains.CCTXGateway]crosschainkeeper.CCTXGateway{ - chains.CCTXGateway_observers: crosschainkeeper.NewCCTXGatewayObservers(app.CrosschainKeeper), - chains.CCTXGateway_zevm: crosschainkeeper.NewCCTXGatewayZEVM(app.CrosschainKeeper), - } - app.CrosschainKeeper.SetCCTXGateways(cctxGateways) - // initialize ibccrosschain keeper and set it to the crosschain keeper // there is a circular dependency between the two keepers, crosschain keeper must be initialized first diff --git a/changelog.md b/changelog.md index fb156220c4..583acd57c9 100644 --- a/changelog.md +++ b/changelog.md @@ -47,8 +47,9 @@ * [2269](https://github.com/zeta-chain/node/pull/2269) - refactor MsgUpdateCrosschainFlags into MsgEnableCCTX, MsgDisableCCTX and MsgUpdateGasPriceIncreaseFlags * [2306](https://github.com/zeta-chain/node/pull/2306) - refactor zetaclient outbound transaction signing logic * [2296](https://github.com/zeta-chain/node/pull/2296) - move `testdata` package to `testutil` to organize test-related utilities -* [2344](https://github.com/zeta-chain/node/pull/2344) - group common data of EVM/Bitcoin signer and observer using base structs * [2317](https://github.com/zeta-chain/node/pull/2317) - add ValidateOutbound method for cctx orchestrator +* [2340](https://github.com/zeta-chain/node/pull/2340) - add ValidateInbound method for cctx orchestrator +* [2344](https://github.com/zeta-chain/node/pull/2344) - group common data of EVM/Bitcoin signer and observer using base structs ### Tests diff --git a/testutil/keeper/crosschain.go b/testutil/keeper/crosschain.go index 11fda8128f..aec1ac7838 100644 --- a/testutil/keeper/crosschain.go +++ b/testutil/keeper/crosschain.go @@ -174,16 +174,8 @@ func CrosschainKeeperWithMocks( lightclientKeeper, ) - cctxGateways := map[chains.CCTXGateway]keeper.CCTXGateway{ - chains.CCTXGateway_observers: keeper.NewCCTXGatewayObservers(*k), - chains.CCTXGateway_zevm: keeper.NewCCTXGatewayZEVM(*k), - } - - k.SetCCTXGateways(cctxGateways) - // initialize ibccrosschain keeper and set it to the crosschain keeper // there is a circular dependency between the two keepers, crosschain keeper must be initialized first - var ibcCrosschainKeeperTmp types.IBCCrosschainKeeper = initIBCCrosschainKeeper( cdc, db, diff --git a/x/crosschain/keeper/cctx_gateway_observers.go b/x/crosschain/keeper/cctx_gateway_observers.go index e9603b3903..7155e61cd0 100644 --- a/x/crosschain/keeper/cctx_gateway_observers.go +++ b/x/crosschain/keeper/cctx_gateway_observers.go @@ -1,8 +1,11 @@ package keeper import ( + "fmt" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/x/crosschain/types" ) @@ -32,28 +35,45 @@ InitiateOutbound updates the store so observers can use the PendingCCTX query: */ func (c CCTXGatewayObservers) InitiateOutbound( ctx sdk.Context, - cctx *types.CrossChainTx, -) (newCCTXStatus types.CctxStatus) { + config InitiateOutboundConfig, +) (newCCTXStatus types.CctxStatus, err error) { tmpCtx, commit := ctx.CacheContext() - outboundReceiverChainID := cctx.GetCurrentOutboundParam().ReceiverChainId - err := func() error { - err := c.crosschainKeeper.PayGasAndUpdateCctx( - tmpCtx, - outboundReceiverChainID, - cctx, - cctx.InboundParams.Amount, - false, - ) - if err != nil { - return err + outboundReceiverChainID := config.CCTX.GetCurrentOutboundParam().ReceiverChainId + // TODO (https://github.com/zeta-chain/node/issues/1010): workaround for this bug + noEthereumTxEvent := false + if chains.IsZetaChain(config.CCTX.InboundParams.SenderChainId) { + noEthereumTxEvent = true + } + + err = func() error { + // If ShouldPayGas flag is set during ValidateInbound PayGasAndUpdateCctx should be called + // which will set GasPrice and Amount. Otherwise, use median gas price and InboundParams amount. + if config.ShouldPayGas { + err := c.crosschainKeeper.PayGasAndUpdateCctx( + tmpCtx, + outboundReceiverChainID, + config.CCTX, + config.CCTX.InboundParams.Amount, + noEthereumTxEvent, + ) + if err != nil { + return err + } + } else { + gasPrice, found := c.crosschainKeeper.GetMedianGasPriceInUint(ctx, config.CCTX.GetCurrentOutboundParam().ReceiverChainId) + if !found { + return fmt.Errorf("gasprice not found for %d", config.CCTX.GetCurrentOutboundParam().ReceiverChainId) + } + config.CCTX.GetCurrentOutboundParam().GasPrice = gasPrice.String() + config.CCTX.GetCurrentOutboundParam().Amount = config.CCTX.InboundParams.Amount } - return c.crosschainKeeper.UpdateNonce(tmpCtx, outboundReceiverChainID, cctx) + return c.crosschainKeeper.SetObserverOutboundInfo(tmpCtx, outboundReceiverChainID, config.CCTX) }() if err != nil { // do not commit anything here as the CCTX should be aborted - cctx.SetAbort(err.Error()) - return types.CctxStatus_Aborted + config.CCTX.SetAbort(err.Error()) + return types.CctxStatus_Aborted, err } commit() - return types.CctxStatus_PendingOutbound + return types.CctxStatus_PendingOutbound, nil } diff --git a/x/crosschain/keeper/cctx_gateway_zevm.go b/x/crosschain/keeper/cctx_gateway_zevm.go index c6cadf7f8f..3a6f9a8135 100644 --- a/x/crosschain/keeper/cctx_gateway_zevm.go +++ b/x/crosschain/keeper/cctx_gateway_zevm.go @@ -19,20 +19,23 @@ func NewCCTXGatewayZEVM(crosschainKeeper Keeper) CCTXGatewayZEVM { } // InitiateOutbound handles evm deposit and immediately validates pending outbound -func (c CCTXGatewayZEVM) InitiateOutbound(ctx sdk.Context, cctx *types.CrossChainTx) (newCCTXStatus types.CctxStatus) { +func (c CCTXGatewayZEVM) InitiateOutbound( + ctx sdk.Context, + config InitiateOutboundConfig, +) (newCCTXStatus types.CctxStatus, err error) { tmpCtx, commit := ctx.CacheContext() - isContractReverted, err := c.crosschainKeeper.HandleEVMDeposit(tmpCtx, cctx) + isContractReverted, err := c.crosschainKeeper.HandleEVMDeposit(tmpCtx, config.CCTX) if err != nil && !isContractReverted { // exceptional case; internal error; should abort CCTX - cctx.SetAbort(err.Error()) - return types.CctxStatus_Aborted + config.CCTX.SetAbort(err.Error()) + return types.CctxStatus_Aborted, err } - newCCTXStatus = c.crosschainKeeper.ValidateOutboundZEVM(ctx, cctx, err, isContractReverted) + newCCTXStatus = c.crosschainKeeper.ValidateOutboundZEVM(ctx, config.CCTX, err, isContractReverted) if newCCTXStatus == types.CctxStatus_OutboundMined { commit() } - return newCCTXStatus + return newCCTXStatus, nil } diff --git a/x/crosschain/keeper/cctx_gateways.go b/x/crosschain/keeper/cctx_gateways.go new file mode 100644 index 0000000000..9f8e79c0d9 --- /dev/null +++ b/x/crosschain/keeper/cctx_gateways.go @@ -0,0 +1,29 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/x/crosschain/types" +) + +// CCTXGateway is interface implemented by every gateway. It is one of interfaces used for communication +// between CCTX gateways and crosschain module, and it is called by crosschain module. +type CCTXGateway interface { + // Initiate a new outbound, this tells the CCTXGateway to carry out the action to execute the outbound. + // It is the only entry point to initiate an outbound and it returns new CCTX status after it is completed. + InitiateOutbound(ctx sdk.Context, config InitiateOutboundConfig) (newCCTXStatus types.CctxStatus, err error) +} + +var cctxGateways map[chains.CCTXGateway]CCTXGateway + +// ResolveCCTXGateway respolves cctx gateway implementation based on provided cctx gateway +func ResolveCCTXGateway(c chains.CCTXGateway, keeper Keeper) (CCTXGateway, bool) { + cctxGateways = map[chains.CCTXGateway]CCTXGateway{ + chains.CCTXGateway_observers: NewCCTXGatewayObservers(keeper), + chains.CCTXGateway_zevm: NewCCTXGatewayZEVM(keeper), + } + + cctxGateway, ok := cctxGateways[c] + return cctxGateway, ok +} diff --git a/x/crosschain/keeper/cctx_orchestrator_validate_inbound.go b/x/crosschain/keeper/cctx_orchestrator_validate_inbound.go new file mode 100644 index 0000000000..4bb90ec915 --- /dev/null +++ b/x/crosschain/keeper/cctx_orchestrator_validate_inbound.go @@ -0,0 +1,50 @@ +package keeper + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/zeta-chain/zetacore/x/crosschain/types" + observertypes "github.com/zeta-chain/zetacore/x/observer/types" +) + +// ValidateInbound is the only entry-point to create new CCTX (eg. when observers voting is done or new inbound event is detected). +// It creates new CCTX object and calls InitiateOutbound method. +func (k Keeper) ValidateInbound( + ctx sdk.Context, + msg *types.MsgVoteInbound, + shouldPayGas bool, +) (*types.CrossChainTx, error) { + tss, tssFound := k.zetaObserverKeeper.GetTSS(ctx) + if !tssFound { + return nil, types.ErrCannotFindTSSKeys + } + + // Do not process if inbound is disabled + if !k.zetaObserverKeeper.IsInboundEnabled(ctx) { + return nil, observertypes.ErrInboundDisabled + } + + // create a new CCTX from the inbound message. The status of the new CCTX is set to PendingInbound. + cctx, err := types.NewCCTX(ctx, *msg, tss.TssPubkey) + if err != nil { + return nil, err + } + + // Initiate outbound, the process function manages the state commit and cctx status change. + // If the process fails, the changes to the evm state are rolled back. + _, err = k.InitiateOutbound(ctx, InitiateOutboundConfig{ + CCTX: &cctx, + ShouldPayGas: shouldPayGas, + }) + if err != nil { + return nil, err + } + + inCctxIndex, ok := ctx.Value(InCCTXIndexKey).(string) + if ok { + cctx.InboundParams.ObservedHash = inCctxIndex + } + k.SetCctxAndNonceToCctxAndInboundHashToCctx(ctx, cctx) + + return &cctx, nil +} diff --git a/x/crosschain/keeper/cctx_orchestrator_validate_outbound.go b/x/crosschain/keeper/cctx_orchestrator_validate_outbound.go index 16b61a3a0a..77989b88fe 100644 --- a/x/crosschain/keeper/cctx_orchestrator_validate_outbound.go +++ b/x/crosschain/keeper/cctx_orchestrator_validate_outbound.go @@ -189,7 +189,7 @@ func (k Keeper) validateFailedOutbound( if err != nil { return err } - err = k.UpdateNonce(ctx, cctx.InboundParams.SenderChainId, cctx) + err = k.SetObserverOutboundInfo(ctx, cctx.InboundParams.SenderChainId, cctx) if err != nil { return err } diff --git a/x/crosschain/keeper/cctx_utils.go b/x/crosschain/keeper/cctx_utils.go index f85b243c1c..3d3cf31cd5 100644 --- a/x/crosschain/keeper/cctx_utils.go +++ b/x/crosschain/keeper/cctx_utils.go @@ -16,9 +16,9 @@ import ( zetaObserverTypes "github.com/zeta-chain/zetacore/x/observer/types" ) -// UpdateNonce sets the CCTX outbound nonce to the next nonce, and updates the nonce of blockchain state. +// SetObserverOutboundInfo sets the CCTX outbound nonce to the next available nonce for the TSS address, and updates the nonce of blockchain state. // It also updates the PendingNonces that is used to track the unfulfilled outbound txs. -func (k Keeper) UpdateNonce(ctx sdk.Context, receiveChainID int64, cctx *types.CrossChainTx) error { +func (k Keeper) SetObserverOutboundInfo(ctx sdk.Context, receiveChainID int64, cctx *types.CrossChainTx) error { chain := k.GetObserverKeeper().GetSupportedChainFromChainID(ctx, receiveChainID) if chain == nil { return zetaObserverTypes.ErrSupportedChains diff --git a/x/crosschain/keeper/cctx_utils_test.go b/x/crosschain/keeper/cctx_utils_test.go index edeefabfa1..798fcdd8c4 100644 --- a/x/crosschain/keeper/cctx_utils_test.go +++ b/x/crosschain/keeper/cctx_utils_test.go @@ -226,7 +226,7 @@ func Test_IsPending(t *testing.T) { } } -func TestKeeper_UpdateNonce(t *testing.T) { +func TestKeeper_SetObserverOutboundInfo(t *testing.T) { t.Run("should error if supported chain is nil", func(t *testing.T) { k, ctx, _, _ := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{ UseObserverMock: true, @@ -236,7 +236,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { // mock failed GetSupportedChainFromChainID keepertest.MockFailedGetSupportedChainFromChainID(observerMock, nil) - err := k.UpdateNonce(ctx, 5, nil) + err := k.SetObserverOutboundInfo(ctx, 5, nil) require.Error(t, err) }) @@ -262,7 +262,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { {Amount: sdkmath.NewUint(1)}, }, } - err := k.UpdateNonce(ctx, 5, &cctx) + err := k.SetObserverOutboundInfo(ctx, 5, &cctx) require.Error(t, err) }) @@ -291,7 +291,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { {Amount: sdkmath.NewUint(1)}, }, } - err := k.UpdateNonce(ctx, 5, &cctx) + err := k.SetObserverOutboundInfo(ctx, 5, &cctx) require.Error(t, err) require.Equal(t, uint64(100), cctx.GetCurrentOutboundParam().TssNonce) }) @@ -324,7 +324,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { {Amount: sdkmath.NewUint(1)}, }, } - err := k.UpdateNonce(ctx, 5, &cctx) + err := k.SetObserverOutboundInfo(ctx, 5, &cctx) require.Error(t, err) }) @@ -358,7 +358,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { {Amount: sdkmath.NewUint(1)}, }, } - err := k.UpdateNonce(ctx, 5, &cctx) + err := k.SetObserverOutboundInfo(ctx, 5, &cctx) require.Error(t, err) }) @@ -395,7 +395,7 @@ func TestKeeper_UpdateNonce(t *testing.T) { {Amount: sdkmath.NewUint(1)}, }, } - err := k.UpdateNonce(ctx, 5, &cctx) + err := k.SetObserverOutboundInfo(ctx, 5, &cctx) require.NoError(t, err) }) } diff --git a/x/crosschain/keeper/evm_deposit.go b/x/crosschain/keeper/evm_deposit.go index 927e9f5eb1..858c709f18 100644 --- a/x/crosschain/keeper/evm_deposit.go +++ b/x/crosschain/keeper/evm_deposit.go @@ -18,6 +18,8 @@ import ( fungibletypes "github.com/zeta-chain/zetacore/x/fungible/types" ) +const InCCTXIndexKey = "inCctxIndex" + // HandleEVMDeposit handles a deposit from an inbound tx // returns (isContractReverted, err) // (true, non-nil) means CallEVM() reverted @@ -102,7 +104,7 @@ func (k Keeper) HandleEVMDeposit(ctx sdk.Context, cctx *types.CrossChainTx) (boo if !evmTxResponse.Failed() && contractCall { logs := evmtypes.LogsToEthereum(evmTxResponse.Logs) if len(logs) > 0 { - ctx = ctx.WithValue("inCctxIndex", cctx.Index) + ctx = ctx.WithValue(InCCTXIndexKey, cctx.Index) txOrigin := cctx.InboundParams.TxOrigin if txOrigin == "" { txOrigin = inboundSender diff --git a/x/crosschain/keeper/evm_hooks.go b/x/crosschain/keeper/evm_hooks.go index 863bdd991b..dbc2d2dbb1 100644 --- a/x/crosschain/keeper/evm_hooks.go +++ b/x/crosschain/keeper/evm_hooks.go @@ -3,6 +3,7 @@ package keeper import ( "encoding/base64" "encoding/hex" + "errors" "fmt" "math/big" @@ -76,6 +77,7 @@ func (k Keeper) ProcessLogs( if connectorZEVMAddr == (ethcommon.Address{}) { return fmt.Errorf("connectorZEVM address is empty") } + for _, log := range logs { eventZrc20Withdrawal, errZrc20 := ParseZRC20WithdrawalEvent(*log) eventZetaSent, errZetaSent := ParseZetaSentEvent(*log, connectorZEVMAddr) @@ -90,18 +92,6 @@ func (k Keeper) ProcessLogs( continue } - // We have found either eventZrc20Withdrawal or eventZetaSent - // These cannot be processed without TSS keys, return an error if TSS is not found - tss, found := k.zetaObserverKeeper.GetTSS(ctx) - if !found { - return errorsmod.Wrap(types.ErrCannotFindTSSKeys, "Cannot process logs without TSS keys") - } - - // Do not process withdrawal events if inbound is disabled - if !k.zetaObserverKeeper.IsInboundEnabled(ctx) { - return observertypes.ErrInboundDisabled - } - // if eventZrc20Withdrawal is not nil we will try to validate it and see if it can be processed if eventZrc20Withdrawal != nil { // Check if the contract is a registered ZRC20 contract. If its not a registered ZRC20 contract, we can discard this event as it is not relevant @@ -119,13 +109,13 @@ func (k Keeper) ProcessLogs( } // If the event is valid, we will process it and create a new CCTX // If the process fails, we will return an error and roll back the transaction - if err := k.ProcessZRC20WithdrawalEvent(ctx, eventZrc20Withdrawal, emittingContract, txOrigin, tss); err != nil { + if err := k.ProcessZRC20WithdrawalEvent(ctx, eventZrc20Withdrawal, emittingContract, txOrigin); err != nil { return err } } // if eventZetaSent is not nil we will try to validate it and see if it can be processed if eventZetaSent != nil { - if err := k.ProcessZetaSentEvent(ctx, eventZetaSent, emittingContract, txOrigin, tss); err != nil { + if err := k.ProcessZetaSentEvent(ctx, eventZetaSent, emittingContract, txOrigin); err != nil { return err } } @@ -140,9 +130,7 @@ func (k Keeper) ProcessZRC20WithdrawalEvent( event *zrc20.ZRC20Withdrawal, emittingContract ethcommon.Address, txOrigin string, - tss observertypes.TSS, ) error { - ctx.Logger().Info(fmt.Sprintf("ZRC20 withdrawal to %s amount %d", hex.EncodeToString(event.To), event.Value)) foreignCoin, found := k.fungibleKeeper.GetForeignCoins(ctx, event.Raw.Address.Hex()) if !found { @@ -188,22 +176,18 @@ func (k Keeper) ProcessZRC20WithdrawalEvent( event.Raw.Index, ) - // Create a new cctx with status as pending Inbound, this is created directly from the event without waiting for any observer votes - cctx, err := types.NewCCTX(ctx, *msg, tss.TssPubkey) + cctx, err := k.ValidateInbound(ctx, msg, false) if err != nil { - return fmt.Errorf("ProcessZRC20WithdrawalEvent: failed to initialize cctx: %s", err.Error()) + return err } - cctx.SetPendingOutbound("ZRC20 withdrawal event setting to pending outbound directly") - // Get gas price and amount - gasprice, found := k.GetGasPrice(ctx, receiverChain.ChainId) - if !found { - return fmt.Errorf("gasprice not found for %s", receiverChain) + + if cctx.CctxStatus.Status == types.CctxStatus_Aborted { + return errors.New("cctx aborted") } - cctx.GetCurrentOutboundParam().GasPrice = fmt.Sprintf("%d", gasprice.Prices[gasprice.MedianIndex]) - cctx.GetCurrentOutboundParam().Amount = cctx.InboundParams.Amount - EmitZRCWithdrawCreated(ctx, cctx) - return k.ProcessCCTX(ctx, cctx, receiverChain) + EmitZRCWithdrawCreated(ctx, *cctx) + + return nil } func (k Keeper) ProcessZetaSentEvent( @@ -211,7 +195,6 @@ func (k Keeper) ProcessZetaSentEvent( event *connectorzevm.ZetaConnectorZEVMZetaSent, emittingContract ethcommon.Address, txOrigin string, - tss observertypes.TSS, ) error { ctx.Logger().Info(fmt.Sprintf( "Zeta withdrawal to %s amount %d to chain with chainId %d", @@ -267,40 +250,16 @@ func (k Keeper) ProcessZetaSentEvent( event.Raw.Index, ) - // create a new cctx with status as pending Inbound, - // this is created directly from the event without waiting for any observer votes - cctx, err := types.NewCCTX(ctx, *msg, tss.TssPubkey) + cctx, err := k.ValidateInbound(ctx, msg, true) if err != nil { - return fmt.Errorf("ProcessZetaSentEvent: failed to initialize cctx: %s", err.Error()) - } - cctx.SetPendingOutbound("ZetaSent event setting to pending outbound directly") - - if err := k.PayGasAndUpdateCctx( - ctx, - receiverChain.ChainId, - &cctx, - amount, - true, - ); err != nil { - return fmt.Errorf("ProcessWithdrawalEvent: pay gas failed: %s", err.Error()) - } - - EmitZetaWithdrawCreated(ctx, cctx) - return k.ProcessCCTX(ctx, cctx, receiverChain) -} - -func (k Keeper) ProcessCCTX(ctx sdk.Context, cctx types.CrossChainTx, receiverChain *chains.Chain) error { - inCctxIndex, ok := ctx.Value("inCctxIndex").(string) - if ok { - cctx.InboundParams.ObservedHash = inCctxIndex + return err } - if err := k.UpdateNonce(ctx, receiverChain.ChainId, &cctx); err != nil { - return fmt.Errorf("ProcessWithdrawalEvent: update nonce failed: %s", err.Error()) + if cctx.CctxStatus.Status == types.CctxStatus_Aborted { + return errors.New("cctx aborted") } - k.SetCctxAndNonceToCctxAndInboundHashToCctx(ctx, cctx) - ctx.Logger().Debug("ProcessCCTX successful \n") + EmitZetaWithdrawCreated(ctx, *cctx) return nil } diff --git a/x/crosschain/keeper/evm_hooks_test.go b/x/crosschain/keeper/evm_hooks_test.go index 80f8c3b28b..bc67f34a99 100644 --- a/x/crosschain/keeper/evm_hooks_test.go +++ b/x/crosschain/keeper/evm_hooks_test.go @@ -211,9 +211,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.NoError(t, err) cctxList := k.GetAllCrossChainTx(ctx) require.Len(t, cctxList, 1) @@ -237,9 +236,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.NoError(t, err) cctxList := k.GetAllCrossChainTx(ctx) require.Len(t, cctxList, 1) @@ -262,9 +260,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { setupGasCoin(t, ctx, zk.FungibleKeeper, sdkk.EvmKeeper, chainID, "ethereum", "ETH") emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "cannot find foreign coin with emittingContract address") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -283,9 +280,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "chain not supported") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -305,10 +301,10 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() + ctx = ctx.WithChainID("test_21-1") - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "failed to convert chainID: chain 21 not found") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -329,9 +325,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.To = ethcommon.Address{}.Bytes() emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "cannot encode address") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -354,13 +349,13 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() + fc, _ := zk.FungibleKeeper.GetForeignCoins(ctx, zrc20.Hex()) fungibleMock.On("GetForeignCoins", mock.Anything, mock.Anything).Return(fc, true) fungibleMock.On("QueryGasLimit", mock.Anything, mock.Anything). Return(big.NewInt(0), fmt.Errorf("error querying gas limit")) - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "error querying gas limit") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -381,9 +376,8 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "gasprice not found") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) @@ -408,10 +402,9 @@ func TestKeeper_ProcessZRC20WithdrawalEvent(t *testing.T) { event.Raw.Address = zrc20 emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) - require.ErrorContains(t, err, "ProcessWithdrawalEvent: update nonce failed") + err = k.ProcessZRC20WithdrawalEvent(ctx, event, emittingContract, txOrigin.Hex()) + require.ErrorContains(t, err, "nonce mismatch") require.Empty(t, k.GetAllCrossChainTx(ctx)) }) } @@ -489,9 +482,8 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) require.NoError(t, err) cctxList := k.GetAllCrossChainTx(ctx) require.Len(t, cctxList, 1) @@ -526,9 +518,8 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "ProcessZetaSentEvent: failed to burn coins from fungible") }) @@ -557,8 +548,8 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "chain not supported") }) @@ -589,9 +580,9 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() + ctx = ctx.WithChainID("test-21-1") - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) require.ErrorContains(t, err, "ProcessZetaSentEvent: failed to convert chainID") }) @@ -619,10 +610,9 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) - require.ErrorContains(t, err, "ProcessWithdrawalEvent: pay gas failed") + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) + require.ErrorContains(t, err, "gas coin contract invalid address") }) t.Run("unable to process ZetaSentEvent if process cctx fails", func(t *testing.T) { @@ -658,9 +648,9 @@ func TestKeeper_ProcessZetaSentEvent(t *testing.T) { require.NoError(t, err) emittingContract := sample.EthAddress() txOrigin := sample.EthAddress() - tss := sample.Tss() - err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex(), tss) - require.ErrorContains(t, err, "ProcessWithdrawalEvent: update nonce failed") + + err = k.ProcessZetaSentEvent(ctx, event, emittingContract, txOrigin.Hex()) + require.ErrorContains(t, err, "nonce mismatch") }) } diff --git a/x/crosschain/keeper/initiate_outbound.go b/x/crosschain/keeper/initiate_outbound.go index dc3fcf4c24..954db18778 100644 --- a/x/crosschain/keeper/initiate_outbound.go +++ b/x/crosschain/keeper/initiate_outbound.go @@ -10,14 +10,23 @@ import ( "github.com/zeta-chain/zetacore/x/crosschain/types" ) +// TODO (https://github.com/zeta-chain/node/issues/2345): this is just a tmp solution because some flows require gas payment and others don't. +// TBD during implementation of issue above if info can be passed to CCTX constructor somehow. +// and not initialize CCTX using MsgVoteInbound and instead use something like (InboundParams, OutboundParams). +// Also check if msg.Digest can be replaced to calculate index +type InitiateOutboundConfig struct { + CCTX *types.CrossChainTx + ShouldPayGas bool +} + // InitiateOutbound initiates the outbound for the CCTX depending on the CCTX gateway. // It does a conditional dispatch to correct CCTX gateway based on the receiver chain // which handles the state changes and error handling. -func (k Keeper) InitiateOutbound(ctx sdk.Context, cctx *types.CrossChainTx) (types.CctxStatus, error) { - receiverChainID := cctx.GetCurrentOutboundParam().ReceiverChainId +func (k Keeper) InitiateOutbound(ctx sdk.Context, config InitiateOutboundConfig) (types.CctxStatus, error) { + receiverChainID := config.CCTX.GetCurrentOutboundParam().ReceiverChainId chainInfo := chains.GetChainFromChainID(receiverChainID) if chainInfo == nil { - return cctx.CctxStatus.Status, cosmoserrors.Wrap( + return config.CCTX.CctxStatus.Status, cosmoserrors.Wrap( types.ErrInitiatitingOutbound, fmt.Sprintf( "chain info not found for %d", receiverChainID, @@ -25,9 +34,9 @@ func (k Keeper) InitiateOutbound(ctx sdk.Context, cctx *types.CrossChainTx) (typ ) } - cctxGateway, ok := k.cctxGateways[chainInfo.CctxGateway] - if !ok { - return cctx.CctxStatus.Status, cosmoserrors.Wrap( + cctxGateway, found := ResolveCCTXGateway(chainInfo.CctxGateway, k) + if !found { + return config.CCTX.CctxStatus.Status, cosmoserrors.Wrap( types.ErrInitiatitingOutbound, fmt.Sprintf( "CCTXGateway not defined for receiver chain %d", receiverChainID, @@ -35,6 +44,6 @@ func (k Keeper) InitiateOutbound(ctx sdk.Context, cctx *types.CrossChainTx) (typ ) } - cctx.SetPendingOutbound("") - return cctxGateway.InitiateOutbound(ctx, cctx), nil + config.CCTX.SetPendingOutbound("") + return cctxGateway.InitiateOutbound(ctx, config) } diff --git a/x/crosschain/keeper/initiate_outbound_test.go b/x/crosschain/keeper/initiate_outbound_test.go index 25032e4a09..78175a76a8 100644 --- a/x/crosschain/keeper/initiate_outbound_test.go +++ b/x/crosschain/keeper/initiate_outbound_test.go @@ -43,7 +43,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { cctx.GetInboundParams().Amount = sdkmath.NewUintFromBigInt(amount) cctx.InboundParams.CoinType = coin.CoinType_Zeta cctx.GetInboundParams().SenderChainId = 0 - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_OutboundMined, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_OutboundMined, newStatus) @@ -72,8 +72,8 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { cctx.GetInboundParams().Amount = sdkmath.NewUintFromBigInt(amount) cctx.InboundParams.CoinType = coin.CoinType_Zeta cctx.GetInboundParams().SenderChainId = 0 - newStatus, err := k.InitiateOutbound(ctx, cctx) - require.NoError(t, err) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) + require.ErrorContains(t, err, "deposit error") require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) require.Equal(t, "deposit error", cctx.CctxStatus.StatusMessage) @@ -105,7 +105,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, "", amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -145,7 +145,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -188,7 +188,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -233,7 +233,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -280,7 +280,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -320,7 +320,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_PendingRevert, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_PendingRevert, newStatus) @@ -358,7 +358,7 @@ func TestKeeper_InitiateOutboundZEVMDeposit(t *testing.T) { cctx := GetERC20Cctx(t, receiver, *senderChain, asset, amount) cctx.GetCurrentOutboundParam().ReceiverChainId = chains.ZetaChainPrivnet.ChainId cctx.OutboundParams = append(cctx.OutboundParams, cctx.GetCurrentOutboundParam()) - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) @@ -393,7 +393,7 @@ func TestKeeper_InitiateOutboundProcessCrosschainMsgPassing(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *receiverChain, "", amount) - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.NoError(t, err) require.Equal(t, types.CctxStatus_PendingOutbound, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_PendingOutbound, newStatus) @@ -417,8 +417,8 @@ func TestKeeper_InitiateOutboundProcessCrosschainMsgPassing(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *receiverChain, "", amount) - newStatus, err := k.InitiateOutbound(ctx, cctx) - require.NoError(t, err) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) + require.ErrorIs(t, err, observertypes.ErrSupportedChains) require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) require.Equal(t, observertypes.ErrSupportedChains.Error(), cctx.CctxStatus.StatusMessage) @@ -446,8 +446,8 @@ func TestKeeper_InitiateOutboundProcessCrosschainMsgPassing(t *testing.T) { // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *receiverChain, "", amount) - newStatus, err := k.InitiateOutbound(ctx, cctx) - require.NoError(t, err) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) + require.ErrorContains(t, err, "cannot find receiver chain nonce") require.Equal(t, types.CctxStatus_Aborted, cctx.CctxStatus.Status) require.Equal(t, types.CctxStatus_Aborted, newStatus) require.Contains(t, cctx.CctxStatus.StatusMessage, "cannot find receiver chain nonce") @@ -468,31 +468,9 @@ func TestKeeper_InitiateOutboundFailures(t *testing.T) { receiverChain.ChainId = 123 // call InitiateOutbound cctx := GetERC20Cctx(t, receiver, *receiverChain, "", amount) - newStatus, err := k.InitiateOutbound(ctx, cctx) + newStatus, err := k.InitiateOutbound(ctx, keeper.InitiateOutboundConfig{CCTX: cctx, ShouldPayGas: true}) require.Error(t, err) require.Equal(t, types.CctxStatus_PendingInbound, newStatus) require.ErrorContains(t, err, "chain info not found") }) - - t.Run("should fail if cctx gateway not found for receiver chain id", func(t *testing.T) { - k, ctx, _, _ := keepertest.CrosschainKeeperWithMocks(t, keepertest.CrosschainMockOptions{ - UseFungibleMock: true, - UseObserverMock: true, - }) - - // reset cctx gateways - k.SetCCTXGateways(map[chains.CCTXGateway]keeper.CCTXGateway{}) - - // Setup mock data - receiver := sample.EthAddress() - amount := big.NewInt(42) - receiverChain := getValidEthChain() - // call InitiateOutbound - cctx := GetERC20Cctx(t, receiver, *receiverChain, "", amount) - newStatus, err := k.InitiateOutbound(ctx, cctx) - require.Equal(t, types.CctxStatus_PendingInbound, newStatus) - require.NotNil(t, err) - require.ErrorContains(t, err, "CCTXGateway not defined for receiver chain") - }) - } diff --git a/x/crosschain/keeper/keeper.go b/x/crosschain/keeper/keeper.go index b32543126c..fc483689b3 100644 --- a/x/crosschain/keeper/keeper.go +++ b/x/crosschain/keeper/keeper.go @@ -8,24 +8,14 @@ import ( storetypes "github.com/cosmos/cosmos-sdk/store/types" sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/x/crosschain/types" ) -// CCTXGateway is interface implemented by every gateway. It is one of interfaces used for communication -// between CCTX gateways and crosschain module, and it is called by crosschain module. -type CCTXGateway interface { - // Initiate a new outbound, this tells the CCTXGateway to carry out the action to execute the outbound. - // It is the only entry point to initiate an outbound and it returns new CCTX status after it is completed. - InitiateOutbound(ctx sdk.Context, cctx *types.CrossChainTx) (newCCTXStatus types.CctxStatus) -} - type ( Keeper struct { - cdc codec.Codec - storeKey storetypes.StoreKey - memKey storetypes.StoreKey - cctxGateways map[chains.CCTXGateway]CCTXGateway + cdc codec.Codec + storeKey storetypes.StoreKey + memKey storetypes.StoreKey stakingKeeper types.StakingKeeper authKeeper types.AccountKeeper @@ -110,10 +100,6 @@ func (k *Keeper) SetIBCCrosschainKeeper(ibcCrosschainKeeper types.IBCCrosschainK k.ibcCrosschainKeeper = ibcCrosschainKeeper } -func (k *Keeper) SetCCTXGateways(cctxGateways map[chains.CCTXGateway]CCTXGateway) { - k.cctxGateways = cctxGateways -} - func (k Keeper) GetStoreKey() storetypes.StoreKey { return k.storeKey } diff --git a/x/crosschain/keeper/msg_server_migrate_tss_funds.go b/x/crosschain/keeper/msg_server_migrate_tss_funds.go index 1bff3b8443..05f218162f 100644 --- a/x/crosschain/keeper/msg_server_migrate_tss_funds.go +++ b/x/crosschain/keeper/msg_server_migrate_tss_funds.go @@ -200,7 +200,7 @@ func (k Keeper) MigrateTSSFundsForChain( return errorsmod.Wrap(types.ErrReceiverIsEmpty, fmt.Sprintf("chain %d is not supported", chainID)) } - err := k.UpdateNonce(ctx, chainID, &cctx) + err := k.SetObserverOutboundInfo(ctx, chainID, &cctx) if err != nil { return err } diff --git a/x/crosschain/keeper/msg_server_vote_inbound_tx.go b/x/crosschain/keeper/msg_server_vote_inbound_tx.go index b118d2bc43..ecf87ef9fd 100644 --- a/x/crosschain/keeper/msg_server_vote_inbound_tx.go +++ b/x/crosschain/keeper/msg_server_vote_inbound_tx.go @@ -97,27 +97,18 @@ func (k msgServer) VoteInbound( if !finalized { return &types.MsgVoteInboundResponse{}, nil } - tss, tssFound := k.zetaObserverKeeper.GetTSS(ctx) - if !tssFound { - return nil, types.ErrCannotFindTSSKeys - } - // create a new CCTX from the inbound message.The status of the new CCTX is set to PendingInbound. - cctx, err := types.NewCCTX(ctx, *msg, tss.TssPubkey) - if err != nil { - return nil, err - } - // Initiate outbound, the process function manages the state commit and cctx status change. - // If the process fails, the changes to the evm state are rolled back. - _, err = k.InitiateOutbound(ctx, &cctx) + + cctx, err := k.ValidateInbound(ctx, msg, true) if err != nil { return nil, err } + // Save the inbound CCTX to the store. This is called irrespective of the status of the CCTX or the outcome of the process function. - k.SaveInbound(ctx, &cctx, msg.EventIndex) + k.SaveObservedInboundInformation(ctx, cctx, msg.EventIndex) return &types.MsgVoteInboundResponse{}, nil } -/* SaveInbound saves the inbound CCTX to the store.It does the following: +/* SaveObservedInboundInformation saves the inbound CCTX to the store.It does the following: - Emits an event for the finalized inbound CCTX. - Adds the inbound CCTX to the finalized inbound CCTX store.This is done to prevent double spending, using the same inbound tx hash and event index. - Updates the CCTX with the finalized height and finalization status. @@ -125,7 +116,7 @@ func (k msgServer) VoteInbound( - Sets the CCTX and nonce to the CCTX and inbound transaction hash to CCTX store. */ -func (k Keeper) SaveInbound(ctx sdk.Context, cctx *types.CrossChainTx, eventIndex uint64) { +func (k Keeper) SaveObservedInboundInformation(ctx sdk.Context, cctx *types.CrossChainTx, eventIndex uint64) { EmitEventInboundFinalized(ctx, cctx) k.AddFinalizedInbound(ctx, cctx.GetInboundParams().ObservedHash, @@ -135,5 +126,5 @@ func (k Keeper) SaveInbound(ctx sdk.Context, cctx *types.CrossChainTx, eventInde cctx.InboundParams.FinalizedZetaHeight = uint64(ctx.BlockHeight()) cctx.InboundParams.TxFinalizationStatus = types.TxFinalizationStatus_Executed k.RemoveInboundTrackerIfExists(ctx, cctx.InboundParams.SenderChainId, cctx.InboundParams.ObservedHash) - k.SetCctxAndNonceToCctxAndInboundHashToCctx(ctx, *cctx) + k.SetCrossChainTx(ctx, *cctx) } diff --git a/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go b/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go index 332a47f270..64a2153611 100644 --- a/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go +++ b/x/crosschain/keeper/msg_server_vote_inbound_tx_test.go @@ -300,7 +300,7 @@ func TestStatus_ChangeStatus(t *testing.T) { } } -func TestKeeper_SaveInbound(t *testing.T) { +func TestKeeper_SaveObservedInboundInformation(t *testing.T) { t.Run("should save the cctx", func(t *testing.T) { k, ctx, _, zk := keepertest.CrosschainKeeper(t) zk.ObserverKeeper.SetTSS(ctx, sample.Tss()) @@ -309,7 +309,7 @@ func TestKeeper_SaveInbound(t *testing.T) { senderChain := getValidEthChain() cctx := GetERC20Cctx(t, receiver, *senderChain, "", amount) eventIndex := sample.Uint64InRange(1, 100) - k.SaveInbound(ctx, cctx, eventIndex) + k.SaveObservedInboundInformation(ctx, cctx, eventIndex) require.Equal(t, types.TxFinalizationStatus_Executed, cctx.InboundParams.TxFinalizationStatus) require.True( t, @@ -340,7 +340,7 @@ func TestKeeper_SaveInbound(t *testing.T) { eventIndex := sample.Uint64InRange(1, 100) zk.ObserverKeeper.SetTSS(ctx, sample.Tss()) - k.SaveInbound(ctx, cctx, eventIndex) + k.SaveObservedInboundInformation(ctx, cctx, eventIndex) require.Equal(t, types.TxFinalizationStatus_Executed, cctx.InboundParams.TxFinalizationStatus) require.True( t, diff --git a/x/crosschain/keeper/msg_server_whitelist_erc20.go b/x/crosschain/keeper/msg_server_whitelist_erc20.go index c8fbd1248b..04ceb4a1a4 100644 --- a/x/crosschain/keeper/msg_server_whitelist_erc20.go +++ b/x/crosschain/keeper/msg_server_whitelist_erc20.go @@ -162,7 +162,7 @@ func (k msgServer) WhitelistERC20( }, }, } - err = k.UpdateNonce(ctx, msg.ChainId, &cctx) + err = k.SetObserverOutboundInfo(ctx, msg.ChainId, &cctx) if err != nil { return nil, err } diff --git a/x/crosschain/types/cctx.go b/x/crosschain/types/cctx.go index f99fbd6cbf..a4080cc968 100644 --- a/x/crosschain/types/cctx.go +++ b/x/crosschain/types/cctx.go @@ -168,7 +168,7 @@ func GetCctxIndexFromBytes(sendHash [32]byte) string { return fmt.Sprintf("0x%s", hex.EncodeToString(sendHash[:])) } -// NewCCTX creates a new CCTX.From a MsgVoteInbound message and a TSS pubkey. +// NewCCTX creates a new CCTX from a MsgVoteInbound message and a TSS pubkey. // It also validates the created cctx func NewCCTX(ctx sdk.Context, msg MsgVoteInbound, tssPubkey string) (CrossChainTx, error) { index := msg.Digest() From f95f0e9e752f3ee6be7304562b9c04f2569670a0 Mon Sep 17 00:00:00 2001 From: Tanmay Date: Fri, 21 Jun 2024 04:11:15 -0400 Subject: [PATCH 4/6] test: add stateful end to end test (#2360) --- Makefile | 5 + changelog.md | 1 + cmd/zetacored/parse_genesis.go | 46 ++++++-- cmd/zetacored/parse_genesis_test.go | 103 +++++++++++++----- .../localnet/docker-compose-import-data.yml | 30 +++++ .../localnet/orchestrator/start-zetae2e.sh | 6 + contrib/localnet/scripts/import-data.sh | 15 +++ contrib/localnet/scripts/start-rosetta.sh | 14 +++ contrib/localnet/scripts/start-zetaclientd.sh | 9 ++ contrib/localnet/scripts/start-zetacored.sh | 8 ++ .../zetacored/zetacored_parse-genesis-file.md | 3 +- 11 files changed, 197 insertions(+), 43 deletions(-) create mode 100644 contrib/localnet/docker-compose-import-data.yml create mode 100644 contrib/localnet/scripts/import-data.sh create mode 100644 contrib/localnet/scripts/start-rosetta.sh diff --git a/Makefile b/Makefile index a97bde5c3a..0520f1f319 100644 --- a/Makefile +++ b/Makefile @@ -255,6 +255,11 @@ start-localnet-skip-build: cd contrib/localnet/ && $(DOCKER) compose -f docker-compose.yml -f docker-compose-setup-only.yml up -d stop-localnet: +start-e2e-import-mainnet-test: zetanode + @echo "--> Starting e2e import-data test" + cd contrib/localnet/ && ./scripts/import-data.sh mainnet && $(DOCKER) compose -f docker-compose.yml -f docker-compose-import-data.yml up -d + +stop-test: cd contrib/localnet/ && $(DOCKER) compose down --remove-orphans ############################################################################### diff --git a/changelog.md b/changelog.md index 583acd57c9..176355afc1 100644 --- a/changelog.md +++ b/changelog.md @@ -63,6 +63,7 @@ * [2329](https://github.com/zeta-chain/node/pull/2329) - fix TODOs in rpc unit tests * [2342](https://github.com/zeta-chain/node/pull/2342) - extend rpc unit tests with testing extension to include synthetic ethereum txs * [2299](https://github.com/zeta-chain/node/pull/2299) - add `zetae2e` command to deploy test contracts +* [2360](https://github.com/zeta-chain/node/pull/2360) - add stateful e2e tests. * [2349](https://github.com/zeta-chain/node/pull/2349) - add TestBitcoinDepositRefund and WithdrawBitcoinMultipleTimes E2E tests ### Fixes diff --git a/cmd/zetacored/parse_genesis.go b/cmd/zetacored/parse_genesis.go index b2a4258995..ec0a9bd313 100644 --- a/cmd/zetacored/parse_genesis.go +++ b/cmd/zetacored/parse_genesis.go @@ -41,7 +41,6 @@ const MaxItemsForList = 10 // Copy represents a set of modules for which, the entire state is copied without any modifications var Copy = map[string]bool{ slashingtypes.ModuleName: true, - govtypes.ModuleName: true, crisistypes.ModuleName: true, feemarkettypes.ModuleName: true, paramstypes.ModuleName: true, @@ -50,24 +49,37 @@ var Copy = map[string]bool{ vestingtypes.ModuleName: true, fungibletypes.ModuleName: true, emissionstypes.ModuleName: true, - authz.ModuleName: true, } // Skip represents a set of modules for which, the entire state is skipped and nothing gets imported var Skip = map[string]bool{ - evmtypes.ModuleName: true, - stakingtypes.ModuleName: true, - genutiltypes.ModuleName: true, - authtypes.ModuleName: true, - banktypes.ModuleName: true, + // Skipping evm this is done to reduce the size of the genesis file evm module uses the majority of the space due to smart contract data + evmtypes.ModuleName: true, + // Skipping staking as new validators would be created for the new chain + stakingtypes.ModuleName: true, + // Skipping genutil as new gentxs would be created + genutiltypes.ModuleName: true, + // Skipping auth as new accounts would be created for the new chain. This also needs to be done as we are skipping evm module + authtypes.ModuleName: true, + // Skipping bank module as it is not used when starting a new chain this is done to make sure the total supply invariant is maintained. + // This would need modification but might be possible to add in non evm based modules in the future + banktypes.ModuleName: true, + // Skipping distribution module as it is not used when starting a new chain , rewards are based on validators and delegators , and so rewards from a different chain do not hold any value distributiontypes.ModuleName: true, - group.ModuleName: true, + // Skipping group module as it is not used when starting a new chain, new groups should be created based on the validator operator keys + group.ModuleName: true, + // Skipping authz as it is not used when starting a new chain, new grants should be created based on the validator hotkeys abd operator keys + authz.ModuleName: true, + // Skipping fungible module as new fungible tokens would be created and system contract would be deployed + fungibletypes.ModuleName: true, + // Skipping gov types as new parameters are set for the new chain + govtypes.ModuleName: true, } // Modify represents a set of modules for which, the state is modified before importing. Each Module should have a corresponding Modify function var Modify = map[string]bool{ - crosschaintypes.ModuleName: true, observertypes.ModuleName: true, + crosschaintypes.ModuleName: true, } func CmdParseGenesisFile() *cobra.Command { @@ -78,6 +90,10 @@ func CmdParseGenesisFile() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { clientCtx := client.GetClientContextFromCmd(cmd) cdc := clientCtx.Codec + modifyEnabled, err := cmd.Flags().GetBool("modify") + if err != nil { + return err + } genesisFilePath := filepath.Join(app.DefaultNodeHome, "config", "genesis.json") if len(args) == 2 { genesisFilePath = args[1] @@ -90,7 +106,7 @@ func CmdParseGenesisFile() *cobra.Command { if err != nil { return err } - err = ImportDataIntoFile(genDoc, importData, cdc) + err = ImportDataIntoFile(genDoc, importData, cdc, modifyEnabled) if err != nil { return err } @@ -103,10 +119,16 @@ func CmdParseGenesisFile() *cobra.Command { return nil }, } + cmd.PersistentFlags().Bool("modify", false, "Modify the genesis file before importing") return cmd } -func ImportDataIntoFile(genDoc *types.GenesisDoc, importFile *types.GenesisDoc, cdc codec.Codec) error { +func ImportDataIntoFile( + genDoc *types.GenesisDoc, + importFile *types.GenesisDoc, + cdc codec.Codec, + modifyEnabled bool, +) error { appState, err := genutiltypes.GenesisStateFromGenDoc(*genDoc) if err != nil { @@ -124,7 +146,7 @@ func ImportDataIntoFile(genDoc *types.GenesisDoc, importFile *types.GenesisDoc, if Copy[m] { appState[m] = importAppState[m] } - if Modify[m] { + if Modify[m] && modifyEnabled { switch m { case crosschaintypes.ModuleName: err := ModifyCrosschainState(appState, importAppState, cdc) diff --git a/cmd/zetacored/parse_genesis_test.go b/cmd/zetacored/parse_genesis_test.go index 7eb75f4a0f..689075eb32 100644 --- a/cmd/zetacored/parse_genesis_test.go +++ b/cmd/zetacored/parse_genesis_test.go @@ -92,43 +92,86 @@ func Test_ModifyObserverState(t *testing.T) { } func Test_ImportDataIntoFile(t *testing.T) { - setConfig(t) - cdc := keepertest.NewCodec() - genDoc := sample.GenDoc(t) - importGenDoc := ImportGenDoc(t, cdc, 100) + t.Run("successfully import data into file and modify data", func(t *testing.T) { + setConfig(t) + cdc := keepertest.NewCodec() + genDoc := sample.GenDoc(t) + importGenDoc := ImportGenDoc(t, cdc, 100) - err := zetacored.ImportDataIntoFile(genDoc, importGenDoc, cdc) - require.NoError(t, err) + err := zetacored.ImportDataIntoFile(genDoc, importGenDoc, cdc, true) + require.NoError(t, err) - appState, err := genutiltypes.GenesisStateFromGenDoc(*genDoc) - require.NoError(t, err) + appState, err := genutiltypes.GenesisStateFromGenDoc(*genDoc) + require.NoError(t, err) - // Crosschain module is in Modify list - crosschainStateAfterImport := crosschaintypes.GetGenesisStateFromAppState(cdc, appState) - require.Len(t, crosschainStateAfterImport.CrossChainTxs, zetacored.MaxItemsForList) - require.Len(t, crosschainStateAfterImport.InboundHashToCctxList, zetacored.MaxItemsForList) - require.Len(t, crosschainStateAfterImport.FinalizedInbounds, zetacored.MaxItemsForList) + // Crosschain module is in Modify list + crosschainStateAfterImport := crosschaintypes.GetGenesisStateFromAppState(cdc, appState) + require.Len(t, crosschainStateAfterImport.CrossChainTxs, zetacored.MaxItemsForList) + require.Len(t, crosschainStateAfterImport.InboundHashToCctxList, zetacored.MaxItemsForList) + require.Len(t, crosschainStateAfterImport.FinalizedInbounds, zetacored.MaxItemsForList) - // Bank module is in Skip list - var bankStateAfterImport banktypes.GenesisState - if appState[banktypes.ModuleName] != nil { - err := cdc.UnmarshalJSON(appState[banktypes.ModuleName], &bankStateAfterImport) - if err != nil { - panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + // Bank module is in Skip list + var bankStateAfterImport banktypes.GenesisState + if appState[banktypes.ModuleName] != nil { + err := cdc.UnmarshalJSON(appState[banktypes.ModuleName], &bankStateAfterImport) + if err != nil { + panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + } } - } - // 4 balances were present in the original genesis state - require.Len(t, bankStateAfterImport.Balances, 4) + // 4 balances were present in the original genesis state + require.Len(t, bankStateAfterImport.Balances, 4) - // Emissions module is in Copy list - var emissionStateAfterImport emissionstypes.GenesisState - if appState[emissionstypes.ModuleName] != nil { - err := cdc.UnmarshalJSON(appState[emissionstypes.ModuleName], &emissionStateAfterImport) - if err != nil { - panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + // Emissions module is in Copy list + var emissionStateAfterImport emissionstypes.GenesisState + if appState[emissionstypes.ModuleName] != nil { + err := cdc.UnmarshalJSON(appState[emissionstypes.ModuleName], &emissionStateAfterImport) + if err != nil { + panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + } } - } - require.Len(t, emissionStateAfterImport.WithdrawableEmissions, 100) + require.Len(t, emissionStateAfterImport.WithdrawableEmissions, 100) + }) + + t.Run("successfully import data into file without modifying data", func(t *testing.T) { + setConfig(t) + cdc := keepertest.NewCodec() + genDoc := sample.GenDoc(t) + importGenDoc := ImportGenDoc(t, cdc, 8) + + err := zetacored.ImportDataIntoFile(genDoc, importGenDoc, cdc, false) + require.NoError(t, err) + + appState, err := genutiltypes.GenesisStateFromGenDoc(*genDoc) + require.NoError(t, err) + + // Crosschain module is in Modify list + crosschainStateAfterImport := crosschaintypes.GetGenesisStateFromAppState(cdc, appState) + require.Len(t, crosschainStateAfterImport.CrossChainTxs, 0) + require.Len(t, crosschainStateAfterImport.InboundHashToCctxList, 0) + require.Len(t, crosschainStateAfterImport.FinalizedInbounds, 0) + + // Bank module is in Skip list + var bankStateAfterImport banktypes.GenesisState + if appState[banktypes.ModuleName] != nil { + err := cdc.UnmarshalJSON(appState[banktypes.ModuleName], &bankStateAfterImport) + if err != nil { + panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + } + } + // 4 balances were present in the original genesis state + require.Len(t, bankStateAfterImport.Balances, 4) + + // Emissions module is in Copy list + var emissionStateAfterImport emissionstypes.GenesisState + if appState[emissionstypes.ModuleName] != nil { + err := cdc.UnmarshalJSON(appState[emissionstypes.ModuleName], &emissionStateAfterImport) + if err != nil { + panic(fmt.Sprintf("Failed to get genesis state from app state: %s", err.Error())) + } + } + require.Len(t, emissionStateAfterImport.WithdrawableEmissions, 8) + + }) } func Test_GetGenDoc(t *testing.T) { diff --git a/contrib/localnet/docker-compose-import-data.yml b/contrib/localnet/docker-compose-import-data.yml new file mode 100644 index 0000000000..19f4068ad5 --- /dev/null +++ b/contrib/localnet/docker-compose-import-data.yml @@ -0,0 +1,30 @@ +version: "3" + +# This docker-compose file configures the localnet environment +# it contains the following services: +# - ZetaChain with 2 nodes (zetacore0, zetacore1) +# - A observer set with 2 clients (zetaclient0, zetaclient1) +# - An Ethereum node (eth) +# - A Bitcoin node (bitcoin) +# - A Rosetta API (rosetta) +# - An orchestrator to manage interaction with the localnet (orchestrator) +services: + rosetta: + entrypoint: ["/root/start-rosetta.sh"] + + zetacore0: + entrypoint: ["/root/start-zetacored.sh", "2","import-data"] + volumes: + - ~/genesis_export/:/root/genesis_data + + zetacore1: + entrypoint: ["/root/start-zetacored.sh", "2","import-data"] + + zetaclient0: + entrypoint: ["/root/start-zetaclientd.sh"] + + zetaclient1: + entrypoint: ["/root/start-zetaclientd.sh"] + + orchestrator: + entrypoint: ["/work/start-zetae2e.sh", "local"] \ No newline at end of file diff --git a/contrib/localnet/orchestrator/start-zetae2e.sh b/contrib/localnet/orchestrator/start-zetae2e.sh index a297be595b..8dcb4ea271 100644 --- a/contrib/localnet/orchestrator/start-zetae2e.sh +++ b/contrib/localnet/orchestrator/start-zetae2e.sh @@ -29,6 +29,12 @@ while [ ! -f ~/.ssh/authorized_keys ]; do sleep 1 done +# need to wait for zetacore0 to be up +while ! curl -s -o /dev/null zetacore0:26657/status ; do + echo "Waiting for zetacore0 rpc" + sleep 10 +done + echo "waiting for geth RPC to start..." sleep 2 diff --git a/contrib/localnet/scripts/import-data.sh b/contrib/localnet/scripts/import-data.sh new file mode 100644 index 0000000000..d71d5c3656 --- /dev/null +++ b/contrib/localnet/scripts/import-data.sh @@ -0,0 +1,15 @@ +#!/bin/bash +if [ $# -lt 1 ] +then + echo "Usage: import-data.sh [network]" + exit 1 +fi + +NETWORK=$1 +echo "NETWORK: ${NETWORK}" +rm -rf ~/genesis_export/ +mkdir ~/genesis_export/ +echo "Download Latest State Export" +LATEST_EXPORT_URL=$(curl https://snapshots.zetachain.com/latest-state-export | jq -r ."${NETWORK}") +echo "LATEST EXPORT URL: ${LATEST_EXPORT_URL}" +wget -q ${LATEST_EXPORT_URL} -O ~/genesis_export/exported-genesis.json \ No newline at end of file diff --git a/contrib/localnet/scripts/start-rosetta.sh b/contrib/localnet/scripts/start-rosetta.sh new file mode 100644 index 0000000000..e675da6a6a --- /dev/null +++ b/contrib/localnet/scripts/start-rosetta.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# This script is used to start the Rosetta API server for the Zetacore network. + +echo "Waiting for network to start producing blocks" +CURRENT_HEIGHT=0 +WAIT_HEIGHT=1 +while [[ $CURRENT_HEIGHT -lt $WAIT_HEIGHT ]] +do + CURRENT_HEIGHT=$(curl -s zetacore0:26657/status | jq '.result.sync_info.latest_block_height' | tr -d '"') + sleep 5 +done + +zetacored rosetta --tendermint zetacore0:26657 --grpc zetacore0:9090 --network athens_101-1 --blockchain zetacore \ No newline at end of file diff --git a/contrib/localnet/scripts/start-zetaclientd.sh b/contrib/localnet/scripts/start-zetaclientd.sh index 01667df450..1b018de4ef 100755 --- a/contrib/localnet/scripts/start-zetaclientd.sh +++ b/contrib/localnet/scripts/start-zetaclientd.sh @@ -28,6 +28,15 @@ while [ ! -f ~/.ssh/authorized_keys ]; do sleep 1 done + + +# need to wait for zetacore0 to be up +while ! curl -s -o /dev/null zetacore0:26657/status ; do + echo "Waiting for zetacore0 rpc" + sleep 10 +done + + # read HOTKEY_BACKEND env var for hotkey keyring backend and set default to test BACKEND="test" if [ "$HOTKEY_BACKEND" == "file" ]; then diff --git a/contrib/localnet/scripts/start-zetacored.sh b/contrib/localnet/scripts/start-zetacored.sh index 8e16ffbc4a..d390b99083 100755 --- a/contrib/localnet/scripts/start-zetacored.sh +++ b/contrib/localnet/scripts/start-zetacored.sh @@ -4,6 +4,8 @@ # It initializes the nodes and creates the genesis.json file # It also starts the nodes # The number of nodes is passed as an first argument to the script +# The second argument is optional and can have the following value: +# - import-data: import data into the genesis file /usr/sbin/sshd @@ -71,6 +73,7 @@ then exit 1 fi NUMOFNODES=$1 +OPTION=$2 # create keys CHAINID="athens_101-1" @@ -254,6 +257,11 @@ then scp $NODE:~/.zetacored/config/gentx/* ~/.zetacored/config/gentx/z2gentx/ done + if [[ "$OPTION" == "import-data" || "$OPTION" == "import-data-upgrade" ]]; then + echo "Importing data" + zetacored parse-genesis-file /root/genesis_data/exported-genesis.json + fi + # 4. Collect all the gentx files in zetacore0 and create the final genesis.json zetacored collect-gentxs zetacored validate-genesis diff --git a/docs/cli/zetacored/zetacored_parse-genesis-file.md b/docs/cli/zetacored/zetacored_parse-genesis-file.md index 44bcb41424..a7b66ebb73 100644 --- a/docs/cli/zetacored/zetacored_parse-genesis-file.md +++ b/docs/cli/zetacored/zetacored_parse-genesis-file.md @@ -9,7 +9,8 @@ zetacored parse-genesis-file [import-genesis-file] [optional-genesis-file] [flag ### Options ``` - -h, --help help for parse-genesis-file + -h, --help help for parse-genesis-file + --modify Modify the genesis file before importing ``` ### Options inherited from parent commands From 2946eddb5d9f6db59f2668f55ece7fd8cd56482c Mon Sep 17 00:00:00 2001 From: Charlie Chen <34498985+ws4charlie@users.noreply.github.com> Date: Fri, 21 Jun 2024 10:03:01 -0500 Subject: [PATCH 5/6] fix: set 1000 sats as minimum amount that can be withdrawn (#2362) * set 1000 sats as minimum amount that can be withdrawn * added changelog entry * allow 1000 satoshis to be withdrawn * added minimum withdraw amount to error message * define an error for invalid withdrawal amount * fix unit test * fix code format --- changelog.md | 1 + x/crosschain/keeper/evm_hooks.go | 14 ++++++++++---- x/crosschain/keeper/evm_hooks_test.go | 16 ++++++++++++---- x/crosschain/types/errors.go | 1 + 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/changelog.md b/changelog.md index 176355afc1..42de5f2e74 100644 --- a/changelog.md +++ b/changelog.md @@ -75,6 +75,7 @@ * [2243](https://github.com/zeta-chain/node/pull/2243) - fix incorrect bitcoin outbound height in the CCTX outbound parameter * [2256](https://github.com/zeta-chain/node/pull/2256) - fix rate limiter falsely included reverted non-withdraw cctxs * [2327](https://github.com/zeta-chain/node/pull/2327) - partially cherry picked the fix to Bitcoin outbound dust amount +* [2362](https://github.com/zeta-chain/node/pull/2362) - set 1000 satoshis as minimum BTC amount that can be withdrawn from zEVM ### CI diff --git a/x/crosschain/keeper/evm_hooks.go b/x/crosschain/keeper/evm_hooks.go index dbc2d2dbb1..2fc6310800 100644 --- a/x/crosschain/keeper/evm_hooks.go +++ b/x/crosschain/keeper/evm_hooks.go @@ -21,6 +21,7 @@ import ( "github.com/zeta-chain/zetacore/cmd/zetacored/config" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" + "github.com/zeta-chain/zetacore/pkg/constant" "github.com/zeta-chain/zetacore/x/crosschain/types" fungibletypes "github.com/zeta-chain/zetacore/x/fungible/types" observertypes "github.com/zeta-chain/zetacore/x/observer/types" @@ -286,15 +287,20 @@ func ValidateZrc20WithdrawEvent(event *zrc20.ZRC20Withdrawal, chainID int64) err // The event was parsed; that means the user has deposited tokens to the contract. if chains.IsBitcoinChain(chainID) { - if event.Value.Cmp(big.NewInt(0)) <= 0 { - return fmt.Errorf("ParseZRC20WithdrawalEvent: invalid amount %s", event.Value.String()) + if event.Value.Cmp(big.NewInt(constant.BTCWithdrawalDustAmount)) < 0 { + return errorsmod.Wrapf( + types.ErrInvalidWithdrawalAmount, + "withdraw amount %s is less than minimum amount %d", + event.Value.String(), + constant.BTCWithdrawalDustAmount, + ) } addr, err := chains.DecodeBtcAddress(string(event.To), chainID) if err != nil { - return fmt.Errorf("ParseZRC20WithdrawalEvent: invalid address %s: %s", event.To, err) + return errorsmod.Wrapf(types.ErrInvalidAddress, "invalid address %s", string(event.To)) } if !chains.IsBtcAddressSupported(addr) { - return fmt.Errorf("ParseZRC20WithdrawalEvent: unsupported address %s", string(event.To)) + return errorsmod.Wrapf(types.ErrInvalidAddress, "unsupported address %s", string(event.To)) } } return nil diff --git a/x/crosschain/keeper/evm_hooks_test.go b/x/crosschain/keeper/evm_hooks_test.go index bc67f34a99..20b1c87e47 100644 --- a/x/crosschain/keeper/evm_hooks_test.go +++ b/x/crosschain/keeper/evm_hooks_test.go @@ -15,6 +15,7 @@ import ( "github.com/zeta-chain/zetacore/cmd/zetacored/config" "github.com/zeta-chain/zetacore/pkg/chains" + "github.com/zeta-chain/zetacore/pkg/constant" keepertest "github.com/zeta-chain/zetacore/testutil/keeper" "github.com/zeta-chain/zetacore/testutil/sample" crosschainkeeper "github.com/zeta-chain/zetacore/x/crosschain/keeper" @@ -164,14 +165,21 @@ func TestValidateZrc20WithdrawEvent(t *testing.T) { require.NoError(t, err) }) - t.Run("unable to validate a event with an invalid amount", func(t *testing.T) { + t.Run("unable to validate a btc withdrawal event with an invalid amount", func(t *testing.T) { btcMainNetWithdrawalEvent, err := crosschainkeeper.ParseZRC20WithdrawalEvent( *sample.GetValidZRC20WithdrawToBTC(t).Logs[3], ) require.NoError(t, err) - btcMainNetWithdrawalEvent.Value = big.NewInt(0) + + // 1000 satoshis is the minimum amount that can be withdrawn + btcMainNetWithdrawalEvent.Value = big.NewInt(constant.BTCWithdrawalDustAmount) err = crosschainkeeper.ValidateZrc20WithdrawEvent(btcMainNetWithdrawalEvent, chains.BitcoinMainnet.ChainId) - require.ErrorContains(t, err, "ParseZRC20WithdrawalEvent: invalid amount") + require.NoError(t, err) + + // 999 satoshis cannot be withdrawn + btcMainNetWithdrawalEvent.Value = big.NewInt(constant.BTCWithdrawalDustAmount - 1) + err = crosschainkeeper.ValidateZrc20WithdrawEvent(btcMainNetWithdrawalEvent, chains.BitcoinMainnet.ChainId) + require.ErrorContains(t, err, "less than minimum amount") }) t.Run("unable to validate a event with an invalid chain ID", func(t *testing.T) { @@ -822,7 +830,7 @@ func TestKeeper_ProcessLogs(t *testing.T) { } err := k.ProcessLogs(ctx, block.Logs, sample.EthAddress(), "") - require.ErrorContains(t, err, "ParseZRC20WithdrawalEvent: invalid address") + require.ErrorContains(t, err, "invalid address") cctxList := k.GetAllCrossChainTx(ctx) require.Len(t, cctxList, 0) }) diff --git a/x/crosschain/types/errors.go b/x/crosschain/types/errors.go index 232bf229db..6f430e5e1a 100644 --- a/x/crosschain/types/errors.go +++ b/x/crosschain/types/errors.go @@ -49,4 +49,5 @@ var ( ErrInvalidRateLimiterFlags = errorsmod.Register(ModuleName, 1152, "invalid rate limiter flags") ErrMaxTxOutTrackerHashesReached = errorsmod.Register(ModuleName, 1153, "max tx out tracker hashes reached") ErrInitiatitingOutbound = errorsmod.Register(ModuleName, 1154, "cannot initiate outbound") + ErrInvalidWithdrawalAmount = errorsmod.Register(ModuleName, 1155, "invalid withdrawal amount") ) From f9ca2be3111b5351925e694c5ec828cb68944394 Mon Sep 17 00:00:00 2001 From: Charlie Chen <34498985+ws4charlie@users.noreply.github.com> Date: Fri, 21 Jun 2024 10:20:00 -0500 Subject: [PATCH 6/6] refactor: integrated base signer structure into existing EVM/BTC signers (#2357) * save local new files to remote * initiated base observer * move base to chains folder * moved logger to base package * added base signer and logger * added changelog entry * integrated base signer into evm/bitcoin; integrated base observer into evm * integrated base observer to evm and bitcoin chain * added changelog entry * cherry pick base Signer structure integration * updated PR number in changelog * updated PR number in changelog * Update changelog.md Co-authored-by: Lucas Bertrand * move Mutex to base signer; improve log print and unit test * fix code format --------- Co-authored-by: Lucas Bertrand --- changelog.md | 1 + cmd/zetaclientd/debug.go | 5 - cmd/zetaclientd/utils.go | 17 ++- testutil/sample/os.go | 15 ++ zetaclient/chains/base/logger_test.go | 9 +- zetaclient/chains/base/observer_test.go | 25 ++-- zetaclient/chains/base/signer.go | 12 ++ zetaclient/chains/base/signer_test.go | 4 + zetaclient/chains/bitcoin/signer/signer.go | 78 +++++----- .../bitcoin/signer/signer_keysign_test.go | 3 +- .../chains/bitcoin/signer/signer_test.go | 28 ++-- zetaclient/chains/evm/signer/signer.go | 141 +++++++++--------- zetaclient/chains/evm/signer/signer_test.go | 13 +- zetaclient/chains/interfaces/interfaces.go | 5 +- zetaclient/testutils/mocks/btc_rpc.go | 23 ++- zetaclient/testutils/mocks/evm_rpc.go | 64 +++++++- zetaclient/testutils/mocks/tss_signer.go | 13 +- zetaclient/tss/tss_signer.go | 12 +- 18 files changed, 291 insertions(+), 177 deletions(-) create mode 100644 testutil/sample/os.go diff --git a/changelog.md b/changelog.md index 42de5f2e74..3e1299e55d 100644 --- a/changelog.md +++ b/changelog.md @@ -50,6 +50,7 @@ * [2317](https://github.com/zeta-chain/node/pull/2317) - add ValidateOutbound method for cctx orchestrator * [2340](https://github.com/zeta-chain/node/pull/2340) - add ValidateInbound method for cctx orchestrator * [2344](https://github.com/zeta-chain/node/pull/2344) - group common data of EVM/Bitcoin signer and observer using base structs +* [2357](https://github.com/zeta-chain/node/pull/2357) - integrate base Signer structure into EVM/Bitcoin Signer ### Tests diff --git a/cmd/zetaclientd/debug.go b/cmd/zetaclientd/debug.go index 881997d0ec..4e10b24c25 100644 --- a/cmd/zetaclientd/debug.go +++ b/cmd/zetaclientd/debug.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "io" "strconv" "strings" "sync" @@ -13,7 +12,6 @@ import ( ethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/ethclient" "github.com/onrik/ethrpc" - "github.com/rs/zerolog" "github.com/spf13/cobra" "github.com/zeta-chain/zetacore/pkg/chains" @@ -61,7 +59,6 @@ func DebugCmd() *cobra.Command { } inboundHash := args[0] var ballotIdentifier string - chainLogger := zerolog.New(io.Discard).Level(zerolog.Disabled) // create a new zetacore client client, err := zetacore.NewClient( @@ -93,7 +90,6 @@ func DebugCmd() *cobra.Command { Mu: &sync.Mutex{}, } evmObserver.WithZetacoreClient(client) - evmObserver.WithLogger(chainLogger) var ethRPC *ethrpc.EthRPC var client *ethclient.Client coinType := coin.CoinType_Cmd @@ -172,7 +168,6 @@ func DebugCmd() *cobra.Command { Mu: &sync.Mutex{}, } btcObserver.WithZetacoreClient(client) - btcObserver.WithLogger(chainLogger) btcObserver.WithChain(*chains.GetChainFromChainID(chainID)) connCfg := &rpcclient.ConnConfig{ Host: cfg.BitcoinConfig.RPCHost, diff --git a/cmd/zetaclientd/utils.go b/cmd/zetaclientd/utils.go index 521c6dc858..64db1c3efa 100644 --- a/cmd/zetaclientd/utils.go +++ b/cmd/zetaclientd/utils.go @@ -54,13 +54,14 @@ func CreateZetacoreClient( return client, nil } +// CreateSignerMap creates a map of ChainSigners for all chains in the config func CreateSignerMap( appContext *context.AppContext, tss interfaces.TSSSigner, logger base.Logger, ts *metrics.TelemetryServer, ) (map[int64]interfaces.ChainSigner, error) { - coreContext := appContext.ZetacoreContext() + zetacoreContext := appContext.ZetacoreContext() signerMap := make(map[int64]interfaces.ChainSigner) // EVM signers @@ -68,7 +69,7 @@ func CreateSignerMap( if evmConfig.Chain.IsZetaChain() { continue } - evmChainParams, found := coreContext.GetEVMChainParams(evmConfig.Chain.ChainId) + evmChainParams, found := zetacoreContext.GetEVMChainParams(evmConfig.Chain.ChainId) if !found { logger.Std.Error().Msgf("ChainParam not found for chain %s", evmConfig.Chain.String()) continue @@ -77,15 +78,15 @@ func CreateSignerMap( erc20CustodyAddress := ethcommon.HexToAddress(evmChainParams.Erc20CustodyContractAddress) signer, err := evmsigner.NewSigner( evmConfig.Chain, - evmConfig.Endpoint, + zetacoreContext, tss, + ts, + logger, + evmConfig.Endpoint, config.GetConnectorABI(), config.GetERC20CustodyABI(), mpiAddress, - erc20CustodyAddress, - coreContext, - logger, - ts) + erc20CustodyAddress) if err != nil { logger.Std.Error().Err(err).Msgf("NewEVMSigner error for chain %s", evmConfig.Chain.String()) continue @@ -95,7 +96,7 @@ func CreateSignerMap( // BTC signer btcChain, btcConfig, enabled := appContext.GetBTCChainAndConfig() if enabled { - signer, err := btcsigner.NewSigner(btcConfig, tss, logger, ts, coreContext) + signer, err := btcsigner.NewSigner(btcChain, zetacoreContext, tss, ts, logger, btcConfig) if err != nil { logger.Std.Error().Err(err).Msgf("NewBTCSigner error for chain %s", btcChain.String()) } else { diff --git a/testutil/sample/os.go b/testutil/sample/os.go new file mode 100644 index 0000000000..b6519adf69 --- /dev/null +++ b/testutil/sample/os.go @@ -0,0 +1,15 @@ +package sample + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +// create a temporary directory for testing +func CreateTempDir(t *testing.T) string { + tempPath, err := os.MkdirTemp("", "tempdir-") + require.NoError(t, err) + return tempPath +} diff --git a/zetaclient/chains/base/logger_test.go b/zetaclient/chains/base/logger_test.go index 07c2859b0e..07c0941f5a 100644 --- a/zetaclient/chains/base/logger_test.go +++ b/zetaclient/chains/base/logger_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/zeta-chain/zetacore/testutil/sample" "github.com/zeta-chain/zetacore/zetaclient/chains/base" "github.com/zeta-chain/zetacore/zetaclient/config" ) @@ -23,7 +24,7 @@ func TestInitLogger(t *testing.T) { LogFormat: "json", LogLevel: 1, // zerolog.InfoLevel, ComplianceConfig: config.ComplianceConfig{ - LogPath: createTempDir(t), + LogPath: sample.CreateTempDir(t), }, }, fail: false, @@ -34,7 +35,7 @@ func TestInitLogger(t *testing.T) { LogFormat: "text", LogLevel: 2, // zerolog.WarnLevel, ComplianceConfig: config.ComplianceConfig{ - LogPath: createTempDir(t), + LogPath: sample.CreateTempDir(t), }, }, fail: false, @@ -45,7 +46,7 @@ func TestInitLogger(t *testing.T) { LogFormat: "unknown", LogLevel: 3, // zerolog.ErrorLevel, ComplianceConfig: config.ComplianceConfig{ - LogPath: createTempDir(t), + LogPath: sample.CreateTempDir(t), }, }, fail: false, @@ -57,7 +58,7 @@ func TestInitLogger(t *testing.T) { LogLevel: 4, // zerolog.DebugLevel, LogSampler: true, ComplianceConfig: config.ComplianceConfig{ - LogPath: createTempDir(t), + LogPath: sample.CreateTempDir(t), }, }, }, diff --git a/zetaclient/chains/base/observer_test.go b/zetaclient/chains/base/observer_test.go index 7dd2f18081..6972478190 100644 --- a/zetaclient/chains/base/observer_test.go +++ b/zetaclient/chains/base/observer_test.go @@ -18,13 +18,6 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" ) -// create a temporary directory for testing -func createTempDir(t *testing.T) string { - tempPath, err := os.MkdirTemp("", "tempdir-") - require.NoError(t, err) - return tempPath -} - // createObserver creates a new observer for testing func createObserver(t *testing.T, dbPath string) *base.Observer { // constructor parameters @@ -62,7 +55,7 @@ func TestNewObserver(t *testing.T) { tss := mocks.NewTSSMainnet() blockCacheSize := base.DefaultBlockCacheSize headersCacheSize := base.DefaultHeadersCacheSize - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) // test cases tests := []struct { @@ -159,7 +152,7 @@ func TestNewObserver(t *testing.T) { } func TestObserverGetterAndSetter(t *testing.T) { - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) t.Run("should be able to update chain", func(t *testing.T) { ob := createObserver(t, dbPath) @@ -258,7 +251,7 @@ func TestObserverGetterAndSetter(t *testing.T) { } func TestOpenDB(t *testing.T) { - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) t.Run("should be able to open db", func(t *testing.T) { @@ -277,7 +270,7 @@ func TestLoadLastBlockScanned(t *testing.T) { t.Run("should be able to load last block scanned", func(t *testing.T) { // create db and write 100 as last block scanned - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) ob.WriteLastBlockScannedToDB(100) @@ -289,7 +282,7 @@ func TestLoadLastBlockScanned(t *testing.T) { }) t.Run("should use latest block if last block scanned not found", func(t *testing.T) { // create empty db - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) // read last block scanned @@ -299,7 +292,7 @@ func TestLoadLastBlockScanned(t *testing.T) { }) t.Run("should overwrite last block scanned if env var is set", func(t *testing.T) { // create db and write 100 as last block scanned - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) ob.WriteLastBlockScannedToDB(100) @@ -322,7 +315,7 @@ func TestLoadLastBlockScanned(t *testing.T) { }) t.Run("should return error on invalid env var", func(t *testing.T) { // create db and write 100 as last block scanned - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) // set invalid env var @@ -338,7 +331,7 @@ func TestLoadLastBlockScanned(t *testing.T) { func TestReadWriteLastBlockScannedToDB(t *testing.T) { t.Run("should be able to write and read last block scanned to db", func(t *testing.T) { // create db and write 100 as last block scanned - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) err := ob.WriteLastBlockScannedToDB(100) require.NoError(t, err) @@ -349,7 +342,7 @@ func TestReadWriteLastBlockScannedToDB(t *testing.T) { }) t.Run("should return error when last block scanned not found in db", func(t *testing.T) { // create empty db - dbPath := createTempDir(t) + dbPath := sample.CreateTempDir(t) ob := createObserver(t, dbPath) lastScannedBlock, err := ob.ReadLastBlockScannedFromDB() diff --git a/zetaclient/chains/base/signer.go b/zetaclient/chains/base/signer.go index 2585218767..a33f116e48 100644 --- a/zetaclient/chains/base/signer.go +++ b/zetaclient/chains/base/signer.go @@ -1,6 +1,8 @@ package base import ( + "sync" + "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" "github.com/zeta-chain/zetacore/zetaclient/context" @@ -24,6 +26,10 @@ type Signer struct { // logger contains the loggers used by signer logger Logger + + // mu protects fields from concurrent access + // Note: base signer simply provides the mutex. It's the sub-struct's responsibility to use it to be thread-safe + mu *sync.Mutex } // NewSigner creates a new base signer @@ -43,6 +49,7 @@ func NewSigner( Std: logger.Std.With().Int64("chain", chain.ChainId).Str("module", "signer").Logger(), Compliance: logger.Compliance, }, + mu: &sync.Mutex{}, } } @@ -94,3 +101,8 @@ func (s *Signer) WithTelemetryServer(ts *metrics.TelemetryServer) *Signer { func (s *Signer) Logger() *Logger { return &s.logger } + +// Mu returns the mutex for the signer +func (s *Signer) Mu() *sync.Mutex { + return s.mu +} diff --git a/zetaclient/chains/base/signer_test.go b/zetaclient/chains/base/signer_test.go index 960c508d6e..3de1d18d4a 100644 --- a/zetaclient/chains/base/signer_test.go +++ b/zetaclient/chains/base/signer_test.go @@ -71,4 +71,8 @@ func TestSignerGetterAndSetter(t *testing.T) { logger.Std.Info().Msg("print standard log") logger.Compliance.Info().Msg("print compliance log") }) + t.Run("should be able to get mutex", func(t *testing.T) { + signer := createSigner(t) + require.NotNil(t, signer.Mu()) + }) } diff --git a/zetaclient/chains/bitcoin/signer/signer.go b/zetaclient/chains/bitcoin/signer/signer.go index f2d47fd988..d39b8a8f9d 100644 --- a/zetaclient/chains/bitcoin/signer/signer.go +++ b/zetaclient/chains/bitcoin/signer/signer.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" ethcommon "github.com/ethereum/go-ethereum/common" - "github.com/rs/zerolog" "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/pkg/coin" @@ -30,7 +29,6 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/context" "github.com/zeta-chain/zetacore/zetaclient/metrics" "github.com/zeta-chain/zetacore/zetaclient/outboundprocessor" - "github.com/zeta-chain/zetacore/zetaclient/tss" ) const ( @@ -45,20 +43,25 @@ var _ interfaces.ChainSigner = &Signer{} // Signer deals with signing BTC transactions and implements the ChainSigner interface type Signer struct { - tssSigner interfaces.TSSSigner - rpcClient interfaces.BTCRPCClient - logger zerolog.Logger - loggerCompliance zerolog.Logger - ts *metrics.TelemetryServer - coreContext *context.ZetacoreContext + // base.Signer implements the base chain signer + base.Signer + + // client is the RPC client to interact with the Bitcoin chain + client interfaces.BTCRPCClient } +// NewSigner creates a new Bitcoin signer func NewSigner( - cfg config.BTCConfig, - tssSigner interfaces.TSSSigner, - logger base.Logger, + chain chains.Chain, + zetacoreContext *context.ZetacoreContext, + tss interfaces.TSSSigner, ts *metrics.TelemetryServer, - coreContext *context.ZetacoreContext) (*Signer, error) { + logger base.Logger, + cfg config.BTCConfig) (*Signer, error) { + // create base signer + baseSigner := base.NewSigner(chain, zetacoreContext, tss, ts, logger) + + // create the bitcoin rpc client using the provided config connCfg := &rpcclient.ConnConfig{ Host: cfg.RPCHost, User: cfg.RPCUsername, @@ -73,12 +76,8 @@ func NewSigner( } return &Signer{ - tssSigner: tssSigner, - rpcClient: client, - logger: logger.Std.With().Str("chain", "BTC").Str("module", "BTCSigner").Logger(), - loggerCompliance: logger.Compliance, - ts: ts, - coreContext: coreContext, + Signer: *baseSigner, + client: client, }, nil } @@ -130,12 +129,12 @@ func (signer *Signer) AddWithdrawTxOutputs( if remainingSats < 0 { return fmt.Errorf("remainder value is negative: %d", remainingSats) } else if remainingSats == nonceMark { - signer.logger.Info().Msgf("adjust remainder value to avoid duplicate nonce-mark: %d", remainingSats) + signer.Logger().Std.Info().Msgf("adjust remainder value to avoid duplicate nonce-mark: %d", remainingSats) remainingSats-- } // 1st output: the nonce-mark btc to TSS self - tssAddrP2WPKH := signer.tssSigner.BTCAddressWitnessPubkeyHash() + tssAddrP2WPKH := signer.TSS().BTCAddressWitnessPubkeyHash() payToSelfScript, err := bitcoin.PayToAddrScript(tssAddrP2WPKH) if err != nil { return err @@ -182,7 +181,10 @@ func (signer *Signer) SignWithdrawTx( // refresh unspent UTXOs and continue with keysign regardless of error err := observer.FetchUTXOs() if err != nil { - signer.logger.Error().Err(err).Msgf("SignWithdrawTx: FetchUTXOs error: nonce %d chain %d", nonce, chain.ChainId) + signer.Logger(). + Std.Error(). + Err(err). + Msgf("SignWithdrawTx: FetchUTXOs error: nonce %d chain %d", nonce, chain.ChainId) } // select N UTXOs to cover the total expense @@ -216,16 +218,16 @@ func (signer *Signer) SignWithdrawTx( return nil, err } if sizeLimit < bitcoin.BtcOutboundBytesWithdrawer { // ZRC20 'withdraw' charged less fee from end user - signer.logger.Info(). + signer.Logger().Std.Info(). Msgf("sizeLimit %d is less than BtcOutboundBytesWithdrawer %d for nonce %d", sizeLimit, txSize, nonce) } if txSize < bitcoin.OutboundBytesMin { // outbound shouldn't be blocked a low sizeLimit - signer.logger.Warn(). + signer.Logger().Std.Warn(). Msgf("txSize %d is less than outboundBytesMin %d; use outboundBytesMin", txSize, bitcoin.OutboundBytesMin) txSize = bitcoin.OutboundBytesMin } if txSize > bitcoin.OutboundBytesMax { // in case of accident - signer.logger.Warn(). + signer.Logger().Std.Warn(). Msgf("txSize %d is greater than outboundBytesMax %d; use outboundBytesMax", txSize, bitcoin.OutboundBytesMax) txSize = bitcoin.OutboundBytesMax } @@ -233,8 +235,10 @@ func (signer *Signer) SignWithdrawTx( // fee calculation // #nosec G701 always in range (checked above) fees := new(big.Int).Mul(big.NewInt(int64(txSize)), gasPrice) - signer.logger.Info().Msgf("bitcoin outbound nonce %d gasPrice %s size %d fees %s consolidated %d utxos of value %v", - nonce, gasPrice.String(), txSize, fees.String(), consolidatedUtxo, consolidatedValue) + signer.Logger(). + Std.Info(). + Msgf("bitcoin outbound nonce %d gasPrice %s size %d fees %s consolidated %d utxos of value %v", + nonce, gasPrice.String(), txSize, fees.String(), consolidatedUtxo, consolidatedValue) // add tx outputs err = signer.AddWithdrawTxOutputs(tx, to, total, amount, nonceMark, fees, cancelTx) @@ -260,11 +264,7 @@ func (signer *Signer) SignWithdrawTx( } } - tssSigner, ok := signer.tssSigner.(*tss.TSS) - if !ok { - return nil, fmt.Errorf("tssSigner is not a TSS") - } - sig65Bs, err := tssSigner.SignBatch(witnessHashes, height, nonce, &chain) + sig65Bs, err := signer.TSS().SignBatch(witnessHashes, height, nonce, chain.ChainId) if err != nil { return nil, fmt.Errorf("SignBatch error: %v", err) } @@ -278,7 +278,7 @@ func (signer *Signer) SignWithdrawTx( S: S, } - pkCompressed := signer.tssSigner.PubKeyCompressedBytes() + pkCompressed := signer.TSS().PubKeyCompressedBytes() hashType := txscript.SigHashAll txWitness := wire.TxWitness{append(sig.Serialize(), byte(hashType)), pkCompressed} tx.TxIn[ix].Witness = txWitness @@ -298,12 +298,12 @@ func (signer *Signer) Broadcast(signedTx *wire.MsgTx) error { str := hex.EncodeToString(outBuff.Bytes()) fmt.Printf("BTCSigner: Transaction Data: %s\n", str) - hash, err := signer.rpcClient.SendRawTransaction(signedTx, true) + hash, err := signer.client.SendRawTransaction(signedTx, true) if err != nil { return err } - signer.logger.Info().Msgf("Broadcasting BTC tx , hash %s ", hash) + signer.Logger().Std.Info().Msgf("Broadcasting BTC tx , hash %s ", hash) return nil } @@ -318,11 +318,11 @@ func (signer *Signer) TryProcessOutbound( defer func() { outboundProcessor.EndTryProcess(outboundID) if err := recover(); err != nil { - signer.logger.Error().Msgf("BTC TryProcessOutbound: %s, caught panic error: %v", cctx.Index, err) + signer.Logger().Std.Error().Msgf("BTC TryProcessOutbound: %s, caught panic error: %v", cctx.Index, err) } }() - logger := signer.logger.With(). + logger := signer.Logger().Std.With(). Str("OutboundID", outboundID). Str("SendHash", cctx.Index). Logger() @@ -341,7 +341,7 @@ func (signer *Signer) TryProcessOutbound( logger.Error().Msgf("chain observer is not a bitcoin observer") return } - flags := signer.coreContext.GetCrossChainFlags() + flags := signer.ZetacoreContext().GetCrossChainFlags() if !flags.IsOutboundEnabled { logger.Info().Msgf("outbound is disabled") return @@ -375,7 +375,7 @@ func (signer *Signer) TryProcessOutbound( amount := float64(params.Amount.Uint64()) / 1e8 // Add 1 satoshi/byte to gasPrice to avoid minRelayTxFee issue - networkInfo, err := signer.rpcClient.GetNetworkInfo() + networkInfo, err := signer.client.GetNetworkInfo() if err != nil { logger.Error().Err(err).Msgf("cannot get bitcoin network info") return @@ -386,7 +386,7 @@ func (signer *Signer) TryProcessOutbound( // compliance check cancelTx := compliance.IsCctxRestricted(cctx) if cancelTx { - compliance.PrintComplianceLog(logger, signer.loggerCompliance, + compliance.PrintComplianceLog(logger, signer.Logger().Compliance, true, chain.ChainId, cctx.Index, cctx.InboundParams.Sender, params.Receiver, "BTC") amount = 0.0 // zero out the amount to cancel the tx } diff --git a/zetaclient/chains/bitcoin/signer/signer_keysign_test.go b/zetaclient/chains/bitcoin/signer/signer_keysign_test.go index 8cc90cf3eb..2506f57059 100644 --- a/zetaclient/chains/bitcoin/signer/signer_keysign_test.go +++ b/zetaclient/chains/bitcoin/signer/signer_keysign_test.go @@ -15,7 +15,6 @@ import ( "github.com/btcsuite/btcutil" "github.com/stretchr/testify/suite" - "github.com/zeta-chain/zetacore/pkg/chains" "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin" "github.com/zeta-chain/zetacore/zetaclient/chains/interfaces" "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" @@ -147,7 +146,7 @@ func getTSSTX( return "", err } - sig65B, err := tss.Sign(witnessHash, 10, 10, &chains.Chain{}, "") + sig65B, err := tss.Sign(witnessHash, 10, 10, 0, "") R := big.NewInt(0).SetBytes(sig65B[:32]) S := big.NewInt(0).SetBytes(sig65B[32:64]) sig := btcec.Signature{ diff --git a/zetaclient/chains/bitcoin/signer/signer_test.go b/zetaclient/chains/bitcoin/signer/signer_test.go index 74628211a4..e351d4a65c 100644 --- a/zetaclient/chains/bitcoin/signer/signer_test.go +++ b/zetaclient/chains/bitcoin/signer/signer_test.go @@ -21,8 +21,6 @@ import ( "github.com/zeta-chain/zetacore/zetaclient/chains/base" "github.com/zeta-chain/zetacore/zetaclient/chains/bitcoin" "github.com/zeta-chain/zetacore/zetaclient/config" - "github.com/zeta-chain/zetacore/zetaclient/context" - "github.com/zeta-chain/zetacore/zetaclient/metrics" "github.com/zeta-chain/zetacore/zetaclient/testutils/mocks" ) @@ -48,13 +46,14 @@ func (s *BTCSignerSuite) SetUpTest(c *C) { tss := &mocks.TSS{ PrivKey: privateKey, } - cfg := config.NewConfig() s.btcSigner, err = NewSigner( - config.BTCConfig{}, + chains.Chain{}, + nil, tss, + nil, base.DefaultLogger(), - &metrics.TelemetryServer{}, - context.NewZetacoreContext(cfg)) + config.BTCConfig{}, + ) c.Assert(err, IsNil) } @@ -231,16 +230,17 @@ func (s *BTCSignerSuite) TestP2WPH(c *C) { func TestAddWithdrawTxOutputs(t *testing.T) { // Create test signer and receiver address signer, err := NewSigner( - config.BTCConfig{}, + chains.Chain{}, + nil, mocks.NewTSSMainnet(), - base.DefaultLogger(), - &metrics.TelemetryServer{}, nil, + base.DefaultLogger(), + config.BTCConfig{}, ) require.NoError(t, err) // tss address and script - tssAddr := signer.tssSigner.BTCAddressWitnessPubkeyHash() + tssAddr := signer.TSS().BTCAddressWitnessPubkeyHash() tssScript, err := bitcoin.PayToAddrScript(tssAddr) require.NoError(t, err) fmt.Printf("tss address: %s", tssAddr.EncodeAddress()) @@ -392,13 +392,13 @@ func TestNewBTCSigner(t *testing.T) { tss := &mocks.TSS{ PrivKey: privateKey, } - cfg := config.NewConfig() btcSigner, err := NewSigner( - config.BTCConfig{}, + chains.Chain{}, + nil, tss, + nil, base.DefaultLogger(), - &metrics.TelemetryServer{}, - context.NewZetacoreContext(cfg)) + config.BTCConfig{}) require.NoError(t, err) require.NotNil(t, btcSigner) } diff --git a/zetaclient/chains/evm/signer/signer.go b/zetaclient/chains/evm/signer/signer.go index d49895e2a3..4fa29a015b 100644 --- a/zetaclient/chains/evm/signer/signer.go +++ b/zetaclient/chains/evm/signer/signer.go @@ -5,10 +5,8 @@ import ( "encoding/hex" "fmt" "math/big" - "math/rand" "strconv" "strings" - "sync" "time" sdk "github.com/cosmos/cosmos-sdk/types" @@ -48,39 +46,54 @@ var ( // Signer deals with the signing EVM transactions and implements the ChainSigner interface type Signer struct { - client interfaces.EVMRPCClient - chain *chains.Chain - tssSigner interfaces.TSSSigner - ethSigner ethtypes.Signer - logger base.Logger - ts *metrics.TelemetryServer - coreContext *clientcontext.ZetacoreContext - - // mu protects below fields from concurrent access - mu *sync.Mutex - zetaConnectorABI abi.ABI - erc20CustodyABI abi.ABI - zetaConnectorAddress ethcommon.Address - er20CustodyAddress ethcommon.Address + // base.Signer implements the base chain signer + base.Signer + + // client is the EVM RPC client to interact with the EVM chain + client interfaces.EVMRPCClient + + // ethSigner encapsulates EVM transaction signature handling + ethSigner ethtypes.Signer + + // zetaConnectorABI is the ABI of the ZetaConnector contract + zetaConnectorABI abi.ABI + + // erc20CustodyABI is the ABI of the ERC20Custody contract + erc20CustodyABI abi.ABI + + // zetaConnectorAddress is the address of the ZetaConnector contract + zetaConnectorAddress ethcommon.Address + + // er20CustodyAddress is the address of the ERC20Custody contract + er20CustodyAddress ethcommon.Address + + // outboundHashBeingReported is a map of outboundHash being reported outboundHashBeingReported map[string]bool } +// NewSigner creates a new EVM signer func NewSigner( chain chains.Chain, + zetacoreContext *clientcontext.ZetacoreContext, + tss interfaces.TSSSigner, + ts *metrics.TelemetryServer, + logger base.Logger, endpoint string, - tssSigner interfaces.TSSSigner, zetaConnectorABI string, erc20CustodyABI string, zetaConnectorAddress ethcommon.Address, erc20CustodyAddress ethcommon.Address, - coreContext *clientcontext.ZetacoreContext, - logger base.Logger, - ts *metrics.TelemetryServer, ) (*Signer, error) { + // create base signer + baseSigner := base.NewSigner(chain, zetacoreContext, tss, ts, logger) + + // create EVM client client, ethSigner, err := getEVMRPC(endpoint) if err != nil { return nil, err } + + // prepare ABIs connectorABI, err := abi.JSON(strings.NewReader(zetaConnectorABI)) if err != nil { return nil, err @@ -91,50 +104,42 @@ func NewSigner( } return &Signer{ - client: client, - chain: &chain, - tssSigner: tssSigner, - ethSigner: ethSigner, - zetaConnectorABI: connectorABI, - erc20CustodyABI: custodyABI, - zetaConnectorAddress: zetaConnectorAddress, - er20CustodyAddress: erc20CustodyAddress, - coreContext: coreContext, - logger: base.Logger{ - Std: logger.Std.With().Str("chain", chain.ChainName.String()).Str("module", "EVMSigner").Logger(), - Compliance: logger.Compliance, - }, - ts: ts, - mu: &sync.Mutex{}, + Signer: *baseSigner, + client: client, + ethSigner: ethSigner, + zetaConnectorABI: connectorABI, + erc20CustodyABI: custodyABI, + zetaConnectorAddress: zetaConnectorAddress, + er20CustodyAddress: erc20CustodyAddress, outboundHashBeingReported: make(map[string]bool), }, nil } // SetZetaConnectorAddress sets the zeta connector address func (signer *Signer) SetZetaConnectorAddress(addr ethcommon.Address) { - signer.mu.Lock() - defer signer.mu.Unlock() + signer.Mu().Lock() + defer signer.Mu().Unlock() signer.zetaConnectorAddress = addr } // SetERC20CustodyAddress sets the erc20 custody address func (signer *Signer) SetERC20CustodyAddress(addr ethcommon.Address) { - signer.mu.Lock() - defer signer.mu.Unlock() + signer.Mu().Lock() + defer signer.Mu().Unlock() signer.er20CustodyAddress = addr } // GetZetaConnectorAddress returns the zeta connector address func (signer *Signer) GetZetaConnectorAddress() ethcommon.Address { - signer.mu.Lock() - defer signer.mu.Unlock() + signer.Mu().Lock() + defer signer.Mu().Unlock() return signer.zetaConnectorAddress } // GetERC20CustodyAddress returns the erc20 custody address func (signer *Signer) GetERC20CustodyAddress() ethcommon.Address { - signer.mu.Lock() - defer signer.mu.Unlock() + signer.Mu().Lock() + defer signer.Mu().Unlock() return signer.er20CustodyAddress } @@ -149,15 +154,14 @@ func (signer *Signer) Sign( nonce uint64, height uint64, ) (*ethtypes.Transaction, []byte, []byte, error) { - log.Debug().Msgf("TSS SIGNER: %s", signer.tssSigner.Pubkey()) + log.Debug().Msgf("Sign: TSS signer: %s", signer.TSS().Pubkey()) // TODO: use EIP-1559 transaction type // https://github.com/zeta-chain/node/issues/1952 tx := ethtypes.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - hashBytes := signer.ethSigner.Hash(tx).Bytes() - sig, err := signer.tssSigner.Sign(hashBytes, height, nonce, signer.chain, "") + sig, err := signer.TSS().Sign(hashBytes, height, nonce, signer.Chain().ChainId, "") if err != nil { return nil, nil, nil, err } @@ -165,11 +169,11 @@ func (signer *Signer) Sign( log.Debug().Msgf("Sign: Signature: %s", hex.EncodeToString(sig[:])) pubk, err := crypto.SigToPub(hashBytes, sig[:]) if err != nil { - signer.logger.Std.Error().Err(err).Msgf("SigToPub error") + signer.Logger().Std.Error().Err(err).Msgf("SigToPub error") } addr := crypto.PubkeyToAddress(*pubk) - signer.logger.Std.Info().Msgf("Sign: Ecrecovery of signature: %s", addr.Hex()) + signer.Logger().Std.Info().Msgf("Sign: Ecrecovery of signature: %s", addr.Hex()) signedTX, err := tx.WithSignature(signer.ethSigner, sig[:]) if err != nil { return nil, nil, nil, err @@ -269,7 +273,7 @@ func (signer *Signer) SignRevertTx(txData *OutboundData) (*ethtypes.Transaction, func (signer *Signer) SignCancelTx(txData *OutboundData) (*ethtypes.Transaction, error) { tx, _, _, err := signer.Sign( nil, - signer.tssSigner.EVMAddress(), + signer.TSS().EVMAddress(), zeroValue, // zero out the amount to cancel the tx evm.EthTransferGasLimit, txData.gasPrice, @@ -326,7 +330,7 @@ func (signer *Signer) TryProcessOutbound( zetacoreClient interfaces.ZetacoreClient, height uint64, ) { - logger := signer.logger.Std.With(). + logger := signer.Logger().Std.With(). Str("outboundID", outboundID). Str("SendHash", cctx.Index). Logger() @@ -363,16 +367,16 @@ func (signer *Signer) TryProcessOutbound( toChain := chains.GetChainFromChainID(txData.toChainID.Int64()) // Get cross-chain flags - crossChainflags := signer.coreContext.GetCrossChainFlags() + crossChainflags := signer.ZetacoreContext().GetCrossChainFlags() // https://github.com/zeta-chain/node/issues/2050 var tx *ethtypes.Transaction // compliance check goes first if compliance.IsCctxRestricted(cctx) { compliance.PrintComplianceLog( logger, - signer.logger.Compliance, + signer.Logger().Compliance, true, - signer.chain.ChainId, + signer.Chain().ChainId, cctx.Index, cctx.InboundParams.Sender, txData.to.Hex(), @@ -529,22 +533,21 @@ func (signer *Signer) BroadcastOutbound( if tx == nil { logger.Warn().Msgf("BroadcastOutbound: no tx to broadcast %s", cctx.Index) } - // Try to broadcast transaction + + // broadcast transaction if tx != nil { outboundHash := tx.Hash().Hex() - logger.Info(). - Msgf("on chain %s nonce %d, outboundHash %s signer %s", signer.chain, cctx.GetCurrentOutboundParam().TssNonce, outboundHash, myID) - //if len(signers) == 0 || myid == signers[send.OutboundParams.Broadcaster] || myid == signers[int(send.OutboundParams.Broadcaster+1)%len(signers)] { + + // try broacasting tx with increasing backoff (1s, 2s, 4s, 8s, 16s) in case of RPC error backOff := 1000 * time.Millisecond - // retry loop: 1s, 2s, 4s, 8s, 16s in case of RPC error for i := 0; i < 5; i++ { - logger.Info(). - Msgf("broadcasting tx %s to chain %s: nonce %d, retry %d", outboundHash, toChain, cctx.GetCurrentOutboundParam().TssNonce, i) - // #nosec G404 randomness is not a security issue here - time.Sleep(time.Duration(rand.Intn(1500)) * time.Millisecond) // FIXME: use backoff + time.Sleep(backOff) err := signer.Broadcast(tx) if err != nil { - log.Warn().Err(err).Msgf("Outbound Broadcast error") + log.Warn(). + Err(err). + Msgf("BroadcastOutbound: error broadcasting tx %s on chain %d nonce %d retry %d signer %s", + outboundHash, toChain.ChainId, cctx.GetCurrentOutboundParam().TssNonce, i, myID) retry, report := zetacore.HandleBroadcastError( err, strconv.FormatUint(cctx.GetCurrentOutboundParam().TssNonce, 10), @@ -560,8 +563,8 @@ func (signer *Signer) BroadcastOutbound( backOff *= 2 continue } - logger.Info(). - Msgf("Broadcast success: nonce %d to chain %s outboundHash %s", cctx.GetCurrentOutboundParam().TssNonce, toChain, outboundHash) + logger.Info().Msgf("BroadcastOutbound: broadcasted tx %s on chain %d nonce %d signer %s", + outboundHash, toChain.ChainId, cctx.GetCurrentOutboundParam().TssNonce, myID) signer.reportToOutboundTracker(zetacoreClient, toChain.ChainId, tx.Nonce(), outboundHash, logger) break // successful broadcast; no need to retry } @@ -686,8 +689,8 @@ func (signer *Signer) reportToOutboundTracker( logger zerolog.Logger, ) { // skip if already being reported - signer.mu.Lock() - defer signer.mu.Unlock() + signer.Mu().Lock() + defer signer.Mu().Unlock() if _, found := signer.outboundHashBeingReported[outboundHash]; found { logger.Info(). Msgf("reportToOutboundTracker: outboundHash %s for chain %d nonce %d is being reported", outboundHash, chainID, nonce) @@ -698,9 +701,9 @@ func (signer *Signer) reportToOutboundTracker( // report to outbound tracker with goroutine go func() { defer func() { - signer.mu.Lock() + signer.Mu().Lock() delete(signer.outboundHashBeingReported, outboundHash) - signer.mu.Unlock() + signer.Mu().Unlock() }() // try monitoring tx inclusion status for 10 minutes diff --git a/zetaclient/chains/evm/signer/signer_test.go b/zetaclient/chains/evm/signer/signer_test.go index 186582dd24..410ea5adf3 100644 --- a/zetaclient/chains/evm/signer/signer_test.go +++ b/zetaclient/chains/evm/signer/signer_test.go @@ -43,20 +43,19 @@ func getNewEvmSigner(tss interfaces.TSSSigner) (*Signer, error) { mpiAddress := ConnectorAddress erc20CustodyAddress := ERC20CustodyAddress logger := base.Logger{} - ts := &metrics.TelemetryServer{} cfg := config.NewConfig() return NewSigner( chains.BscMainnet, - mocks.EVMRPCEnabled, + context.NewZetacoreContext(cfg), tss, + nil, + logger, + mocks.EVMRPCEnabled, config.GetConnectorABI(), config.GetERC20CustodyABI(), mpiAddress, - erc20CustodyAddress, - context.NewZetacoreContext(cfg), - logger, - ts) + erc20CustodyAddress) } // getNewEvmChainObserver creates a new EVM chain observer for testing @@ -256,7 +255,7 @@ func TestSigner_SignCancelTx(t *testing.T) { // Verify tx body basics // Note: Cancel tx sends 0 gas token to TSS self address - verifyTxBodyBasics(t, tx, evmSigner.tssSigner.EVMAddress(), txData.nonce, big.NewInt(0)) + verifyTxBodyBasics(t, tx, evmSigner.TSS().EVMAddress(), txData.nonce, big.NewInt(0)) }) t.Run("SignCancelTx - should fail if keysign fails", func(t *testing.T) { // Pause tss to make keysign fail diff --git a/zetaclient/chains/interfaces/interfaces.go b/zetaclient/chains/interfaces/interfaces.go index 1d901d4f53..1ef94ec8de 100644 --- a/zetaclient/chains/interfaces/interfaces.go +++ b/zetaclient/chains/interfaces/interfaces.go @@ -167,7 +167,10 @@ type TSSSigner interface { // Note: it specifies optionalPubkey to use a different pubkey than the current pubkey set during keygen // TODO: check if optionalPubkey is needed // https://github.com/zeta-chain/node/issues/2085 - Sign(data []byte, height uint64, nonce uint64, chain *chains.Chain, optionalPubkey string) ([65]byte, error) + Sign(data []byte, height uint64, nonce uint64, chainID int64, optionalPubkey string) ([65]byte, error) + + // SignBatch signs the data in batch + SignBatch(digests [][]byte, height uint64, nonce uint64, chainID int64) ([][65]byte, error) EVMAddress() ethcommon.Address BTCAddress() string diff --git a/zetaclient/testutils/mocks/btc_rpc.go b/zetaclient/testutils/mocks/btc_rpc.go index cfd63ef87b..01d286d31a 100644 --- a/zetaclient/testutils/mocks/btc_rpc.go +++ b/zetaclient/testutils/mocks/btc_rpc.go @@ -17,7 +17,9 @@ var _ interfaces.BTCRPCClient = &MockBTCRPCClient{} // MockBTCRPCClient is a mock implementation of the BTCRPCClient interface type MockBTCRPCClient struct { - Txs []*btcutil.Tx + err error + blockCount int64 + Txs []*btcutil.Tx } // NewMockBTCRPCClient creates a new mock BTC RPC client @@ -28,6 +30,10 @@ func NewMockBTCRPCClient() *MockBTCRPCClient { // Reset clears the mock data func (c *MockBTCRPCClient) Reset() *MockBTCRPCClient { + if c.err != nil { + return nil + } + c.Txs = []*btcutil.Tx{} return c } @@ -95,7 +101,10 @@ func (c *MockBTCRPCClient) GetRawTransactionVerbose(_ *chainhash.Hash) (*btcjson } func (c *MockBTCRPCClient) GetBlockCount() (int64, error) { - return 0, errors.New("not implemented") + if c.err != nil { + return 0, c.err + } + return c.blockCount, nil } func (c *MockBTCRPCClient) GetBlockHash(_ int64) (*chainhash.Hash, error) { @@ -118,6 +127,16 @@ func (c *MockBTCRPCClient) GetBlockHeader(_ *chainhash.Hash) (*wire.BlockHeader, // Feed data to the mock BTC RPC client for testing // ---------------------------------------------------------------------------- +func (c *MockBTCRPCClient) WithError(err error) *MockBTCRPCClient { + c.err = err + return c +} + +func (c *MockBTCRPCClient) WithBlockCount(blkCnt int64) *MockBTCRPCClient { + c.blockCount = blkCnt + return c +} + func (c *MockBTCRPCClient) WithRawTransaction(tx *btcutil.Tx) *MockBTCRPCClient { c.Txs = append(c.Txs, tx) return c diff --git a/zetaclient/testutils/mocks/evm_rpc.go b/zetaclient/testutils/mocks/evm_rpc.go index 926be6ec28..fa40357592 100644 --- a/zetaclient/testutils/mocks/evm_rpc.go +++ b/zetaclient/testutils/mocks/evm_rpc.go @@ -31,7 +31,9 @@ func (s subscription) Err() <-chan error { var _ interfaces.EVMRPCClient = &MockEvmClient{} type MockEvmClient struct { - Receipts []*ethtypes.Receipt + err error + blockNumber uint64 + Receipts []*ethtypes.Receipt } func NewMockEvmClient() *MockEvmClient { @@ -44,56 +46,92 @@ func (e *MockEvmClient) SubscribeFilterLogs( _ ethereum.FilterQuery, _ chan<- ethtypes.Log, ) (ethereum.Subscription, error) { + if e.err != nil { + return subscription{}, e.err + } return subscription{}, nil } func (e *MockEvmClient) CodeAt(_ context.Context, _ ethcommon.Address, _ *big.Int) ([]byte, error) { + if e.err != nil { + return nil, e.err + } return []byte{}, nil } func (e *MockEvmClient) CallContract(_ context.Context, _ ethereum.CallMsg, _ *big.Int) ([]byte, error) { + if e.err != nil { + return nil, e.err + } return []byte{}, nil } func (e *MockEvmClient) HeaderByNumber(_ context.Context, _ *big.Int) (*ethtypes.Header, error) { + if e.err != nil { + return nil, e.err + } return ðtypes.Header{}, nil } func (e *MockEvmClient) PendingCodeAt(_ context.Context, _ ethcommon.Address) ([]byte, error) { + if e.err != nil { + return nil, e.err + } return []byte{}, nil } func (e *MockEvmClient) PendingNonceAt(_ context.Context, _ ethcommon.Address) (uint64, error) { + if e.err != nil { + return 0, e.err + } return 0, nil } func (e *MockEvmClient) SuggestGasPrice(_ context.Context) (*big.Int, error) { + if e.err != nil { + return nil, e.err + } return big.NewInt(0), nil } func (e *MockEvmClient) SuggestGasTipCap(_ context.Context) (*big.Int, error) { + if e.err != nil { + return nil, e.err + } return big.NewInt(0), nil } func (e *MockEvmClient) EstimateGas(_ context.Context, _ ethereum.CallMsg) (gas uint64, err error) { + if e.err != nil { + return 0, e.err + } gas = 0 err = nil return } func (e *MockEvmClient) SendTransaction(_ context.Context, _ *ethtypes.Transaction) error { - return nil + return e.err } func (e *MockEvmClient) FilterLogs(_ context.Context, _ ethereum.FilterQuery) ([]ethtypes.Log, error) { + if e.err != nil { + return nil, e.err + } return []ethtypes.Log{}, nil } func (e *MockEvmClient) BlockNumber(_ context.Context) (uint64, error) { - return 0, nil + if e.err != nil { + return 0, e.err + } + return e.blockNumber, nil } func (e *MockEvmClient) BlockByNumber(_ context.Context, _ *big.Int) (*ethtypes.Block, error) { + if e.err != nil { + return nil, e.err + } return ðtypes.Block{}, nil } @@ -101,10 +139,17 @@ func (e *MockEvmClient) TransactionByHash( _ context.Context, _ ethcommon.Hash, ) (tx *ethtypes.Transaction, isPending bool, err error) { + if e.err != nil { + return nil, false, e.err + } return ðtypes.Transaction{}, false, nil } func (e *MockEvmClient) TransactionReceipt(_ context.Context, _ ethcommon.Hash) (*ethtypes.Receipt, error) { + if e.err != nil { + return nil, e.err + } + // pop a receipt from the list if len(e.Receipts) > 0 { receipt := e.Receipts[len(e.Receipts)-1] @@ -120,6 +165,9 @@ func (e *MockEvmClient) TransactionSender( _ ethcommon.Hash, _ uint, ) (ethcommon.Address, error) { + if e.err != nil { + return ethcommon.Address{}, e.err + } return ethcommon.Address{}, nil } @@ -131,6 +179,16 @@ func (e *MockEvmClient) Reset() *MockEvmClient { // ---------------------------------------------------------------------------- // Feed data to the mock evm client for testing // ---------------------------------------------------------------------------- +func (e *MockEvmClient) WithError(err error) *MockEvmClient { + e.err = err + return e +} + +func (e *MockEvmClient) WithBlockNumber(blockNumber uint64) *MockEvmClient { + e.blockNumber = blockNumber + return e +} + func (e *MockEvmClient) WithReceipt(receipt *ethtypes.Receipt) *MockEvmClient { e.Receipts = append(e.Receipts, receipt) return e diff --git a/zetaclient/testutils/mocks/tss_signer.go b/zetaclient/testutils/mocks/tss_signer.go index 87d2f2ac3b..ea439e23b6 100644 --- a/zetaclient/testutils/mocks/tss_signer.go +++ b/zetaclient/testutils/mocks/tss_signer.go @@ -67,7 +67,7 @@ func (s *TSS) WithPrivKey(privKey *ecdsa.PrivateKey) *TSS { } // Sign uses test key unrelated to any tss key in production -func (s *TSS) Sign(data []byte, _ uint64, _ uint64, _ *chains.Chain, _ string) ([65]byte, error) { +func (s *TSS) Sign(data []byte, _ uint64, _ uint64, _ int64, _ string) ([65]byte, error) { // return error if tss is paused if s.paused { return [65]byte{}, fmt.Errorf("tss is paused") @@ -83,6 +83,17 @@ func (s *TSS) Sign(data []byte, _ uint64, _ uint64, _ *chains.Chain, _ string) ( return sigbyte, nil } +// SignBatch uses test key unrelated to any tss key in production +func (s *TSS) SignBatch(_ [][]byte, _ uint64, _ uint64, _ int64) ([][65]byte, error) { + // return error if tss is paused + if s.paused { + return nil, fmt.Errorf("tss is paused") + } + + // mock not implemented yet + return nil, fmt.Errorf("not implemented") +} + func (s *TSS) Pubkey() []byte { publicKeyBytes := crypto.FromECDSAPub(&s.PrivKey.PublicKey) return publicKeyBytes diff --git a/zetaclient/tss/tss_signer.go b/zetaclient/tss/tss_signer.go index 2915dd6d0a..8fdd0384ee 100644 --- a/zetaclient/tss/tss_signer.go +++ b/zetaclient/tss/tss_signer.go @@ -218,7 +218,7 @@ func (tss *TSS) Sign( digest []byte, height uint64, nonce uint64, - chain *chains.Chain, + chainID int64, optionalPubKey string, ) ([65]byte, error) { H := digest @@ -250,8 +250,8 @@ func (tss *TSS) Sign( // post blame data if enabled if IsEnvFlagEnabled(envFlagPostBlame) { digest := hex.EncodeToString(digest) - index := observertypes.GetBlameIndex(chain.ChainId, nonce, digest, height) - zetaHash, err := tss.ZetacoreClient.PostBlameData(&ksRes.Blame, chain.ChainId, index) + index := observertypes.GetBlameIndex(chainID, nonce, digest, height) + zetaHash, err := tss.ZetacoreClient.PostBlameData(&ksRes.Blame, chainID, index) if err != nil { log.Error().Err(err).Msg("error sending blame data to core") return [65]byte{}, err @@ -304,7 +304,7 @@ func (tss *TSS) Sign( // SignBatch is hash of some data // digest should be batch of hashes of some data -func (tss *TSS) SignBatch(digests [][]byte, height uint64, nonce uint64, chain *chains.Chain) ([][65]byte, error) { +func (tss *TSS) SignBatch(digests [][]byte, height uint64, nonce uint64, chainID int64) ([][65]byte, error) { tssPubkey := tss.CurrentPubkey digestBase64 := make([]string, len(digests)) for i, digest := range digests { @@ -326,8 +326,8 @@ func (tss *TSS) SignBatch(digests [][]byte, height uint64, nonce uint64, chain * // post blame data if enabled if IsEnvFlagEnabled(envFlagPostBlame) { digest := combineDigests(digestBase64) - index := observertypes.GetBlameIndex(chain.ChainId, nonce, hex.EncodeToString(digest), height) - zetaHash, err := tss.ZetacoreClient.PostBlameData(&ksRes.Blame, chain.ChainId, index) + index := observertypes.GetBlameIndex(chainID, nonce, hex.EncodeToString(digest), height) + zetaHash, err := tss.ZetacoreClient.PostBlameData(&ksRes.Blame, chainID, index) if err != nil { log.Error().Err(err).Msg("error sending blame data to core") return [][65]byte{}, err