Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #33 from unparalleled-js/feat/get-txns-by-block
Browse files Browse the repository at this point in the history
feat: implement new API methods for ape 0.3.0
  • Loading branch information
antazoey authored Jun 20, 2022
2 parents d595ea7 + f27905b commit d127ec0
Show file tree
Hide file tree
Showing 14 changed files with 436 additions and 162 deletions.
2 changes: 1 addition & 1 deletion ape_starknet/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def cli():
"""Starknet ecosystem commands"""


cli.add_command(accounts) # type: ignore
cli.add_command(accounts)
33 changes: 21 additions & 12 deletions ape_starknet/accounts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from ape_starknet.tokens import TokenManager
from ape_starknet.transactions import InvokeFunctionTransaction
from ape_starknet.utils import PLUGIN_NAME, get_chain_id
from ape_starknet.utils import get_chain_id
from ape_starknet.utils.basemodel import StarknetMixin

APP_KEY_FILE_KEY = "ape-starknet"
Expand Down Expand Up @@ -210,7 +210,7 @@ def deploy_account(
private_key = private_key or get_random_private_key()
key_pair = KeyPair.from_private_key(private_key)

contract_address = self.provider._deploy( # type: ignore
contract_address = self.provider._deploy(
COMPILED_ACCOUNT_CONTRACT, key_pair.public_key, token=token
)
self.import_account(alias, network_name, contract_address, key_pair.private_key)
Expand Down Expand Up @@ -263,10 +263,10 @@ def address(self) -> AddressType:
@cached_property
def signer(self) -> StarkCurveSigner:
key_pair = KeyPair.from_private_key(self._get_key())
network = self.provider.network
chain_id = get_chain_id(network.name)
return StarkCurveSigner(
account_address=self.contract_address, key_pair=key_pair, chain_id=chain_id
account_address=self.contract_address,
key_pair=key_pair,
chain_id=get_chain_id(self.provider.chain_id),
)

@cached_property
Expand Down Expand Up @@ -335,7 +335,10 @@ def prepare_transaction(self, txn: TransactionAPI) -> TransactionAPI:
"data_offset": 0,
"data_len": len(stark_tx.calldata),
}
txn.data = [[account_call], stark_tx.calldata, self.nonce]
contract_type = self.chain_manager.contracts[contract_address]
txn.data = self.starknet.encode_calldata(
contract_type.abi, execute_abi, [[account_call], stark_tx.calldata, self.nonce]
)
txn.receiver = contract_address
txn.sender = None
txn.method_abi = execute_abi
Expand Down Expand Up @@ -371,16 +374,22 @@ def transfer(
raise ValueError("value is not an integer.")

if not isinstance(account, str) and hasattr(account, "contract_address"):
account = account.contract_address # type: ignore
receiver = getattr(account, "contract_address")

elif isinstance(account, str):
checksummed_address = self.starknet.decode_address(account)
receiver = self.starknet.encode_address(checksummed_address)

if not isinstance(account, int):
account = self.starknet.encode_address(account) # type: ignore
elif isinstance(account, int):
receiver = account

else:
raise TypeError(f"Unable to handle account type '{type(account)}'.")

if self.contract_address is None:
raise ValueError("Contract address cannot be None")

sender = self.starknet.encode_address(self.contract_address)
return self.token_manager.transfer(sender, account, value, **kwargs) # type: ignore
return self.token_manager.transfer(self.contract_address, receiver, value, **kwargs)

def deploy(self, contract: ContractContainer, *args, **kwargs) -> ContractInstance:
return contract.deploy(sender=self)
Expand All @@ -400,7 +409,7 @@ def check_signature( # type: ignore
data: int,
signature: Optional[ECSignature] = None, # TransactionAPI doesn't need it
) -> bool:
int_address = self.network_manager.get_ecosystem(PLUGIN_NAME).encode_address(self.address)
int_address = self.starknet.encode_address(self.address)
return verify_ecdsa_sig(int_address, data, signature)

def get_deployments(self) -> List[StarknetAccountDeployment]:
Expand Down
6 changes: 3 additions & 3 deletions ape_starknet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class ProviderConfig(PluginConfig):


class StarknetConfig(PluginConfig):
mainnet: NetworkConfig = NetworkConfig(required_confirmations=7, block_time=13) # type: ignore
testnet: NetworkConfig = NetworkConfig(required_confirmations=2, block_time=15) # type: ignore
local: NetworkConfig = NetworkConfig() # type: ignore
mainnet: NetworkConfig = NetworkConfig(required_confirmations=7, block_time=13)
testnet: NetworkConfig = NetworkConfig(required_confirmations=2, block_time=15)
local: NetworkConfig = NetworkConfig()
default_network: str = LOCAL_NETWORK_NAME
providers: ProviderConfig = ProviderConfig()
3 changes: 1 addition & 2 deletions ape_starknet/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from ape.api import ConverterAPI
from ape.types import AddressType
from eth_utils import is_checksum_address

from ape_starknet.utils import is_hex_address, to_checksum_address
from ape_starknet.utils import is_checksum_address, is_hex_address, to_checksum_address


# NOTE: This utility converter ensures that all bytes args can accept hex too
Expand Down
179 changes: 136 additions & 43 deletions ape_starknet/ecosystems.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from typing import Any, Dict, Iterator, List, Tuple, Type, Union

from ape.api import (
BlockAPI,
BlockConsensusAPI,
BlockGasAPI,
EcosystemAPI,
ReceiptAPI,
TransactionAPI,
)
from ape.api import BlockAPI, EcosystemAPI, ReceiptAPI, TransactionAPI
from ape.types import AddressType, ContractLog, RawAddress
from eth_utils import is_0x_prefixed
from ethpm_types.abi import ConstructorABI, EventABI, MethodABI
Expand All @@ -17,6 +10,7 @@
from starknet_py.utils.data_transformer import DataTransformer # type: ignore
from starkware.starknet.definitions.fields import ContractAddressSalt # type: ignore
from starkware.starknet.definitions.transaction_type import TransactionType # type: ignore
from starkware.starknet.public.abi import get_selector_from_name # type: ignore
from starkware.starknet.public.abi_structs import identifier_manager_from_abi # type: ignore
from starkware.starknet.services.api.contract_class import ContractClass # type: ignore

Expand All @@ -37,8 +31,9 @@


class StarknetBlock(BlockAPI):
gas_data: BlockGasAPI = None # type: ignore
consensus_data: BlockConsensusAPI = None # type: ignore
"""
A block in Starknet.
"""


class Starknet(EcosystemAPI):
Expand All @@ -65,7 +60,7 @@ def decode_address(cls, raw_address: RawAddress) -> AddressType:
return to_checksum_address(raw_address)

@classmethod
def encode_address(cls, address: AddressType) -> RawAddress:
def encode_address(cls, address: AddressType) -> int:
return parse_address(address)

def serialize_transaction(self, transaction: TransactionAPI) -> bytes:
Expand Down Expand Up @@ -104,7 +99,7 @@ def encode_calldata(
full_abi = [abi.dict() if hasattr(abi, "dict") else abi for abi in full_abi]
id_manager = identifier_manager_from_abi(full_abi)
transformer = DataTransformer(method_abi.dict(), id_manager)
encoded_args: List[Any] = []
pre_encoded_args: List[Any] = []
index = 0
last_index = len(method_abi.inputs) - 1
did_process_array_during_arr_len = False
Expand All @@ -115,40 +110,64 @@ def encode_calldata(
did_process_array_during_arr_len = False
continue

array_arg = [self.encode_primitive_value(v) for v in call_arg]
encoded_args.append(array_arg)
encoded_arg = self._pre_encode_value(call_arg)
pre_encoded_args.append(encoded_arg)
elif (
input_type.name == "arr_len"
input_type.name in ("arr_len", "call_array_len")
and index < last_index
and str(method_abi.inputs[index + 1].type).endswith("*")
):
array_arg = [self.encode_primitive_value(v) for v in call_args[index + 1]]
encoded_args.append(array_arg)
did_process_array_during_arr_len = True
pre_encoded_arg = self._pre_encode_value(call_arg)

elif isinstance(call_arg, dict):
encoded_struct = {
key: self.encode_primitive_value(value) for key, value in call_arg.items()
}
encoded_args.append(encoded_struct)
if isinstance(pre_encoded_arg, int):
# 'arr_len' was provided.
array_index = index + 1
pre_encoded_array = self._pre_encode_array(call_args[array_index])
pre_encoded_args.append(pre_encoded_array)
did_process_array_during_arr_len = True
else:
pre_encoded_args.append(pre_encoded_arg)

else:
encoded_arg = self.encode_primitive_value(call_arg)
encoded_args.append(encoded_arg)
pre_encoded_args.append(self._pre_encode_value(call_arg))

index += 1

calldata, _ = transformer.from_python(*encoded_args)
return calldata
encoded_calldata, _ = transformer.from_python(*pre_encoded_args)
return encoded_calldata

def _pre_encode_value(self, value: Any) -> Any:
if isinstance(value, dict):
return self._pre_encode_struct(value)
elif isinstance(value, (list, tuple)):
return self._pre_encode_array(value)
else:
return self.encode_primitive_value(value)

def _pre_encode_array(self, array: Any) -> List:
if not isinstance(array, (list, tuple)):
# Will handle single item structs and felts.
return self._pre_encode_array([array])

def encode_primitive_value(self, value: Any) -> Any:
encoded_array = []
for item in array:
encoded_value = self._pre_encode_value(item)
encoded_array.append(encoded_value)

return encoded_array

def _pre_encode_struct(self, struct: Dict) -> Dict:
encoded_struct = {}
for key, value in struct.items():
encoded_struct[key] = self._pre_encode_value(value)

return encoded_struct

def encode_primitive_value(self, value: Any) -> int:
if isinstance(value, int):
return value

elif isinstance(value, (list, tuple)):
return [self.encode_primitive_value(v) for v in value]

if isinstance(value, str) and is_0x_prefixed(value):
elif isinstance(value, str) and is_0x_prefixed(value):
return int(value, 16)

elif isinstance(value, HexBytes):
Expand All @@ -162,23 +181,42 @@ def decode_receipt(self, data: dict) -> ReceiptAPI:
if txn_type == TransactionType.INVOKE_FUNCTION.value:
data["receiver"] = data.pop("contract_address")

max_fee = data.get("max_fee", 0) or 0
if isinstance(max_fee, str):
max_fee = int(max_fee, 16)

receiver = data.get("receiver")
if receiver:
receiver = self.decode_address(receiver)

# 'contract_address' is for deploy-txns and refers to the new contract.
contract_address = data.get("contract_address")
if contract_address:
contract_address = self.decode_address(contract_address)

block_hash = data.get("block_hash")
if block_hash:
block_hash = HexBytes(block_hash).hex()

return StarknetReceipt(
provider=data.get("provider"),
type=data["type"],
transaction_hash=data["transaction_hash"],
transaction_hash=HexBytes(data["transaction_hash"]).hex(),
status=data["status"].value,
block_number=data["block_number"],
block_hash=data["block_hash"],
block_hash=block_hash,
events=data.get("events", []),
contract_address=data.get("contract_address"),
receiver=data.get("receiver", ""),
receiver=receiver,
contract_address=contract_address,
actual_fee=data.get("actual_fee", 0),
max_fee=max_fee,
)

def decode_block(self, data: dict) -> BlockAPI:
return StarknetBlock(
number=data["block_number"],
hash=HexBytes(data["block_hash"]),
parent_hash=HexBytes(data["parent_block_hash"]),
parentHash=HexBytes(data["parent_block_hash"]),
size=len(data["transactions"]), # TODO: Figure out size
timestamp=data["timestamp"],
)
Expand All @@ -190,8 +228,9 @@ def encode_deployment(
if not salt:
salt = ContractAddressSalt.get_random_value()

constructor_args = list(args)
contract = ContractClass.deserialize(deployment_bytecode)
calldata = self.encode_calldata(contract.abi, abi, args)
calldata = self.encode_calldata(contract.abi, abi, constructor_args)
return DeployTransaction(
salt=salt,
constructor_calldata=calldata,
Expand All @@ -202,23 +241,77 @@ def encode_deployment(
def encode_transaction(
self, address: AddressType, abi: MethodABI, *args, **kwargs
) -> TransactionAPI:
# NOTE: This method only works for invoke-transactions
contract_type = self.chain_manager.contracts[address]
encoded_calldata = self.encode_calldata(contract_type.abi, abi, list(args))
return InvokeFunctionTransaction(
contract_address=address,
method_abi=abi,
calldata=args,
calldata=encoded_calldata,
sender=kwargs.get("sender"),
max_fee=kwargs.get("max_fee", 0),
)

def create_transaction(self, **kwargs) -> TransactionAPI:
txn_type = kwargs.pop("type")
txn_type = kwargs.pop("type", kwargs.pop("tx_type", ""))
txn_cls: Union[Type[InvokeFunctionTransaction], Type[DeployTransaction]]
if txn_type == TransactionType.INVOKE_FUNCTION:
invoking = txn_type == TransactionType.INVOKE_FUNCTION
if invoking:
txn_cls = InvokeFunctionTransaction
elif txn_type == TransactionType.DEPLOY:
txn_cls = DeployTransaction

return txn_cls(**kwargs)
txn_data: Dict[str, Any] = {**kwargs, "signature": None}
if "chain_id" not in txn_data and self.network_manager.active_provider:
txn_data["chain_id"] = self.provider.chain_id

# For deploy-txns, 'contract_address' is the address of the newly deployed contract.
if "contract_address" in txn_data:
txn_data["contract_address"] = self.decode_address(txn_data["contract_address"])

if not invoking:
return txn_cls(**txn_data)

""" ~ Invoke transactions ~ """

if "receiver" in txn_data:
# Model expects 'contract_address' key during serialization.
# NOTE: Deploy transactions have a different 'contract_address' and that is handled
# above before getting to the 'Invoke transactions' section.
txn_data["contract_address"] = self.decode_address(txn_data["receiver"])

if (
"max_fee" in txn_data
and not isinstance(txn_data["max_fee"], int)
and txn_data["max_fee"] is not None
):
txn_data["max_fee"] = self.encode_primitive_value(txn_data["max_fee"])

if "method_abi" not in txn_data:
contract_int = txn_data["contract_address"]
contract_str = self.decode_address(contract_int)
contract = self.chain_manager.contracts.get(contract_str)
if not contract:
raise ValueError("Unable to create transaction objects from other networks.")

selector = txn_data["entry_point_selector"]
if isinstance(selector, str):
selector = int(selector, 16)

for abi in contract.mutable_methods:
selector_to_check = get_selector_from_name(abi.name)

if selector == selector_to_check:
txn_data["method_abi"] = abi

if "calldata" in txn_data and txn_data["calldata"] is not None:
# Transactions in blocks show calldata as flattened hex-strs
# but elsewhere we expect flattened ints. Convert to ints for
# consistency and testing purposes.
encoded_calldata = [self.encode_primitive_value(v) for v in txn_data["calldata"]]
txn_data["calldata"] = encoded_calldata

return txn_cls(**txn_data)

def decode_logs(self, abi: EventABI, raw_logs: List[Dict]) -> Iterator[ContractLog]:
for index, log in enumerate(raw_logs):
Expand All @@ -230,4 +323,4 @@ def decode_logs(self, abi: EventABI, raw_logs: List[Dict]) -> Iterator[ContractL
transaction_hash=log["transaction_hash"],
block_hash=log["block_hash"],
block_number=log["block_number"],
) # type: ignore
)
Loading

0 comments on commit d127ec0

Please sign in to comment.