From 8c08b1089544e509c59fc0d84463bf70035f04ac Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Thu, 16 Jun 2022 09:54:06 -0500 Subject: [PATCH 1/6] test: put test --- tests/functional/test_contract.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/functional/test_contract.py b/tests/functional/test_contract.py index f19b33f0..ae595fa7 100644 --- a/tests/functional/test_contract.py +++ b/tests/functional/test_contract.py @@ -91,6 +91,9 @@ def test_external_call_array_outputs(contract, account): receipt = contract.get_array() assert receipt.return_value == [1, 2, 3] + receipt = contract.get_array(sender=account) + assert receipt.return_value == [1, 2, 3] + def test_view_call_array_outputs(contract, account): array = contract.view_array() From 0172f9be93033177ce126c789d347cc9a3cc99ed Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Fri, 17 Jun 2022 13:39:47 -0500 Subject: [PATCH 2/6] fix: correctly use accounts --- ape_starknet/__init__.py | 2 +- ape_starknet/accounts/__init__.py | 188 ++++++++++-------- ape_starknet/accounts/_cli.py | 2 +- ape_starknet/conversion.py | 2 +- ape_starknet/ecosystems.py | 30 ++- ape_starknet/explorer.py | 14 +- ape_starknet/provider.py | 90 ++++----- ape_starknet/tokens.py | 41 +--- ape_starknet/transactions.py | 38 ++-- ape_starknet/{_utils.py => utils/__init__.py} | 3 + ape_starknet/utils/basemodel.py | 22 ++ setup.py | 3 +- tests/conftest.py | 2 +- tests/functional/test_accounts.py | 2 +- tests/functional/test_contract.py | 2 + 15 files changed, 231 insertions(+), 210 deletions(-) rename ape_starknet/{_utils.py => utils/__init__.py} (96%) create mode 100644 ape_starknet/utils/basemodel.py diff --git a/ape_starknet/__init__.py b/ape_starknet/__init__.py index 13500860..3b6aa35b 100644 --- a/ape_starknet/__init__.py +++ b/ape_starknet/__init__.py @@ -2,7 +2,6 @@ from ape.api.networks import LOCAL_NETWORK_NAME, NetworkAPI, create_network_type from ape.types import AddressType -from ape_starknet._utils import NETWORKS, PLUGIN_NAME from ape_starknet.accounts import StarknetAccountContracts, StarknetKeyfileAccount from ape_starknet.config import StarknetConfig from ape_starknet.conversion import StarknetAddressConverter @@ -10,6 +9,7 @@ from ape_starknet.explorer import StarknetExplorer from ape_starknet.provider import StarknetProvider from ape_starknet.tokens import TokenManager +from ape_starknet.utils import NETWORKS, PLUGIN_NAME tokens = TokenManager() network_names = [LOCAL_NETWORK_NAME] + [k for k in NETWORKS.keys()] diff --git a/ape_starknet/accounts/__init__.py b/ape_starknet/accounts/__init__.py index 02cba6c7..19488595 100644 --- a/ape_starknet/accounts/__init__.py +++ b/ape_starknet/accounts/__init__.py @@ -1,5 +1,4 @@ import json -import os from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterator, List, Optional, Union @@ -9,33 +8,29 @@ from ape.api.address import BaseAddress from ape.api.networks import LOCAL_NETWORK_NAME from ape.contracts import ContractContainer, ContractInstance -from ape.exceptions import AccountsError, ProviderError +from ape.exceptions import AccountsError, ProviderError, SignatureError from ape.logging import logger from ape.types import AddressType, SignableMessage -from ape.utils import abstractmethod +from ape.utils import abstractmethod, cached_property from eth_keyfile import create_keyfile_json, decode_keyfile_json # type: ignore from eth_utils import text_if_str, to_bytes +from ethpm_types import ContractType +from ethpm_types.abi import MethodABI from hexbytes import HexBytes from services.external_api.client import BadRequest # type: ignore from starknet_py.net import KeyPair # type: ignore -from starknet_py.net.account.account_client import AccountClient # type: ignore from starknet_py.net.account.compiled_account_contract import ( # type: ignore COMPILED_ACCOUNT_CONTRACT, ) +from starknet_py.net.signer.stark_curve_signer import StarkCurveSigner # type: ignore from starknet_py.utils.crypto.facade import ECSignature, sign_calldata # type: ignore from starkware.cairo.lang.vm.cairo_runner import verify_ecdsa_sig # type: ignore from starkware.crypto.signature.signature import get_random_private_key # type: ignore -from ape_starknet._utils import ( - ALPHA_MAINNET_WL_DEPLOY_TOKEN_KEY, - PLUGIN_NAME, - get_chain_id, - handle_client_errors, -) -from ape_starknet.ecosystems import Starknet -from ape_starknet.provider import StarknetProvider from ape_starknet.tokens import TokenManager -from ape_starknet.transactions import InvokeFunctionTransaction, StarknetTransaction +from ape_starknet.transactions import InvokeFunctionTransaction +from ape_starknet.utils import PLUGIN_NAME, get_chain_id +from ape_starknet.utils.basemodel import StarknetMixin APP_KEY_FILE_KEY = "ape-starknet" """ @@ -45,7 +40,7 @@ APP_KEY_FILE_VERSION = "0.1.0" -class StarknetAccountContracts(AccountContainerAPI): +class StarknetAccountContracts(AccountContainerAPI, StarknetMixin): ephemeral_accounts: Dict[str, Dict] = {} """Local-network accounts that do not persist.""" @@ -108,6 +103,9 @@ def __getitem__(self, item: Union[AddressType, int]) -> AccountAPI: # First, use the account's public key (what Ape is used to). return super().__getitem__(address) + def get_account(self, address: Union[AddressType, int]) -> "BaseStarknetAccount": + return self[address] # type: ignore + def load(self, alias: str) -> "BaseStarknetAccount": if alias in self.ephemeral_accounts: account = StarknetEphemeralAccount( @@ -182,8 +180,7 @@ def import_account( new_account.write(passphrase=None, private_key=private_key, deployments=deployments) # Add account contract to cache - ecosystem = self.network_manager.starknet - address = ecosystem.decode_address(contract_address) + address = self.starknet.decode_address(contract_address) if self.network_manager.active_provider and self.provider.network.explorer: try: contract_type = self.provider.network.explorer.get_contract_type(address) @@ -240,7 +237,7 @@ class StarknetAccountDeployment: contract_address: AddressType -class BaseStarknetAccount(AccountAPI): +class BaseStarknetAccount(AccountAPI, StarknetMixin): token_manager: TokenManager = TokenManager() @abstractmethod @@ -251,41 +248,114 @@ def _get_key(self) -> int: def get_account_data(self) -> Dict: ... - def __repr__(self): - return f"<{self.__class__.__name__} {self.contract_address}>" - @property def contract_address(self) -> Optional[AddressType]: - ecosystem = self.network_manager.ecosystems[PLUGIN_NAME] for deployment in self.get_deployments(): network_name = deployment.network_name - network = ecosystem.networks[network_name] + network = self.starknet.networks[network_name] if network_name == network.name: address = deployment.contract_address - return ecosystem.decode_address(address) + return self.starknet.decode_address(address) return None @property def address(self) -> AddressType: public_key = self.get_account_data()["address"] - return self.network_manager.starknet.decode_address(public_key) + return self.starknet.decode_address(public_key) - @property - def provider(self) -> StarknetProvider: - provider = super().provider - if not isinstance(provider, StarknetProvider): - # Mostly for mypy - raise AccountsError("Must use a Starknet provider.") + @cached_property + def signer(self) -> StarkCurveSigner: + key_pair = KeyPair.from_private_key(self._get_key()) + network = self.provider.network + chain_id = get_chain_id(network.name) + return StarkCurveSigner( + account_address=self.contract_address, key_pair=key_pair, chain_id=chain_id + ) + + @cached_property + def contract_type(self) -> Optional[ContractType]: + if not self.contract_address: + # Contract not deployed to this network yet + return None + + contract_type = self.chain_manager.contracts.get(self.contract_address) + if not contract_type: + raise AccountsError(f"Account '{self.contract_address}' is was expected but not found.") - return provider + return contract_type + + @cached_property + def execute_abi(self) -> Optional[MethodABI]: + contract_address = self.contract_address + contract_type = self.contract_type + if not contract_address or not contract_type: + return None + + execute_abi_ls = [ + abi for abi in contract_type.abi if getattr(abi, "name", "") == "__execute__" + ] + if not execute_abi_ls: + raise AccountsError(f"Account '{contract_address}' does not have __execute__ method.") + + return execute_abi_ls[0] + + def __repr__(self): + return f"<{self.__class__.__name__} {self.contract_address}>" + + def call(self, txn: TransactionAPI, send_everything: bool = False) -> ReceiptAPI: + if send_everything: + raise NotImplementedError("send_everything currently isn't implemented in Starknet.") + + if not isinstance(txn, InvokeFunctionTransaction): + raise AccountsError("Can only call Starknet transactions.") + + txn = self.prepare_transaction(txn) + if not txn.signature: + raise SignatureError("The transaction was not signed.") + + return self.provider.send_transaction(txn) + + def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: + contract_address = self.contract_address + execute_abi = self.execute_abi + if not contract_address or not execute_abi: + raise AccountsError( + f"Account is not deployed to network '{self.provider.network.name}'." + ) + + if not isinstance(txn, InvokeFunctionTransaction): + raise AccountsError("Can only prepare invoke transactions.") + + txn: InvokeFunctionTransaction = super().prepare_transaction(txn) # type: ignore + stark_tx = txn.as_starknet_object() + account_call = { + "to": stark_tx.contract_address, + "selector": stark_tx.entry_point_selector, + "data_offset": 0, + "data_len": len(stark_tx.calldata), + } + txn.data = [[account_call], stark_tx.calldata, self.nonce] + # txn.data = self.starknet.encode_calldata( + # self.contract_type.abi, execute_abi, [[account_call], stark_tx.calldata, self.nonce] + # ) + txn.receiver = contract_address + txn.sender = None + txn.method_abi = execute_abi + sign_result = self.sign_transaction(txn) + if not sign_result: + raise SignatureError("Failed to sign transaction.") + + r, s = sign_result + txn.signature = (0, r, s) + return txn def sign_transaction(self, txn: TransactionAPI) -> Optional[ECSignature]: if not isinstance(txn, InvokeFunctionTransaction): raise AccountsError("This account can only sign Starknet transactions.") - starknet_object = txn.as_starknet_object() - return self.sign_message(starknet_object.calldata) + # NOTE: 'v' is not used + return self.signer.sign_transaction(txn.as_starknet_object()) def transfer( self, @@ -306,65 +376,17 @@ def transfer( account = account.contract_address # type: ignore if not isinstance(account, int): - account = self.provider.network.ecosystem.encode_address(account) # type: ignore + account = self.starknet.encode_address(account) # type: ignore if self.contract_address is None: raise ValueError("Contract address cannot be None") - sender = self.provider.network.ecosystem.encode_address(self.contract_address) + sender = self.starknet.encode_address(self.contract_address) return self.token_manager.transfer(sender, account, value, **kwargs) # type: ignore def deploy(self, contract: ContractContainer, *args, **kwargs) -> ContractInstance: return contract.deploy(sender=self) - @handle_client_errors - def send_transaction(self, txn: TransactionAPI, token: Optional[str] = None) -> ReceiptAPI: - if not token and hasattr(txn, "token") and txn.token: # type: ignore - token = txn.token # type: ignore - else: - token = os.environ.get(ALPHA_MAINNET_WL_DEPLOY_TOKEN_KEY) - - if not isinstance(txn, StarknetTransaction): - # Mostly for mypy - raise AccountsError("Can only send Starknet transactions.") - - account_client = self.create_account_client() - starknet_txn = txn.as_starknet_object() - txn_info = account_client.add_transaction_sync(starknet_txn, token=token) - - error = txn_info.get("error", {}) - if error: - message = error.get("message", error) - raise AccountsError(message) - - txn_hash = txn_info["transaction_hash"] - - starknet: Starknet = self.provider.network.ecosystem # type: ignore - return_value = [starknet.encode_primitive_value(v) for v in txn_info.get("result", [])] - - if return_value and isinstance(txn, InvokeFunctionTransaction): - return_value = starknet.decode_returndata(txn.method_abi, return_value) - if isinstance(return_value, (list, tuple)) and len(return_value) == 1: - return_value = return_value[0] - - receipt = self.provider.get_transaction(txn_hash) - receipt.return_value = return_value - return receipt - - def create_account_client(self) -> AccountClient: - network = self.provider.network - key_pair = KeyPair( - public_key=network.ecosystem.encode_address(self.address), - private_key=self._get_key(), - ) - chain_id = get_chain_id(network.name) - return AccountClient( - self.contract_address, - self.provider.uri, - key_pair=key_pair, - chain=chain_id, - ) - def get_deployment(self, network_name: str) -> Optional[StarknetAccountDeployment]: for deployment in self.get_deployments(): if deployment.network_name in network_name: diff --git a/ape_starknet/accounts/_cli.py b/ape_starknet/accounts/_cli.py index 81924498..32ae998c 100644 --- a/ape_starknet/accounts/_cli.py +++ b/ape_starknet/accounts/_cli.py @@ -11,12 +11,12 @@ from ape.cli.options import ApeCliContextObject from ape.utils import add_padding_to_strings -from ape_starknet._utils import PLUGIN_NAME from ape_starknet.accounts import ( BaseStarknetAccount, StarknetAccountContracts, StarknetKeyfileAccount, ) +from ape_starknet.utils import PLUGIN_NAME def _get_container(cli_ctx: ApeCliContextObject) -> StarknetAccountContracts: diff --git a/ape_starknet/conversion.py b/ape_starknet/conversion.py index 99e91bc4..54a5c4c6 100644 --- a/ape_starknet/conversion.py +++ b/ape_starknet/conversion.py @@ -4,7 +4,7 @@ from ape.types import AddressType from eth_utils import is_checksum_address -from ape_starknet._utils import is_hex_address, to_checksum_address +from ape_starknet.utils import is_hex_address, to_checksum_address # NOTE: This utility converter ensures that all bytes args can accept hex too diff --git a/ape_starknet/ecosystems.py b/ape_starknet/ecosystems.py index daf93ab0..1ec555d9 100644 --- a/ape_starknet/ecosystems.py +++ b/ape_starknet/ecosystems.py @@ -20,7 +20,6 @@ from starkware.starknet.public.abi_structs import identifier_manager_from_abi # type: ignore from starkware.starknet.services.api.contract_class import ContractClass # type: ignore -from ape_starknet._utils import to_checksum_address from ape_starknet.exceptions import StarknetEcosystemError from ape_starknet.transactions import ( DeployTransaction, @@ -28,6 +27,7 @@ StarknetReceipt, StarknetTransaction, ) +from ape_starknet.utils import to_checksum_address NETWORKS = { # chain_id, network_id @@ -101,24 +101,37 @@ def encode_calldata( method_abi: Union[ConstructorABI, MethodABI], call_args: Union[List, Tuple], ) -> List: + full_abi = [abi.dict() if hasattr(abi, "dict") else abi for abi in full_abi] id_manager = identifier_manager_from_abi(full_abi) transformer = DataTransformer(method_abi.dict(), id_manager) - encoded_args = [] + encoded_args: List[Any] = [] index = 0 last_index = len(method_abi.inputs) - 1 + did_process_array_during_arr_len = False + for call_arg, input_type in zip(call_args, method_abi.inputs): if str(input_type.type).endswith("*"): - # (arrays) Was processed the iteration before. - continue + if did_process_array_during_arr_len: + did_process_array_during_arr_len = False + continue + array_arg = [self.encode_primitive_value(v) for v in call_arg] + encoded_args.append(array_arg) elif ( input_type.name == "arr_len" and index < last_index and str(method_abi.inputs[index + 1].type).endswith("*") ): - # Handle arrays. array_arg = [self.encode_primitive_value(v) for v in call_args[index + 1]] encoded_args.append(array_arg) + did_process_array_during_arr_len = True + + elif isinstance(call_arg, dict): + encoded_struct = {} + for key, value in call_arg.items(): + encoded_struct[key] = self.encode_primitive_value(value) + + encoded_args.append(encoded_struct) else: encoded_arg = self.encode_primitive_value(call_arg) @@ -130,10 +143,15 @@ def encode_calldata( return calldata def encode_primitive_value(self, value: Any) -> Any: - if isinstance(value, (list, tuple)): + if isinstance(value, int): + return value + + elif isinstance(value, (list, tuple)): return [self.encode_primitive_value(v) for v in value] + if isinstance(value, str) and is_0x_prefixed(value): return int(value, 16) + elif isinstance(value, HexBytes): return int(value.hex(), 16) diff --git a/ape_starknet/explorer.py b/ape_starknet/explorer.py index c56cd5a1..8b36c25b 100644 --- a/ape_starknet/explorer.py +++ b/ape_starknet/explorer.py @@ -1,28 +1,18 @@ from typing import Iterator, Optional from ape.api import ExplorerAPI, ReceiptAPI -from ape.exceptions import ProviderError from ape.types import AddressType from ethpm_types import ContractType -from ape_starknet.provider import StarknetProvider +from ape_starknet.utils.basemodel import StarknetMixin -class StarknetExplorer(ExplorerAPI): +class StarknetExplorer(ExplorerAPI, StarknetMixin): BASE_URIS = { "testnet": "https://goerli.voyager.online", "mainnet": "https://voyager.online", } - @property - def provider(self) -> StarknetProvider: - provider = super().provider - if not isinstance(provider, StarknetProvider): - # Mostly for mypy - raise ProviderError("Must use a Starknet provider.") - - return provider - @property def base_uri(self) -> str: network_name = self.provider.network.name diff --git a/ape_starknet/provider.py b/ape_starknet/provider.py index 991118d0..c42d95c8 100644 --- a/ape_starknet/provider.py +++ b/ape_starknet/provider.py @@ -21,22 +21,22 @@ InvokeSpecificInfo, ) -from ape_starknet._utils import ( +from ape_starknet.config import StarknetConfig +from ape_starknet.tokens import TokenManager +from ape_starknet.transactions import InvokeFunctionTransaction, StarknetTransaction +from ape_starknet.utils import ( ALPHA_MAINNET_WL_DEPLOY_TOKEN_KEY, PLUGIN_NAME, get_chain_id, get_virtual_machine_error, handle_client_errors, ) -from ape_starknet.config import StarknetConfig -from ape_starknet.ecosystems import Starknet -from ape_starknet.tokens import TokenManager -from ape_starknet.transactions import InvokeFunctionTransaction, StarknetTransaction +from ape_starknet.utils.basemodel import StarknetMixin DEFAULT_PORT = 8545 -class StarknetProvider(SubprocessProvider, ProviderAPI): +class StarknetProvider(SubprocessProvider, ProviderAPI, StarknetMixin): """ A Starknet provider. """ @@ -105,13 +105,13 @@ def chain_id(self) -> int: return get_chain_id(self.network.name).value @handle_client_errors - def get_balance(self, address: str) -> int: + def get_balance(self, address: AddressType) -> int: network = self.network.name if network == LOCAL_NETWORK_NAME: # Fees / balances are currently not supported in local return 0 - account = self.account_manager.containers["starknet"][address] # type: ignore + account = self.account_contracts[address] account_contract_address = account.contract_address # type: ignore return self.token_manager.get_balance(account_contract_address) @@ -124,13 +124,14 @@ def get_abi(self, address: str) -> List[Dict]: return self.get_code_and_abi(address)["abi"] # type: ignore @handle_client_errors - def get_nonce(self, address: str) -> int: + def get_nonce(self, address: AddressType) -> int: # Check if passing a public-key address of a local account - container = self.account_manager.containers["starknet"] - if address in container.public_key_addresses: # type: ignore - address = container[address].contract_address # type: ignore + if address in self.account_contracts.public_key_addresses: + contract_address = self.account_contracts.get_account(address).contract_address + if contract_address: + address = contract_address - checksum_address = self.network.ecosystem.decode_address(address) + checksum_address = self.starknet.decode_address(address) contract = self.chain_manager.contracts.instance_at(checksum_address) if not isinstance(contract, ContractInstance): @@ -174,7 +175,7 @@ def get_block(self, block_id: BlockID) -> BlockAPI: raise ValueError(f"Unsupported BlockID type '{type(block_id)}'.") block = self.starknet_client.get_block_sync(**{kwarg: block_id}) - return self.network.ecosystem.decode_block(block.dump()) + return self.starknet.decode_block(block.dump()) @handle_client_errors def send_call(self, txn: TransactionAPI) -> bytes: @@ -189,10 +190,8 @@ def send_call(self, txn: TransactionAPI) -> bytes: starknet_obj = txn.as_starknet_object() return_value = self.client.call_contract_sync(starknet_obj) - decoded_return_value = self.provider.network.ecosystem.decode_returndata( - txn.method_abi, return_value - ) - return decoded_return_value + decoded_return_value = self.starknet.decode_returndata(txn.method_abi, return_value) + return decoded_return_value # type: ignore @handle_client_errors def get_transaction(self, txn_hash: str) -> ReceiptAPI: @@ -208,11 +207,10 @@ def get_transaction(self, txn_hash: str) -> ReceiptAPI: else: raise ValueError(f"No value found for '{txn_info}'.") - ecosystem = self.provider.network.ecosystem - receipt_dict["contract_address"] = ecosystem.decode_address(txn_info.contract_address) + receipt_dict["contract_address"] = self.starknet.decode_address(txn_info.contract_address) receipt_dict["type"] = txn_type receipt_dict["events"] = [vars(e) for e in receipt_dict["events"]] - return self.network.ecosystem.decode_receipt(receipt_dict) + return self.starknet.decode_receipt(receipt_dict) @handle_client_errors def send_transaction(self, txn: TransactionAPI, token: Optional[str] = None) -> ReceiptAPI: @@ -227,32 +225,28 @@ def send_transaction(self, txn: TransactionAPI, token: Optional[str] = None) -> "Unable to send non-Starknet transaction using a Starknet provider." ) - if txn.sender: - # If using a sender, send the transaction from your sender's account contract. - container = self.account_manager.containers["starknet"] - result = container[txn.sender].send_transaction(txn, token=token) # type: ignore - return result - else: - starknet_txn = txn.as_starknet_object() - txn_info = self.starknet_client.add_transaction_sync(starknet_txn, token=token) - - error = txn_info.get("error", {}) - if error: - message = error.get("message", error) - raise ProviderError(message) - - starknet: Starknet = self.provider.network.ecosystem # type: ignore - return_value = [starknet.encode_primitive_value(v) for v in txn_info.get("result", [])] - - if return_value and isinstance(txn, InvokeFunctionTransaction): - return_value = starknet.decode_returndata(txn.method_abi, return_value) - if isinstance(return_value, (list, tuple)) and len(return_value) == 1: - return_value = return_value[0] - - txn_hash = txn_info["transaction_hash"] - receipt = self.get_transaction(txn_hash) - receipt.return_value = return_value - return receipt + starknet_txn = txn.as_starknet_object() + txn_info = self.starknet_client.add_transaction_sync(starknet_txn, token=token) + + error = txn_info.get("error", {}) + if error: + message = error.get("message", error) + raise ProviderError(message) + + # Return felts as ints and let calling context decide if hexstr is more appropriate. + return_value = [ + self.starknet.encode_primitive_value(v) if isinstance(v, str) else v + for v in txn_info.get("result", []) + ] + if return_value and isinstance(txn, InvokeFunctionTransaction): + return_value = self.starknet.decode_returndata(txn.method_abi, return_value) + if isinstance(return_value, (list, tuple)) and len(return_value) == 1: + return_value = return_value[0] + + txn_hash = txn_info["transaction_hash"] + receipt = self.get_transaction(txn_hash) + receipt.return_value = return_value + return receipt @handle_client_errors def get_contract_logs( @@ -295,7 +289,7 @@ def _deploy(self, contract_data: Union[str, Dict], *args, token: Optional[str] = {}, ) ctor_abi = ConstructorABI(**data) - transaction = self.network.ecosystem.encode_deployment( + transaction = self.starknet.encode_deployment( HexBytes(contract.serialize()), ctor_abi, *args ) wl_token = token or os.environ.get(ALPHA_MAINNET_WL_DEPLOY_TOKEN_KEY) diff --git a/ape_starknet/tokens.py b/ape_starknet/tokens.py index b50b452f..a082d4e2 100644 --- a/ape_starknet/tokens.py +++ b/ape_starknet/tokens.py @@ -1,14 +1,12 @@ -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import Dict, List, Optional from ape.contracts import ContractInstance from ape.contracts.base import ContractCall -from ape.exceptions import ContractError, ProviderError +from ape.exceptions import ContractError from ape.types import AddressType -from ape.utils import ManagerAccessMixin from ethpm_types.abi import MethodABI -if TYPE_CHECKING: - from ape_starknet.provider import StarknetProvider +from ape_starknet.utils.basemodel import StarknetMixin def missing_contract_error(token: str, contract_address: AddressType) -> ContractError: @@ -26,7 +24,7 @@ def _select_method_abi(name: str, abi: List[Dict]) -> Optional[Dict]: return None -class TokenManager(ManagerAccessMixin): +class TokenManager(StarknetMixin): TOKEN_ADDRESS_MAP = { "eth": { "testnet": "0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7", @@ -37,16 +35,6 @@ class TokenManager(ManagerAccessMixin): }, } - @property - def provider(self) -> "StarknetProvider": - from ape_starknet.provider import StarknetProvider - - provider = super().provider - if not isinstance(provider, StarknetProvider): - raise ProviderError("Must be using a Starknet provider.") - - return provider - def get_balance(self, account: AddressType, token: str = "eth") -> int: contract_address = self._get_contract_address(token=token) if not contract_address: @@ -75,27 +63,12 @@ def transfer(self, sender: int, receiver: int, amount: int, token: str = "eth"): return contract = self.chain_manager.contracts.instance_at(contract_address) - if not isinstance(contract, ContractInstance): raise missing_contract_error(token, contract_address) - sender_address = self.provider.network.ecosystem.decode_address(sender) - + sender_account = self.account_contracts[sender] if "transfer" in [m.name for m in contract.contract_type.mutable_methods]: - return contract.transfer(receiver, amount, sender=sender_address) - - # Handle proxy-implementation (not yet supported in ape-core) - abi_name = "transfer" - method_abi = self._get_method_abi(abi_name, token=token) - if not method_abi: - raise ContractError(f"Contract has no method named '{abi_name}'.") - - method_abi_obj = MethodABI.parse_obj(method_abi) - transaction = self.provider.network.ecosystem.encode_transaction( - contract_address, method_abi_obj, receiver, amount - ) - account = self.account_manager.containers["starknet"][sender_address] # type: ignore - return account.send_transaction(transaction) # type: ignore + return contract.transfer(receiver, amount, sender=sender_account) def _get_contract_address(self, token: str = "eth") -> Optional[AddressType]: network = self.provider.network.name @@ -113,7 +86,7 @@ def _get_method_abi(self, method_name: str, token: str = "eth") -> Optional[Dict method_abi = MethodABI.parse_obj(implementation_abi) address_int = ContractCall(method_abi, contract_address)() - actual_contract_address = self.provider.network.ecosystem.decode_address(address_int) + actual_contract_address = self.starknet.decode_address(address_int) actual_abi = self.provider.get_abi(actual_contract_address) selected_abi = _select_method_abi(method_name, actual_abi) return selected_abi diff --git a/ape_starknet/transactions.py b/ape_starknet/transactions.py index b820c96f..ff33d929 100644 --- a/ape_starknet/transactions.py +++ b/ape_starknet/transactions.py @@ -2,7 +2,7 @@ from ape.api import ReceiptAPI, TransactionAPI from ape.contracts import ContractEvent -from ape.exceptions import ProviderError +from ape.exceptions import ProviderError, TransactionError from ape.types import AddressType, ContractLog from ape.utils import abstractmethod from ethpm_types.abi import EventABI, MethodABI @@ -18,6 +18,8 @@ from starkware.starknet.public.abi import get_selector_from_name # type: ignore from starkware.starknet.services.api.contract_class import ContractClass # type: ignore +from ape_starknet.utils.basemodel import StarknetMixin + class StarknetTransaction(TransactionAPI): """ @@ -70,7 +72,7 @@ def as_starknet_object(self) -> Deploy: ) -class InvokeFunctionTransaction(StarknetTransaction): +class InvokeFunctionTransaction(StarknetTransaction, StarknetMixin): type: TransactionType = TransactionType.INVOKE_FUNCTION method_abi: MethodABI max_fee: int = 0 @@ -81,29 +83,23 @@ class InvokeFunctionTransaction(StarknetTransaction): receiver: AddressType = Field(alias="contract_address") def as_starknet_object(self) -> InvokeFunction: - from ape_starknet.ecosystems import Starknet - from ape_starknet.provider import StarknetProvider - - ecosystem = self.provider.network.ecosystem - if ( - not isinstance(self.provider, StarknetProvider) - or not isinstance(ecosystem, Starknet) - or not self.provider.client - ): + if not self.provider.client: # **NOTE**: This check is mostly done for mypy. raise ProviderError("Must be connected to a Starknet provider.") - method_abi = self.method_abi - contract_address = ecosystem.encode_address(self.receiver) - contract_abi = self.provider.get_abi(contract_address) + contract_address_int = self.starknet.encode_address(self.receiver) + contract_type = self.chain_manager.contracts.get(self.receiver) + if not contract_type: + raise TransactionError(message=f"Unknown contract '{self.receiver}'.") - call_data = ecosystem.encode_calldata(contract_abi, method_abi, self.data) - selector = get_selector_from_name(method_abi.name) + contract_abi = [a.dict() for a in contract_type.abi] + selector = get_selector_from_name(self.method_abi.name) + encoded_call_data = self.starknet.encode_calldata(contract_abi, self.method_abi, self.data) return InvokeFunction( - contract_address=contract_address, + contract_address=contract_address_int, entry_point_selector=selector, - calldata=call_data, - signature=[], # NOTE: Signatures are not supported on signing transactions + calldata=encoded_call_data, + signature=[self.signature[1], self.signature[2]] if self.signature else [], max_fee=self.max_fee, version=self.version, ) @@ -131,7 +127,7 @@ def convert(item: Any) -> int: return call_data -class StarknetReceipt(ReceiptAPI): +class StarknetReceipt(ReceiptAPI, StarknetMixin): """ An object represented a confirmed transaction in Starknet. """ @@ -180,7 +176,7 @@ def decode_logs(self, abi: Union[EventABI, ContractEvent]) -> Iterator[ContractL log_data["block_number"] = self.block_number log_data_items.append(log_data) - yield from self.provider.network.ecosystem.decode_logs(abi, log_data_items) + yield from self.starknet.decode_logs(abi, log_data_items) __all__ = [ diff --git a/ape_starknet/_utils.py b/ape_starknet/utils/__init__.py similarity index 96% rename from ape_starknet/_utils.py rename to ape_starknet/utils/__init__.py index 16dfbfb0..a6a01079 100644 --- a/ape_starknet/_utils.py +++ b/ape_starknet/utils/__init__.py @@ -116,4 +116,7 @@ def get_virtual_machine_error(err: Exception) -> Optional[VirtualMachineError]: # Fix escaping newline issue with error message. err_msg = err_msg.replace("\\n", "").strip() + err_msg = err_msg.replace( + "Transaction was rejected with following starknet error: ", "" + ).strip() return ContractLogicError(revert_message=err_msg) diff --git a/ape_starknet/utils/basemodel.py b/ape_starknet/utils/basemodel.py new file mode 100644 index 00000000..1b4ba9d2 --- /dev/null +++ b/ape_starknet/utils/basemodel.py @@ -0,0 +1,22 @@ +from typing import TYPE_CHECKING + +from ape.utils import ManagerAccessMixin + +if TYPE_CHECKING: + from ape_starknet.accounts import StarknetAccountContracts + from ape_starknet.ecosystems import Starknet + from ape_starknet.provider import StarknetProvider + + +class StarknetMixin(ManagerAccessMixin): + @property + def starknet(self) -> "Starknet": + return self.network_manager.starknet # type: ignore + + @property + def provider(self) -> "StarknetProvider": + return super().provider # type: ignore + + @property + def account_contracts(self) -> "StarknetAccountContracts": + return self.account_manager.containers["starknet"] # type: ignore diff --git a/setup.py b/setup.py index df20ab90..e10e7043 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ "lint": [ "black>=22.3.0,<23.0", # auto-formatter and linter "mypy>=0.961,<1.0", # Static type analyzer + "types-requests", # NOTE: Needed due to mypy typeshed "flake8>=4.0.1,<5.0", # Style linter "isort>=5.10.1,<6.0", # Import sorting linter "types-pkg-resources>=0.1.3,<0.2", @@ -60,7 +61,7 @@ "click>=8.1.0,<8.2", "hexbytes>=0.2.2,<0.3", "pydantic>=1.9.0,<2.0", - "eth-ape>=0.2.8,<0.3.0", + "eth-ape==0.2.8", "ethpm-types", # Use same as `eth-ape`. "starknet.py>=0.3.2a0,<0.4", "starknet-devnet>=0.2.3,<0.3", diff --git a/tests/conftest.py b/tests/conftest.py index bc3db132..b3ade529 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,13 +9,13 @@ from ape.api import EcosystemAPI from ape.api.networks import LOCAL_NETWORK_NAME -from ape_starknet._utils import PLUGIN_NAME from ape_starknet.accounts import ( StarknetAccountContracts, StarknetEphemeralAccount, StarknetKeyfileAccount, ) from ape_starknet.provider import StarknetProvider +from ape_starknet.utils import PLUGIN_NAME # NOTE: Ensure that we don't use local paths for these ape.config.DATA_FOLDER = Path(mkdtemp()).resolve() diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index dedeac37..ca6dc747 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -2,7 +2,7 @@ from eth_utils import remove_0x_prefix from starkware.cairo.lang.vm.cairo_runner import pedersen_hash # type: ignore -from ape_starknet._utils import is_hex_address +from ape_starknet.utils import is_hex_address def test_address(existing_key_file_account, public_key): diff --git a/tests/functional/test_contract.py b/tests/functional/test_contract.py index ae595fa7..43122f6a 100644 --- a/tests/functional/test_contract.py +++ b/tests/functional/test_contract.py @@ -91,6 +91,8 @@ def test_external_call_array_outputs(contract, account): receipt = contract.get_array() assert receipt.return_value == [1, 2, 3] + +def test_external_call_array_outputs_from_account(contract, account): receipt = contract.get_array(sender=account) assert receipt.return_value == [1, 2, 3] From de56dc5bc04c17e64644585266778f72cf66f0e5 Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Fri, 17 Jun 2022 13:40:35 -0500 Subject: [PATCH 3/6] fix: rm accidental code --- ape_starknet/accounts/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/ape_starknet/accounts/__init__.py b/ape_starknet/accounts/__init__.py index 19488595..f17e540c 100644 --- a/ape_starknet/accounts/__init__.py +++ b/ape_starknet/accounts/__init__.py @@ -336,9 +336,6 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: "data_len": len(stark_tx.calldata), } txn.data = [[account_call], stark_tx.calldata, self.nonce] - # txn.data = self.starknet.encode_calldata( - # self.contract_type.abi, execute_abi, [[account_call], stark_tx.calldata, self.nonce] - # ) txn.receiver = contract_address txn.sender = None txn.method_abi = execute_abi From 04abd787fe5d13ed3c69ee9d5ee63566278cc5a6 Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Fri, 17 Jun 2022 14:12:30 -0500 Subject: [PATCH 4/6] fix: issue with sig again --- ape_starknet/accounts/__init__.py | 19 ++++++++++--------- ape_starknet/transactions.py | 6 +++++- ape_starknet/utils/__init__.py | 14 +++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/ape_starknet/accounts/__init__.py b/ape_starknet/accounts/__init__.py index f17e540c..572c9efc 100644 --- a/ape_starknet/accounts/__init__.py +++ b/ape_starknet/accounts/__init__.py @@ -10,7 +10,7 @@ from ape.contracts import ContractContainer, ContractInstance from ape.exceptions import AccountsError, ProviderError, SignatureError from ape.logging import logger -from ape.types import AddressType, SignableMessage +from ape.types import AddressType, SignableMessage, TransactionSignature from ape.utils import abstractmethod, cached_property from eth_keyfile import create_keyfile_json, decode_keyfile_json # type: ignore from eth_utils import text_if_str, to_bytes @@ -339,20 +339,21 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI: txn.receiver = contract_address txn.sender = None txn.method_abi = execute_abi - sign_result = self.sign_transaction(txn) - if not sign_result: - raise SignatureError("Failed to sign transaction.") - - r, s = sign_result - txn.signature = (0, r, s) + txn.signature = self.sign_transaction(txn) return txn - def sign_transaction(self, txn: TransactionAPI) -> Optional[ECSignature]: + def sign_transaction(self, txn: TransactionAPI) -> TransactionSignature: if not isinstance(txn, InvokeFunctionTransaction): raise AccountsError("This account can only sign Starknet transactions.") # NOTE: 'v' is not used - return self.signer.sign_transaction(txn.as_starknet_object()) + sign_result = self.signer.sign_transaction(txn.as_starknet_object()) + if not sign_result: + raise SignatureError("Failed to sign transaction.") + + r = to_bytes(sign_result[0]) + s = to_bytes(sign_result[1]) + return TransactionSignature(v=0, r=r, s=s) # type: ignore def transfer( self, diff --git a/ape_starknet/transactions.py b/ape_starknet/transactions.py index ff33d929..9b92cab2 100644 --- a/ape_starknet/transactions.py +++ b/ape_starknet/transactions.py @@ -5,6 +5,7 @@ from ape.exceptions import ProviderError, TransactionError from ape.types import AddressType, ContractLog from ape.utils import abstractmethod +from eth_utils import to_int from ethpm_types.abi import EventABI, MethodABI from hexbytes import HexBytes from pydantic import Field @@ -95,11 +96,14 @@ def as_starknet_object(self) -> InvokeFunction: contract_abi = [a.dict() for a in contract_type.abi] selector = get_selector_from_name(self.method_abi.name) encoded_call_data = self.starknet.encode_calldata(contract_abi, self.method_abi, self.data) + return InvokeFunction( contract_address=contract_address_int, entry_point_selector=selector, calldata=encoded_call_data, - signature=[self.signature[1], self.signature[2]] if self.signature else [], + signature=[to_int(self.signature.r), to_int(self.signature.s)] + if self.signature + else [], max_fee=self.max_fee, version=self.version, ) diff --git a/ape_starknet/utils/__init__.py b/ape_starknet/utils/__init__.py index a6a01079..667125cd 100644 --- a/ape_starknet/utils/__init__.py +++ b/ape_starknet/utils/__init__.py @@ -2,13 +2,7 @@ from typing import Any, Optional, Union from ape.api.networks import LOCAL_NETWORK_NAME -from ape.exceptions import ( - AddressError, - ApeException, - ContractLogicError, - ProviderError, - VirtualMachineError, -) +from ape.exceptions import ApeException, ContractLogicError, ProviderError, VirtualMachineError from ape.types import AddressType, RawAddress from eth_typing import HexAddress, HexStr from eth_utils import ( @@ -49,9 +43,11 @@ def get_chain_id(network_id: Union[str, int]) -> StarknetChainId: def to_checksum_address(address: RawAddress) -> AddressType: try: - hex_address = hexstr_if_str(to_hex, address).lower() + hex_address = hexstr_if_str(to_hex, address) except AttributeError: - raise AddressError(f"Value must be any string, instead got type {type(address)}") + raise ValueError( + f"Value must be any string, int, or bytes, instead got type {type(address)}" + ) cleaned_address = remove_0x_prefix(HexStr(hex_address)) address_hash = encode_hex(keccak(text=cleaned_address)) From 80102827d20b80cf94bc8eb458f296de329c129f Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Fri, 17 Jun 2022 14:18:54 -0500 Subject: [PATCH 5/6] chore: address my pr feedback --- ape_starknet/accounts/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ape_starknet/accounts/__init__.py b/ape_starknet/accounts/__init__.py index 572c9efc..2ab3ea1f 100644 --- a/ape_starknet/accounts/__init__.py +++ b/ape_starknet/accounts/__init__.py @@ -281,7 +281,7 @@ def contract_type(self) -> Optional[ContractType]: contract_type = self.chain_manager.contracts.get(self.contract_address) if not contract_type: - raise AccountsError(f"Account '{self.contract_address}' is was expected but not found.") + raise AccountsError(f"Account '{self.contract_address}' was expected but not found.") return contract_type diff --git a/setup.py b/setup.py index e10e7043..ca8fd327 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "click>=8.1.0,<8.2", "hexbytes>=0.2.2,<0.3", "pydantic>=1.9.0,<2.0", - "eth-ape==0.2.8", + "eth-ape>=0.2.8,<0.3.0", "ethpm-types", # Use same as `eth-ape`. "starknet.py>=0.3.2a0,<0.4", "starknet-devnet>=0.2.3,<0.3", From beb1e64418e242a9f411ed266e0b433ceb686b0e Mon Sep 17 00:00:00 2001 From: unparalleled-js Date: Fri, 17 Jun 2022 14:20:28 -0500 Subject: [PATCH 6/6] fix: remote mypy attempt to fix --- ape_starknet/accounts/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ape_starknet/accounts/__init__.py b/ape_starknet/accounts/__init__.py index 2ab3ea1f..9ce6b6a8 100644 --- a/ape_starknet/accounts/__init__.py +++ b/ape_starknet/accounts/__init__.py @@ -298,7 +298,11 @@ def execute_abi(self) -> Optional[MethodABI]: if not execute_abi_ls: raise AccountsError(f"Account '{contract_address}' does not have __execute__ method.") - return execute_abi_ls[0] + abi = execute_abi_ls[0] + if not isinstance(abi, MethodABI): + raise AccountsError("ABI for '__execute__' is not a method.") + + return abi def __repr__(self): return f"<{self.__class__.__name__} {self.contract_address}>"