From 14efcbd4e6b353b5a6e151bb6a2906498451508a Mon Sep 17 00:00:00 2001 From: Doctor Date: Sat, 13 Jan 2024 20:34:06 +0300 Subject: [PATCH] Refactor: remove all parameters from Inject marker --- aioinject/containers.py | 60 +++++++++++-------------------- aioinject/context.py | 49 +++++++++---------------- aioinject/markers.py | 8 ++--- aioinject/providers.py | 4 --- aioinject/validation/_builtin.py | 41 ++++++++++----------- tests/container/test_container.py | 50 ++++++-------------------- tests/container/test_override.py | 18 ++++++++-- tests/providers/test_callable.py | 15 +++----- tests/test_inject.py | 45 +++-------------------- tests/utils/test_utils.py | 4 +-- 10 files changed, 95 insertions(+), 199 deletions(-) diff --git a/aioinject/containers.py b/aioinject/containers.py index 0f1e167..f2178b3 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -1,5 +1,4 @@ import contextlib -from collections import defaultdict from collections.abc import Iterator from typing import Any, TypeAlias, TypeVar @@ -8,52 +7,29 @@ _T = TypeVar("_T") -_Providers: TypeAlias = dict[type[_T], list[Provider[_T]]] +_Providers: TypeAlias = dict[type[_T], Provider[_T]] class Container: def __init__(self) -> None: - self.providers: _Providers[Any] = defaultdict(list) - self._overrides: _Providers[Any] = defaultdict(list) + self.providers: _Providers[Any] = {} + self._overrides: _Providers[Any] = {} def register( self, provider: Provider[Any], ) -> None: - self.providers[provider.type_].append(provider) + if provider.type_ in self.providers: + msg = f"Provider for type {provider.type_} is already registered" + raise ValueError(msg) - @staticmethod - def _get_provider( - providers: _Providers[_T], - type_: type[_T], - impl: Any | None = None, - ) -> Provider[_T] | None: - type_providers = providers[type_] - if not type_providers: - return None + self.providers[provider.type_] = provider - if impl is None and len(type_providers) == 1: - return type_providers[0] - - if impl is None: - err_msg = ( - f"Multiple providers for type {type_.__qualname__} were found, " - f"you have to specify implementation using 'impl' parameter: " - f"Annotated[IService, Inject(impl=Service)]" - ) - raise ValueError(err_msg) - return next((p for p in type_providers if p.impl is impl), None) - - def get_provider( - self, - type_: type[_T], - impl: Any | None = None, - ) -> Provider[_T]: - overridden_provider = self._get_provider(self._overrides, type_, impl) - if overridden_provider is not None: + def get_provider(self, type_: type[_T]) -> Provider[_T]: + if (overridden_provider := self._overrides.get(type_)) is not None: return overridden_provider - provider = self._get_provider(self.providers, type_, impl) + provider = self.providers.get(type_) if provider is None: err_msg = f"Provider for type {type_.__qualname__} not found" raise ValueError(err_msg) @@ -70,12 +46,16 @@ def override( self, provider: Provider[Any], ) -> Iterator[None]: - self._overrides[provider.type_].append(provider) + previous = self._overrides.get(provider.type_) + self._overrides[provider.type_] = provider + yield - self._overrides[provider.type_].remove(provider) + + del self._overrides[provider.type_] + if previous is not None: + self._overrides[provider.type_] = previous async def aclose(self) -> None: - for providers in self.providers.values(): - for provider in providers: - if isinstance(provider, Singleton): - await provider.aclose() + for provider in self.providers.values(): + if isinstance(provider, Singleton): + await provider.aclose() diff --git a/aioinject/context.py b/aioinject/context.py index a663685..a7a7e5c 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -29,7 +29,6 @@ _T = TypeVar("_T") _AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"] -_TypeAndImpl: TypeAlias = tuple[type[_T], _T | None] _ExitStackT = TypeVar("_ExitStackT") context_var: ContextVar[_AnyCtx] = ContextVar("aioinject_context") @@ -43,7 +42,7 @@ class _BaseInjectionContext(Generic[_ExitStackT]): def __init__(self, container: Container) -> None: self._container = container self._exit_stack = self._exit_stack_type() - self._cache: dict[_TypeAndImpl[Any], Any] = {} + self._cache: dict[type[Any], Any] = {} self._token = None def __class_getitem__( @@ -61,20 +60,15 @@ class InjectionContext(_BaseInjectionContext[AsyncExitStack]): async def resolve( self, type_: type[_T], - impl: Any | None = None, - *, - use_cache: bool = True, ) -> _T: - if use_cache and (type_, impl) in self._cache: - return self._cache[type_, impl] + provider = self._container.get_provider(type_) + use_cache = True + + if use_cache and type_ in self._cache: + return self._cache[type_] - provider = self._container.get_provider(type_, impl) dependencies = { - dep.name: await self.resolve( - type_=dep.type_, - impl=dep.implementation, - use_cache=dep.use_cache, - ) + dep.name: await self.resolve(type_=dep.type_) for dep in provider.dependencies } @@ -83,7 +77,7 @@ async def resolve( stack=self._exit_stack, ) if use_cache: - self._cache[type_, impl] = resolved + self._cache[type_] = resolved return resolved @overload @@ -117,10 +111,9 @@ async def execute( for dependency in dependencies: if dependency.name in kwargs: continue + resolved[dependency.name] = await self.resolve( type_=dependency.type_, - impl=dependency.implementation, - use_cache=dependency.use_cache, ) if inspect.iscoroutinefunction(function): return await function(*args, **kwargs, **resolved) @@ -144,20 +137,14 @@ class SyncInjectionContext(_BaseInjectionContext[ExitStack]): def resolve( self, type_: type[_T], - impl: Any | None = None, - *, - use_cache: bool = True, ) -> _T: - if use_cache and (type_, impl) in self._cache: - return self._cache[type_, impl] + use_cache = True + if use_cache and type_ in self._cache: + return self._cache[type_] - provider = self._container.get_provider(type_, impl) + provider = self._container.get_provider(type_) dependencies = { - dep.name: self.resolve( - type_=dep.type_, - impl=dep.implementation, - use_cache=dep.use_cache, - ) + dep.name: self.resolve(type_=dep.type_) for dep in provider.dependencies } @@ -165,7 +152,7 @@ def resolve( if isinstance(resolved, contextlib.ContextDecorator): resolved = self._exit_stack.enter_context(resolved) # type: ignore[arg-type] if use_cache: - self._cache[type_, impl] = resolved + self._cache[type_] = resolved return resolved def execute( @@ -179,11 +166,7 @@ def execute( for dependency in dependencies: if dependency.name in kwargs: continue - resolved[dependency.name] = self.resolve( - type_=dependency.type_, - impl=dependency.implementation, - use_cache=dependency.use_cache, - ) + resolved[dependency.name] = self.resolve(type_=dependency.type_) return function(*args, **kwargs, **resolved) def __enter__(self) -> Self: diff --git a/aioinject/markers.py b/aioinject/markers.py index 6eba15a..4589f67 100644 --- a/aioinject/markers.py +++ b/aioinject/markers.py @@ -1,8 +1,6 @@ -from dataclasses import dataclass -from typing import Any +import dataclasses -@dataclass(slots=True) +@dataclasses.dataclass(slots=True) class Inject: - impl: Any | None = None - cache: bool = True + pass diff --git a/aioinject/providers.py b/aioinject/providers.py index 08825b9..35d4b00 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -37,8 +37,6 @@ class Dependency(Generic[_T]): name: str type_: type[_T] - implementation: Any - use_cache: bool def _get_annotation_args(type_hint: Any) -> tuple[type, tuple[Any, ...]]: @@ -82,8 +80,6 @@ def collect_dependencies( yield Dependency( name=name, type_=dep_type, - implementation=inject_marker.impl, - use_cache=inject_marker.cache, ) diff --git a/aioinject/validation/_builtin.py b/aioinject/validation/_builtin.py index 9c9c5b3..8ec956f 100644 --- a/aioinject/validation/_builtin.py +++ b/aioinject/validation/_builtin.py @@ -14,15 +14,14 @@ def all_dependencies_are_present( container: aioinject.Container, ) -> Sequence[ContainerValidationError]: errors = [] - for providers in container.providers.values(): - for provider in providers: - for dependency in provider.dependencies: - if dependency.type_ not in container.providers: - error = DependencyNotFoundError( - message=f"Provider for type {dependency.type_} not found", - dependency=dependency.type_, - ) - errors.append(error) + for provider in container.providers.values(): + for dependency in provider.dependencies: + if dependency.type_ not in container.providers: + error = DependencyNotFoundError( + message=f"Provider for type {dependency.type_} not found", + dependency=dependency.type_, + ) + errors.append(error) return errors @@ -41,17 +40,15 @@ def __call__( container: aioinject.Container, ) -> Sequence[ContainerValidationError]: errors = [] - for providers in container.providers.values(): - for provider in providers: - if not self.dependant(provider): - continue - - for dependency in provider.dependencies: - dependency_provider = container.get_provider( - type_=dependency.type_, - impl=dependency.implementation, - ) - if self.dependency(dependency_provider): - msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}" - errors.append(ContainerValidationError(msg)) + for provider in container.providers.values(): + if not self.dependant(provider): + continue + + for dependency in provider.dependencies: + dependency_provider = container.get_provider( + type_=dependency.type_, + ) + if self.dependency(dependency_provider): + msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}" + errors.append(ContainerValidationError(msg)) return errors diff --git a/tests/container/test_container.py b/tests/container/test_container.py index 2b6971c..9db70a1 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -3,7 +3,7 @@ import pytest -from aioinject import Singleton, providers +from aioinject import Scoped, Singleton, providers from aioinject.containers import Container from aioinject.context import InjectionContext @@ -38,18 +38,20 @@ def test_can_register_single(container: Container) -> None: provider = providers.Scoped(_ServiceA) container.register(provider) - expected = {_ServiceA: [provider]} + expected = {_ServiceA: provider} assert container.providers == expected -def test_can_register_multi(container: Container) -> None: - provider_a = providers.Scoped(_ServiceA) - provider_b = providers.Scoped(_ServiceB) - container.register(provider_a) - container.register(provider_b) +def test_cant_register_multiple_providers_for_same_type( + container: Container, +) -> None: + container.register(Scoped(int)) - expected = {_ServiceA: [provider_a], _ServiceB: [provider_b]} - assert container.providers == expected + with pytest.raises( + ValueError, + match="^Provider for type is already registered$", + ): + container.register(Scoped(int)) def test_can_retrieve_single_provider(container: Container) -> None: @@ -58,36 +60,6 @@ def test_can_retrieve_single_provider(container: Container) -> None: assert container.get_provider(int) -@pytest.fixture -def multi_provider_container(container: Container) -> Container: - a_provider = providers.Scoped(_ServiceA, type_=_AbstractService) - b_provider = providers.Scoped(_ServiceB, type_=_AbstractService) - container.register(a_provider) - container.register(b_provider) - return container - - -def test_get_provider_raises_error_if_multiple_providers( - multi_provider_container: Container, -) -> None: - with pytest.raises(ValueError) as exc_info: # noqa: PT011 - assert multi_provider_container.get_provider(_AbstractService) - - msg = ( - f"Multiple providers for type {_AbstractService.__qualname__} were found, " - f"you have to specify implementation using 'impl' parameter: " - f"Annotated[IService, Inject(impl=Service)]" - ) - assert str(exc_info.value) == msg - - -def test_can_get_multi_provider_if__specified( - multi_provider_container: Container, -) -> None: - assert multi_provider_container.get_provider(_AbstractService, _ServiceA) - assert multi_provider_container.get_provider(_AbstractService, _ServiceB) - - def test_missing_provider() -> None: container = Container() with pytest.raises(ValueError) as exc_info: # noqa: PT011 diff --git a/tests/container/test_override.py b/tests/container/test_override.py index ec54d58..b7e1623 100644 --- a/tests/container/test_override.py +++ b/tests/container/test_override.py @@ -10,10 +10,22 @@ def test_provider_override() -> None: container = Container() container.register(Object(a)) with container.sync_context() as ctx: - assert ctx.resolve(_A, use_cache=False) is a + assert ctx.resolve(_A) is a a_overridden = _A() assert a_overridden is not a - with container.override(Object(a_overridden)): - assert ctx.resolve(_A, use_cache=False) is a_overridden is not a + with container.sync_context() as ctx, container.override( + Object(a_overridden), + ): + assert ctx.resolve(_A) is a_overridden is not a + + +def test_override_multiple_times() -> None: + container = Container() + with container.override(Object(1)): + with container.override(Object(2)), container.sync_context() as ctx: + assert ctx.resolve(int) == 2 # noqa: PLR2004 + + with container.sync_context() as ctx: + assert ctx.resolve(int) == 1 diff --git a/tests/providers/test_callable.py b/tests/providers/test_callable.py index 4527a60..318b7eb 100644 --- a/tests/providers/test_callable.py +++ b/tests/providers/test_callable.py @@ -1,4 +1,3 @@ -from collections import defaultdict from collections.abc import AsyncGenerator, AsyncIterator, Generator, Iterator from typing import Annotated from unittest.mock import patch @@ -85,12 +84,12 @@ def __init__(self, a: int, b: str) -> None: def test_annotated_type_hint() -> None: def factory( - a: Annotated[int, Inject(cache=False)], # noqa: ARG001 + a: Annotated[int, Inject()], # noqa: ARG001 ) -> None: pass provider = providers.Scoped(factory) - assert provider.type_hints == {"a": Annotated[int, Inject(cache=False)]} + assert provider.type_hints == {"a": Annotated[int, Inject()]} def test_is_async_on_sync() -> None: @@ -122,9 +121,9 @@ def factory( a: int, # noqa: ARG001 service: Annotated[ # noqa: ARG001 dict[str, int], - Inject(defaultdict), + Inject(), ], - string: Annotated[str, Inject(cache=False)], # noqa: ARG001 + string: Annotated[str, Inject()], # noqa: ARG001 ) -> None: pass @@ -133,20 +132,14 @@ def factory( Dependency( name="a", type_=int, - implementation=None, - use_cache=True, ), Dependency( name="service", type_=dict[str, int], - implementation=defaultdict, - use_cache=True, ), Dependency( name="string", type_=str, - implementation=None, - use_cache=False, ), ) assert provider.dependencies == expected diff --git a/tests/test_inject.py b/tests/test_inject.py index 7151689..7de3660 100644 --- a/tests/test_inject.py +++ b/tests/test_inject.py @@ -12,12 +12,12 @@ class _Session: class _Service: - def __init__(self, session: Annotated[_Session, Inject]) -> None: + def __init__(self, session: _Session) -> None: self.session = session class _Interface: - def __init__(self, session: Annotated[_Session, Inject]) -> None: + def __init__(self, session: _Session) -> None: self.session = session @@ -25,19 +25,11 @@ class _ImplementationA(_Interface): pass -class _ImplementationB(_Interface): - pass - - -def _get_impl_b(session: Annotated[_Session, Inject]) -> _ImplementationB: - return _ImplementationB(session) - - class _NeedsMultipleImplementations: def __init__( self, - a: Annotated[_Interface, Inject(_ImplementationA)], - b: Annotated[_Interface, Inject(_get_impl_b)], + a: _Interface, + b: _Interface, ) -> None: self.a = a self.b = b @@ -49,7 +41,6 @@ def container() -> Container: container.register(providers.Scoped(_Session)) container.register(providers.Scoped(_Service)) container.register(providers.Scoped(_ImplementationA, type_=_Interface)) - container.register(providers.Scoped(_get_impl_b, type_=_Interface)) container.register(providers.Scoped(_NeedsMultipleImplementations)) return container @@ -57,7 +48,7 @@ def container() -> Container: @inject def _injectee( test: Annotated[_Session, Inject], - test_no_cache: Annotated[_Session, Inject(cache=False)], + test_no_cache: Annotated[_Session, Inject()], ) -> tuple[_Session, _Session]: return test, test_no_cache @@ -82,17 +73,6 @@ def test_simple_inject(container: Container) -> None: assert isinstance(session, _Session) -def test_no_cache_marker(container: Container) -> None: - with container.sync_context(): - test_first, no_cache_first = _injectee() # type: ignore[call-arg] - test_second, no_cache_second = _injectee() # type: ignore[call-arg] - - assert test_first is test_second - assert test_first is not no_cache_first - assert test_second is not no_cache_second - assert no_cache_first is not no_cache_second - - @pytest.mark.anyio async def test_simple_service(container: Container) -> None: with container.sync_context() as ctx: @@ -119,21 +99,6 @@ async def test_retrieve_service_with_dependencies( assert isinstance(service.session, _Session) -@pytest.mark.anyio -async def test_service_with_multiple_dependencies_with_same_type( - container: Container, -) -> None: - with container.sync_context() as ctx: - service = ctx.resolve(_NeedsMultipleImplementations) - assert isinstance(service.a, _ImplementationA) - assert isinstance(service.b, _ImplementationB) - - async with container.context() as ctx: - service = await ctx.resolve(_NeedsMultipleImplementations) - assert isinstance(service.a, _ImplementationA) - assert isinstance(service.b, _ImplementationB) - - @pytest.mark.anyio async def test_inject_using_container( container: Container, diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index ba349d3..a1ead5e 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -8,11 +8,11 @@ def test_inject_annotations_returns_all_inject_markers() -> None: def func( a: int, # noqa: ARG001 b: Annotated[int, Inject], # noqa: ARG001 - c: Annotated[int, Inject(impl=bool)], # noqa: ARG001 + c: Annotated[int, Inject()], # noqa: ARG001 ) -> None: pass assert get_inject_annotations(func) == { "b": Annotated[int, Inject], - "c": Annotated[int, Inject(impl=bool)], + "c": Annotated[int, Inject()], }