Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Identify WTuples by name when treating them as structs in an ABI method #358

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
state_totals 65 32 - | 32 16 -
stress_tests/BruteForceRotationSearch 228 163 - | 152 106 -
string_ops 156 154 - | 58 55 -
struct_by_name/Demo 271 217 - | 155 113 -
struct_in_box/Example 242 206 - | 127 99 -
stubs/BigUInt 192 121 - | 126 73 -
stubs/Bytes 944 279 - | 606 153 -
Expand All @@ -134,4 +135,4 @@
unssa/UnSSA 432 368 - | 241 204 -
voting/VotingRoundApp 1580 1475 - | 727 644 -
with_reentrancy/WithReentrancy 245 234 - | 126 117 -
Total 70231 54641 54582 | 33265 22305 22261
Total 70502 54858 54799 | 33420 22418 22374
10 changes: 6 additions & 4 deletions src/puya/arc56.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,15 @@ def _get_source_info(debug_info: DebugInfo) -> Sequence[models.SourceInfo]:

class _StructAliases:
def __init__(self, structs: Iterable[ARC4Struct]) -> None:
self.aliases = dict[str, str]()
alias_to_fullname = dict[str, str]()
for struct in structs:
self.aliases[struct.fullname] = (
alias = (
struct.fullname
if struct.name in self.aliases or struct.name in models.AVMType
if struct.name in alias_to_fullname or struct.name in models.AVMType
else struct.name
)
alias_to_fullname[alias] = struct.fullname
self.aliases = {v: k for k, v in alias_to_fullname.items()}

@typing.overload
def resolve(self, struct: str) -> str: ...
Expand All @@ -203,7 +205,7 @@ def resolve(self, struct: str | None) -> str | None:

def _struct_to_event(structs: _StructAliases, struct: ARC4Struct) -> models.Event:
return models.Event(
name=struct.name,
name=structs.resolve(struct.name),
desc=struct.desc,
args=[
models.EventArg(
Expand Down
12 changes: 5 additions & 7 deletions src/puya/ir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,20 +573,18 @@ def _wtypes_to_structs(
Will recursively include any structs referenced in fields
"""
structs = list(structs)
struct_results = dict[wtypes.ARC4Struct | wtypes.WTuple, ARC4Struct]()
struct_results = dict[str, ARC4Struct]()
while structs:
struct = structs.pop()
if struct in struct_results:
if struct.name in struct_results:
continue
structs.extend(
wtype
for wtype in struct.fields.values()
if isinstance(wtype, wtypes.ARC4Struct) and wtype not in struct_results
if isinstance(wtype, wtypes.ARC4Struct) and wtype.name not in struct_results
)
struct_results[struct] = _wtype_to_struct(struct)
return {
wtype.name: struct_results[wtype] for wtype in sorted(struct_results, key=lambda s: s.name)
}
struct_results[struct.name] = _wtype_to_struct(struct)
return dict(sorted(struct_results.items(), key=lambda item: item[0]))


def _wtype_to_struct(struct: wtypes.ARC4Struct | wtypes.WTuple) -> ARC4Struct:
Expand Down
3 changes: 2 additions & 1 deletion src/puya/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import enum
import re
import typing
from collections.abc import Mapping, Sequence

Expand Down Expand Up @@ -62,7 +63,7 @@ class ARC4Struct:

@property
def name(self) -> str:
return self.fullname.rsplit(".", maxsplit=1)[-1]
return re.split(r"\W", self.fullname)[-1]


@attrs.frozen(kw_only=True)
Expand Down
15 changes: 11 additions & 4 deletions src/puyapy/awst_build/arc4_client_gen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import textwrap
from collections.abc import Iterable, Sequence
from pathlib import Path
Expand All @@ -18,6 +19,7 @@

_AUTO_GENERATED_COMMENT = "# This file is auto-generated, do not modify"
_INDENT = " " * 4
_NON_ALPHA_NUMERIC = re.compile(r"\W+")


def write_arc4_client(contract: arc56.Contract, out_dir: Path) -> None:
Expand All @@ -42,7 +44,7 @@ def __init__(self, contract: arc56.Contract):
self.contract = contract
self.python_methods = set[str]()
self.struct_to_class = dict[str, str]()
self.reserved_class_names = {contract.name}
self.reserved_class_names = set[str]()
self.reserved_method_names = set[str]()
self.class_decls = list[str]()

Expand All @@ -53,6 +55,7 @@ def generate(cls, contract: arc56.Contract) -> str:
def _gen(self) -> str:
# generate class definitions for any referenced structs in methods
# don't generate from self.contract.structs as it may contain other struct definitions
client_class = self._unique_class(self.contract.name)
for method in self.contract.methods:
for struct in filter(None, (method.returns.struct, *(a.struct for a in method.args))):
if struct not in self.struct_to_class and (
Expand All @@ -70,7 +73,7 @@ def _gen(self) -> str:
"",
*self.class_decls,
"",
f"class {self.contract.name}(algopy.arc4.ARC4Client, typing.Protocol):",
f"class {client_class}(algopy.arc4.ARC4Client, typing.Protocol):",
*_docstring(self.contract.desc),
*self._gen_methods(),
)
Expand Down Expand Up @@ -110,7 +113,7 @@ def _get_client_type(self, typ: str) -> str:
return str(arc4_to_pytype(typ, None))

def _unique_class(self, name: str) -> str:
base_name = name
base_name = name = _get_python_safe_name(name)
seq = 1
while name in self.reserved_class_names:
seq += 1
Expand All @@ -120,7 +123,7 @@ def _unique_class(self, name: str) -> str:
return name

def _unique_method(self, name: str) -> str:
base_name = name
base_name = name = _get_python_safe_name(name)
seq = 1
while name in self.reserved_method_names:
seq += 1
Expand Down Expand Up @@ -218,3 +221,7 @@ def _indent(lines: Iterable[str] | str) -> str:
if not isinstance(lines, str):
lines = "\n".join(lines)
return textwrap.indent(lines, _INDENT)


def _get_python_safe_name(name: str) -> str:
return _NON_ALPHA_NUMERIC.sub("_", name)
48 changes: 48 additions & 0 deletions test_cases/struct_by_name/contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import typing

from algopy import ARC4Contract, arc4

from test_cases.struct_by_name.mod import StructTwo as StructThree


class StructOne(typing.NamedTuple):
x: arc4.UInt8
y: arc4.UInt8


class StructTwo(typing.NamedTuple):
x: arc4.UInt8
y: arc4.UInt8


class DemoContract(ARC4Contract):
"""
Verify that even though named tuples with different names but the same structure should be
considered 'comparable' in the type system, they should be output separately when being
interpreted as an arc4 Struct in an abi method
"""

@arc4.abimethod()
def get_one(self) -> StructOne:
return StructOne(
x=arc4.UInt8(1),
y=arc4.UInt8(1),
)

@arc4.abimethod()
def get_two(self) -> StructTwo:
return StructTwo(
x=arc4.UInt8(1),
y=arc4.UInt8(1),
)

@arc4.abimethod()
def get_three(self) -> StructThree:
return StructThree(
x=arc4.UInt8(1),
y=arc4.UInt8(1),
)

@arc4.abimethod()
def compare(self) -> bool:
return self.get_one() == self.get_two() and self.get_two() == self.get_three()
8 changes: 8 additions & 0 deletions test_cases/struct_by_name/mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typing

from algopy import arc4


class StructTwo(typing.NamedTuple):
x: arc4.UInt8
y: arc4.UInt8
Loading
Loading