Skip to content

Commit

Permalink
Refactor: remove all parameters from Inject marker
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirVondukr committed Jan 13, 2024
1 parent 5fcd07d commit 14efcbd
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 199 deletions.
60 changes: 20 additions & 40 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import contextlib
from collections import defaultdict
from collections.abc import Iterator
from typing import Any, TypeAlias, TypeVar

Expand All @@ -8,52 +7,29 @@


_T = TypeVar("_T")
_Providers: TypeAlias = dict[type[_T], list[Provider[_T]]]
_Providers: TypeAlias = dict[type[_T], Provider[_T]]


class Container:
def __init__(self) -> None:
self.providers: _Providers[Any] = defaultdict(list)
self._overrides: _Providers[Any] = defaultdict(list)
self.providers: _Providers[Any] = {}
self._overrides: _Providers[Any] = {}

def register(
self,
provider: Provider[Any],
) -> None:
self.providers[provider.type_].append(provider)
if provider.type_ in self.providers:
msg = f"Provider for type {provider.type_} is already registered"
raise ValueError(msg)

@staticmethod
def _get_provider(
providers: _Providers[_T],
type_: type[_T],
impl: Any | None = None,
) -> Provider[_T] | None:
type_providers = providers[type_]
if not type_providers:
return None
self.providers[provider.type_] = provider

if impl is None and len(type_providers) == 1:
return type_providers[0]

if impl is None:
err_msg = (
f"Multiple providers for type {type_.__qualname__} were found, "
f"you have to specify implementation using 'impl' parameter: "
f"Annotated[IService, Inject(impl=Service)]"
)
raise ValueError(err_msg)
return next((p for p in type_providers if p.impl is impl), None)

def get_provider(
self,
type_: type[_T],
impl: Any | None = None,
) -> Provider[_T]:
overridden_provider = self._get_provider(self._overrides, type_, impl)
if overridden_provider is not None:
def get_provider(self, type_: type[_T]) -> Provider[_T]:
if (overridden_provider := self._overrides.get(type_)) is not None:
return overridden_provider

provider = self._get_provider(self.providers, type_, impl)
provider = self.providers.get(type_)
if provider is None:
err_msg = f"Provider for type {type_.__qualname__} not found"
raise ValueError(err_msg)
Expand All @@ -70,12 +46,16 @@ def override(
self,
provider: Provider[Any],
) -> Iterator[None]:
self._overrides[provider.type_].append(provider)
previous = self._overrides.get(provider.type_)
self._overrides[provider.type_] = provider

yield
self._overrides[provider.type_].remove(provider)

del self._overrides[provider.type_]
if previous is not None:
self._overrides[provider.type_] = previous

async def aclose(self) -> None:
for providers in self.providers.values():
for provider in providers:
if isinstance(provider, Singleton):
await provider.aclose()
for provider in self.providers.values():
if isinstance(provider, Singleton):
await provider.aclose()
49 changes: 16 additions & 33 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
_T = TypeVar("_T")

_AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"]
_TypeAndImpl: TypeAlias = tuple[type[_T], _T | None]
_ExitStackT = TypeVar("_ExitStackT")

context_var: ContextVar[_AnyCtx] = ContextVar("aioinject_context")
Expand All @@ -43,7 +42,7 @@ class _BaseInjectionContext(Generic[_ExitStackT]):
def __init__(self, container: Container) -> None:
self._container = container
self._exit_stack = self._exit_stack_type()
self._cache: dict[_TypeAndImpl[Any], Any] = {}
self._cache: dict[type[Any], Any] = {}
self._token = None

def __class_getitem__(
Expand All @@ -61,20 +60,15 @@ class InjectionContext(_BaseInjectionContext[AsyncExitStack]):
async def resolve(
self,
type_: type[_T],
impl: Any | None = None,
*,
use_cache: bool = True,
) -> _T:
if use_cache and (type_, impl) in self._cache:
return self._cache[type_, impl]
provider = self._container.get_provider(type_)
use_cache = True

if use_cache and type_ in self._cache:
return self._cache[type_]

provider = self._container.get_provider(type_, impl)
dependencies = {
dep.name: await self.resolve(
type_=dep.type_,
impl=dep.implementation,
use_cache=dep.use_cache,
)
dep.name: await self.resolve(type_=dep.type_)
for dep in provider.dependencies
}

Expand All @@ -83,7 +77,7 @@ async def resolve(
stack=self._exit_stack,
)
if use_cache:
self._cache[type_, impl] = resolved
self._cache[type_] = resolved
return resolved

@overload
Expand Down Expand Up @@ -117,10 +111,9 @@ async def execute(
for dependency in dependencies:
if dependency.name in kwargs:
continue

resolved[dependency.name] = await self.resolve(
type_=dependency.type_,
impl=dependency.implementation,
use_cache=dependency.use_cache,
)
if inspect.iscoroutinefunction(function):
return await function(*args, **kwargs, **resolved)
Expand All @@ -144,28 +137,22 @@ class SyncInjectionContext(_BaseInjectionContext[ExitStack]):
def resolve(
self,
type_: type[_T],
impl: Any | None = None,
*,
use_cache: bool = True,
) -> _T:
if use_cache and (type_, impl) in self._cache:
return self._cache[type_, impl]
use_cache = True
if use_cache and type_ in self._cache:
return self._cache[type_]

provider = self._container.get_provider(type_, impl)
provider = self._container.get_provider(type_)
dependencies = {
dep.name: self.resolve(
type_=dep.type_,
impl=dep.implementation,
use_cache=dep.use_cache,
)
dep.name: self.resolve(type_=dep.type_)
for dep in provider.dependencies
}

resolved = provider.provide_sync(**dependencies)
if isinstance(resolved, contextlib.ContextDecorator):
resolved = self._exit_stack.enter_context(resolved) # type: ignore[arg-type]
if use_cache:
self._cache[type_, impl] = resolved
self._cache[type_] = resolved
return resolved

def execute(
Expand All @@ -179,11 +166,7 @@ def execute(
for dependency in dependencies:
if dependency.name in kwargs:
continue
resolved[dependency.name] = self.resolve(
type_=dependency.type_,
impl=dependency.implementation,
use_cache=dependency.use_cache,
)
resolved[dependency.name] = self.resolve(type_=dependency.type_)
return function(*args, **kwargs, **resolved)

def __enter__(self) -> Self:
Expand Down
8 changes: 3 additions & 5 deletions aioinject/markers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from dataclasses import dataclass
from typing import Any
import dataclasses


@dataclass(slots=True)
@dataclasses.dataclass(slots=True)
class Inject:
impl: Any | None = None
cache: bool = True
pass
4 changes: 0 additions & 4 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
class Dependency(Generic[_T]):
name: str
type_: type[_T]
implementation: Any
use_cache: bool


def _get_annotation_args(type_hint: Any) -> tuple[type, tuple[Any, ...]]:
Expand Down Expand Up @@ -82,8 +80,6 @@ def collect_dependencies(
yield Dependency(
name=name,
type_=dep_type,
implementation=inject_marker.impl,
use_cache=inject_marker.cache,
)


Expand Down
41 changes: 19 additions & 22 deletions aioinject/validation/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ def all_dependencies_are_present(
container: aioinject.Container,
) -> Sequence[ContainerValidationError]:
errors = []
for providers in container.providers.values():
for provider in providers:
for dependency in provider.dependencies:
if dependency.type_ not in container.providers:
error = DependencyNotFoundError(
message=f"Provider for type {dependency.type_} not found",
dependency=dependency.type_,
)
errors.append(error)
for provider in container.providers.values():
for dependency in provider.dependencies:
if dependency.type_ not in container.providers:
error = DependencyNotFoundError(
message=f"Provider for type {dependency.type_} not found",
dependency=dependency.type_,
)
errors.append(error)

return errors

Expand All @@ -41,17 +40,15 @@ def __call__(
container: aioinject.Container,
) -> Sequence[ContainerValidationError]:
errors = []
for providers in container.providers.values():
for provider in providers:
if not self.dependant(provider):
continue

for dependency in provider.dependencies:
dependency_provider = container.get_provider(
type_=dependency.type_,
impl=dependency.implementation,
)
if self.dependency(dependency_provider):
msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}"
errors.append(ContainerValidationError(msg))
for provider in container.providers.values():
if not self.dependant(provider):
continue

for dependency in provider.dependencies:
dependency_provider = container.get_provider(
type_=dependency.type_,
)
if self.dependency(dependency_provider):
msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}"
errors.append(ContainerValidationError(msg))
return errors
50 changes: 11 additions & 39 deletions tests/container/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from aioinject import Singleton, providers
from aioinject import Scoped, Singleton, providers
from aioinject.containers import Container
from aioinject.context import InjectionContext

Expand Down Expand Up @@ -38,18 +38,20 @@ def test_can_register_single(container: Container) -> None:
provider = providers.Scoped(_ServiceA)
container.register(provider)

expected = {_ServiceA: [provider]}
expected = {_ServiceA: provider}
assert container.providers == expected


def test_can_register_multi(container: Container) -> None:
provider_a = providers.Scoped(_ServiceA)
provider_b = providers.Scoped(_ServiceB)
container.register(provider_a)
container.register(provider_b)
def test_cant_register_multiple_providers_for_same_type(
container: Container,
) -> None:
container.register(Scoped(int))

expected = {_ServiceA: [provider_a], _ServiceB: [provider_b]}
assert container.providers == expected
with pytest.raises(
ValueError,
match="^Provider for type <class 'int'> is already registered$",
):
container.register(Scoped(int))


def test_can_retrieve_single_provider(container: Container) -> None:
Expand All @@ -58,36 +60,6 @@ def test_can_retrieve_single_provider(container: Container) -> None:
assert container.get_provider(int)


@pytest.fixture
def multi_provider_container(container: Container) -> Container:
a_provider = providers.Scoped(_ServiceA, type_=_AbstractService)
b_provider = providers.Scoped(_ServiceB, type_=_AbstractService)
container.register(a_provider)
container.register(b_provider)
return container


def test_get_provider_raises_error_if_multiple_providers(
multi_provider_container: Container,
) -> None:
with pytest.raises(ValueError) as exc_info: # noqa: PT011
assert multi_provider_container.get_provider(_AbstractService)

msg = (
f"Multiple providers for type {_AbstractService.__qualname__} were found, "
f"you have to specify implementation using 'impl' parameter: "
f"Annotated[IService, Inject(impl=Service)]"
)
assert str(exc_info.value) == msg


def test_can_get_multi_provider_if__specified(
multi_provider_container: Container,
) -> None:
assert multi_provider_container.get_provider(_AbstractService, _ServiceA)
assert multi_provider_container.get_provider(_AbstractService, _ServiceB)


def test_missing_provider() -> None:
container = Container()
with pytest.raises(ValueError) as exc_info: # noqa: PT011
Expand Down
18 changes: 15 additions & 3 deletions tests/container/test_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,22 @@ def test_provider_override() -> None:
container = Container()
container.register(Object(a))
with container.sync_context() as ctx:
assert ctx.resolve(_A, use_cache=False) is a
assert ctx.resolve(_A) is a

a_overridden = _A()
assert a_overridden is not a

with container.override(Object(a_overridden)):
assert ctx.resolve(_A, use_cache=False) is a_overridden is not a
with container.sync_context() as ctx, container.override(
Object(a_overridden),
):
assert ctx.resolve(_A) is a_overridden is not a


def test_override_multiple_times() -> None:
container = Container()
with container.override(Object(1)):
with container.override(Object(2)), container.sync_context() as ctx:
assert ctx.resolve(int) == 2 # noqa: PLR2004

with container.sync_context() as ctx:
assert ctx.resolve(int) == 1
Loading

0 comments on commit 14efcbd

Please sign in to comment.