Skip to content

Commit

Permalink
refactor: use BytesAugmentedAssignment for arc4.String
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Nov 7, 2024
1 parent 6933902 commit 94f2531
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
8 changes: 6 additions & 2 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,11 +1337,15 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
@attrs.frozen
class BytesAugmentedAssignment(Statement):
target: Lvalue = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
validator=[
expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype, wtypes.arc4_string_alias)
]
)
op: BytesBinaryOperator
value: Expression = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
validator=[
expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype, wtypes.arc4_string_alias)
]
)

@value.validator
Expand Down
15 changes: 11 additions & 4 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,14 +1053,21 @@ def visit_biguint_augmented_assignment(
def visit_bytes_augmented_assignment(
self, statement: awst_nodes.BytesAugmentedAssignment
) -> TStatement:
target_value = self.visit_and_materialise_single(statement.target)
rhs = self.visit_and_materialise_single(statement.value)
expr = create_bytes_binary_op(statement.op, target_value, rhs, statement.source_location)
if statement.target.wtype == wtypes.arc4_string_alias:
value: ValueProvider = arc4.concat_values(
self.context, statement.target, statement.value, statement.source_location
)
else:
target_value = self.visit_and_materialise_single(statement.target)
rhs = self.visit_and_materialise_single(statement.value)
value = create_bytes_binary_op(
statement.op, target_value, rhs, statement.source_location
)

handle_assignment(
self.context,
target=statement.target,
value=expr,
value=value,
is_nested_update=False,
assignment_location=statement.source_location,
)
Expand Down
18 changes: 6 additions & 12 deletions src/puyapy/awst_build/eb/arc4/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
ARC4Decode,
ARC4Encode,
ArrayConcat,
AssignmentStatement,
BytesAugmentedAssignment,
BytesBinaryOperator,
Expression,
Statement,
StringConstant,
Expand Down Expand Up @@ -103,17 +104,10 @@ def augmented_assignment(
else:
value = expect.argument_of_type_else_dummy(rhs, self.pytype).resolve()

# TODO: does this actually need to be a AugmentedAssignment node to ensure LHS is only
# evaluated once
lhs = self.single_eval().resolve_lvalue()
return AssignmentStatement(
target=lhs,
value=ArrayConcat(
left=lhs,
right=value,
wtype=wtypes.arc4_string_alias,
source_location=location,
),
return BytesAugmentedAssignment(
target=self.resolve_lvalue(),
op=BytesBinaryOperator.add,
value=value,
source_location=location,
)

Expand Down
8 changes: 4 additions & 4 deletions test_cases/arc4_types/out/module.awst
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ contract Arc4StringTypesContract
world: arc4.dynamic_array<arc4.uint8> = arc4_encode('World!', arc4.dynamic_array<arc4.uint8>)
assert(arc4_encode('Hello World!', arc4.dynamic_array<arc4.uint8>) == hello + space + world)
thing: arc4.dynamic_array<arc4.uint8> = arc4_encode('hi', arc4.dynamic_array<arc4.uint8>)
thing: arc4.dynamic_array<arc4.uint8> = thing + thing
thing += thing
assert(thing == arc4_encode('hihi', arc4.dynamic_array<arc4.uint8>))
value: arc4.dynamic_array<arc4.uint8> = arc4_encode('a', arc4.dynamic_array<arc4.uint8>) + arc4_encode('b', arc4.dynamic_array<arc4.uint8>) + arc4_encode('cd', arc4.dynamic_array<arc4.uint8>)
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('e', arc4.dynamic_array<arc4.uint8>)
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('f', arc4.dynamic_array<arc4.uint8>)
value: arc4.dynamic_array<arc4.uint8> = value + arc4_encode('g', arc4.dynamic_array<arc4.uint8>)
value += arc4_encode('e', arc4.dynamic_array<arc4.uint8>)
value += arc4_encode('f', arc4.dynamic_array<arc4.uint8>)
value += arc4_encode('g', arc4.dynamic_array<arc4.uint8>)
assert(arc4_encode('abcdefg', arc4.dynamic_array<arc4.uint8>) == value)
assert(arc4_decode(arc4_encode('', arc4.dynamic_array<arc4.uint8>), string) == '')
assert(arc4_decode(arc4_encode('hello', arc4.dynamic_array<arc4.uint8>), string) == 'hello')
Expand Down

0 comments on commit 94f2531

Please sign in to comment.