diff --git a/src/puya/arc32.py b/src/puya/arc32.py index 5d8f44503a..a89b598be7 100644 --- a/src/puya/arc32.py +++ b/src/puya/arc32.py @@ -6,6 +6,8 @@ from puya import log from puya.errors import InternalError from puya.models import ( + ABIMethodArgConstantDefault, + ABIMethodArgDefault, ARC4ABIMethod, ARC4BareMethod, ARC4CreateOption, @@ -91,26 +93,27 @@ def _get_signature(method: ARC4ABIMethod) -> str: def _encode_default_arg( - metadata: ContractMetaData, source: str, loc: SourceLocation | None + metadata: ContractMetaData, source: ABIMethodArgDefault, loc: SourceLocation | None ) -> JSONDict: - if state := metadata.global_state.get(source): + if isinstance(source, ABIMethodArgConstantDefault): + return {"source": "constant", "data": source.value} + if state := metadata.global_state.get(source.name): return { "source": "global-state", # TODO: handle non utf-8 bytes "data": state.key.decode("utf-8"), } - if state := metadata.local_state.get(source): + if state := metadata.local_state.get(source.name): return { "source": "local-state", "data": state.key.decode("utf-8"), } for method in metadata.arc4_methods: - if isinstance(method, ARC4ABIMethod) and method.name == source: + if isinstance(method, ARC4ABIMethod) and method.name == source.name: return { "source": "abi-method", "data": _encode_abi_method(method), } - # TODO: constants raise InternalError(f"Cannot find source '{source}' on {metadata.ref}", loc) diff --git a/src/puya/ir/arc4_router.py b/src/puya/ir/arc4_router.py index 9b235a57d4..9ba95bbebf 100644 --- a/src/puya/ir/arc4_router.py +++ b/src/puya/ir/arc4_router.py @@ -9,6 +9,7 @@ ) from puya.errors import CodeError, InternalError from puya.models import ( + ABIMethodArgConstantDefault, ARC4ABIMethod, ARC4ABIMethodConfig, ARC4BareMethod, @@ -480,7 +481,7 @@ def _validate_default_args( args_by_name = {a.name: a for a in method.args} for ( parameter_name, - source_name, + default_source, ) in method.arc4_method_config.default_args.items(): # any invalid parameter matches should have been caught earlier parameter = args_by_name[parameter_name] @@ -492,6 +493,17 @@ def _validate_default_args( case "account": param_arc4_type = "address" + if isinstance(default_source, ABIMethodArgConstantDefault): + if not _is_valid_client_literal_for_arc4_type( + default_source.value, param_arc4_type + ): + logger.warning( + f"'{default_source.value}' is not a valid" + f" default value for parameter '{parameter_name}'" + ) + continue + + source_name = default_source.name try: source = known_sources[source_name] except KeyError as ex: @@ -681,6 +693,18 @@ def _get_abi_signature(subroutine: awst_nodes.ContractMethod, config: ARC4ABIMet return f"{config.name}({','.join(arg_types)}){return_type}" +def _is_valid_client_literal_for_arc4_type(literal: str | int, arc4_type_alias: str) -> bool: + if arc4_type_alias.startswith(("uint", "ufixed")): + return isinstance(literal, int) + + match arc4_type_alias: + case "byte" | "bool": + return isinstance(literal, int) + case "address" | "string": + return isinstance(literal, str) + return False + + def _wtype_to_arc4(wtype: wtypes.WType, loc: SourceLocation | None = None) -> str: match wtype: case wtypes.ARC4Type(arc4_name=arc4_name): diff --git a/src/puya/models.py b/src/puya/models.py index f469b9b591..36e99e5987 100644 --- a/src/puya/models.py +++ b/src/puya/models.py @@ -66,6 +66,19 @@ class ARC4BareMethodConfig: create: ARC4CreateOption = ARC4CreateOption.disallow +@attrs.frozen(kw_only=True) +class ABIMethodArgConstantDefault: + value: int | str + + +@attrs.frozen(kw_only=True) +class ABIMethodArgMemberDefault: + name: str + + +ABIMethodArgDefault = ABIMethodArgMemberDefault | ABIMethodArgConstantDefault + + @attrs.frozen(kw_only=True) class ARC4ABIMethodConfig: source_location: SourceLocation @@ -77,7 +90,7 @@ class ARC4ABIMethodConfig: create: ARC4CreateOption = ARC4CreateOption.disallow name: str readonly: bool = False - default_args: immutabledict[str, str] = immutabledict() + default_args: immutabledict[str, ABIMethodArgDefault] = immutabledict() """Mapping is from parameter -> source""" structs: immutabledict[str, ARC32StructDef] = immutabledict() diff --git a/src/puyapy/awst_build/arc4_utils.py b/src/puyapy/awst_build/arc4_utils.py index 7cee4f749e..c232226e79 100644 --- a/src/puyapy/awst_build/arc4_utils.py +++ b/src/puyapy/awst_build/arc4_utils.py @@ -12,6 +12,8 @@ from puya.awst import wtypes from puya.errors import CodeError, InternalError from puya.models import ( + ABIMethodArgDefault, + ABIMethodArgMemberDefault, ARC4ABIMethodConfig, ARC4BareMethodConfig, ARC4CreateOption, @@ -183,7 +185,7 @@ def get_arc4_abimethod_data( readonly = default_readonly # map "default_args" param - default_args = dict[str, str]() + default_args = dict[str, ABIMethodArgDefault]() match evaluated_args.pop("default_args", {}): case {**options}: method_arg_names = func_types.keys() - {"output"} @@ -198,7 +200,7 @@ def get_arc4_abimethod_data( else: # if it's in method_arg_names, it's a str assert isinstance(parameter, str) - default_args[parameter] = value + default_args[parameter] = ABIMethodArgMemberDefault(name=value) case invalid_default_args_option: context.error(f"invalid default_args option: {invalid_default_args_option}", dec_loc)