Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #29 from ca11ab1e/style-sourcery
Browse files Browse the repository at this point in the history
style: clean-up Sourcery issues
  • Loading branch information
antazoey authored Jun 18, 2022
2 parents 577a862 + eb5dc27 commit 01c091b
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 54 deletions.
2 changes: 1 addition & 1 deletion ape_starknet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ape_starknet.utils import NETWORKS, PLUGIN_NAME

tokens = TokenManager()
network_names = [LOCAL_NETWORK_NAME] + [k for k in NETWORKS.keys()]
network_names = [LOCAL_NETWORK_NAME] + list(NETWORKS.keys())


@plugins.register(plugins.ConversionPlugin)
Expand Down
37 changes: 17 additions & 20 deletions ape_starknet/accounts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import json
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -53,9 +54,7 @@ def _key_file_paths(self) -> Iterator[Path]:

@property
def aliases(self) -> Iterator[str]:
for key in self.ephemeral_accounts.keys():
yield key

yield from self.ephemeral_accounts.keys()
for key_file in self._key_file_paths:
yield key_file.stem

Expand Down Expand Up @@ -108,10 +107,9 @@ def get_account(self, address: Union[AddressType, int]) -> "BaseStarknetAccount"

def load(self, alias: str) -> "BaseStarknetAccount":
if alias in self.ephemeral_accounts:
account = StarknetEphemeralAccount(
return StarknetEphemeralAccount(
raw_account_data=self.ephemeral_accounts[alias], account_key=alias
)
return account

return self.load_key_file_account(alias)

Expand Down Expand Up @@ -182,13 +180,11 @@ def import_account(
# Add account contract to cache
address = self.starknet.decode_address(contract_address)
if self.network_manager.active_provider and self.provider.network.explorer:
try:
# Skip errors when unable to store contract type.
with contextlib.suppress(ProviderError, BadRequest):
contract_type = self.provider.network.explorer.get_contract_type(address)
if contract_type:
self.chain_manager.contracts[address] = contract_type
except (ProviderError, BadRequest):
# Unable to store contract type.
pass

def deploy_account(
self, alias: str, private_key: Optional[int] = None, token: Optional[str] = None
Expand Down Expand Up @@ -390,11 +386,14 @@ def deploy(self, contract: ContractContainer, *args, **kwargs) -> ContractInstan
return contract.deploy(sender=self)

def get_deployment(self, network_name: str) -> Optional[StarknetAccountDeployment]:
for deployment in self.get_deployments():
if deployment.network_name in network_name:
return deployment

return None
return next(
(
deployment
for deployment in self.get_deployments()
if deployment.network_name in network_name
),
None,
)

def check_signature( # type: ignore
self,
Expand Down Expand Up @@ -493,14 +492,12 @@ def delete(self, network: str, passphrase: Optional[str] = None):
remaining_deployments = [
vars(d) for d in self.get_deployments() if d.network_name != network
]
if not remaining_deployments:
if remaining_deployments:
self.write(passphrase=passphrase, deployments=remaining_deployments)
elif click.confirm(f"Completely delete local key for account '{self.address}'?"):
# Delete entire account JSON if no more deployments.
# The user has to agree to an additional prompt since this may be very destructive.

if click.confirm(f"Completely delete local key for account '{self.address}'?"):
self.key_file_path.unlink()
else:
self.write(passphrase=passphrase, deployments=remaining_deployments)
self.key_file_path.unlink()

def sign_message(
self, msg: SignableMessage, passphrase: Optional[str] = None
Expand Down
6 changes: 2 additions & 4 deletions ape_starknet/accounts/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def create(cli_ctx, alias, network, token):
def _list(cli_ctx):
"""List your Starknet accounts"""

starknet_accounts = cast(
List[StarknetKeyfileAccount], [a for a in _get_container(cli_ctx).accounts]
)
starknet_accounts = cast(List[StarknetKeyfileAccount], list(_get_container(cli_ctx).accounts))

if len(starknet_accounts) == 0:
cli_ctx.logger.warning("No accounts found.")
Expand All @@ -78,7 +76,7 @@ def _list(cli_ctx):
key = f"Contract address ({deployment.network_name})"
output_dict[key] = deployment.contract_address

output_keys = add_padding_to_strings([k for k in output_dict.keys()])
output_keys = add_padding_to_strings(list(output_dict.keys()))
output_dict = {k: output_dict[k.rstrip()] for k in output_keys}

for k, v in output_dict.items():
Expand Down
11 changes: 4 additions & 7 deletions ape_starknet/ecosystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ def encode_calldata(
did_process_array_during_arr_len = True

elif isinstance(call_arg, dict):
encoded_struct = {}
for key, value in call_arg.items():
encoded_struct[key] = self.encode_primitive_value(value)

encoded_struct = {
key: self.encode_primitive_value(value) for key, value in call_arg.items()
}
encoded_args.append(encoded_struct)

else:
Expand Down Expand Up @@ -222,8 +221,7 @@ def create_transaction(self, **kwargs) -> TransactionAPI:
return txn_cls(**kwargs)

def decode_logs(self, abi: EventABI, raw_logs: List[Dict]) -> Iterator[ContractLog]:
index = 0
for log in raw_logs:
for index, log in enumerate(raw_logs):
event_args = dict(zip([a.name for a in abi.inputs], log["data"]))
yield ContractLog( # type: ignore
name=abi.name,
Expand All @@ -233,4 +231,3 @@ def decode_logs(self, abi: EventABI, raw_logs: List[Dict]) -> Iterator[ContractL
block_hash=log["block_hash"],
block_number=log["block_number"],
) # type: ignore
index += 1
6 changes: 2 additions & 4 deletions ape_starknet/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def estimate_gas_cost(self, txn: TransactionAPI) -> int:
if not self.client:
raise ProviderNotConnectedError()

result = self.client.estimate_fee_sync(starknet_object)
return result
return self.client.estimate_fee_sync(starknet_object)

@property
def gas_price(self) -> int:
Expand Down Expand Up @@ -190,8 +189,7 @@ def send_call(self, txn: TransactionAPI) -> bytes:

starknet_obj = txn.as_starknet_object()
return_value = self.client.call_contract_sync(starknet_obj)
decoded_return_value = self.starknet.decode_returndata(txn.method_abi, return_value)
return decoded_return_value # type: ignore
return self.starknet.decode_returndata(txn.method_abi, return_value) # type: ignore

@handle_client_errors
def get_transaction(self, txn_hash: str) -> ReceiptAPI:
Expand Down
6 changes: 2 additions & 4 deletions ape_starknet/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def get_balance(self, account: AddressType, token: str = "eth") -> int:
raise ContractError(f"Contract has no method '{abi_name}'.")

method_abi_obj = MethodABI.parse_obj(method_abi)
balance = ContractCall(method_abi_obj, contract_address)()
return balance
return ContractCall(method_abi_obj, contract_address)()

def transfer(self, sender: int, receiver: int, amount: int, token: str = "eth"):
contract_address = self._get_contract_address(token=token)
Expand Down Expand Up @@ -88,5 +87,4 @@ def _get_method_abi(self, method_name: str, token: str = "eth") -> Optional[Dict
address_int = ContractCall(method_abi, contract_address)()
actual_contract_address = self.starknet.decode_address(address_int)
actual_abi = self.provider.get_abi(actual_contract_address)
selected_abi = _select_method_abi(method_name, actual_abi)
return selected_abi
return _select_method_abi(method_name, actual_abi)
10 changes: 6 additions & 4 deletions ape_starknet/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ def decode_logs(self, abi: Union[EventABI, ContractEvent]) -> Iterator[ContractL

log_data_items: List[Dict] = []
for log in self.logs:
log_data = {**log}
log_data["block_hash"] = self.block_hash
log_data["transaction_hash"] = self.txn_hash
log_data["block_number"] = self.block_number
log_data = {
**log,
"block_hash": self.block_hash,
"transaction_hash": self.txn_hash,
"block_number": self.block_number,
}
log_data_items.append(log_data)

yield from self.starknet.decode_logs(abi, log_data_items)
Expand Down
12 changes: 4 additions & 8 deletions ape_starknet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ def get_chain_id(network_id: Union[str, int]) -> StarknetChainId:
def to_checksum_address(address: RawAddress) -> AddressType:
try:
hex_address = hexstr_if_str(to_hex, address)
except AttributeError:
raise ValueError(
f"Value must be any string, int, or bytes, instead got type {type(address)}"
)
except AttributeError as exc:
msg = f"Value must be any string, int, or bytes, instead got type {type(address)}"
raise ValueError(msg) from exc

cleaned_address = remove_0x_prefix(HexStr(hex_address))
address_hash = encode_hex(keccak(text=cleaned_address))
Expand All @@ -66,10 +65,7 @@ def to_checksum_address(address: RawAddress) -> AddressType:


def is_hex_address(value: Any) -> bool:
if not is_text(value):
return False

return _HEX_ADDRESS_REG_EXP.fullmatch(value) is not None
return _HEX_ADDRESS_REG_EXP.fullmatch(value) is not None if is_text(value) else False


def handle_client_errors(f):
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_unsigned_contract_transaction(contract, account, initial_balance):
def test_decode_logs(contract, account, ecosystem):
increase_amount = 9933
receipt = contract.increase_balance(account.address, increase_amount, sender=account)
logs = [log for log in receipt.decode_logs(contract.balance_increased)]
logs = list(receipt.decode_logs(contract.balance_increased))
assert len(logs) == 1
assert logs[0].amount == increase_amount

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/test_ecosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def test_encode_and_decode_address(value, ecosystem):


def test_decode_logs(ecosystem, event_abi, raw_logs):
actual = [log for log in ecosystem.decode_logs(event_abi, raw_logs)]
actual = list(ecosystem.decode_logs(event_abi, raw_logs))
assert len(actual) == 1
assert actual[0].amount == "4321"

0 comments on commit 01c091b

Please sign in to comment.