Skip to content

Commit

Permalink
⚗️ Add new experimental ABI coder
Browse files Browse the repository at this point in the history
  • Loading branch information
michprev committed Feb 12, 2024
1 parent 4e7b4ab commit 625cc5a
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 2 deletions.
2 changes: 1 addition & 1 deletion wake/deployment/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rich import print

from wake.development.core import Abi, Account, Address, Eip712Domain, Wei
from wake.development.core import Abi, Account, Address, Eip712Domain, Wei, abi
from wake.development.internal import UnknownEvent
from wake.development.primitive_types import *
from wake.development.transactions import (
Expand Down
181 changes: 181 additions & 0 deletions wake/development/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,187 @@ def fix_library_abi(args: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return ret


class abi:
@classmethod
def _normalize_input(cls, arguments: Iterable) -> List:
ret = []
for arg in arguments:
if isinstance(arg, Address):
ret.append(str(arg))
elif isinstance(arg, Account):
ret.append(str(arg.address))
elif isinstance(arg, (list, tuple)):
ret.append(cls._normalize_input(arg))
elif dataclasses.is_dataclass(arg):
ret.append(cls._normalize_input(dataclasses.astuple(arg)))
else:
ret.append(arg)
return ret

@classmethod
def _normalize_output(cls, types: Sequence[Type], arguments: Sequence) -> Tuple:
ret = []
assert len(types) == len(arguments)
for t, arg in zip(types, arguments):
origin = get_origin(t)

if isinstance(origin, type) and issubclass(origin, list):
ret.append(
origin(cls._normalize_output([get_args(t)[0]] * len(arg), arg))
)
elif isinstance(t, type) and issubclass(t, int):
ret.append(t(arg))
elif isinstance(t, type) and issubclass(t, (bytes, bytearray)):
ret.append(t(arg))
elif issubclass(t, Enum):
ret.append(t(arg))
elif issubclass(t, (Account, Address)):
ret.append(t(arg))
elif dataclasses.is_dataclass(t):
assert isinstance(arg, tuple)
resolved_types = get_type_hints(
t # pyright: ignore reportGeneralTypeIssues
)
field_types = [
resolved_types[field.name]
for field in dataclasses.fields(t)
if field.init
]
assert len(arg) == len(field_types)
ret.append(t(*cls._normalize_output(field_types, arg)))
else:
# int, str, bool does not need to be normalized
ret.append(arg)

return tuple(ret)

@classmethod
def _types_from_type(cls, t: Type) -> str:
origin = get_origin(t)

if isinstance(origin, type) and issubclass(origin, list):
if hasattr(origin, "length"):
return f"{cls._types_from_type(get_args(t)[0])}[{getattr(origin, 'length')}]"
else:
return f"{cls._types_from_type(get_args(t)[0])}[]"
elif isinstance(t, type) and issubclass(t, Integer):
if t.min == 0:
bits = math.ceil(math.log2(t.max + 1))
return f"uint{bits}"
else:
bits = math.ceil(math.log2(t.max - t.min + 1))
return f"int{bits}"
elif isinstance(t, type) and issubclass(t, FixedSizeBytes):
return f"bytes{t.length}"
elif t is int:
# fallback for int used directly
return "int256"
elif t is bytes or t is bytearray:
return "bytes"
elif t is str:
return "string"
elif issubclass(t, Enum):
return "uint8"
elif t is bool:
return "bool"
elif issubclass(t, (Account, Address)):
return "address"
elif dataclasses.is_dataclass(t):
hints = get_type_hints(
t, # pyright: ignore reportGeneralTypeIssues
include_extras=True,
)
return f"({','.join(cls._types_from_type(hints[f.name]) for f in dataclasses.fields(t))})"
else:
raise ValueError(f"Unsupported type {t}")

@classmethod
def _types_from_args(cls, args: Iterable) -> str:
if isinstance(args, tuple):
return f"({','.join(cls._types_from_args(arg) for arg in args)})"
elif isinstance(args, list):
for arg in args:
try:
arg_type = cls._types_from_args(arg)
if hasattr(args, "length"):
return f"{arg_type}[{getattr(args, 'length')}]"
else:
return f"{arg_type}[]"
except ValueError:
pass

raise ValueError("Could not determine type of list")
elif isinstance(args, (Address, Account)):
return "address"
elif isinstance(args, str):
return "string"
elif isinstance(args, (bytes, bytearray)):
if hasattr(args, "length"):
return f"bytes{getattr(args, 'length')}"
else:
return "bytes"
elif isinstance(args, bool):
return "bool"
elif callable(args):
return "function"
elif isinstance(args, IntEnum):
return "uint8"
elif dataclasses.is_dataclass(args):
return cls._types_from_type(type(args))
elif isinstance(args, int):
if not hasattr(args, "min") or not hasattr(args, "max"):
raise ValueError(
"Integer cannot be directly ABI-encoded. Use typecast to intN or uintN instead."
)
min = getattr(args, "min")
max = getattr(args, "max")
if min == 0:
bits = math.ceil(math.log2(max + 1))
return f"uint{bits}"
else:
bits = math.ceil(math.log2(max - min + 1))
return f"int{bits}"
else:
raise ValueError(f"Unsupported type {type(args)}")

@classmethod
def encode(cls, *args) -> bytes:
return eth_abi.abi.encode(
[cls._types_from_args(a) for a in args], cls._normalize_input(args)
)

@classmethod
def encode_packed(cls, *args) -> bytes:
return eth_abi.packed.encode_packed(
[cls._types_from_args(a) for a in args], cls._normalize_input(args)
)

@classmethod
def encode_with_selector(cls, selector: bytes, *args) -> bytes:
return selector + cls.encode(*args)

@classmethod
def encode_with_signature(cls, signature: str, *args) -> bytes:
selector = keccak.new(data=signature.encode("utf-8"), digest_bits=256).digest()[
:4
]
return cls.encode_with_selector(selector, *args)

@classmethod
def encode_call(cls, func: Callable, args: Iterable) -> bytes:
selector = func.selector
return cls.encode_with_selector(selector, *args)

@classmethod
def decode(cls, data: bytes, types: Sequence[Type]) -> Any:
ret = cls._normalize_output(
types, eth_abi.abi.decode([cls._types_from_type(t) for t in types], data)
)
if len(ret) == 1:
return ret[0]
return ret


class Abi:
@staticmethod
def _normalize_input(arguments: Iterable) -> List:
Expand Down
2 changes: 1 addition & 1 deletion wake/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rich import print

from wake.development.core import Abi, Account, Address, Eip712Domain, Wei
from wake.development.core import Abi, Account, Address, Eip712Domain, Wei, abi
from wake.development.internal import UnknownEvent
from wake.development.primitive_types import *
from wake.development.transactions import (
Expand Down

0 comments on commit 625cc5a

Please sign in to comment.