Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARC4 operators & improved optimisation #104

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/puya/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@
import structlog

from puya.awst import wtypes
from puya.awst.nodes import BoolConstant, Expression, IntrinsicCall, Literal
from puya.awst.nodes import (
BigUIntBinaryOperator,
BoolConstant,
Expression,
IntrinsicCall,
Literal,
UInt64BinaryOperator,
)
from puya.awst_build.eb.base import BuilderBinaryOp
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import expect_operand_wtype
from puya.errors import CodeError

if TYPE_CHECKING:
from puya.awst_build.eb.base import ExpressionBuilder
Expand Down Expand Up @@ -37,3 +46,41 @@ def uint64_to_biguint(
stack_args=[arg],
)
return itob_call


def translate_uint64_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> UInt64BinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return UInt64BinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported UInt64 math operator {operator.value}", loc) from ex


def translate_biguint_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> BigUIntBinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return BigUIntBinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported BigUInt math operator {operator.value}", loc) from ex
110 changes: 108 additions & 2 deletions src/puya/awst_build/eb/arc4/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,31 @@

from puya.awst import wtypes
from puya.awst.nodes import (
ARC4Decode,
ARC4Encode,
BigUIntBinaryOperation,
DecimalConstant,
Expression,
IntegerConstant,
Literal,
NumericComparison,
NumericComparisonExpression,
ReinterpretCast,
Statement,
UInt64BinaryOperation,
)
from puya.awst_build.eb._utils import (
translate_biguint_math_operator,
translate_uint64_math_operator,
uint64_to_biguint,
)
from puya.awst_build.eb._utils import uint64_to_biguint
from puya.awst_build.eb.arc4.base import (
ARC4ClassExpressionBuilder,
ARC4EncodedExpressionBuilder,
arc4_bool_bytes,
get_integer_literal_value,
)
from puya.awst_build.eb.base import BuilderComparisonOp, ExpressionBuilder
from puya.awst_build.eb.base import BuilderBinaryOp, BuilderComparisonOp, ExpressionBuilder
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import convert_literal_to_expr
from puya.errors import CodeError, InternalError, TodoError
Expand Down Expand Up @@ -185,6 +193,14 @@ def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> Expres
negate=negate,
)

def unary_plus(self, location: SourceLocation) -> ExpressionBuilder:
# unary + is allowed, but for the current types it has no real impact
# so just expand the existing expression to include the unary operator
raise TodoError(location)

def bitwise_invert(self, location: SourceLocation) -> ExpressionBuilder:
raise TodoError(location)

def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
Expand Down Expand Up @@ -216,6 +232,96 @@ def compare(
)
return var_expression(cmp_expr)

def binary_op(
self,
other: ExpressionBuilder | Literal,
op: BuilderBinaryOp,
location: SourceLocation,
*,
reverse: bool,
) -> ExpressionBuilder:
other_expr = convert_literal_to_expr(other, self.wtype)
if self.wtype.n <= 64:
result_expr = self._uint64_binary_op(other_expr, op, location, reverse=reverse)
else:
result_expr = self._biguint_binary_op(other_expr, op, location, reverse=reverse)
encoded_result = ARC4Encode(value=result_expr, source_location=location, wtype=self.wtype)
return var_expression(encoded_result)

def _uint64_binary_op(
self, other: Expression, op: BuilderBinaryOp, location: SourceLocation, *, reverse: bool
) -> Expression:
if other.wtype == self.wtype:
other = ARC4Decode(
value=other,
wtype=wtypes.uint64_wtype,
source_location=location,
)
elif isinstance(other.wtype, wtypes.ARC4UIntN):
raise TodoError(location, "TODO: support mixed size operators with arc4 numerics")
elif other.wtype == wtypes.uint64_wtype:
pass
elif other.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to arc4.UIntN")
else:
return NotImplemented
lhs: Expression = ARC4Decode(
value=self.expr,
wtype=wtypes.uint64_wtype,
source_location=self.source_location,
)
rhs = other
if reverse:
(lhs, rhs) = (rhs, lhs)
uint64_op = translate_uint64_math_operator(op, location)
bin_op_expr = UInt64BinaryOperation(
source_location=location, left=lhs, op=uint64_op, right=rhs
)
return bin_op_expr

def _biguint_binary_op(
self, other: Expression, op: BuilderBinaryOp, location: SourceLocation, *, reverse: bool
) -> Expression:
if other.wtype == self.wtype:
other = ReinterpretCast(
expr=other,
wtype=wtypes.biguint_wtype,
source_location=other.source_location,
)
elif isinstance(other.wtype, wtypes.ARC4UIntN):
raise TodoError(location, "TODO: support mixed size operators with arc4 numerics")
elif other.wtype == wtypes.uint64_wtype:
other = uint64_to_biguint(other, location)
elif other.wtype == wtypes.biguint_wtype:
pass
elif other.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to arc4.UIntN")
else:
return NotImplemented
lhs: Expression = ReinterpretCast(
expr=self.expr,
wtype=wtypes.biguint_wtype,
source_location=self.source_location,
)
rhs = other
if reverse:
(lhs, rhs) = (rhs, lhs)
biguint_op = translate_biguint_math_operator(op, location)
bin_op_expr = BigUIntBinaryOperation(
source_location=location, left=lhs, op=biguint_op, right=rhs
)
return bin_op_expr

def augmented_assignment(
self, op: BuilderBinaryOp, rhs: ExpressionBuilder | Literal, location: SourceLocation
) -> Statement:
raise TodoError(location)
# rhs_expr = convert_literal_to_expr(rhs, self.wtype)
# if self.wtype.n <= 64:
# return self._uint64_augmented_assignment(rhs_expr, op, location)
# else:
# return self._biguint_augmented_assignment(rhs_expr, op, location)


class UFixedNxMExpressionBuilder(ARC4EncodedExpressionBuilder):
def __init__(self, expr: Expression):
Expand Down
28 changes: 4 additions & 24 deletions src/puya/awst_build/eb/biguint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@
from puya.awst.nodes import (
BigUIntAugmentedAssignment,
BigUIntBinaryOperation,
BigUIntBinaryOperator,
BigUIntConstant,
Literal,
NumericComparison,
NumericComparisonExpression,
ReinterpretCast,
Statement,
)
from puya.awst_build.eb._utils import uint64_to_biguint
from puya.awst_build.eb._utils import translate_biguint_math_operator, uint64_to_biguint
from puya.awst_build.eb.base import (
BuilderBinaryOp,
BuilderComparisonOp,
Expand Down Expand Up @@ -124,7 +123,7 @@ def binary_op(
if other_expr.wtype == self.wtype:
pass
elif other_expr.wtype == wtypes.uint64_wtype:
other_expr = uint64_to_biguint(other, location)
other_expr = uint64_to_biguint(other_expr, location)
elif other_expr.wtype == wtypes.bool_wtype:
raise TodoError(location, "TODO: support upcast from bool to biguint")
else:
Expand All @@ -133,7 +132,7 @@ def binary_op(
rhs = other_expr
if reverse:
(lhs, rhs) = (rhs, lhs)
biguint_op = _translate_biguint_math_operator(op, location)
biguint_op = translate_biguint_math_operator(op, location)
bin_op_expr = BigUIntBinaryOperation(
source_location=location, left=lhs, op=biguint_op, right=rhs
)
Expand All @@ -154,29 +153,10 @@ def augmented_assignment(
f"Invalid operand type {value.wtype} for {op.value}= with {self.wtype}", location
)
target = self.lvalue()
biguint_op = _translate_biguint_math_operator(op, location)
biguint_op = translate_biguint_math_operator(op, location)
return BigUIntAugmentedAssignment(
source_location=location,
target=target,
value=value,
op=biguint_op,
)


def _translate_biguint_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> BigUIntBinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return BigUIntBinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported BigUInt math operator {operator.value}", loc) from ex
25 changes: 3 additions & 22 deletions src/puya/awst_build/eb/uint64.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
Statement,
UInt64AugmentedAssignment,
UInt64BinaryOperation,
UInt64BinaryOperator,
UInt64Constant,
UInt64UnaryOperation,
UInt64UnaryOperator,
)
from puya.awst_build.eb._utils import translate_uint64_math_operator
from puya.awst_build.eb.base import (
BuilderBinaryOp,
BuilderComparisonOp,
Expand Down Expand Up @@ -130,7 +130,7 @@ def binary_op(
rhs = other_expr
if reverse:
(lhs, rhs) = (rhs, lhs)
uint64_op = _translate_uint64_math_operator(op, location)
uint64_op = translate_uint64_math_operator(op, location)
bin_op_expr = UInt64BinaryOperation(
source_location=location, left=lhs, op=uint64_op, right=rhs
)
Expand All @@ -149,26 +149,7 @@ def augmented_assignment(
f"Invalid operand type {value.wtype} for {op.value}= with {self.wtype}", location
)
target = self.lvalue()
uint64_op = _translate_uint64_math_operator(op, location)
uint64_op = translate_uint64_math_operator(op, location)
return UInt64AugmentedAssignment(
source_location=location, target=target, value=value, op=uint64_op
)


def _translate_uint64_math_operator(
operator: BuilderBinaryOp, loc: SourceLocation
) -> UInt64BinaryOperator:
if operator is BuilderBinaryOp.div:
logger.error(
(
"To maintain semantic compatibility with Python, "
"only the truncating division operator (//) is supported "
),
location=loc,
)
# continue traversing code to generate any further errors
operator = BuilderBinaryOp.floor_div
try:
return UInt64BinaryOperator(operator.value)
except ValueError as ex:
raise CodeError(f"Unsupported UInt64 math operator {operator.value}", loc) from ex
Loading
Loading