Skip to content

Commit

Permalink
fix: check target type for ARC4Decode
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Oct 25, 2024
1 parent 7f4964b commit 1c05162
Show file tree
Hide file tree
Showing 40 changed files with 391 additions and 390 deletions.
10 changes: 5 additions & 5 deletions examples/box_storage/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
arc4_types/Arc4DynamicStringArray 283 124 - | 172 53 -
arc4_types/Arc4MutableParams 471 286 - | 292 141 -
arc4_types/Arc4Mutation 2958 1426 - | 1977 593 -
arc4_types/Arc4NumericTypes 740 186 - | 239 26 -
arc4_types/Arc4NumericTypes 749 186 - | 243 26 -
arc4_types/Arc4RefTypes 85 46 - | 32 27 -
arc4_types/Arc4StringTypes 455 35 - | 245 13 -
arc4_types/Arc4StructsFromAnotherModule 67 12 - | 49 6 -
Expand Down Expand Up @@ -130,4 +130,4 @@
unssa/UnSSA 432 368 - | 241 204 -
voting/VotingRoundApp 1593 1483 - | 734 649 -
with_reentrancy/WithReentrancy 255 242 - | 132 122 -
Total 69191 53576 53517 | 32839 21764 21720
Total 69200 53576 53517 | 32843 21764 21720
14 changes: 7 additions & 7 deletions examples/tictactoe/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/voting/out/module.awst

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,18 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:

@attrs.frozen
class ARC4Decode(Expression):
value: Expression = attrs.field()
value: Expression = attrs.field(
validator=expression_has_wtype(
wtypes.arc4_bool_wtype,
wtypes.ARC4UIntN,
wtypes.ARC4Tuple,
wtypes.ARC4DynamicArray, # only if element type is bytes for now
)
)

@value.validator
def _value_wtype_validator(self, _attribute: object, value: Expression) -> None:
if not isinstance(value.wtype, wtypes.ARC4Type):
raise InternalError(
f"ARC4Decode should only be used with expressions of ARC4Type, got {value.wtype}",
self.source_location,
)
assert isinstance(value.wtype, wtypes.ARC4Type) # validated by `value`
if not value.wtype.can_encode_type(self.wtype):
raise InternalError(
f"ARC4Decode from {value.wtype} should have non ARC4 target type {self.wtype}",
Expand Down
8 changes: 2 additions & 6 deletions src/puya/awst/to_code_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,8 @@ def visit_integer_constant(self, expr: nodes.IntegerConstant) -> str:
suffix = "u"
case wtypes.biguint_wtype:
suffix = "n"
case wtypes.ARC4UIntN(n=n, decode_type=decode_type):
if decode_type == wtypes.uint64_wtype:
suffix = f"arc4u{n}"
else:
assert decode_type == wtypes.biguint_wtype
suffix = f"arc4n{n}"
case wtypes.ARC4UIntN(n=n):
suffix = f"_arc4u{n}"
case _:
raise InternalError(
f"Numeric type not implemented: {expr.wtype}", expr.source_location
Expand Down
41 changes: 12 additions & 29 deletions src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,25 +257,25 @@ def name_to_index(self, name: str, source_location: SourceLocation) -> int:
class ARC4Type(WType):
scalar_type: typing.Literal[AVMType.bytes] = attrs.field(default=AVMType.bytes, init=False)
arc4_name: str = attrs.field(eq=False) # exclude from equality in case of aliasing
decode_type: WType | None
native_type: WType | None

def can_encode_type(self, wtype: WType) -> bool:
return wtype == self.decode_type
return wtype == self.native_type


arc4_bool_wtype: typing.Final = ARC4Type(
name="arc4.bool",
arc4_name="bool",
immutable=True,
decode_type=bool_wtype,
native_type=bool_wtype,
)


@typing.final
@attrs.frozen(kw_only=True)
class ARC4UIntN(ARC4Type):
immutable: bool = attrs.field(default=True, init=False)
decode_type: WType = attrs.field()
native_type: WType = attrs.field(default=biguint_wtype, init=False)
n: int = attrs.field()
arc4_name: str = attrs.field(eq=False)
name: str = attrs.field(init=False)
Expand All @@ -288,22 +288,6 @@ def _n_validator(self, _attribute: object, n: int) -> None:
if not (8 <= n <= 512):
raise CodeError("Bit size must be between 8 and 512 inclusive", self.source_location)

@decode_type.validator
def _decode_type_validator(self, _attribute: object, decode_type: WType) -> None:
if decode_type == uint64_wtype:
if self.n > 64:
raise InternalError(
f"ARC-4 UIntN type received decode type {decode_type},"
f" which is smaller than size {self.n}",
self.source_location,
)
elif decode_type == biguint_wtype:
pass
else:
raise InternalError(
f"Unhandled decode_type for ARC-4 UIntN: {decode_type}", self.source_location
)

@arc4_name.default
def _arc4_name(self) -> str:
return f"uint{self.n}"
Expand All @@ -325,7 +309,7 @@ class ARC4UFixedNxM(ARC4Type):
arc4_name: str = attrs.field(init=False, eq=False)
name: str = attrs.field(init=False)
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
decode_type = attrs.field(default=None, init=False)
native_type: None = attrs.field(default=None, init=False)

@arc4_name.default
def _arc4_name(self) -> str:
Expand Down Expand Up @@ -365,7 +349,7 @@ class ARC4Tuple(ARC4Type):
name: str = attrs.field(init=False)
arc4_name: str = attrs.field(init=False, eq=False)
immutable: bool = attrs.field(init=False)
decode_type: WTuple = attrs.field(init=False)
native_type: WTuple = attrs.field(init=False)

@name.default
def _name(self) -> str:
Expand All @@ -379,8 +363,8 @@ def _arc4_name(self) -> str:
def _immutable(self) -> bool:
return all(typ.immutable for typ in self.types)

@decode_type.default
def _decode_type(self) -> WTuple:
@native_type.default
def _native_type(self) -> WTuple:
return WTuple(self.types, self.source_location)

def can_encode_type(self, wtype: WType) -> bool:
Expand Down Expand Up @@ -409,7 +393,7 @@ def _expect_arc4_type(wtype: WType) -> ARC4Type:
@attrs.frozen(kw_only=True)
class ARC4Array(ARC4Type):
element_type: ARC4Type = attrs.field(converter=_expect_arc4_type)
decode_type: WType | None = None
native_type: WType | None = None
immutable: bool = False


Expand Down Expand Up @@ -469,7 +453,7 @@ class ARC4Struct(ARC4Type):
immutable: bool = attrs.field()
source_location: SourceLocation | None = attrs.field(default=None, eq=False)
arc4_name: str = attrs.field(init=False, eq=False)
decode_type: WType | None = None
native_type: None = attrs.field(default=None, init=False)

@immutable.default
def _immutable(self) -> bool:
Expand Down Expand Up @@ -497,22 +481,21 @@ def can_encode_type(self, wtype: WType) -> bool:
arc4_byte_alias: typing.Final = ARC4UIntN(
n=8,
arc4_name="byte",
decode_type=uint64_wtype,
source_location=None,
)

arc4_string_alias: typing.Final = ARC4DynamicArray(
arc4_name="string",
element_type=arc4_byte_alias,
decode_type=string_wtype,
native_type=string_wtype,
immutable=True,
source_location=None,
)

arc4_address_alias: typing.Final = ARC4StaticArray(
arc4_name="address",
element_type=arc4_byte_alias,
decode_type=account_wtype,
native_type=account_wtype,
array_size=32,
immutable=True,
source_location=None,
Expand Down
6 changes: 3 additions & 3 deletions src/puya/ir/arc4_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,12 +716,12 @@ def _maybe_avm_to_arc4_equivalent_type(wtype: wtypes.WType) -> wtypes.ARC4Type |
case wtypes.bool_wtype:
return wtypes.arc4_bool_wtype
case wtypes.uint64_wtype:
return wtypes.ARC4UIntN(n=64, decode_type=wtype, source_location=None)
return wtypes.ARC4UIntN(n=64, source_location=None)
case wtypes.biguint_wtype:
return wtypes.ARC4UIntN(n=512, decode_type=wtype, source_location=None)
return wtypes.ARC4UIntN(n=512, source_location=None)
case wtypes.bytes_wtype:
return wtypes.ARC4DynamicArray(
element_type=wtypes.arc4_byte_alias, decode_type=wtype, source_location=None
element_type=wtypes.arc4_byte_alias, native_type=wtype, source_location=None
)
case wtypes.string_wtype:
return wtypes.arc4_string_alias
Expand Down
44 changes: 21 additions & 23 deletions src/puya/ir/builder/arc4.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,40 @@ def _decode_arc4_value(
target_wtype: wtypes.WType,
loc: SourceLocation,
) -> ValueProvider:
match arc4_wtype:
case wtypes.ARC4UIntN(n=scale) | wtypes.ARC4UFixedNxM(n=scale):
if scale > 64:
return value
else:
return Intrinsic(
op=AVMOp.btoi,
args=[value],
source_location=loc,
)
case wtypes.arc4_bool_wtype:
match arc4_wtype, target_wtype:
case wtypes.ARC4UIntN(), wtypes.biguint_wtype:
return value
case wtypes.ARC4UIntN(), (wtypes.uint64_wtype | wtypes.bool_wtype):
return Intrinsic(
op=AVMOp.btoi,
args=[value],
source_location=loc,
)
case wtypes.arc4_bool_wtype, wtypes.bool_wtype:
return Intrinsic(
op=AVMOp.getbit,
args=[value, UInt64Constant(value=0, source_location=None)],
source_location=loc,
types=(IRType.bool,),
)
case wtypes.ARC4DynamicArray(element_type=wtypes.ARC4UIntN(n=8)):
case wtypes.ARC4DynamicArray(element_type=wtypes.ARC4UIntN(n=8)), (
wtypes.bytes_wtype | wtypes.string_wtype
):
return Intrinsic(
op=AVMOp.extract,
immediates=[2, 0],
args=[value],
source_location=loc,
)
case wtypes.ARC4Tuple() as arc4_tuple:
case wtypes.ARC4Tuple() as arc4_tuple, wtypes.WTuple() as native_tuple if (
len(arc4_tuple.types) == len(native_tuple.types)
):
return _visit_arc4_tuple_decode(
context, arc4_tuple, value, target_wtype=target_wtype, source_location=loc
)
case _:
raise InternalError(
f"Unsupported wtype for ARC4Decode: {arc4_wtype}",
location=loc,
context, arc4_tuple, value, target_wtype=native_tuple, source_location=loc
)
raise InternalError(
f"unsupported ARC4Decode operation from {arc4_wtype} to {target_wtype}", loc
)


def encode_arc4_struct(
Expand Down Expand Up @@ -608,13 +609,10 @@ def _visit_arc4_tuple_decode(
context: IRFunctionBuildContext,
wtype: wtypes.ARC4Tuple,
value: Value,
target_wtype: wtypes.WType,
target_wtype: wtypes.WTuple,
source_location: SourceLocation,
) -> ValueProvider:
items = list[Value]()
if not isinstance(target_wtype, wtypes.WTuple):
raise InternalError("expected ARC4Decode of a tuple to target a WTuple", source_location)

for index, (target_item_wtype, item_wtype) in enumerate(
zip(target_wtype.types, wtype.types, strict=True)
):
Expand Down
6 changes: 2 additions & 4 deletions src/puyapy/awst_build/pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,7 @@ def parameterise(
name=name,
bits=bits,
native_type=native_type,
wtype=wtypes.ARC4UIntN(
n=bits, decode_type=native_type.wtype, source_location=source_location
),
wtype=wtypes.ARC4UIntN(n=bits, source_location=source_location),
)

return parameterise
Expand Down Expand Up @@ -790,7 +788,7 @@ def parameterise(
name="algopy.arc4.DynamicBytes",
wtype=wtypes.ARC4DynamicArray(
element_type=ARC4ByteType.wtype,
decode_type=wtypes.bytes_wtype,
native_type=wtypes.bytes_wtype,
source_location=None,
),
size=None,
Expand Down
Loading

0 comments on commit 1c05162

Please sign in to comment.