diff --git a/rocketry/core/parameters/parameters.py b/rocketry/core/parameters/parameters.py index dbb0ec00..72743129 100644 --- a/rocketry/core/parameters/parameters.py +++ b/rocketry/core/parameters/parameters.py @@ -3,6 +3,8 @@ from functools import partial import inspect +from typing_extensions import Annotated, get_args, get_origin + from rocketry._base import RedBase from rocketry.core.utils import is_pickleable from rocketry.core.utils import filter_keyword_args @@ -55,7 +57,17 @@ def _from_signature(cls, __func:Callable, **kwargs) -> 'Parameters': func_params = inspect.signature(__func).parameters params = cls() for name, param in func_params.items(): - default = param.default + if get_origin(param.annotation) is Annotated: + annotated_args = get_args(param.annotation) + arguments = annotated_args[1:] + assert ( + len(arguments) == 1 + ), f"Cannot specify multiple `Annotated` Custom arguments for `{name}`!" + + default = arguments[0] + else: + default = param.default + if isinstance(default, BaseArgument): params[name] = default return params diff --git a/rocketry/test/parameters/test_construct.py b/rocketry/test/parameters/test_construct.py index e96f0ade..253a56f4 100644 --- a/rocketry/test/parameters/test_construct.py +++ b/rocketry/test/parameters/test_construct.py @@ -1,7 +1,8 @@ import pytest +from typing_extensions import Annotated import rocketry -from rocketry.args import FuncArg +from rocketry.args import FuncArg, SimpleArg from rocketry.core import Parameters from rocketry.args import Private @@ -76,3 +77,19 @@ def my_param(): assert isinstance(session.parameters._params['a_param'], FuncArg) assert session.parameters['a_param'] == 5 assert session.parameters._params['a_param'].func is my_param + + +def test_annotated(): + arg = SimpleArg(1) + + def func(arg: Annotated[int, arg]): + ... + + params = Parameters._from_signature(func).materialize() + assert params["arg"] == 1 + + def func2(arg: Annotated[int, arg, str]): + ... + + with pytest.raises(AssertionError): + Parameters._from_signature(func2)