Skip to content

Commit

Permalink
- create separate pytype for ARC4Tuple (in future probably want to do…
Browse files Browse the repository at this point in the history
… the same for structs)

- introduce pytypes.TupleLikeType that captures both native and ARC4Tuples
- make a NamedTupleType subclass of TupleType that exposes a `fields` property which is easier to create & consume
- use NamedTuple for pre-compiled data structures
- correct implicit arc4 conversion to use iterate_static, which also ensure single evaluation
- introduce pytype for NamedTupleBaseType - this is somewhat of a lie, since typing.NamedTuple is not actually part of MRO at runtime due to Python magic, but it's a good lie because it allows resolving expression and type builders for NamedTuple without any hackery
- fixing some error messages to start with lower case
- expose concrete types on WType and PyType classes, this simplifies some things and Sequence/Mapping was mostly used as shorthand
- add some negative tests for NamedTuple scenarios
- general refactoring of new NamedTuple feature
  • Loading branch information
achidlow authored and daniel-makerx committed Oct 23, 2024
1 parent 6b28f8b commit eaf2700
Show file tree
Hide file tree
Showing 30 changed files with 734 additions and 869 deletions.
4 changes: 1 addition & 3 deletions scripts/generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,7 @@ def pytype_repr(typ: pytypes.PyType) -> str:
except KeyError:
pass
match typ:
case pytypes.TupleType(generic=pytypes.GenericTupleType, items=tuple_items) if len(
tuple_items
) > 1:
case pytypes.TupleType(items=tuple_items) if len(tuple_items) > 1:
item_strs = [pytype_repr(item) for item in tuple_items]
return (
f"pytypes.GenericTupleType.parameterise("
Expand Down
12 changes: 4 additions & 8 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,16 +727,12 @@ class FieldExpression(Expression):

@wtype.default
def _wtype_factory(self) -> wtypes.WType:
struct_wtype = self.base.wtype
if isinstance(struct_wtype, wtypes.WTuple) and struct_wtype.names:
index = struct_wtype.name_to_index(self.name, self.source_location)
return struct_wtype.types[index]
if not isinstance(struct_wtype, wtypes.WStructType | wtypes.ARC4Struct):
raise InternalError("invalid struct wtype")
dataclass_type = self.base.wtype
assert isinstance(dataclass_type, wtypes.WStructType | wtypes.ARC4Struct | wtypes.WTuple)
try:
return struct_wtype.fields[self.name]
return dataclass_type.fields[self.name]
except KeyError:
raise CodeError(f"invalid field for {struct_wtype}", self.source_location) from None
raise CodeError(f"invalid field for {dataclass_type}", self.source_location) from None

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_field_expression(self)
Expand Down
93 changes: 48 additions & 45 deletions src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping
from functools import cached_property

import attrs
Expand All @@ -10,6 +10,7 @@
from puya.errors import CodeError, InternalError
from puya.models import TransactionType
from puya.parse import SourceLocation
from puya.utils import unique

logger = log.get_logger(__name__)

Expand Down Expand Up @@ -160,12 +161,12 @@ def from_type(cls, transaction_type: TransactionType | None) -> "WInnerTransacti
@typing.final
@attrs.frozen
class WStructType(WType):
fields: Mapping[str, WType] = attrs.field(converter=immutabledict)
fields: immutabledict[str, WType] = attrs.field(converter=immutabledict)
scalar_type: None = attrs.field(default=None, init=False)
source_location: SourceLocation | None = attrs.field(eq=False)

@fields.validator
def _fields_validator(self, _: object, fields: Mapping[str, WType]) -> None:
def _fields_validator(self, _: object, fields: immutabledict[str, WType]) -> None:
if not fields:
raise CodeError("struct needs fields", self.source_location)
if void_wtype in fields.values():
Expand All @@ -191,26 +192,18 @@ def _name(self) -> str:
return f"array<{self.element_type.name}>"


def _names_converter(val: Iterable[str] | None) -> tuple[str, ...] | None:
return None if val is None else tuple[str, ...](val)


@typing.final
@attrs.frozen
class WTuple(WType):
names: Sequence[str] | None = attrs.field(
default=None,
kw_only=True,
converter=_names_converter,
)
types: Sequence[WType] = attrs.field(converter=tuple[WType, ...])
types: tuple[WType, ...] = attrs.field(converter=tuple[WType, ...])
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
scalar_type: None = attrs.field(default=None, init=False)
immutable: bool = attrs.field(default=True, init=False)
name: str = attrs.field(eq=False, kw_only=True)
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
names: tuple[str, ...] | None = attrs.field(default=None)

@types.validator
def _types_validator(self, _attribute: object, types: Sequence[WType]) -> None:
def _types_validator(self, _attribute: object, types: tuple[WType, ...]) -> None:
if not types:
raise CodeError("empty tuples are not supported", self.source_location)
if void_wtype in types:
Expand All @@ -220,10 +213,26 @@ def _types_validator(self, _attribute: object, types: Sequence[WType]) -> None:
def _name(self) -> str:
return f"tuple<{','.join([t.name for t in self.types])}>"

@names.validator
def _names_validator(self, _attribute: object, names: tuple[str, ...] | None) -> None:
if names is None:
return
if len(names) != len(self.types):
raise InternalError("mismatch between tuple item names length and types")
if len(names) != len(unique(names)):
raise CodeError("tuple item names are not unique", self.source_location)

@cached_property
def fields(self) -> Mapping[str, WType]:
"""Mapping of item names to types if `names` is defined, otherwise empty."""
if self.names is None:
return {}
return dict(zip(self.names, self.types, strict=True))

def name_to_index(self, name: str, source_location: SourceLocation) -> int:
if self.names is None:
raise CodeError(
"Cannot access tuple item by name of an unnamed tuple", source_location
"cannot access tuple item by name of an unnamed tuple", source_location
)
try:
return self.names.index(name)
Expand Down Expand Up @@ -326,7 +335,7 @@ def _m_validator(self, _attribute: object, m: int) -> None:
raise CodeError("Precision must be between 1 and 160 inclusive", self.source_location)


def _required_arc4_wtypes(wtypes: Iterable[WType]) -> Sequence[ARC4Type]:
def _required_arc4_wtypes(wtypes: Iterable[WType]) -> tuple[ARC4Type, ...]:
result = []
for wtype in wtypes:
if not isinstance(wtype, ARC4Type):
Expand All @@ -339,7 +348,7 @@ def _required_arc4_wtypes(wtypes: Iterable[WType]) -> Sequence[ARC4Type]:
@attrs.frozen(kw_only=True)
class ARC4Tuple(ARC4Type):
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
types: Sequence[ARC4Type] = attrs.field(converter=_required_arc4_wtypes)
types: tuple[ARC4Type, ...] = attrs.field(converter=_required_arc4_wtypes)
name: str = attrs.field(init=False)
arc4_name: str = attrs.field(init=False, eq=False)
immutable: bool = attrs.field(init=False)
Expand All @@ -362,14 +371,20 @@ def _decode_type(self) -> WTuple:
return WTuple(self.types, self.source_location)

def can_encode_type(self, wtype: WType) -> bool:
if wtype == self.decode_type:
return True
elif not isinstance(wtype, WTuple) or len(wtype.types) != len(self.types):
return False
return all(
return super().can_encode_type(wtype) or _is_arc4_encodeable_tuple(wtype, self.types)


def _is_arc4_encodeable_tuple(
wtype: WType, target_types: tuple[ARC4Type, ...]
) -> typing.TypeGuard[WTuple]:
return (
isinstance(wtype, WTuple)
and len(wtype.types) == len(target_types)
and all(
arc4_wtype == encode_wtype or arc4_wtype.can_encode_type(encode_wtype)
for arc4_wtype, encode_wtype in zip(self.types, wtype.types, strict=True)
for arc4_wtype, encode_wtype in zip(target_types, wtype.types, strict=True)
)
)


def _expect_arc4_type(wtype: WType) -> ARC4Type:
Expand Down Expand Up @@ -428,7 +443,7 @@ def _require_arc4_fields(fields: Mapping[str, WType]) -> immutabledict[str, ARC4
]
if non_arc4_fields:
raise CodeError(
"Invalid ARC4 Struct declaration,"
"invalid ARC4 Struct declaration,"
f" the following fields are not ARC4 encoded types: {', '.join(non_arc4_fields)}",
)
return immutabledict(fields)
Expand All @@ -437,7 +452,7 @@ def _require_arc4_fields(fields: Mapping[str, WType]) -> immutabledict[str, ARC4
@typing.final
@attrs.frozen(kw_only=True)
class ARC4Struct(ARC4Type):
fields: Mapping[str, ARC4Type] = attrs.field(converter=_require_arc4_fields)
fields: immutabledict[str, ARC4Type] = attrs.field(converter=_require_arc4_fields)
immutable: bool = attrs.field()
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
arc4_name: str = attrs.field(init=False, eq=False)
Expand All @@ -452,29 +467,17 @@ def _arc4_name(self) -> str:
return f"({','.join(item.arc4_name for item in self.types)})"

@cached_property
def names(self) -> Sequence[str]:
return list(self.fields.keys())
def names(self) -> tuple[str, ...]:
return tuple(self.fields.keys())

@cached_property
def types(self) -> Sequence[ARC4Type]:
return list(self.fields.values())
def types(self) -> tuple[ARC4Type, ...]:
return tuple(self.fields.values())

def can_encode_type(self, wtype: WType) -> bool:
if wtype == self.decode_type:
return True
elif not isinstance(wtype, WTuple) or len(wtype.types) != len(self.types):
return False
elif wtype.names is not None:
# Named tuple must have same fields and types
return len(wtype.names) == len(self.fields) and all(
n == f and (t == ft or ft.can_encode_type(t))
for n, t, (f, ft) in zip(
wtype.names, wtype.types, self.fields.items(), strict=True
)
)
return all(
arc4_wtype == encode_wtype or arc4_wtype.can_encode_type(encode_wtype)
for arc4_wtype, encode_wtype in zip(self.types, wtype.types, strict=True)
return super().can_encode_type(wtype) or (
_is_arc4_encodeable_tuple(wtype, self.types)
and (wtype.names is None or wtype.names == self.names)
)


Expand Down
16 changes: 16 additions & 0 deletions src/puya/ir/_puya_lib.awst.json
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -2270,6 +2272,8 @@
"column": 11,
"end_column": 25
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -2356,6 +2360,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -3091,6 +3097,8 @@
"column": 11,
"end_column": 25
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -3159,6 +3167,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -4244,6 +4254,8 @@
"column": 11,
"end_column": 26
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down Expand Up @@ -4312,6 +4324,8 @@
}
],
"source_location": null,
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"body": {
Expand Down Expand Up @@ -5652,6 +5666,8 @@
"column": 11,
"end_column": 26
},
"name": "tuple<bytes,bytes>",
"names": null,
"_type": "WTuple"
},
"_type": "TupleExpression"
Expand Down
20 changes: 7 additions & 13 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,16 +397,9 @@ def handle_arc4_assign(
source_location=source_location,
is_mutation=True,
)
case awst_nodes.TupleItemExpression(
base=awst_nodes.Expression(wtype=wtypes.ARC4Tuple() as base_type),
index=str(),
):
raise InternalError(
f"{base_type} does not support indexing by name", target.source_location
)
case awst_nodes.TupleItemExpression(
base=awst_nodes.Expression(wtype=wtypes.ARC4Tuple() as tuple_wtype) as base_expr,
index=int(index_value),
index=index_value,
):
return handle_arc4_assign(
context,
Expand Down Expand Up @@ -448,11 +441,12 @@ def handle_arc4_assign(


def _get_tuple_var_name(expr: awst_nodes.TupleItemExpression) -> str:
if isinstance(expr.base, awst_nodes.TupleItemExpression):
return format_tuple_index(expr.base.wtype, _get_tuple_var_name(expr.base), expr.index)
if isinstance(expr.base, awst_nodes.VarExpression):
return format_tuple_index(expr.base.wtype, expr.base.name, expr.index)
raise CodeError("Invalid assignment target", expr.base.source_location)
if isinstance(expr.base.wtype, wtypes.WTuple):
if isinstance(expr.base, awst_nodes.TupleItemExpression):
return format_tuple_index(expr.base.wtype, _get_tuple_var_name(expr.base), expr.index)
if isinstance(expr.base, awst_nodes.VarExpression):
return format_tuple_index(expr.base.wtype, expr.base.name, expr.index)
raise CodeError("invalid assignment target", expr.base.source_location)


def concat_values(
Expand Down
20 changes: 7 additions & 13 deletions src/puya/ir/builder/itxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ def _resolve_inner_txn_params_var_name(self, params: awst_nodes.Expression) -> s
case awst_nodes.VarExpression(name=var_name):
pass
case awst_nodes.TupleItemExpression(
base=awst_nodes.VarExpression(name=name, wtype=base_wtype), index=index
base=awst_nodes.VarExpression(name=name, wtype=wtypes.WTuple() as base_wtype),
index=index,
):
return format_tuple_index(base_wtype, name, index)
case awst_nodes.Copy(value=value):
Expand All @@ -570,7 +571,7 @@ def _get_assignment_target_local_names(
match target:
case awst_nodes.VarExpression(name=var_name) if expected_number == 1:
return [(var_name, target.source_location)]
case awst_nodes.VarExpression(name=var_name, wtype=var_wtype):
case awst_nodes.VarExpression(name=var_name, wtype=wtypes.WTuple() as var_wtype):
return [
(format_tuple_index(var_wtype, var_name, idx), target.source_location)
for idx in range(expected_number)
Expand All @@ -581,8 +582,7 @@ def _get_assignment_target_local_names(
items = typing.cast(Sequence[awst_nodes.VarExpression], items)
return [(expr.name, expr.source_location) for expr in items]
case awst_nodes.TupleItemExpression(
base=awst_nodes.TupleExpression(wtype=tuple_wtype) as base,
index=index,
base=awst_nodes.TupleExpression(wtype=tuple_wtype) as base, index=index
):
tuple_names = _get_assignment_target_local_names(base, len(tuple_wtype.types))
return [tuple_names[index]]
Expand Down Expand Up @@ -646,18 +646,12 @@ def visit_inner_transaction_field(self, itxn_field: awst_nodes.InnerTransactionF
pass

def visit_tuple_item_expression(self, expr: awst_nodes.TupleItemExpression) -> None:
if isinstance(expr.index, str):
assert isinstance(expr.base.wtype, wtypes.WTuple), "Tuple item must index tuple"
idx_num = expr.base.wtype.name_to_index(expr.index, expr.source_location)
else:
idx_num = expr.index

start_len = len(self._actions)
super().visit_tuple_item_expression(expr)
added = self._actions[start_len:]
# only keep the relevant action
if isinstance(expr.wtype, wtypes.WInnerTransaction):
self._actions[start_len:] = [added[idx_num]]
self._actions[start_len:] = [added[expr.index]]

def visit_slice_expression(self, expr: awst_nodes.SliceExpression) -> None:
start_len = len(self._actions)
Expand Down Expand Up @@ -700,13 +694,13 @@ def _is_last_itxn(expr: awst_nodes.Expression) -> bool:
if not isinstance(base.wtype, wtypes.WTuple):
return False

idx_num = (
index = (
expr.index
if isinstance(expr, awst_nodes.TupleItemExpression)
else base.wtype.name_to_index(expr.name, expr.source_location)
)
tuple_size = len(base.wtype.types)
if idx_num == -1 or (idx_num + 1) == tuple_size:
if index == -1 or (index + 1) == tuple_size:
return _is_submit_expr_of_size(base, tuple_size)
else:
return False
Expand Down
Loading

0 comments on commit eaf2700

Please sign in to comment.