From 2e5a0b5e2efa46238c4c0731012a1053682eaae0 Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 00:48:34 +0200 Subject: [PATCH 1/9] add failing test, --- .gitignore | 4 ++++ aioinject/containers.py | 30 +++++++++++++++++++++++++----- aioinject/providers.py | 24 +++++++++++++++++------- aioinject/utils.py | 8 ++++++++ tests/container/test_container.py | 23 +++++++++++++++++++++++ tests/providers/test_scoped.py | 25 ++++++++++++------------- 6 files changed, 89 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index e06aa35..acf64e2 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,7 @@ coverage.xml # virtualenv .venv venv*/ + + +# python cached files +*.py[cod] \ No newline at end of file diff --git a/aioinject/containers.py b/aioinject/containers.py index 07fa557..e4e5b44 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -18,16 +18,36 @@ class Container: def __init__(self) -> None: self.providers: _Providers[Any] = {} self._singletons = SingletonStore() + self._unresolved_providers: list[Provider[Any]] = [] + self._type_context: dict[str, type[Any]] = {} + + def _resolve_unresolved_provider(self) -> None: + for provider in self._unresolved_providers: + with contextlib.suppress(NameError): + self._register_impl(provider) + self._unresolved_providers.remove(provider) + + + def _register_impl(self, provider: Provider[Any]) -> None: + provider_type = provider.resolve_type(self._type_context) + if provider_type in self.providers: + msg = f"Provider for type {provider_type} is already registered" + raise ValueError(msg) + self.providers[provider_type] = provider + if klass_name := getattr(provider_type, "__name__" , None): + self._type_context[klass_name] = provider_type + def register( self, provider: Provider[Any], - ) -> None: - if provider.type_ in self.providers: - msg = f"Provider for type {provider.type_} is already registered" - raise ValueError(msg) + ) -> None: + try: + self._register_impl(provider) + except NameError: + self._unresolved_providers.append(provider) + self._resolve_unresolved_provider() - self.providers[provider.type_] = provider def get_provider(self, type_: type[_T]) -> Provider[_T]: try: diff --git a/aioinject/providers.py b/aioinject/providers.py index bb8a3ed..f3355fc 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -23,6 +23,7 @@ from aioinject.markers import Inject from aioinject.utils import ( + _get_type_hints, is_context_manager_function, remove_annotation, ) @@ -31,6 +32,7 @@ _T = TypeVar("_T") + @dataclass(slots=True, kw_only=True, frozen=True) class Dependency(Generic[_T]): name: str @@ -61,14 +63,13 @@ def _find_inject_marker_in_annotation_args( def collect_dependencies( - dependant: typing.Callable[..., object] | dict[str, Any], + dependant: typing.Callable[..., object] | dict[str, Any], ctx: dict[str, type[Any]] | None = None ) -> typing.Iterable[Dependency[object]]: if not isinstance(dependant, dict): with remove_annotation(dependant.__annotations__, "return"): - type_hints = typing.get_type_hints(dependant, include_extras=True) + type_hints = _get_type_hints(dependant, context=ctx) else: type_hints = dependant - for name, hint in type_hints.items(): dep_type, args = _get_annotation_args(hint) inject_marker = _find_inject_marker_in_annotation_args(args) @@ -120,11 +121,11 @@ def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]: ) -def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]: +def _guess_return_type(factory: _FactoryType[_T], context: dict[str, type[Any] ] | None = None) -> type[_T]: if isclass(factory): return typing.cast(type[_T], factory) - type_hints = typing.get_type_hints(factory) + type_hints = _get_type_hints(factory, context=context) try: return_type = type_hints["return"] except KeyError as e: @@ -159,7 +160,6 @@ class DependencyLifetime(enum.Enum): @runtime_checkable class Provider(Protocol[_T]): - type_: type[_T] impl: Any lifetime: DependencyLifetime @@ -168,6 +168,10 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: ... + + def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: + ... + @property def type_hints(self) -> dict[str, Any]: @@ -197,9 +201,12 @@ def __init__( factory: _FactoryType[_T], type_: type[_T] | None = None, ) -> None: - self.type_ = type_ or _guess_return_type(factory) + self.type_ = type_ self.impl = factory + def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: + return self.type_ or _guess_return_type(self.impl, context=context) + def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: return self.impl(**kwargs) # type: ignore[return-value] @@ -257,6 +264,9 @@ def __init__( self.type_ = type_ or type(object_) self.impl = object_ + def resolve_type(self, _: dict[str, Any] | None = None) -> type[_T]: + return self.type_ + def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002 return self.impl diff --git a/aioinject/utils.py b/aioinject/utils.py index d82e6ca..dae1744 100644 --- a/aioinject/utils.py +++ b/aioinject/utils.py @@ -94,3 +94,11 @@ def remove_annotation( yield if annotation is not sentinel: annotations[name] = annotation + +def _get_type_hints( + obj: Any, + context: dict[str, type[Any]] | None = None +) -> dict[str, Any]: + if not context: + context = {} + return typing.get_type_hints(obj, include_extras=True, localns=context) \ No newline at end of file diff --git a/tests/container/test_container.py b/tests/container/test_container.py index 9db70a1..7249319 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -1,5 +1,6 @@ import contextlib from collections.abc import AsyncIterator +from typing import TYPE_CHECKING import pytest @@ -91,3 +92,25 @@ async def dependency() -> AsyncIterator[int]: await container.aclose() assert shutdown is True + +@pytest.mark.anyio +async def test_deffered_types() -> None: + if TYPE_CHECKING: + from decimal import Decimal + + def some_deffered_type() -> 'Decimal': + from decimal import Decimal + return Decimal('1.0') + + class DoubledDecimal: + def __init__(self, decimal: 'Decimal') -> None: + self.decimal = decimal * 2 + container = Container() + def register_decimal_scoped() -> None: + from decimal import Decimal + container.register(Scoped(some_deffered_type, Decimal)) + + register_decimal_scoped() + container.register(Scoped(DoubledDecimal)) + async with container.context() as ctx: + assert (await ctx.resolve(DoubledDecimal)).decimal == DoubledDecimal(some_deffered_type()).decimal \ No newline at end of file diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index 7a9e125..2f290f8 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -140,21 +140,20 @@ def factory( ) assert provider.dependencies == expected +def iterable() -> Iterator[int]: + yield 42 -def test_generator_return_types() -> None: - def iterable() -> Iterator[int]: - yield 42 +def gen() -> Generator[int, None, None]: + yield 42 - def gen() -> Generator[int, None, None]: - yield 42 +async def async_iterable() -> AsyncIterator[int]: + yield 42 - async def async_iterable() -> AsyncIterator[int]: - yield 42 +async def async_gen() -> AsyncGenerator[int, None]: + yield 42 - async def async_gen() -> AsyncGenerator[int, None]: - yield 42 +@pytest.mark.parametrize("factory", [iterable, gen, async_iterable, async_gen]) +def test_generator_return_types(factory) -> None: + provider = providers.Scoped(factory) + assert provider.resolve_type() is int - factories = [iterable, gen, async_iterable, async_gen] - for factory in factories: - provider = providers.Scoped(factory) - assert provider.type_ is int From 0d4f7f32bb53e4482c8d2f3c1b946377415a021e Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 01:09:37 +0200 Subject: [PATCH 2/9] test_deffered_type pass --- aioinject/containers.py | 7 ++++--- aioinject/context.py | 2 +- aioinject/providers.py | 25 ++++++++++++++----------- aioinject/validation/_builtin.py | 14 ++++++++------ tests/container/test_container.py | 8 ++++---- tests/providers/test_object.py | 2 +- tests/providers/test_scoped.py | 2 +- 7 files changed, 33 insertions(+), 27 deletions(-) diff --git a/aioinject/containers.py b/aioinject/containers.py index e4e5b44..ffe4653 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -19,7 +19,7 @@ def __init__(self) -> None: self.providers: _Providers[Any] = {} self._singletons = SingletonStore() self._unresolved_providers: list[Provider[Any]] = [] - self._type_context: dict[str, type[Any]] = {} + self.type_context: dict[str, type[Any]] = {} def _resolve_unresolved_provider(self) -> None: for provider in self._unresolved_providers: @@ -29,13 +29,14 @@ def _resolve_unresolved_provider(self) -> None: def _register_impl(self, provider: Provider[Any]) -> None: - provider_type = provider.resolve_type(self._type_context) + provider_type = provider.resolve_type(self.type_context) if provider_type in self.providers: + provider.dependencies msg = f"Provider for type {provider_type} is already registered" raise ValueError(msg) self.providers[provider_type] = provider if klass_name := getattr(provider_type, "__name__" , None): - self._type_context[klass_name] = provider_type + self.type_context[klass_name] = provider_type def register( diff --git a/aioinject/context.py b/aioinject/context.py index b7471e7..f933f4b 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -60,7 +60,7 @@ async def resolve( return cached dependencies = {} - for dependency in provider.dependencies: + for dependency in provider.resolve_dependencies(self._container.type_context): dependencies[dependency.name] = await self.resolve( type_=dependency.type_, ) diff --git a/aioinject/providers.py b/aioinject/providers.py index f3355fc..c7dda77 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -82,7 +82,7 @@ def collect_dependencies( ) -def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]: +def _get_provider_type_hints(provider: Provider[Any], context: dict[str, Any] | None = None) -> dict[str, Any]: source = provider.impl if inspect.isclass(source): source = source.__init__ @@ -90,7 +90,7 @@ def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]: if isinstance(source, functools.partial): return {} - type_hints = typing.get_type_hints(source, include_extras=True) + type_hints = _get_type_hints(source, context=context) for key, value in type_hints.items(): _, args = _get_annotation_args(value) for arg in args: @@ -162,6 +162,7 @@ class DependencyLifetime(enum.Enum): class Provider(Protocol[_T]): impl: Any lifetime: DependencyLifetime + _cached_dependecies: tuple[Dependency[object], ...] async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... @@ -173,18 +174,21 @@ def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: ... - @property - def type_hints(self) -> dict[str, Any]: + + def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: ... @property def is_async(self) -> bool: ... - @functools.cached_property - def dependencies(self) -> tuple[Dependency[object], ...]: - return tuple(collect_dependencies(self.type_hints)) - + def resolve_dependencies(self, context: dict[str, Any] | None = None) -> tuple[Dependency[object], ...]: + if deps := getattr(self, "_cached_dependecies", None): + return deps + ret = tuple(collect_dependencies(self.type_hints(context), ctx=context)) + self._cached_dependecies = ret + return ret + @functools.cached_property def is_generator(self) -> bool: return is_context_manager_function(self.impl) @@ -216,9 +220,8 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: return self.provide_sync(kwargs) - @functools.cached_property - def type_hints(self) -> dict[str, Any]: - type_hints = _get_provider_type_hints(self) + def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: + type_hints = _get_provider_type_hints(self, context=context) if "return" in type_hints: del type_hints["return"] return type_hints diff --git a/aioinject/validation/_builtin.py b/aioinject/validation/_builtin.py index 8ec956f..26ae55c 100644 --- a/aioinject/validation/_builtin.py +++ b/aioinject/validation/_builtin.py @@ -15,11 +15,12 @@ def all_dependencies_are_present( ) -> Sequence[ContainerValidationError]: errors = [] for provider in container.providers.values(): - for dependency in provider.dependencies: - if dependency.type_ not in container.providers: + for dependency in provider.resolve_dependencies(container.type_context): + dep_type = dependency.resolve_type(container.type_context) + if dep_type not in container.providers: error = DependencyNotFoundError( - message=f"Provider for type {dependency.type_} not found", - dependency=dependency.type_, + message=f"Provider for type {dep_type} not found", + dependency=dep_type, ) errors.append(error) @@ -44,9 +45,10 @@ def __call__( if not self.dependant(provider): continue - for dependency in provider.dependencies: + for dependency in provider.resolve_dependencies(container.type_context): + dep_type = dependency.resolve_type(container.type_context) dependency_provider = container.get_provider( - type_=dependency.type_, + type_=dep_type, ) if self.dependency(dependency_provider): msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}" diff --git a/tests/container/test_container.py b/tests/container/test_container.py index 7249319..da123c1 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -94,16 +94,16 @@ async def dependency() -> AsyncIterator[int]: assert shutdown is True @pytest.mark.anyio -async def test_deffered_types() -> None: +async def test_deffered_type() -> None: if TYPE_CHECKING: from decimal import Decimal - def some_deffered_type() -> 'Decimal': + def some_deffered_type() -> "Decimal": from decimal import Decimal - return Decimal('1.0') + return Decimal("1.0") class DoubledDecimal: - def __init__(self, decimal: 'Decimal') -> None: + def __init__(self, decimal: "Decimal") -> None: self.decimal = decimal * 2 container = Container() def register_decimal_scoped() -> None: diff --git a/tests/providers/test_object.py b/tests/providers/test_object.py index 925d258..567c51b 100644 --- a/tests/providers/test_object.py +++ b/tests/providers/test_object.py @@ -37,7 +37,7 @@ def test_should_have_no_dependencies( ) -> None: for obj in dependencies_test_data: provider = Object(object_=obj) - assert not provider.dependencies + assert not provider.resolve_dependencies() def test_should_have_empty_type_hints( diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index 2f290f8..f2cae08 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -138,7 +138,7 @@ def factory( type_=str, ), ) - assert provider.dependencies == expected + assert provider.resolve_dependencies() == expected def iterable() -> Iterator[int]: yield 42 From 66fdceea7b0e837c6532690fc2351cc47b4ded3a Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 01:10:12 +0200 Subject: [PATCH 3/9] format code --- aioinject/containers.py | 21 +++++++++------------ aioinject/context.py | 4 +++- aioinject/providers.py | 29 +++++++++++++++++++---------- aioinject/utils.py | 5 +++-- aioinject/validation/_builtin.py | 8 ++++++-- tests/container/test_container.py | 13 ++++++++++--- tests/providers/test_scoped.py | 6 +++++- 7 files changed, 55 insertions(+), 31 deletions(-) diff --git a/aioinject/containers.py b/aioinject/containers.py index ffe4653..2a3dc37 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -27,29 +27,26 @@ def _resolve_unresolved_provider(self) -> None: self._register_impl(provider) self._unresolved_providers.remove(provider) - def _register_impl(self, provider: Provider[Any]) -> None: - provider_type = provider.resolve_type(self.type_context) - if provider_type in self.providers: - provider.dependencies - msg = f"Provider for type {provider_type} is already registered" - raise ValueError(msg) - self.providers[provider_type] = provider - if klass_name := getattr(provider_type, "__name__" , None): - self.type_context[klass_name] = provider_type - + provider_type = provider.resolve_type(self.type_context) + if provider_type in self.providers: + provider.dependencies + msg = f"Provider for type {provider_type} is already registered" + raise ValueError(msg) + self.providers[provider_type] = provider + if klass_name := getattr(provider_type, "__name__", None): + self.type_context[klass_name] = provider_type def register( self, provider: Provider[Any], - ) -> None: + ) -> None: try: self._register_impl(provider) except NameError: self._unresolved_providers.append(provider) self._resolve_unresolved_provider() - def get_provider(self, type_: type[_T]) -> Provider[_T]: try: return self.providers[type_] diff --git a/aioinject/context.py b/aioinject/context.py index f933f4b..9dd95e3 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -60,7 +60,9 @@ async def resolve( return cached dependencies = {} - for dependency in provider.resolve_dependencies(self._container.type_context): + for dependency in provider.resolve_dependencies( + self._container.type_context, + ): dependencies[dependency.name] = await self.resolve( type_=dependency.type_, ) diff --git a/aioinject/providers.py b/aioinject/providers.py index c7dda77..b8caf12 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -32,7 +32,6 @@ _T = TypeVar("_T") - @dataclass(slots=True, kw_only=True, frozen=True) class Dependency(Generic[_T]): name: str @@ -63,7 +62,8 @@ def _find_inject_marker_in_annotation_args( def collect_dependencies( - dependant: typing.Callable[..., object] | dict[str, Any], ctx: dict[str, type[Any]] | None = None + dependant: typing.Callable[..., object] | dict[str, Any], + ctx: dict[str, type[Any]] | None = None, ) -> typing.Iterable[Dependency[object]]: if not isinstance(dependant, dict): with remove_annotation(dependant.__annotations__, "return"): @@ -82,7 +82,10 @@ def collect_dependencies( ) -def _get_provider_type_hints(provider: Provider[Any], context: dict[str, Any] | None = None) -> dict[str, Any]: +def _get_provider_type_hints( + provider: Provider[Any], + context: dict[str, Any] | None = None, +) -> dict[str, Any]: source = provider.impl if inspect.isclass(source): source = source.__init__ @@ -121,7 +124,10 @@ def _get_provider_type_hints(provider: Provider[Any], context: dict[str, Any] | ) -def _guess_return_type(factory: _FactoryType[_T], context: dict[str, type[Any] ] | None = None) -> type[_T]: +def _guess_return_type( + factory: _FactoryType[_T], + context: dict[str, type[Any]] | None = None, +) -> type[_T]: if isclass(factory): return typing.cast(type[_T], factory) @@ -169,12 +175,10 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: ... - + def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: ... - - def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: ... @@ -182,13 +186,18 @@ def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: def is_async(self) -> bool: ... - def resolve_dependencies(self, context: dict[str, Any] | None = None) -> tuple[Dependency[object], ...]: + def resolve_dependencies( + self, + context: dict[str, Any] | None = None, + ) -> tuple[Dependency[object], ...]: if deps := getattr(self, "_cached_dependecies", None): return deps - ret = tuple(collect_dependencies(self.type_hints(context), ctx=context)) + ret = tuple( + collect_dependencies(self.type_hints(context), ctx=context), + ) self._cached_dependecies = ret return ret - + @functools.cached_property def is_generator(self) -> bool: return is_context_manager_function(self.impl) diff --git a/aioinject/utils.py b/aioinject/utils.py index dae1744..3c6c336 100644 --- a/aioinject/utils.py +++ b/aioinject/utils.py @@ -95,10 +95,11 @@ def remove_annotation( if annotation is not sentinel: annotations[name] = annotation + def _get_type_hints( obj: Any, - context: dict[str, type[Any]] | None = None + context: dict[str, type[Any]] | None = None, ) -> dict[str, Any]: if not context: context = {} - return typing.get_type_hints(obj, include_extras=True, localns=context) \ No newline at end of file + return typing.get_type_hints(obj, include_extras=True, localns=context) diff --git a/aioinject/validation/_builtin.py b/aioinject/validation/_builtin.py index 26ae55c..c11d6b3 100644 --- a/aioinject/validation/_builtin.py +++ b/aioinject/validation/_builtin.py @@ -15,7 +15,9 @@ def all_dependencies_are_present( ) -> Sequence[ContainerValidationError]: errors = [] for provider in container.providers.values(): - for dependency in provider.resolve_dependencies(container.type_context): + for dependency in provider.resolve_dependencies( + container.type_context, + ): dep_type = dependency.resolve_type(container.type_context) if dep_type not in container.providers: error = DependencyNotFoundError( @@ -45,7 +47,9 @@ def __call__( if not self.dependant(provider): continue - for dependency in provider.resolve_dependencies(container.type_context): + for dependency in provider.resolve_dependencies( + container.type_context, + ): dep_type = dependency.resolve_type(container.type_context) dependency_provider = container.get_provider( type_=dep_type, diff --git a/tests/container/test_container.py b/tests/container/test_container.py index da123c1..6d6010b 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -93,24 +93,31 @@ async def dependency() -> AsyncIterator[int]: await container.aclose() assert shutdown is True + @pytest.mark.anyio async def test_deffered_type() -> None: if TYPE_CHECKING: from decimal import Decimal - + def some_deffered_type() -> "Decimal": from decimal import Decimal + return Decimal("1.0") - + class DoubledDecimal: def __init__(self, decimal: "Decimal") -> None: self.decimal = decimal * 2 + container = Container() + def register_decimal_scoped() -> None: from decimal import Decimal + container.register(Scoped(some_deffered_type, Decimal)) register_decimal_scoped() container.register(Scoped(DoubledDecimal)) async with container.context() as ctx: - assert (await ctx.resolve(DoubledDecimal)).decimal == DoubledDecimal(some_deffered_type()).decimal \ No newline at end of file + assert (await ctx.resolve(DoubledDecimal)).decimal == DoubledDecimal( + some_deffered_type(), + ).decimal diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index f2cae08..f1218ce 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -140,20 +140,24 @@ def factory( ) assert provider.resolve_dependencies() == expected + def iterable() -> Iterator[int]: yield 42 + def gen() -> Generator[int, None, None]: yield 42 + async def async_iterable() -> AsyncIterator[int]: yield 42 + async def async_gen() -> AsyncGenerator[int, None]: yield 42 + @pytest.mark.parametrize("factory", [iterable, gen, async_iterable, async_gen]) def test_generator_return_types(factory) -> None: provider = providers.Scoped(factory) assert provider.resolve_type() is int - From ac01b7e732d0fa316c683d573b289ec93bb76760 Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 01:23:52 +0200 Subject: [PATCH 4/9] fix lint errors --- aioinject/_store.py | 24 +++++++++++--------- aioinject/containers.py | 9 ++++---- aioinject/context.py | 2 +- aioinject/providers.py | 38 ++++++++++++++++++++++++++------ aioinject/validation/_builtin.py | 4 ++-- tests/providers/test_object.py | 2 +- tests/providers/test_scoped.py | 12 +++++----- 7 files changed, 59 insertions(+), 32 deletions(-) diff --git a/aioinject/_store.py b/aioinject/_store.py index 22e7552..d90a3fc 100644 --- a/aioinject/_store.py +++ b/aioinject/_store.py @@ -34,23 +34,27 @@ def __init__(self) -> None: self._sync_exit_stack = contextlib.ExitStack() def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]: - return self._cache.get(provider.type_, NotInCache.sentinel) + return self._cache.get(provider.resolve_type(), NotInCache.sentinel) def add(self, provider: Provider[T], obj: T) -> None: if provider.lifetime is not DependencyLifetime.transient: - self._cache[provider.type_] = obj + self._cache[provider.resolve_type()] = obj def lock( self, provider: Provider[Any], ) -> AbstractAsyncContextManager[bool]: - return contextlib.nullcontext(provider.type_ not in self._cache) + return contextlib.nullcontext( + provider.resolve_type() not in self._cache, + ) def sync_lock( self, provider: Provider[Any], ) -> AbstractContextManager[bool]: - return contextlib.nullcontext(provider.type_ not in self._cache) + return contextlib.nullcontext( + provider.resolve_type() not in self._cache, + ) @typing.overload async def enter_context( @@ -124,9 +128,9 @@ def __init__(self) -> None: @contextlib.asynccontextmanager async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]: - if provider.type_ not in self._cache: - async with self._locks[provider.type_]: - yield provider.type_ not in self._cache + if provider.resolve_type() not in self._cache: + async with self._locks[provider.resolve_type()]: + yield provider.resolve_type() not in self._cache return yield False @@ -135,8 +139,8 @@ def sync_lock( self, provider: Provider[Any], ) -> Iterator[bool]: - if provider.type_ not in self._cache: - with self._sync_locks[provider.type_]: - yield provider.type_ not in self._cache + if provider.resolve_type() not in self._cache: + with self._sync_locks[provider.resolve_type()]: + yield provider.resolve_type() not in self._cache return yield False diff --git a/aioinject/containers.py b/aioinject/containers.py index 2a3dc37..67b6b80 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -30,7 +30,6 @@ def _resolve_unresolved_provider(self) -> None: def _register_impl(self, provider: Provider[Any]) -> None: provider_type = provider.resolve_type(self.type_context) if provider_type in self.providers: - provider.dependencies msg = f"Provider for type {provider_type} is already registered" raise ValueError(msg) self.providers[provider_type] = provider @@ -71,14 +70,14 @@ def override( self, provider: Provider[Any], ) -> Iterator[None]: - previous = self.providers.get(provider.type_) - self.providers[provider.type_] = provider + previous = self.providers.get(provider.resolve_type()) + self.providers[provider.resolve_type()] = provider yield - del self.providers[provider.type_] + del self.providers[provider.resolve_type()] if previous is not None: - self.providers[provider.type_] = previous + self.providers[provider.resolve_type()] = previous async def __aenter__(self) -> Self: return self diff --git a/aioinject/context.py b/aioinject/context.py index 9dd95e3..477c012 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -153,7 +153,7 @@ def resolve( return cached dependencies = {} - for dependency in provider.dependencies: + for dependency in provider.resolve_dependencies(): dependencies[dependency.name] = self.resolve( type_=dependency.type_, ) diff --git a/aioinject/providers.py b/aioinject/providers.py index b8caf12..53c680b 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -169,6 +169,7 @@ class Provider(Protocol[_T]): impl: Any lifetime: DependencyLifetime _cached_dependecies: tuple[Dependency[object], ...] + _cached_type: type[_T] | None async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... @@ -176,9 +177,19 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: ... - def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: + def _resolve_type_impl( + self, + context: dict[str, Any] | None = None, + ) -> type[_T]: ... + def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: + if resolved := getattr(self, "_cached_type", None): + return resolved + ret = self._resolve_type_impl(context) + self._cached_type = ret + return ret + def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: ... @@ -202,8 +213,12 @@ def resolve_dependencies( def is_generator(self) -> bool: return is_context_manager_function(self.impl) - def __repr__(self) -> str: - return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})" + def __repr__(self) -> str: # pragma: no cover + try: + type_ = repr(self.resolve_type()) + except NameError: + type_ = "UNKNOWN" + return f"{self.__class__.__qualname__}(type={type_}, implementation={self.impl})" class Scoped(Provider[_T]): @@ -217,7 +232,10 @@ def __init__( self.type_ = type_ self.impl = factory - def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: + def _resolve_type_impl( + self, + context: dict[str, Any] | None = None, + ) -> type[_T]: return self.type_ or _guess_return_type(self.impl, context=context) def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: @@ -229,7 +247,10 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: return self.provide_sync(kwargs) - def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: + def type_hints( + self, + context: dict[str, Any] | None = None, + ) -> dict[str, Any]: type_hints = _get_provider_type_hints(self, context=context) if "return" in type_hints: del type_hints["return"] @@ -263,7 +284,7 @@ class Transient(Scoped[_T]): class Object(Provider[_T]): - type_hints: ClassVar[dict[str, Any]] = {} + _type_hints: ClassVar[dict[str, Any]] = {} is_async = False impl: _T lifetime = DependencyLifetime.scoped # It's ok to cache it @@ -276,7 +297,7 @@ def __init__( self.type_ = type_ or type(object_) self.impl = object_ - def resolve_type(self, _: dict[str, Any] | None = None) -> type[_T]: + def _resolve_type_impl(self, _: dict[str, Any] | None = None) -> type[_T]: return self.type_ def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002 @@ -284,3 +305,6 @@ def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002 async def provide(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002 return self.impl + + def type_hints(self, _: dict[str, Any] | None = None) -> dict[str, Any]: + return self._type_hints diff --git a/aioinject/validation/_builtin.py b/aioinject/validation/_builtin.py index c11d6b3..6f21c1d 100644 --- a/aioinject/validation/_builtin.py +++ b/aioinject/validation/_builtin.py @@ -18,7 +18,7 @@ def all_dependencies_are_present( for dependency in provider.resolve_dependencies( container.type_context, ): - dep_type = dependency.resolve_type(container.type_context) + dep_type = dependency.type_ if dep_type not in container.providers: error = DependencyNotFoundError( message=f"Provider for type {dep_type} not found", @@ -50,7 +50,7 @@ def __call__( for dependency in provider.resolve_dependencies( container.type_context, ): - dep_type = dependency.resolve_type(container.type_context) + dep_type = dependency.type_ dependency_provider = container.get_provider( type_=dep_type, ) diff --git a/tests/providers/test_object.py b/tests/providers/test_object.py index 567c51b..46637d9 100644 --- a/tests/providers/test_object.py +++ b/tests/providers/test_object.py @@ -45,7 +45,7 @@ def test_should_have_empty_type_hints( ) -> None: for obj in dependencies_test_data: provider = Object(object_=obj) - assert not provider.type_hints + assert not provider.type_hints() def test_should_not_be_async( diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index f1218ce..44bfa21 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator, AsyncIterator, Generator, Iterator -from typing import Annotated +from typing import Annotated, Any from unittest.mock import patch import pytest @@ -60,7 +60,7 @@ def factory(a: int, b: str) -> None: # noqa: ARG001 "a": Annotated[int, Inject], "b": Annotated[str, Inject], } - assert provider.type_hints == expected + assert provider.type_hints() == expected def test_type_hints_on_class() -> None: @@ -73,7 +73,7 @@ def __init__(self, a: int, b: str) -> None: "a": Annotated[int, Inject], "b": Annotated[str, Inject], } - assert provider.type_hints == expected + assert provider.type_hints() == expected def test_annotated_type_hint() -> None: @@ -83,7 +83,7 @@ def factory( pass provider = providers.Scoped(factory) - assert provider.type_hints == { + assert provider.type_hints() == { "a": Annotated[int, Inject()], } @@ -109,7 +109,7 @@ def factory() -> None: pass provider = providers.Scoped(factory) - assert provider.dependencies == () + assert provider.resolve_dependencies() == () def test_dependencies() -> None: @@ -158,6 +158,6 @@ async def async_gen() -> AsyncGenerator[int, None]: @pytest.mark.parametrize("factory", [iterable, gen, async_iterable, async_gen]) -def test_generator_return_types(factory) -> None: +def test_generator_return_types(factory: Any) -> None: provider = providers.Scoped(factory) assert provider.resolve_type() is int From 27d1257d898edc02912235a59b7437abd9486f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D7=A0=D7=99=D7=A8?= <88795475+nrbnlulu@users.noreply.github.com> Date: Fri, 26 Jan 2024 15:27:17 +0200 Subject: [PATCH 5/9] Update aioinject/providers.py Co-authored-by: Doctor <50728601+ThirVondukr@users.noreply.github.com> --- aioinject/providers.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/aioinject/providers.py b/aioinject/providers.py index 53c680b..f85229c 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -184,11 +184,9 @@ def _resolve_type_impl( ... def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: - if resolved := getattr(self, "_cached_type", None): - return resolved - ret = self._resolve_type_impl(context) - self._cached_type = ret - return ret + if self._cached_type is None: + self._cached_type = self._resolve_type_impl(context) + return self._cached_type def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: ... From ff72bf1475d202718d896556e1d0d35836de4962 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D7=A0=D7=99=D7=A8?= <88795475+nrbnlulu@users.noreply.github.com> Date: Fri, 26 Jan 2024 15:28:01 +0200 Subject: [PATCH 6/9] Update aioinject/providers.py Co-authored-by: Doctor <50728601+ThirVondukr@users.noreply.github.com> --- aioinject/providers.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/aioinject/providers.py b/aioinject/providers.py index f85229c..da46e10 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -199,13 +199,11 @@ def resolve_dependencies( self, context: dict[str, Any] | None = None, ) -> tuple[Dependency[object], ...]: - if deps := getattr(self, "_cached_dependecies", None): - return deps - ret = tuple( - collect_dependencies(self.type_hints(context), ctx=context), - ) - self._cached_dependecies = ret - return ret + if self._cached_dependencies is None: + self._cached_dependencies = tuple( + collect_dependencies(self.type_hints(context), ctx=context), + ) + return self._cached_dependencies @functools.cached_property def is_generator(self) -> bool: From 7c375cd2dc5840db6d654ee332b4899c7ed56eb4 Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 15:42:10 +0200 Subject: [PATCH 7/9] use try except instead of getttr --- aioinject/providers.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/aioinject/providers.py b/aioinject/providers.py index da46e10..8491fde 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -168,8 +168,8 @@ class DependencyLifetime(enum.Enum): class Provider(Protocol[_T]): impl: Any lifetime: DependencyLifetime - _cached_dependecies: tuple[Dependency[object], ...] - _cached_type: type[_T] | None + _cached_dependencies: tuple[Dependency[object], ...] + _cached_type: type[_T] # TODO: I think it is redundant. async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... @@ -184,9 +184,11 @@ def _resolve_type_impl( ... def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: - if self._cached_type is None: + try: + return self._cached_type + except AttributeError: self._cached_type = self._resolve_type_impl(context) - return self._cached_type + return self._cached_type def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]: ... @@ -199,11 +201,13 @@ def resolve_dependencies( self, context: dict[str, Any] | None = None, ) -> tuple[Dependency[object], ...]: - if self._cached_dependencies is None: - self._cached_dependencies = tuple( + try: + return self._cached_dependencies + except AttributeError: + self._cached_dependencies = tuple( collect_dependencies(self.type_hints(context), ctx=context), ) - return self._cached_dependencies + return self._cached_dependencies @functools.cached_property def is_generator(self) -> bool: From a4419d97fb87ee592163f954859efff98a65fe84 Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 15:46:47 +0200 Subject: [PATCH 8/9] Update test_deffered_type function name to test_deffered_dependecies --- tests/container/test_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/container/test_container.py b/tests/container/test_container.py index 6d6010b..6f7258f 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -95,7 +95,7 @@ async def dependency() -> AsyncIterator[int]: @pytest.mark.anyio -async def test_deffered_type() -> None: +async def test_deffered_dependecies() -> None: if TYPE_CHECKING: from decimal import Decimal From 430c0f659212ae6800a4c7727e9bfe79b44b4513 Mon Sep 17 00:00:00 2001 From: Nir Date: Fri, 26 Jan 2024 15:50:01 +0200 Subject: [PATCH 9/9] Update workflow to trigger on pull requests as well --- .github/workflows/test.yaml | 2 +- aioinject/providers.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 8a3372e..57178e5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,5 +1,5 @@ name: Test -on: [push] +on: [push, pull_request] jobs: test: runs-on: ubuntu-latest diff --git a/aioinject/providers.py b/aioinject/providers.py index 8491fde..99cc497 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -169,7 +169,7 @@ class Provider(Protocol[_T]): impl: Any lifetime: DependencyLifetime _cached_dependencies: tuple[Dependency[object], ...] - _cached_type: type[_T] # TODO: I think it is redundant. + _cached_type: type[_T] # TODO: I think it is redundant. async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... @@ -201,10 +201,10 @@ def resolve_dependencies( self, context: dict[str, Any] | None = None, ) -> tuple[Dependency[object], ...]: - try: + try: return self._cached_dependencies except AttributeError: - self._cached_dependencies = tuple( + self._cached_dependencies = tuple( collect_dependencies(self.type_hints(context), ctx=context), ) return self._cached_dependencies