diff --git a/client/src/ledger_app_clients/ethereum/eip712/InputData.py b/client/src/ledger_app_clients/ethereum/eip712/InputData.py index a19ccf3e83..aab19f9251 100644 --- a/client/src/ledger_app_clients/ethereum/eip712/InputData.py +++ b/client/src/ledger_app_clients/ethereum/eip712/InputData.py @@ -4,7 +4,8 @@ import signal import sys import copy -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union +import struct from ledger_app_clients.ethereum import keychain from ledger_app_clients.ethereum.client import EthAppClient, EIP712FieldType @@ -118,69 +119,63 @@ def send_struct_def_field(typename, keyname): return (typename, type_enum, typesize, array_lvls) -def encode_integer(value, typesize): - data = bytearray() - +def encode_integer(value: Union[str | int], typesize: int) -> bytes: # Some are already represented as integers in the JSON, but most as strings if isinstance(value, str): - base = 10 - if value.startswith("0x"): - base = 16 - value = int(value, base) + value = int(value, 0) if value == 0: - data.append(0) + data = b'\x00' else: - if value < 0: # negative number, send it as unsigned - mask = 0 - for i in range(typesize): # make a mask as big as the typesize - mask = (mask << 8) | 0xff - value &= mask - while value > 0: - data.append(value & 0xff) - value >>= 8 - data.reverse() + # biggest uint type accepted by struct.pack + uint64_mask = 0xffffffffffffffff + data = struct.pack(">QQQQ", + (value >> 192) & uint64_mask, + (value >> 128) & uint64_mask, + (value >> 64) & uint64_mask, + value & uint64_mask) + data = data[len(data) - typesize:] + data = data.lstrip(b'\x00') return data -def encode_int(value, typesize): +def encode_int(value: str, typesize: int) -> bytes: return encode_integer(value, typesize) -def encode_uint(value, typesize): +def encode_uint(value: str, typesize: int) -> bytes: return encode_integer(value, typesize) -def encode_hex_string(value, size): - data = bytearray() - value = value[2:] # skip 0x - byte_idx = 0 - while byte_idx < size: - data.append(int(value[(byte_idx * 2):(byte_idx * 2 + 2)], 16)) - byte_idx += 1 - return data +def encode_hex_string(value: str, size: int) -> bytes: + assert value.startswith("0x") + value = value[2:] + if len(value) < (size * 2): + value = value.rjust(size * 2, "0") + assert len(value) == (size * 2) + return bytes.fromhex(value) -def encode_address(value, typesize): +def encode_address(value: str, typesize: int) -> bytes: return encode_hex_string(value, 20) -def encode_bool(value, typesize): - return encode_integer(value, typesize) +def encode_bool(value: str, typesize: int) -> bytes: + return encode_integer(value, 1) -def encode_string(value, typesize): +def encode_string(value: str, typesize: int) -> bytes: data = bytearray() for char in value: data.append(ord(char)) return data -def encode_bytes_fix(value, typesize): +def encode_bytes_fix(value: str, typesize: int) -> bytes: return encode_hex_string(value, typesize) -def encode_bytes_dyn(value, typesize): +def encode_bytes_dyn(value: str, typesize: int) -> bytes: # length of the value string # - the length of 0x (2) # / by the length of one byte in a hex string (2)