From 3413f8c093a8b60c55eb0801a5172b7dca3ecb83 Mon Sep 17 00:00:00 2001 From: Alexandru Popenta Date: Mon, 29 Jul 2024 11:22:23 +0300 Subject: [PATCH] fixes after review --- multiversx_sdk_cli/cli_contracts.py | 2 +- multiversx_sdk_cli/cli_wallet.py | 42 +++++++++---------- multiversx_sdk_cli/contracts.py | 37 +---------------- multiversx_sdk_cli/dns.py | 64 ++++++++++++++--------------- multiversx_sdk_cli/utils.py | 2 - 5 files changed, 55 insertions(+), 92 deletions(-) diff --git a/multiversx_sdk_cli/cli_contracts.py b/multiversx_sdk_cli/cli_contracts.py index da4235ca..0ca7149e 100644 --- a/multiversx_sdk_cli/cli_contracts.py +++ b/multiversx_sdk_cli/cli_contracts.py @@ -444,7 +444,7 @@ def upgrade(args: Any): def query(args: Any): logger.debug("query") - # workaround so we can use the function bellow to set chainID + # workaround so we can use the function below to set chainID args.chain = "" cli_shared.prepare_chain_id_in_args(args) diff --git a/multiversx_sdk_cli/cli_wallet.py b/multiversx_sdk_cli/cli_wallet.py index 1ccbe199..06205f68 100644 --- a/multiversx_sdk_cli/cli_wallet.py +++ b/multiversx_sdk_cli/cli_wallet.py @@ -3,7 +3,7 @@ import logging import sys from pathlib import Path -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Optional, Tuple from multiversx_sdk import (Address, Mnemonic, UserPEM, UserSecretKey, UserWallet) @@ -115,27 +115,7 @@ def wallet_new(args: Any): address_hrp = args.address_hrp shard = args.shard - if shard is not None: - if shard not in CURRENT_SHARDS: - raise BadUserInput(f"Wrong shard provided. Choose between {CURRENT_SHARDS}") - - is_wallet_generated = False - for _ in range(MAX_ITERATIONS_FOR_GENERATING_WALLET): - mnemonic = Mnemonic.generate() - pubkey = mnemonic.derive_key().generate_public_key() - generated_address_shard = get_shard_of_pubkey(pubkey.buffer, NUMBER_OF_SHARDS) - - if shard == generated_address_shard: - is_wallet_generated = True - break - - if not is_wallet_generated: - raise WalletGenerationError(f"Couldn't generate wallet in shard {shard}") - else: - mnemonic = Mnemonic.generate() - - # this is done to get rid of the Pylance error: possibly unbound - mnemonic = cast(Mnemonic, mnemonic) # type: ignore + mnemonic = _generate_mnemonic_with_shard_constraint(shard) print(f"Mnemonic: {mnemonic.get_text()}") print(f"Wallet address: {mnemonic.derive_key().generate_public_key().to_address(address_hrp).to_bech32()}") @@ -172,6 +152,24 @@ def wallet_new(args: Any): logger.info(f"Wallet ({format}) saved: {outfile}") +def _generate_mnemonic_with_shard_constraint(shard: Optional[int] = None) -> Mnemonic: + if shard is not None: + if shard not in CURRENT_SHARDS: + raise BadUserInput(f"Wrong shard provided. Choose between {CURRENT_SHARDS}") + + for _ in range(MAX_ITERATIONS_FOR_GENERATING_WALLET): + mnemonic = Mnemonic.generate() + pubkey = mnemonic.derive_key().generate_public_key() + generated_address_shard = get_shard_of_pubkey(pubkey.buffer, NUMBER_OF_SHARDS) + + if shard == generated_address_shard: + return mnemonic + + raise WalletGenerationError(f"Couldn't generate wallet in shard {shard}") + + return Mnemonic.generate() + + def convert_wallet(args: Any): infile = Path(args.infile).expanduser().resolve() if args.infile else None outfile = Path(args.outfile).expanduser().resolve() if args.outfile else None diff --git a/multiversx_sdk_cli/contracts.py b/multiversx_sdk_cli/contracts.py index 8fa20a38..98885dc4 100644 --- a/multiversx_sdk_cli/contracts.py +++ b/multiversx_sdk_cli/contracts.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Any, List, Optional, Protocol, Sequence, Union +from typing import Any, List, Optional, Protocol, Union from multiversx_sdk import (Address, QueryRunnerAdapter, SmartContractQueriesController, @@ -8,13 +8,11 @@ TokenComputer, TokenTransfer, Transaction, TransactionPayload) from multiversx_sdk.abi import Abi -from multiversx_sdk.network_providers.interface import IContractQuery from multiversx_sdk_cli import errors from multiversx_sdk_cli.accounts import Account from multiversx_sdk_cli.constants import DEFAULT_HRP from multiversx_sdk_cli.interfaces import IAddress -from multiversx_sdk_cli.utils import Object logger = logging.getLogger("contracts") @@ -29,37 +27,6 @@ def query_contract(self, query: Any) -> 'IContractQueryResponse': ... -class QueryResult(Object): - def __init__(self, as_base64: str, as_hex: str, as_number: Optional[int]): - self.base64 = as_base64 - self.hex = as_hex - self.number = as_number - - -class ContractQuery(IContractQuery): - def __init__(self, address: IAddress, function: str, value: int, arguments: List[bytes], caller: Optional[IAddress] = None): - self.contract = address - self.function = function - self.caller = caller - self.value = value - self.encoded_arguments = [item.hex() for item in arguments] - - def get_contract(self) -> IAddress: - return self.contract - - def get_function(self) -> str: - return self.function - - def get_encoded_arguments(self) -> Sequence[str]: - return self.encoded_arguments - - def get_caller(self) -> Optional[IAddress]: - return self.caller - - def get_value(self) -> int: - return self.value - - class IContractQueryResponse(Protocol): return_data: List[str] return_code: str @@ -201,7 +168,7 @@ def query_contract(self, contract_address: IAddress, proxy: INetworkProvider, function: str, - arguments: Union[List[Any], None], + arguments: Optional[List[Any]], should_prepare_args: bool) -> List[Any]: args = arguments if arguments else [] if should_prepare_args: diff --git a/multiversx_sdk_cli/dns.py b/multiversx_sdk_cli/dns.py index 5859a4ef..746bf398 100644 --- a/multiversx_sdk_cli/dns.py +++ b/multiversx_sdk_cli/dns.py @@ -28,16 +28,11 @@ def resolve(name: str, proxy: INetworkProvider) -> Address: name_arg = "0x{}".format(str.encode(name).hex()) dns_address = dns_address_for_name(name) - chain_id = proxy.get_network_config().chain_id - config = TransactionsFactoryConfig(chain_id) - contract = SmartContract(config) - - response = contract.query_contract( + response = _query_contract( contract_address=dns_address, proxy=proxy, function="resolve", - arguments=[name_arg], - args_from_file=False + args=[name_arg] ) if len(response) == 0: @@ -51,17 +46,14 @@ def validate_name(name: str, shard_id: int, proxy: INetworkProvider): name_arg = "0x{}".format(str.encode(name).hex()) dns_address = compute_dns_address_for_shard_id(shard_id) - chain_id = proxy.get_network_config().chain_id - config = TransactionsFactoryConfig(chain_id) - contract = SmartContract(config) - - response = contract.query_contract( + response = _query_contract( contract_address=dns_address, proxy=proxy, function="validateName", - arguments=[name_arg], - args_from_file=False - )[0] + args=[name_arg] + ) + + response = response[0] return_code = response["returnCode"] if return_code == "ok": @@ -105,17 +97,14 @@ def name_hash(name: str) -> bytes: def registration_cost(shard_id: int, proxy: INetworkProvider) -> int: dns_address = compute_dns_address_for_shard_id(shard_id) - chain_id = proxy.get_network_config().chain_id - config = TransactionsFactoryConfig(chain_id) - contract = SmartContract(config) - - response = contract.query_contract( + response = _query_contract( contract_address=dns_address, proxy=proxy, - function="getRegistrationCost", - arguments=[], - args_from_file=False - )[0] + function="versgetRegistrationCostion", + args=[] + ) + + response = response[0] data = response["returnDataParts"][0] if not data: @@ -127,17 +116,14 @@ def registration_cost(shard_id: int, proxy: INetworkProvider) -> int: def version(shard_id: int, proxy: INetworkProvider) -> str: dns_address = compute_dns_address_for_shard_id(shard_id) - chain_id = proxy.get_network_config().chain_id - config = TransactionsFactoryConfig(chain_id) - contract = SmartContract(config) - - response = contract.query_contract( + response = _query_contract( contract_address=dns_address, proxy=proxy, function="version", - arguments=[], - args_from_file=False - )[0] + args=[] + ) + + response = response[0] return bytearray.fromhex(response["returnDataParts"][0]).decode() @@ -161,3 +147,17 @@ def compute_dns_address_for_shard_id(shard_id: int) -> Address: def dns_register_data(name: str) -> str: name_enc: bytes = str.encode(name) return "register@{}".format(name_enc.hex()) + + +def _query_contract(contract_address: Address, proxy: INetworkProvider, function: str, args: List[Any]) -> List[Any]: + chain_id = proxy.get_network_config().chain_id + config = TransactionsFactoryConfig(chain_id) + contract = SmartContract(config) + + return contract.query_contract( + contract_address=contract_address, + proxy=proxy, + function=function, + arguments=args, + should_prepare_args=False + ) diff --git a/multiversx_sdk_cli/utils.py b/multiversx_sdk_cli/utils.py index bc4d5b9c..a61e41fa 100644 --- a/multiversx_sdk_cli/utils.py +++ b/multiversx_sdk_cli/utils.py @@ -42,8 +42,6 @@ def default(self, o: Any) -> Any: return o.to_dictionary() if isinstance(o, bytes): return o.hex() - if isinstance(o, list): - return [self.default(item) for item in o] # type: ignore return super().default(o)