Skip to content

Commit

Permalink
add support for autowiring with Annotated (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
lozanocampillod authored Jul 4, 2024
1 parent 7cbd7bb commit 85f5d25
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
23 changes: 20 additions & 3 deletions injectable/autowiring/autowired_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from functools import wraps
from typing import TypeVar, Callable, Any
from typing import TypeVar, Callable, Any, get_args, _AnnotatedAlias

from injectable.autowiring.autowired_type import _Autowired
from injectable.errors import AutowiringError
Expand Down Expand Up @@ -43,7 +43,9 @@ def autowired(func: T) -> T:
signature = inspect.signature(func)
autowired_parameters = []
for index, parameter in enumerate(signature.parameters.values()):
if not isinstance(parameter.annotation, _Autowired):
annotation = _get_parameter_annotation(parameter)

if not isinstance(annotation, _Autowired):
if len(autowired_parameters) == 0 or parameter.kind in [
parameter.KEYWORD_ONLY,
parameter.VAR_KEYWORD,
Expand All @@ -68,7 +70,8 @@ def wrapper(*args, **kwargs):
for parameter in autowired_parameters:
if parameter.name in bound_arguments:
continue
dependency = parameter.annotation.inject()
annotation = _get_parameter_annotation(parameter)
dependency = annotation.inject()
if parameter.kind is parameter.POSITIONAL_ONLY:
args.append(dependency)
else:
Expand All @@ -77,3 +80,17 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


def _get_parameter_annotation(parameter) -> type:
if isinstance(parameter.annotation, _AnnotatedAlias):
autowired_annotations = list(
filter(lambda t: isinstance(t, _Autowired), get_args(parameter.annotation))
)
if len(autowired_annotations) == 0:
return parameter.annotation
if len(autowired_annotations) > 1:
raise AutowiringError("Multiple Autowired annotations found")
return autowired_annotations[0]

return parameter.annotation
43 changes: 43 additions & 0 deletions tests/unit/autowiring/autowired_decorator_unit_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Annotated
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -185,3 +186,45 @@ def f(a: AutowiredMockA, b: AutowiredMockB, c: AutowiredMockC):
assert parameters["a"] is None
assert parameters["b"] is None
assert parameters["c"] is AutowiredMockC.inject()

def test__autowired__with_one_autowired_in_annotated(self):
# given
AutowiredMock = MagicMock(spec=_Autowired)

@autowired
def f(a: Annotated[type, AutowiredMock]):
return {"a": a}

# when
parameters = f()

# then
assert AutowiredMock.inject.called is True
assert parameters["a"] is AutowiredMock.inject()

def test__autowired__with_two_autowired_in_annotation_raises(self):
# given
AutowiredMockA = MagicMock(spec=_Autowired)
AutowiredMockB = MagicMock(spec=_Autowired)

with pytest.raises(AutowiringError):

@autowired
def f(a: Annotated[type, AutowiredMockA, AutowiredMockB]):
return {"a": a}

def test__autowired_with_no_autowire_in_annotation_continues(self):
# given
AutowiredMock = MagicMock(spec=_Autowired)

@autowired
def f(a: Annotated[str, str], b: Annotated[type, AutowiredMock]):
return {"a": a, "b": b}

# when
parameters = f("test")

# then
assert AutowiredMock.inject.called is True
assert parameters["a"] == "test"
assert parameters["b"] is AutowiredMock.inject()

0 comments on commit 85f5d25

Please sign in to comment.