From 625cc5abcff86b4d9256a329f76de121b3837211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michal=20P=C5=99evr=C3=A1til?= Date: Sun, 11 Feb 2024 21:39:38 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=97=EF=B8=8F=20Add=20new=20experimental?= =?UTF-8?q?=20ABI=20coder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wake/deployment/__init__.py | 2 +- wake/development/core.py | 181 ++++++++++++++++++++++++++++++++++++ wake/testing/__init__.py | 2 +- 3 files changed, 183 insertions(+), 2 deletions(-) diff --git a/wake/deployment/__init__.py b/wake/deployment/__init__.py index 19ae2be54..3ea6b2fa1 100644 --- a/wake/deployment/__init__.py +++ b/wake/deployment/__init__.py @@ -1,6 +1,6 @@ from rich import print -from wake.development.core import Abi, Account, Address, Eip712Domain, Wei +from wake.development.core import Abi, Account, Address, Eip712Domain, Wei, abi from wake.development.internal import UnknownEvent from wake.development.primitive_types import * from wake.development.transactions import ( diff --git a/wake/development/core.py b/wake/development/core.py index bb58a05ca..f312e4ef7 100644 --- a/wake/development/core.py +++ b/wake/development/core.py @@ -158,6 +158,187 @@ def fix_library_abi(args: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return ret +class abi: + @classmethod + def _normalize_input(cls, arguments: Iterable) -> List: + ret = [] + for arg in arguments: + if isinstance(arg, Address): + ret.append(str(arg)) + elif isinstance(arg, Account): + ret.append(str(arg.address)) + elif isinstance(arg, (list, tuple)): + ret.append(cls._normalize_input(arg)) + elif dataclasses.is_dataclass(arg): + ret.append(cls._normalize_input(dataclasses.astuple(arg))) + else: + ret.append(arg) + return ret + + @classmethod + def _normalize_output(cls, types: Sequence[Type], arguments: Sequence) -> Tuple: + ret = [] + assert len(types) == len(arguments) + for t, arg in zip(types, arguments): + origin = get_origin(t) + + if isinstance(origin, type) and issubclass(origin, list): + ret.append( + origin(cls._normalize_output([get_args(t)[0]] * len(arg), arg)) + ) + elif isinstance(t, type) and issubclass(t, int): + ret.append(t(arg)) + elif isinstance(t, type) and issubclass(t, (bytes, bytearray)): + ret.append(t(arg)) + elif issubclass(t, Enum): + ret.append(t(arg)) + elif issubclass(t, (Account, Address)): + ret.append(t(arg)) + elif dataclasses.is_dataclass(t): + assert isinstance(arg, tuple) + resolved_types = get_type_hints( + t # pyright: ignore reportGeneralTypeIssues + ) + field_types = [ + resolved_types[field.name] + for field in dataclasses.fields(t) + if field.init + ] + assert len(arg) == len(field_types) + ret.append(t(*cls._normalize_output(field_types, arg))) + else: + # int, str, bool does not need to be normalized + ret.append(arg) + + return tuple(ret) + + @classmethod + def _types_from_type(cls, t: Type) -> str: + origin = get_origin(t) + + if isinstance(origin, type) and issubclass(origin, list): + if hasattr(origin, "length"): + return f"{cls._types_from_type(get_args(t)[0])}[{getattr(origin, 'length')}]" + else: + return f"{cls._types_from_type(get_args(t)[0])}[]" + elif isinstance(t, type) and issubclass(t, Integer): + if t.min == 0: + bits = math.ceil(math.log2(t.max + 1)) + return f"uint{bits}" + else: + bits = math.ceil(math.log2(t.max - t.min + 1)) + return f"int{bits}" + elif isinstance(t, type) and issubclass(t, FixedSizeBytes): + return f"bytes{t.length}" + elif t is int: + # fallback for int used directly + return "int256" + elif t is bytes or t is bytearray: + return "bytes" + elif t is str: + return "string" + elif issubclass(t, Enum): + return "uint8" + elif t is bool: + return "bool" + elif issubclass(t, (Account, Address)): + return "address" + elif dataclasses.is_dataclass(t): + hints = get_type_hints( + t, # pyright: ignore reportGeneralTypeIssues + include_extras=True, + ) + return f"({','.join(cls._types_from_type(hints[f.name]) for f in dataclasses.fields(t))})" + else: + raise ValueError(f"Unsupported type {t}") + + @classmethod + def _types_from_args(cls, args: Iterable) -> str: + if isinstance(args, tuple): + return f"({','.join(cls._types_from_args(arg) for arg in args)})" + elif isinstance(args, list): + for arg in args: + try: + arg_type = cls._types_from_args(arg) + if hasattr(args, "length"): + return f"{arg_type}[{getattr(args, 'length')}]" + else: + return f"{arg_type}[]" + except ValueError: + pass + + raise ValueError("Could not determine type of list") + elif isinstance(args, (Address, Account)): + return "address" + elif isinstance(args, str): + return "string" + elif isinstance(args, (bytes, bytearray)): + if hasattr(args, "length"): + return f"bytes{getattr(args, 'length')}" + else: + return "bytes" + elif isinstance(args, bool): + return "bool" + elif callable(args): + return "function" + elif isinstance(args, IntEnum): + return "uint8" + elif dataclasses.is_dataclass(args): + return cls._types_from_type(type(args)) + elif isinstance(args, int): + if not hasattr(args, "min") or not hasattr(args, "max"): + raise ValueError( + "Integer cannot be directly ABI-encoded. Use typecast to intN or uintN instead." + ) + min = getattr(args, "min") + max = getattr(args, "max") + if min == 0: + bits = math.ceil(math.log2(max + 1)) + return f"uint{bits}" + else: + bits = math.ceil(math.log2(max - min + 1)) + return f"int{bits}" + else: + raise ValueError(f"Unsupported type {type(args)}") + + @classmethod + def encode(cls, *args) -> bytes: + return eth_abi.abi.encode( + [cls._types_from_args(a) for a in args], cls._normalize_input(args) + ) + + @classmethod + def encode_packed(cls, *args) -> bytes: + return eth_abi.packed.encode_packed( + [cls._types_from_args(a) for a in args], cls._normalize_input(args) + ) + + @classmethod + def encode_with_selector(cls, selector: bytes, *args) -> bytes: + return selector + cls.encode(*args) + + @classmethod + def encode_with_signature(cls, signature: str, *args) -> bytes: + selector = keccak.new(data=signature.encode("utf-8"), digest_bits=256).digest()[ + :4 + ] + return cls.encode_with_selector(selector, *args) + + @classmethod + def encode_call(cls, func: Callable, args: Iterable) -> bytes: + selector = func.selector + return cls.encode_with_selector(selector, *args) + + @classmethod + def decode(cls, data: bytes, types: Sequence[Type]) -> Any: + ret = cls._normalize_output( + types, eth_abi.abi.decode([cls._types_from_type(t) for t in types], data) + ) + if len(ret) == 1: + return ret[0] + return ret + + class Abi: @staticmethod def _normalize_input(arguments: Iterable) -> List: diff --git a/wake/testing/__init__.py b/wake/testing/__init__.py index d0585cae4..e586094cf 100644 --- a/wake/testing/__init__.py +++ b/wake/testing/__init__.py @@ -1,6 +1,6 @@ from rich import print -from wake.development.core import Abi, Account, Address, Eip712Domain, Wei +from wake.development.core import Abi, Account, Address, Eip712Domain, Wei, abi from wake.development.internal import UnknownEvent from wake.development.primitive_types import * from wake.development.transactions import (