Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support deffered types #2

Merged
merged 9 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ coverage.xml
# virtualenv
.venv
venv*/


# python cached files
*.py[cod]
24 changes: 14 additions & 10 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,27 @@ def __init__(self) -> None:
self._sync_exit_stack = contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
return self._cache.get(provider.resolve_type(), NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.type_] = obj
self._cache[provider.resolve_type()] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -124,9 +128,9 @@ def __init__(self) -> None:

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
if provider.resolve_type() not in self._cache:
async with self._locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
return
yield False

Expand All @@ -135,8 +139,8 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
if provider.resolve_type() not in self._cache:
with self._sync_locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
return
yield False
35 changes: 26 additions & 9 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,33 @@ class Container:
def __init__(self) -> None:
self.providers: _Providers[Any] = {}
self._singletons = SingletonStore()
self._unresolved_providers: list[Provider[Any]] = []
self.type_context: dict[str, type[Any]] = {}

def _resolve_unresolved_provider(self) -> None:
for provider in self._unresolved_providers:
with contextlib.suppress(NameError):
self._register_impl(provider)
self._unresolved_providers.remove(provider)

def _register_impl(self, provider: Provider[Any]) -> None:
provider_type = provider.resolve_type(self.type_context)
if provider_type in self.providers:
msg = f"Provider for type {provider_type} is already registered"
raise ValueError(msg)
self.providers[provider_type] = provider
if klass_name := getattr(provider_type, "__name__", None):
self.type_context[klass_name] = provider_type
Comment on lines +36 to +37
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, but I think it's safe to assume that all python objects have __name__ attribute? 🤔

Suggested change
if klass_name := getattr(provider_type, "__name__", None):
self.type_context[klass_name] = provider_type
self.type_context[provider_type.__name__] = provider_type


def register(
self,
provider: Provider[Any],
) -> None:
if provider.type_ in self.providers:
msg = f"Provider for type {provider.type_} is already registered"
raise ValueError(msg)

self.providers[provider.type_] = provider
try:
self._register_impl(provider)
except NameError:
self._unresolved_providers.append(provider)
self._resolve_unresolved_provider()

def get_provider(self, type_: type[_T]) -> Provider[_T]:
try:
Expand All @@ -53,14 +70,14 @@ def override(
self,
provider: Provider[Any],
) -> Iterator[None]:
previous = self.providers.get(provider.type_)
self.providers[provider.type_] = provider
previous = self.providers.get(provider.resolve_type())
self.providers[provider.resolve_type()] = provider

yield

del self.providers[provider.type_]
del self.providers[provider.resolve_type()]
if previous is not None:
self.providers[provider.type_] = previous
self.providers[provider.resolve_type()] = previous

async def __aenter__(self) -> Self:
return self
Expand Down
6 changes: 4 additions & 2 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def resolve(
return cached

dependencies = {}
for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies(
self._container.type_context,
):
dependencies[dependency.name] = await self.resolve(
type_=dependency.type_,
)
Expand Down Expand Up @@ -151,7 +153,7 @@ def resolve(
return cached

dependencies = {}
for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies():
dependencies[dependency.name] = self.resolve(
type_=dependency.type_,
)
Expand Down
84 changes: 65 additions & 19 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from aioinject.markers import Inject
from aioinject.utils import (
_get_type_hints,
is_context_manager_function,
remove_annotation,
)
Expand Down Expand Up @@ -62,13 +63,13 @@ def _find_inject_marker_in_annotation_args(

def collect_dependencies(
dependant: typing.Callable[..., object] | dict[str, Any],
ctx: dict[str, type[Any]] | None = None,
) -> typing.Iterable[Dependency[object]]:
if not isinstance(dependant, dict):
with remove_annotation(dependant.__annotations__, "return"):
type_hints = typing.get_type_hints(dependant, include_extras=True)
type_hints = _get_type_hints(dependant, context=ctx)
else:
type_hints = dependant

for name, hint in type_hints.items():
dep_type, args = _get_annotation_args(hint)
inject_marker = _find_inject_marker_in_annotation_args(args)
Expand All @@ -81,15 +82,18 @@ def collect_dependencies(
)


def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]:
def _get_provider_type_hints(
provider: Provider[Any],
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
source = provider.impl
if inspect.isclass(source):
source = source.__init__

if isinstance(source, functools.partial):
return {}

type_hints = typing.get_type_hints(source, include_extras=True)
type_hints = _get_type_hints(source, context=context)
for key, value in type_hints.items():
_, args = _get_annotation_args(value)
for arg in args:
Expand Down Expand Up @@ -120,11 +124,14 @@ def _get_provider_type_hints(provider: Provider[Any]) -> dict[str, Any]:
)


def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
def _guess_return_type(
factory: _FactoryType[_T],
context: dict[str, type[Any]] | None = None,
) -> type[_T]:
if isclass(factory):
return typing.cast(type[_T], factory)

type_hints = typing.get_type_hints(factory)
type_hints = _get_type_hints(factory, context=context)
try:
return_type = type_hints["return"]
except KeyError as e:
Expand Down Expand Up @@ -159,34 +166,59 @@ class DependencyLifetime(enum.Enum):

@runtime_checkable
class Provider(Protocol[_T]):
type_: type[_T]
impl: Any
lifetime: DependencyLifetime
_cached_dependecies: tuple[Dependency[object], ...]
_cached_type: type[_T] | None

async def provide(self, kwargs: Mapping[str, Any]) -> _T:
...

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
...

@property
def type_hints(self) -> dict[str, Any]:
def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
...

def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]:
if resolved := getattr(self, "_cached_type", None):
return resolved
ret = self._resolve_type_impl(context)
self._cached_type = ret
return ret
nrbnlulu marked this conversation as resolved.
Show resolved Hide resolved

def type_hints(self, context: dict[str, Any] | None) -> dict[str, Any]:
...

@property
def is_async(self) -> bool:
...

@functools.cached_property
def dependencies(self) -> tuple[Dependency[object], ...]:
return tuple(collect_dependencies(self.type_hints))
def resolve_dependencies(
self,
context: dict[str, Any] | None = None,
) -> tuple[Dependency[object], ...]:
if deps := getattr(self, "_cached_dependecies", None):
return deps
ret = tuple(
collect_dependencies(self.type_hints(context), ctx=context),
)
self._cached_dependecies = ret
return ret
nrbnlulu marked this conversation as resolved.
Show resolved Hide resolved

@functools.cached_property
def is_generator(self) -> bool:
return is_context_manager_function(self.impl)

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})"
def __repr__(self) -> str: # pragma: no cover
try:
type_ = repr(self.resolve_type())
except NameError:
type_ = "UNKNOWN"
return f"{self.__class__.__qualname__}(type={type_}, implementation={self.impl})"


class Scoped(Provider[_T]):
Expand All @@ -197,9 +229,15 @@ def __init__(
factory: _FactoryType[_T],
type_: type[_T] | None = None,
) -> None:
self.type_ = type_ or _guess_return_type(factory)
self.type_ = type_
self.impl = factory

def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
return self.type_ or _guess_return_type(self.impl, context=context)

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
return self.impl(**kwargs) # type: ignore[return-value]

Expand All @@ -209,9 +247,11 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T:

return self.provide_sync(kwargs)

@functools.cached_property
def type_hints(self) -> dict[str, Any]:
type_hints = _get_provider_type_hints(self)
def type_hints(
self,
context: dict[str, Any] | None = None,
) -> dict[str, Any]:
type_hints = _get_provider_type_hints(self, context=context)
if "return" in type_hints:
del type_hints["return"]
return type_hints
Expand Down Expand Up @@ -244,7 +284,7 @@ class Transient(Scoped[_T]):


class Object(Provider[_T]):
type_hints: ClassVar[dict[str, Any]] = {}
_type_hints: ClassVar[dict[str, Any]] = {}
is_async = False
impl: _T
lifetime = DependencyLifetime.scoped # It's ok to cache it
Expand All @@ -257,8 +297,14 @@ def __init__(
self.type_ = type_ or type(object_)
self.impl = object_

def _resolve_type_impl(self, _: dict[str, Any] | None = None) -> type[_T]:
return self.type_

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002
return self.impl

async def provide(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002
return self.impl

def type_hints(self, _: dict[str, Any] | None = None) -> dict[str, Any]:
return self._type_hints
9 changes: 9 additions & 0 deletions aioinject/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,12 @@ def remove_annotation(
yield
if annotation is not sentinel:
annotations[name] = annotation


def _get_type_hints(
obj: Any,
context: dict[str, type[Any]] | None = None,
) -> dict[str, Any]:
if not context:
context = {}
return typing.get_type_hints(obj, include_extras=True, localns=context)
18 changes: 12 additions & 6 deletions aioinject/validation/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ def all_dependencies_are_present(
) -> Sequence[ContainerValidationError]:
errors = []
for provider in container.providers.values():
for dependency in provider.dependencies:
if dependency.type_ not in container.providers:
for dependency in provider.resolve_dependencies(
container.type_context,
):
dep_type = dependency.type_
if dep_type not in container.providers:
error = DependencyNotFoundError(
message=f"Provider for type {dependency.type_} not found",
dependency=dependency.type_,
message=f"Provider for type {dep_type} not found",
dependency=dep_type,
)
errors.append(error)

Expand All @@ -44,9 +47,12 @@ def __call__(
if not self.dependant(provider):
continue

for dependency in provider.dependencies:
for dependency in provider.resolve_dependencies(
container.type_context,
):
dep_type = dependency.type_
dependency_provider = container.get_provider(
type_=dependency.type_,
type_=dep_type,
)
if self.dependency(dependency_provider):
msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}"
Expand Down
Loading