Skip to content

Commit

Permalink
WIP return Annotated on Autowired new cls call
Browse files Browse the repository at this point in the history
  • Loading branch information
roo-oliv committed Jul 12, 2024
1 parent 4a21131 commit 1a29a29
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
12 changes: 10 additions & 2 deletions injectable/autowiring/autowired_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def wrapper(*args, **kwargs):
def _get_parameter_annotation(parameter) -> Union[type, _Autowired]:
if isinstance(parameter.annotation, _AnnotatedAlias):
autowired_annotations = list(
filter(lambda t: _is_autowired(t), get_args(parameter.annotation))
filter(lambda t: _is_autowired(t), _get_args_flattened(parameter.annotation))
)
if len(autowired_annotations) == 0:
return parameter.annotation
if len(autowired_annotations) > 1:
raise AutowiringError("Multiple Autowired annotations found")
autowired_annotation = autowired_annotations[0]
if not isinstance(autowired_annotation, _Autowired):
return autowired_annotation(dependency=get_args(parameter.annotation)[0])
return get_args(autowired_annotation(dependency=get_args(parameter.annotation)[0]))[1]
if autowired_annotation.dependency is None:
return type(autowired_annotation)(
dependency=get_args(parameter.annotation)[0],
Expand All @@ -111,3 +111,11 @@ def _is_autowired(annotation) -> bool:
return isinstance(annotation, _Autowired) or (
inspect.isclass(annotation) and issubclass(annotation, Autowired)
)


def _get_args_flattened(annotation) -> list:
return [
arg
for e in get_args(annotation)
for arg in (_get_args_flattened(e) if isinstance(e, _AnnotatedAlias) else [e])
]
7 changes: 5 additions & 2 deletions injectable/autowiring/autowired_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, TypeVar, Sequence
from typing import Union, TypeVar, Sequence, Annotated

import typing_inspect

Expand Down Expand Up @@ -144,10 +144,13 @@ def __new__(
exclude_groups: Sequence[str] = None,
lazy: bool = False,
) -> T:
return _Autowired(
instance = _Autowired(
dependency,
namespace=namespace,
group=group,
exclude_groups=exclude_groups,
lazy=lazy,
)
if dependency is None:
return instance
return Annotated[dependency, instance]

0 comments on commit 1a29a29

Please sign in to comment.