From 95b0eb24d1585c241ca513dd085857520109e73b Mon Sep 17 00:00:00 2001 From: Francois Beutin Date: Wed, 27 Mar 2024 11:16:19 +0100 Subject: [PATCH] Use sync or async exchange calls in ethereum client --- .../src/ledger_app_clients/ethereum/client.py | 152 +++++++++++------- tests/ragger/test_blind_sign.py | 10 +- tests/ragger/test_domain_name.py | 67 +++----- tests/ragger/test_eip712.py | 22 ++- tests/ragger/test_get_address.py | 7 +- tests/ragger/test_nft.py | 19 +-- 6 files changed, 136 insertions(+), 141 deletions(-) diff --git a/client/src/ledger_app_clients/ethereum/client.py b/client/src/ledger_app_clients/ethereum/client.py index aa2b27b9ff..079c89320a 100644 --- a/client/src/ledger_app_clients/ethereum/client.py +++ b/client/src/ledger_app_clients/ethereum/client.py @@ -1,4 +1,5 @@ import rlp +from contextlib import contextmanager from enum import IntEnum from ragger.backend import BackendInterface from ragger.utils import RAPDU @@ -40,61 +41,88 @@ def __init__(self, client: BackendInterface): self._client = client self._cmd_builder = CommandBuilder() - def _send(self, payload: bytes): - return self._client.exchange_async_raw(payload) + @contextmanager + def _exchange_async(self, payload: bytes): + with self._client.exchange_async_raw(payload) as response: + yield response + + def _exchange(self, payload: bytes): + return self._client.exchange_raw(payload) def response(self) -> Optional[RAPDU]: return self._client.last_async_response + @contextmanager def eip712_send_struct_def_struct_name(self, name: str): - return self._send(self._cmd_builder.eip712_send_struct_def_struct_name(name)) + with self._exchange_async(self._cmd_builder.eip712_send_struct_def_struct_name(name)) as response: + yield response + @contextmanager def eip712_send_struct_def_struct_field(self, field_type: EIP712FieldType, type_name: str, type_size: int, array_levels: list, key_name: str): - return self._send(self._cmd_builder.eip712_send_struct_def_struct_field( + with self._exchange_async(self._cmd_builder.eip712_send_struct_def_struct_field( field_type, type_name, type_size, array_levels, - key_name)) + key_name)) as response: + yield response + @contextmanager def eip712_send_struct_impl_root_struct(self, name: str): - return self._send(self._cmd_builder.eip712_send_struct_impl_root_struct(name)) + with self._exchange_async(self._cmd_builder.eip712_send_struct_impl_root_struct(name)) as response: + yield response + @contextmanager def eip712_send_struct_impl_array(self, size: int): - return self._send(self._cmd_builder.eip712_send_struct_impl_array(size)) + with self._exchange_async(self._cmd_builder.eip712_send_struct_impl_array(size)) as response: + yield response + @contextmanager def eip712_send_struct_impl_struct_field(self, raw_value: bytes): chunks = self._cmd_builder.eip712_send_struct_impl_struct_field(bytearray(raw_value)) for chunk in chunks[:-1]: - with self._send(chunk): - pass - return self._send(chunks[-1]) + self._exchange(chunk) + with self._exchange_async(chunks[-1]) as response: + yield response + @contextmanager def eip712_sign_new(self, bip32_path: str): - return self._send(self._cmd_builder.eip712_sign_new(bip32_path)) + with self._exchange_async(self._cmd_builder.eip712_sign_new(bip32_path)) as response: + yield response + @contextmanager def eip712_sign_legacy(self, bip32_path: str, domain_hash: bytes, message_hash: bytes): - return self._send(self._cmd_builder.eip712_sign_legacy(bip32_path, - domain_hash, - message_hash)) + with self._exchange_async(self._cmd_builder.eip712_sign_legacy(bip32_path, + domain_hash, + message_hash)) as response: + yield response + @contextmanager def eip712_filtering_activate(self): - return self._send(self._cmd_builder.eip712_filtering_activate()) + with self._exchange_async(self._cmd_builder.eip712_filtering_activate()) as response: + yield response + @contextmanager def eip712_filtering_message_info(self, name: str, filters_count: int, sig: bytes): - return self._send(self._cmd_builder.eip712_filtering_message_info(name, filters_count, sig)) + with self._exchange_async(self._cmd_builder.eip712_filtering_message_info(name, + filters_count, + sig)) as response: + yield response + @contextmanager def eip712_filtering_show_field(self, name: str, sig: bytes): - return self._send(self._cmd_builder.eip712_filtering_show_field(name, sig)) + with self._exchange_async(self._cmd_builder.eip712_filtering_show_field(name, sig)) as response: + yield response + @contextmanager def sign(self, bip32_path: str, tx_params: dict): @@ -111,24 +139,26 @@ def sign(self, tx = prefix + rlp.encode(decoded + suffix) chunks = self._cmd_builder.sign(bip32_path, tx, suffix) for chunk in chunks[:-1]: - with self._send(chunk): - pass - return self._send(chunks[-1]) + self._exchange(chunk) + with self._exchange_async(chunks[-1]) as response: + yield response def get_challenge(self): - return self._send(self._cmd_builder.get_challenge()) + return self._exchange(self._cmd_builder.get_challenge()) + @contextmanager def get_public_addr(self, display: bool = True, chaincode: bool = False, bip32_path: str = "m/44'/60'/0'/0/0", - chain_id: Optional[int] = None): - return self._send(self._cmd_builder.get_public_addr(display, - chaincode, - bip32_path, - chain_id)) + chain_id: Optional[int] = None) -> RAPDU: + with self._exchange_async(self._cmd_builder.get_public_addr(display, + chaincode, + bip32_path, + chain_id)) as response: + yield response - def provide_domain_name(self, challenge: int, name: str, addr: bytes): + def provide_domain_name(self, challenge: int, name: str, addr: bytes) -> RAPDU: payload = format_tlv(DomainNameTag.STRUCTURE_TYPE, 3) # TrustedDomainName payload += format_tlv(DomainNameTag.STRUCTURE_VERSION, 1) payload += format_tlv(DomainNameTag.SIGNER_KEY_ID, 0) # test key @@ -142,9 +172,9 @@ def provide_domain_name(self, challenge: int, name: str, addr: bytes): chunks = self._cmd_builder.provide_domain_name(payload) for chunk in chunks[:-1]: - with self._send(chunk): + with self._exchange(chunk): pass - return self._send(chunks[-1]) + return self._exchange(chunks[-1]) def set_plugin(self, plugin_name: str, @@ -155,7 +185,7 @@ def set_plugin(self, version: int = 1, key_id: int = 2, algo_id: int = 1, - sig: Optional[bytes] = None): + sig: Optional[bytes] = None) -> RAPDU: if sig is None: # Temporarily get a command with an empty signature to extract the payload and # compute the signature on it @@ -170,15 +200,15 @@ def set_plugin(self, bytes()) # skip APDU header & empty sig sig = sign_data(Key.SET_PLUGIN, tmp[5:-1]) - return self._send(self._cmd_builder.set_plugin(type_, - version, - plugin_name, - contract_addr, - selector, - chain_id, - key_id, - algo_id, - sig)) + return self._exchange(self._cmd_builder.set_plugin(type_, + version, + plugin_name, + contract_addr, + selector, + chain_id, + key_id, + algo_id, + sig)) def provide_nft_metadata(self, collection: str, @@ -188,7 +218,7 @@ def provide_nft_metadata(self, version: int = 1, key_id: int = 1, algo_id: int = 1, - sig: Optional[bytes] = None): + sig: Optional[bytes] = None) -> RAPDU: if sig is None: # Temporarily get a command with an empty signature to extract the payload and # compute the signature on it @@ -202,20 +232,20 @@ def provide_nft_metadata(self, bytes()) # skip APDU header & empty sig sig = sign_data(Key.NFT, tmp[5:-1]) - return self._send(self._cmd_builder.provide_nft_information(type_, - version, - collection, - addr, - chain_id, - key_id, - algo_id, - sig)) + return self._exchange(self._cmd_builder.provide_nft_information(type_, + version, + collection, + addr, + chain_id, + key_id, + algo_id, + sig)) def set_external_plugin(self, plugin_name: str, contract_address: bytes, method_selelector: bytes, - sig: Optional[bytes] = None): + sig: Optional[bytes] = None) -> RAPDU: if sig is None: # Temporarily get a command with an empty signature to extract the payload and # compute the signature on it @@ -223,21 +253,25 @@ def set_external_plugin(self, # skip APDU header & empty sig sig = sign_data(Key.CAL, tmp[5:]) - return self._send(self._cmd_builder.set_external_plugin(plugin_name, contract_address, method_selelector, sig)) + return self._exchange(self._cmd_builder.set_external_plugin(plugin_name, + contract_address, + method_selelector, + sig)) + @contextmanager def personal_sign(self, path: str, msg: bytes): chunks = self._cmd_builder.personal_sign(path, msg) for chunk in chunks[:-1]: - with self._send(chunk): - pass - return self._send(chunks[-1]) + self._exchange(chunk) + with self._exchange_async(chunks[-1]) as response: + yield response def provide_token_metadata(self, ticker: str, addr: bytes, decimals: int, chain_id: int, - sig: Optional[bytes] = None): + sig: Optional[bytes] = None) -> RAPDU: if sig is None: # Temporarily get a command with an empty signature to extract the payload and # compute the signature on it @@ -248,8 +282,8 @@ def provide_token_metadata(self, bytes()) # skip APDU header & empty sig sig = sign_data(Key.CAL, tmp[6:]) - return self._send(self._cmd_builder.provide_erc20_token_information(ticker, - addr, - decimals, - chain_id, - sig)) + return self._exchange(self._cmd_builder.provide_erc20_token_information(ticker, + addr, + decimals, + chain_id, + sig)) diff --git a/tests/ragger/test_blind_sign.py b/tests/ragger/test_blind_sign.py index a4ed7679c2..f3a9acb85a 100644 --- a/tests/ragger/test_blind_sign.py +++ b/tests/ragger/test_blind_sign.py @@ -1,9 +1,10 @@ import json +import pytest from ragger.backend import BackendInterface from ragger.firmware import Firmware from ragger.navigator import Navigator, NavInsID from ragger.error import ExceptionRAPDU -from ledger_app_clients.ethereum.client import EthAppClient +from ledger_app_clients.ethereum.client import EthAppClient, StatusWord from web3 import Web3 from constants import ROOT_SNAPSHOT_PATH, ABIS_FOLDER @@ -35,13 +36,10 @@ def test_blind_sign(firmware: Firmware, "data": data, "chainId": 1 } - try: + with pytest.raises(ExceptionRAPDU) as e: with app_client.sign("m/44'/60'/0'/0/0", tx_params): pass - except ExceptionRAPDU: - pass - else: - assert False + assert e.value.status == StatusWord.INVALID_DATA moves = list() if firmware.device.startswith("nano"): diff --git a/tests/ragger/test_domain_name.py b/tests/ragger/test_domain_name.py index 71eea7bdb6..90f5419b05 100644 --- a/tests/ragger/test_domain_name.py +++ b/tests/ragger/test_domain_name.py @@ -33,9 +33,8 @@ def verbose(request) -> bool: def common(app_client: EthAppClient) -> int: if app_client._client.firmware.device == "nanos": pytest.skip("Not supported on LNS") - with app_client.get_challenge(): - pass - return ResponseParser.challenge(app_client.response().data) + challenge = app_client.get_challenge() + return ResponseParser.challenge(challenge.data) def test_send_fund(firmware: Firmware, @@ -49,8 +48,7 @@ def test_send_fund(firmware: Firmware, if verbose: settings_toggle(firmware, navigator, [SettingID.VERBOSE_ENS]) - with app_client.provide_domain_name(challenge, NAME, ADDR): - pass + app_client.provide_domain_name(challenge, NAME, ADDR) with app_client.sign(BIP32_PATH, { @@ -83,13 +81,9 @@ def test_send_fund_wrong_challenge(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - try: - with app_client.provide_domain_name(~challenge & 0xffffffff, NAME, ADDR): - pass - except ExceptionRAPDU as e: - assert e.status == StatusWord.INVALID_DATA - else: - assert False # An exception should have been raised + with pytest.raises(ExceptionRAPDU) as e: + app_client.provide_domain_name(~challenge & 0xffffffff, NAME, ADDR) + assert e.value.status == StatusWord.INVALID_DATA def test_send_fund_wrong_addr(firmware: Firmware, @@ -99,8 +93,7 @@ def test_send_fund_wrong_addr(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - with app_client.provide_domain_name(challenge, NAME, ADDR): - pass + app_client.provide_domain_name(challenge, NAME, ADDR) addr = bytearray(ADDR) addr.reverse() @@ -133,8 +126,7 @@ def test_send_fund_non_mainnet(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - with app_client.provide_domain_name(challenge, NAME, ADDR): - pass + app_client.provide_domain_name(challenge, NAME, ADDR) with app_client.sign(BIP32_PATH, { @@ -164,8 +156,7 @@ def test_send_fund_unknown_chain(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - with app_client.provide_domain_name(challenge, NAME, ADDR): - pass + app_client.provide_domain_name(challenge, NAME, ADDR) with app_client.sign(BIP32_PATH, { @@ -194,13 +185,9 @@ def test_send_fund_domain_too_long(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - try: - with app_client.provide_domain_name(challenge, "ledger" + "0"*25 + ".eth", ADDR): - pass - except ExceptionRAPDU as e: - assert e.status == StatusWord.INVALID_DATA - else: - assert False # An exception should have been raised + with pytest.raises(ExceptionRAPDU) as e: + app_client.provide_domain_name(challenge, "ledger" + "0"*25 + ".eth", ADDR) + assert e.value.status == StatusWord.INVALID_DATA def test_send_fund_domain_invalid_character(firmware: Firmware, @@ -209,13 +196,9 @@ def test_send_fund_domain_invalid_character(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - try: - with app_client.provide_domain_name(challenge, "l\xe8dger.eth", ADDR): - pass - except ExceptionRAPDU as e: - assert e.status == StatusWord.INVALID_DATA - else: - assert False # An exception should have been raised + with pytest.raises(ExceptionRAPDU) as e: + app_client.provide_domain_name(challenge, "l\xe8dger.eth", ADDR) + assert e.value.status == StatusWord.INVALID_DATA def test_send_fund_uppercase(firmware: Firmware, @@ -224,13 +207,9 @@ def test_send_fund_uppercase(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - try: - with app_client.provide_domain_name(challenge, NAME.upper(), ADDR): - pass - except ExceptionRAPDU as e: - assert e.status == StatusWord.INVALID_DATA - else: - assert False # An exception should have been raised + with pytest.raises(ExceptionRAPDU) as e: + app_client.provide_domain_name(challenge, NAME.upper(), ADDR) + assert e.value.status == StatusWord.INVALID_DATA def test_send_fund_domain_non_ens(firmware: Firmware, @@ -239,10 +218,6 @@ def test_send_fund_domain_non_ens(firmware: Firmware, app_client = EthAppClient(backend) challenge = common(app_client) - try: - with app_client.provide_domain_name(challenge, "ledger.hte", ADDR): - pass - except ExceptionRAPDU as e: - assert e.status == StatusWord.INVALID_DATA - else: - assert False # An exception should have been raised + with pytest.raises(ExceptionRAPDU) as e: + app_client.provide_domain_name(challenge, "ledger.hte", ADDR) + assert e.value.status == StatusWord.INVALID_DATA diff --git a/tests/ragger/test_eip712.py b/tests/ragger/test_eip712.py index b4551885e1..1aedc00247 100644 --- a/tests/ragger/test_eip712.py +++ b/tests/ragger/test_eip712.py @@ -204,19 +204,15 @@ def test_eip712_address_substitution(firmware: Firmware, with open("%s/address_substitution.json" % (eip712_json_path())) as file: data = json.load(file) - with app_client.provide_token_metadata("DAI", - bytes.fromhex(data["message"]["token"][2:]), - 18, - 1): - pass - - with app_client.get_challenge(): - pass - challenge = ResponseParser.challenge(app_client.response().data) - with app_client.provide_domain_name(challenge, - "vitalik.eth", - bytes.fromhex(data["message"]["to"][2:])): - pass + app_client.provide_token_metadata("DAI", + bytes.fromhex(data["message"]["token"][2:]), + 18, + 1) + + challenge = ResponseParser.challenge(app_client.get_challenge().data) + app_client.provide_domain_name(challenge, + "vitalik.eth", + bytes.fromhex(data["message"]["to"][2:])) if verbose: settings_toggle(firmware, navigator, [SettingID.VERBOSE_EIP712]) diff --git a/tests/ragger/test_get_address.py b/tests/ragger/test_get_address.py index 60cf4bb415..de933c2877 100644 --- a/tests/ragger/test_get_address.py +++ b/tests/ragger/test_get_address.py @@ -52,15 +52,12 @@ def test_get_pk_rejected(firmware: Firmware, navigator: Navigator): app_client = EthAppClient(backend) - try: + with pytest.raises(ExceptionRAPDU) as e: with app_client.get_public_addr(): navigator.navigate_and_compare(ROOT_SNAPSHOT_PATH, "get_pk_rejected", get_moves(firmware, navigator, reject=True)) - except ExceptionRAPDU as e: - assert e.status == StatusWord.CONDITION_NOT_SATISFIED - else: - assert False # An exception should have been raised + assert e.value.status == StatusWord.CONDITION_NOT_SATISFIED def test_get_pk(firmware: Firmware, diff --git a/tests/ragger/test_nft.py b/tests/ragger/test_nft.py index fd39871222..9046e39d48 100644 --- a/tests/ragger/test_nft.py +++ b/tests/ragger/test_nft.py @@ -96,13 +96,11 @@ def common_test_nft(fw: Firmware, _, DEVICE_ADDR, _ = ResponseParser.pk_addr(app_client.response().data) data = collec.contract.encodeABI(action.fn_name, action.fn_args) - with app_client.set_plugin(plugin_name, - collec.addr, - get_selector_from_data(data), - collec.chain_id): - pass - with app_client.provide_nft_metadata(collec.name, collec.addr, collec.chain_id): - pass + app_client.set_plugin(plugin_name, + collec.addr, + get_selector_from_data(data), + collec.chain_id) + app_client.provide_nft_metadata(collec.name, collec.addr, collec.chain_id) tx_params = { "nonce": NONCE, "gasPrice": Web3.to_wei(GAS_PRICE, "gwei"), @@ -133,12 +131,9 @@ def common_test_nft_reject(test_fn: Callable, nav: Navigator, collec: NFTCollection, action: Action): - try: + with pytest.raises(ExceptionRAPDU) as e: test_fn(fw, back, nav, collec, action, True) - except ExceptionRAPDU as e: - assert e.status == StatusWord.CONDITION_NOT_SATISFIED - else: - assert False # An exception should have been raised + assert e.value.status == StatusWord.CONDITION_NOT_SATISFIED # ERC-721