Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: check target type for ARC4Decode #336

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/box_storage/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
arc4_types/Arc4DynamicStringArray 283 124 - | 172 53 -
arc4_types/Arc4MutableParams 471 286 - | 292 141 -
arc4_types/Arc4Mutation 2958 1426 - | 1977 593 -
arc4_types/Arc4NumericTypes 740 186 - | 239 26 -
arc4_types/Arc4NumericTypes 749 186 - | 243 26 -
arc4_types/Arc4RefTypes 85 46 - | 32 27 -
arc4_types/Arc4StringTypes 455 35 - | 245 13 -
arc4_types/Arc4StructsFromAnotherModule 67 12 - | 49 6 -
Expand Down Expand Up @@ -130,4 +130,4 @@
unssa/UnSSA 432 368 - | 241 204 -
voting/VotingRoundApp 1593 1483 - | 734 649 -
with_reentrancy/WithReentrancy 255 242 - | 132 122 -
Total 69191 53576 53517 | 32839 21764 21720
Total 69200 53576 53517 | 32843 21764 21720
14 changes: 7 additions & 7 deletions examples/tictactoe/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,18 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:

@attrs.frozen
class ARC4Decode(Expression):
value: Expression = attrs.field()
value: Expression = attrs.field(
validator=expression_has_wtype(
wtypes.arc4_bool_wtype,
wtypes.ARC4UIntN,
wtypes.ARC4Tuple,
wtypes.ARC4DynamicArray, # only if element type is bytes for now
)
)

@value.validator
def _value_wtype_validator(self, _attribute: object, value: Expression) -> None:
if not isinstance(value.wtype, wtypes.ARC4Type):
raise InternalError(
f"ARC4Decode should only be used with expressions of ARC4Type, got {value.wtype}",
self.source_location,
)
assert isinstance(value.wtype, wtypes.ARC4Type) # validated by `value`
if not value.wtype.can_encode_type(self.wtype):
raise InternalError(
f"ARC4Decode from {value.wtype} should have non ARC4 target type {self.wtype}",
Expand Down
8 changes: 2 additions & 6 deletions src/puya/awst/to_code_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,8 @@ def visit_integer_constant(self, expr: nodes.IntegerConstant) -> str:
suffix = "u"
case wtypes.biguint_wtype:
suffix = "n"
case wtypes.ARC4UIntN(n=n, decode_type=decode_type):
if decode_type == wtypes.uint64_wtype:
suffix = f"arc4u{n}"
else:
assert decode_type == wtypes.biguint_wtype
suffix = f"arc4n{n}"
case wtypes.ARC4UIntN(n=n):
suffix = f"_arc4u{n}"
case _:
raise InternalError(
f"Numeric type not implemented: {expr.wtype}", expr.source_location
Expand Down
41 changes: 12 additions & 29 deletions src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,25 +257,25 @@ def name_to_index(self, name: str, source_location: SourceLocation) -> int:
class ARC4Type(WType):
scalar_type: typing.Literal[AVMType.bytes] = attrs.field(default=AVMType.bytes, init=False)
arc4_name: str = attrs.field(eq=False) # exclude from equality in case of aliasing
decode_type: WType | None
native_type: WType | None

def can_encode_type(self, wtype: WType) -> bool:
return wtype == self.decode_type
return wtype == self.native_type


arc4_bool_wtype: typing.Final = ARC4Type(
name="arc4.bool",
arc4_name="bool",
immutable=True,
decode_type=bool_wtype,
native_type=bool_wtype,
)


@typing.final
@attrs.frozen(kw_only=True)
class ARC4UIntN(ARC4Type):
immutable: bool = attrs.field(default=True, init=False)
decode_type: WType = attrs.field()
native_type: WType = attrs.field(default=biguint_wtype, init=False)
n: int = attrs.field()
arc4_name: str = attrs.field(eq=False)
name: str = attrs.field(init=False)
Expand All @@ -288,22 +288,6 @@ def _n_validator(self, _attribute: object, n: int) -> None:
if not (8 <= n <= 512):
raise CodeError("Bit size must be between 8 and 512 inclusive", self.source_location)

@decode_type.validator
def _decode_type_validator(self, _attribute: object, decode_type: WType) -> None:
if decode_type == uint64_wtype:
if self.n > 64:
raise InternalError(
f"ARC-4 UIntN type received decode type {decode_type},"
f" which is smaller than size {self.n}",
self.source_location,
)
elif decode_type == biguint_wtype:
pass
else:
raise InternalError(
f"Unhandled decode_type for ARC-4 UIntN: {decode_type}", self.source_location
)

@arc4_name.default
def _arc4_name(self) -> str:
return f"uint{self.n}"
Expand All @@ -325,7 +309,7 @@ class ARC4UFixedNxM(ARC4Type):
arc4_name: str = attrs.field(init=False, eq=False)
name: str = attrs.field(init=False)
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
decode_type = attrs.field(default=None, init=False)
native_type: None = attrs.field(default=None, init=False)

@arc4_name.default
def _arc4_name(self) -> str:
Expand Down Expand Up @@ -365,7 +349,7 @@ class ARC4Tuple(ARC4Type):
name: str = attrs.field(init=False)
arc4_name: str = attrs.field(init=False, eq=False)
immutable: bool = attrs.field(init=False)
decode_type: WTuple = attrs.field(init=False)
native_type: WTuple = attrs.field(init=False)

@name.default
def _name(self) -> str:
Expand All @@ -379,8 +363,8 @@ def _arc4_name(self) -> str:
def _immutable(self) -> bool:
return all(typ.immutable for typ in self.types)

@decode_type.default
def _decode_type(self) -> WTuple:
@native_type.default
def _native_type(self) -> WTuple:
return WTuple(self.types, self.source_location)

def can_encode_type(self, wtype: WType) -> bool:
Expand Down Expand Up @@ -409,7 +393,7 @@ def _expect_arc4_type(wtype: WType) -> ARC4Type:
@attrs.frozen(kw_only=True)
class ARC4Array(ARC4Type):
element_type: ARC4Type = attrs.field(converter=_expect_arc4_type)
decode_type: WType | None = None
native_type: WType | None = None
immutable: bool = False


Expand Down Expand Up @@ -469,7 +453,7 @@ class ARC4Struct(ARC4Type):
immutable: bool = attrs.field()
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
arc4_name: str = attrs.field(init=False, eq=False)
decode_type: WType | None = None
native_type: None = attrs.field(default=None, init=False)

@immutable.default
def _immutable(self) -> bool:
Expand Down Expand Up @@ -497,22 +481,21 @@ def can_encode_type(self, wtype: WType) -> bool:
arc4_byte_alias: typing.Final = ARC4UIntN(
n=8,
arc4_name="byte",
decode_type=uint64_wtype,
source_location=None,
)

arc4_string_alias: typing.Final = ARC4DynamicArray(
arc4_name="string",
element_type=arc4_byte_alias,
decode_type=string_wtype,
native_type=string_wtype,
immutable=True,
source_location=None,
)

arc4_address_alias: typing.Final = ARC4StaticArray(
arc4_name="address",
element_type=arc4_byte_alias,
decode_type=account_wtype,
native_type=account_wtype,
array_size=32,
immutable=True,
source_location=None,
Expand Down
6 changes: 3 additions & 3 deletions src/puya/ir/arc4_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,12 @@ def _maybe_avm_to_arc4_equivalent_type(wtype: wtypes.WType) -> wtypes.ARC4Type |
case wtypes.bool_wtype:
return wtypes.arc4_bool_wtype
case wtypes.uint64_wtype:
return wtypes.ARC4UIntN(n=64, decode_type=wtype, source_location=None)
return wtypes.ARC4UIntN(n=64, source_location=None)
case wtypes.biguint_wtype:
return wtypes.ARC4UIntN(n=512, decode_type=wtype, source_location=None)
return wtypes.ARC4UIntN(n=512, source_location=None)
case wtypes.bytes_wtype:
return wtypes.ARC4DynamicArray(
element_type=wtypes.arc4_byte_alias, decode_type=wtype, source_location=None
element_type=wtypes.arc4_byte_alias, native_type=wtype, source_location=None
)
case wtypes.string_wtype:
return wtypes.arc4_string_alias
Expand Down
44 changes: 21 additions & 23 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,40 @@ def _decode_arc4_value(
target_wtype: wtypes.WType,
loc: SourceLocation,
) -> ValueProvider:
match arc4_wtype:
case wtypes.ARC4UIntN(n=scale) | wtypes.ARC4UFixedNxM(n=scale):
if scale > 64:
return value
else:
return Intrinsic(
op=AVMOp.btoi,
args=[value],
source_location=loc,
)
case wtypes.arc4_bool_wtype:
match arc4_wtype, target_wtype:
case wtypes.ARC4UIntN(), wtypes.biguint_wtype:
return value
case wtypes.ARC4UIntN(), (wtypes.uint64_wtype | wtypes.bool_wtype):
return Intrinsic(
op=AVMOp.btoi,
args=[value],
source_location=loc,
)
case wtypes.arc4_bool_wtype, wtypes.bool_wtype:
return Intrinsic(
op=AVMOp.getbit,
args=[value, UInt64Constant(value=0, source_location=None)],
source_location=loc,
types=(IRType.bool,),
)
case wtypes.ARC4DynamicArray(element_type=wtypes.ARC4UIntN(n=8)):
case wtypes.ARC4DynamicArray(element_type=wtypes.ARC4UIntN(n=8)), (
wtypes.bytes_wtype | wtypes.string_wtype
):
return Intrinsic(
op=AVMOp.extract,
immediates=[2, 0],
args=[value],
source_location=loc,
)
case wtypes.ARC4Tuple() as arc4_tuple:
case wtypes.ARC4Tuple() as arc4_tuple, wtypes.WTuple() as native_tuple if (
len(arc4_tuple.types) == len(native_tuple.types)
):
return _visit_arc4_tuple_decode(
context, arc4_tuple, value, target_wtype=target_wtype, source_location=loc
)
case _:
raise InternalError(
f"Unsupported wtype for ARC4Decode: {arc4_wtype}",
location=loc,
context, arc4_tuple, value, target_wtype=native_tuple, source_location=loc
)
raise InternalError(
f"unsupported ARC4Decode operation from {arc4_wtype} to {target_wtype}", loc
)


def encode_arc4_struct(
Expand Down Expand Up @@ -608,13 +609,10 @@ def _visit_arc4_tuple_decode(
context: IRFunctionBuildContext,
wtype: wtypes.ARC4Tuple,
value: Value,
target_wtype: wtypes.WType,
target_wtype: wtypes.WTuple,
source_location: SourceLocation,
) -> ValueProvider:
items = list[Value]()
if not isinstance(target_wtype, wtypes.WTuple):
raise InternalError("expected ARC4Decode of a tuple to target a WTuple", source_location)

for index, (target_item_wtype, item_wtype) in enumerate(
zip(target_wtype.types, wtype.types, strict=True)
):
Expand Down
6 changes: 2 additions & 4 deletions src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,7 @@ def parameterise(
name=name,
bits=bits,
native_type=native_type,
wtype=wtypes.ARC4UIntN(
n=bits, decode_type=native_type.wtype, source_location=source_location
),
wtype=wtypes.ARC4UIntN(n=bits, source_location=source_location),
)

return parameterise
Expand Down Expand Up @@ -790,7 +788,7 @@ def parameterise(
name="algopy.arc4.DynamicBytes",
wtype=wtypes.ARC4DynamicArray(
element_type=ARC4ByteType.wtype,
decode_type=wtypes.bytes_wtype,
native_type=wtypes.bytes_wtype,
source_location=None,
),
size=None,
Expand Down
Loading
Loading