From 92b38283ad6094e5120209bd431efba1009a1c7d Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Mon, 25 Mar 2024 14:38:51 +0100 Subject: [PATCH] fmt --- README.md | 18 +++--- pyproject.toml | 1 + src/magic_di/_connectable.py | 15 ++--- src/magic_di/_container.py | 58 +++++++++++--------- src/magic_di/_injector.py | 83 +++++++++++++++++----------- src/magic_di/_signature.py | 2 + src/magic_di/_utils.py | 25 +++++---- src/magic_di/celery/__init__.py | 4 +- src/magic_di/celery/_async_utils.py | 5 +- src/magic_di/celery/_loader.py | 21 +++---- src/magic_di/celery/_provide.py | 15 ++--- src/magic_di/celery/_task.py | 31 ++++++----- src/magic_di/exceptions.py | 48 +++++++--------- src/magic_di/fastapi/__init__.py | 1 + src/magic_di/fastapi/_app.py | 32 +++++++---- src/magic_di/fastapi/_provide.py | 18 +++--- src/magic_di/testing.py | 11 ++-- src/magic_di/utils.py | 12 ++-- tests/conftest.py | 22 +++----- tests/test_app.py | 25 +++++---- tests/test_celery.py | 85 ++++++++++++++++------------- tests/test_injector.py | 15 ++--- tests/test_utils.py | 1 + 23 files changed, 299 insertions(+), 249 deletions(-) diff --git a/README.md b/README.md index 714e54d..f8d1f93 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ Dependency Injector that makes your life easier with built-in support of FastAPI, Celery (and it can be integrated with everything) -What are the problems with FastAPI’s dependency injector? -1) It forces you to use global variables. +What are the problems with FastAPI’s dependency injector? +1) It forces you to use global variables. 2) You need to write an endless number of fabrics with startup logic 3) It makes your project highly dependent on FastAPI’s injector by using “Depends” everywhere. @@ -211,7 +211,7 @@ def hello_world(service: Provide[MyInterface]) -> dict: } ``` -Using `injector.bind`, you can bind implementations that will be injected everywhere the bound interface is used. +Using `injector.bind`, you can bind implementations that will be injected everywhere the bound interface is used. ## Integration with Celery @@ -247,22 +247,22 @@ app = Celery( class CalculatorTaskDeps(BaseCeleryConnectableDeps): calculator: Calculator - + class CalculatorTask(InjectableCeleryTask): deps: CalculatorTaskDeps - + async def run(self, x: int, y: int, smart_processor: SmartProcessor = Provide()): return smart_processor.process( await self.deps.calculator.calculate(x, y) ) - + app.register_task(CalculatorTask) ``` ### Limitations -You could notice that in these examples tasks are using Python async/await. -`InjectableCeleryTask` provides support for writing async code. However, it still executes code synchronously. +You could notice that in these examples tasks are using Python async/await. +`InjectableCeleryTask` provides support for writing async code. However, it still executes code synchronously. **Due to this, getting results from async tasks is not possible in the following cases:** * When the `task_always_eager` config flag is enabled and task creation occurs inside the running event loop (e.g., inside an async FastAPI endpoint) * When calling the `.apply()` method inside running event loop (e.g., inside an async FastAPI endpoint) @@ -297,7 +297,7 @@ def client(injector: DependencyInjector, service_mock: InjectableMock): with TestClient(app) as client: yield client - + def test_http_handler(client): resp = client.post('/hello-world') diff --git a/pyproject.toml b/pyproject.toml index a8ad34a..1e295b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ build-backend = "poetry.core.masonry.api" [tool.ruff] target-version = "py38" # The lowest supported version +line-length = 100 [tool.ruff.lint] # By default, enable all the lint rules. diff --git a/src/magic_di/_connectable.py b/src/magic_di/_connectable.py index 11d0068..8210b68 100644 --- a/src/magic_di/_connectable.py +++ b/src/magic_di/_connectable.py @@ -6,15 +6,14 @@ class ConnectableProtocol(Protocol): """ Interface for injectable clients. Adding these methods to your class will allow it to be dependency injectable. - The dependency injector uses duck typing to check that the class implements the interface. + The dependency injector uses duck typing to check that the class + implements the interface. This means that you do not need to inherit from this protocol. """ - async def __connect__(self) -> None: - ... + async def __connect__(self) -> None: ... - async def __disconnect__(self) -> None: - ... + async def __disconnect__(self) -> None: ... class Connectable: @@ -23,8 +22,6 @@ class Connectable: without adding these empty methods. """ - async def __connect__(self) -> None: - ... + async def __connect__(self) -> None: ... - async def __disconnect__(self) -> None: - ... + async def __disconnect__(self) -> None: ... diff --git a/src/magic_di/_container.py b/src/magic_di/_container.py index f2ac79f..c069d23 100644 --- a/src/magic_di/_container.py +++ b/src/magic_di/_container.py @@ -1,24 +1,26 @@ +from __future__ import annotations + import functools import inspect from dataclasses import dataclass from threading import Lock -from typing import Generic, Iterable, Type, TypeVar, cast +from typing import Generic, Iterable, TypeVar, cast T = TypeVar("T") @dataclass(frozen=True) class Dependency(Generic[T]): - object: Type[T] + object: type[T] instance: T | None = None class SingletonDependencyContainer: def __init__(self): - self._deps: dict[Type[T], Dependency[T]] = {} + self._deps: dict[type[T], Dependency[T]] = {} self._lock: Lock = Lock() - def add(self, obj: Type[T], **kwargs) -> Type[T]: + def add(self, obj: type[T], **kwargs) -> type[T]: with self._lock: if dep := self._get(obj): return dep @@ -38,33 +40,33 @@ def add(self, obj: Type[T], **kwargs) -> Type[T]: return wrapped - def get(self, obj: Type[T]) -> Type[T] | None: + def get(self, obj: type[T]) -> type[T] | None: with self._lock: return self._get(obj) - def iter_instances(self, *, reverse: bool = False) -> Iterable[tuple[Type[T], T]]: + def iter_instances(self, *, reverse: bool = False) -> Iterable[tuple[type[T], T]]: with self._lock: deps_iter: Iterable = list( - reversed(self._deps.values()) if reverse else self._deps.values() + reversed(self._deps.values()) if reverse else self._deps.values(), ) for dep in deps_iter: if dep.instance: yield dep.object, dep.instance - def _get(self, obj: Type[T]) -> Type[T] | None: + def _get(self, obj: type[T]) -> type[T] | None: dep = self._deps.get(obj) return dep.object if dep else None -def _wrap(obj: Type[T], *args, **kwargs) -> Type[T]: +def _wrap(obj: type[T], *args, **kwargs) -> type[T]: if not inspect.isclass(obj): partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) - return cast(Type[T], partial) + return cast(type[T], partial) _instance = None - def __new__(cls): + def new(_): nonlocal _instance if _instance is not None: @@ -73,20 +75,28 @@ def __new__(cls): _instance = obj(*args, **kwargs) return _instance - # This class wraps obj and creates a singleton class that uses *args and **kwargs for initialization. - # Please note that it also doesn’t allow passing additional args after creating this partial. - # It was made to make it more obvious that you can’t create two different instances with different parameters. + # This class wraps obj and creates a singleton class + # that uses *args and **kwargs for initialization. + # + # Please note that it also doesn't allow passing additional args + # after creating this partial. + # It was made to make it more obvious that you can't create + # two different instances with different parameters. + # # Example: - # injected_redis = injector.inject(Redis) - # injected_redis() # works - # injected_redis(timeout=10) # doesn’t work + # >>> injected_redis = injector.inject(Redis) + # >>> injected_redis() # works + # >>> injected_redis(timeout=10) # doesn't work # - # Here we manually create a new class using the `type` metaclass to prevent possible overrides of it. - # We copy all object attributes (__dict__) so that upon inspection, + # Here we manually create a new class using the `type` metaclass + # to prevent possible overrides of it. + # We copy all object attributes (__dict__) + # so that upon inspection, # the class should look exactly like the wrapped class. - # However, we override the __new__ method to return an instance of the original class. + # However, we override the __new__ method + # to return an instance of the original class. # Since the original class was not modified, it will use its own metaclass. - SingletonPartialCls = functools.wraps( + return functools.wraps( obj, updated=(), )( @@ -95,9 +105,7 @@ def __new__(cls): (obj,), { **obj.__dict__, - "__new__": __new__, + "__new__": new, }, - ) + ), ) - - return SingletonPartialCls diff --git a/src/magic_di/_injector.py b/src/magic_di/_injector.py index a021a6b..8d089a7 100644 --- a/src/magic_di/_injector.py +++ b/src/magic_di/_injector.py @@ -1,10 +1,20 @@ +from __future__ import annotations + import inspect import logging from contextlib import contextmanager from threading import Lock -from typing import Annotated, Any, Callable, Iterable, Type, TypeVar, cast, get_origin +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Iterable, + TypeVar, + cast, + get_origin, +) -from magic_di._connectable import ConnectableProtocol from magic_di._container import SingletonDependencyContainer from magic_di._signature import Signature from magic_di._utils import ( @@ -15,19 +25,23 @@ ) from magic_di.exceptions import InjectionError, InspectionError +if TYPE_CHECKING: + from magic_di._connectable import ConnectableProtocol + # flag to use in typing.Annotated # to forcefully mark dependency as injectable # even if dependency is not connectable -# Example: typing.Annotated[MyDependency, Injectable] +# Example: +# >>> typing.Annotated[MyDependency, Injectable] Injectable = type("Injectable", (), {})() -def is_injector(cls: Type) -> bool: +def is_injector(cls: type) -> bool: """ Check if a class is a subclass of DependencyInjector. Args: - cls (Type): The class to check. + cls (type): The class to check. Returns: bool: True if the class is a subclass of DependencyInjector, False otherwise. @@ -45,12 +59,14 @@ def is_forcefully_marked_as_injectable(cls: Any) -> bool: T = TypeVar("T") AnyObject = TypeVar("AnyObject") +logger = logging.getLogger("injector") + class DependencyInjector: def __init__( self, bindings: dict | None = None, - logger: logging.Logger = logging.getLogger("injector"), + logger: logging.Logger = logger, ): self.bindings: dict = bindings or {} self.logger: logging.Logger = logger @@ -104,21 +120,20 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]: signature = Signature(obj, is_injectable=is_injectable(obj)) - for name, hint in hints.items(): - hint = self._unwrap_type_hint(hint) + for name, hint_ in hints.items(): + hint = self._unwrap_type_hint(hint_) hint_with_extra = hints_with_extras[name] if is_injector(hint): signature.injector_arg = name - continue elif not is_injectable(hint) and not is_forcefully_marked_as_injectable( - hint_with_extra + hint_with_extra, ): signature.kwargs[name] = hint - continue + else: + signature.deps[name] = hint - signature.deps[name] = hint - except Exception as exc: + except Exception as exc: # noqa: BLE001 raise InspectionError(obj) from exc return signature @@ -129,12 +144,13 @@ async def connect(self): """ # unpack to create copy of list for postponed in [*self._postponed]: - # First, we need to create instances of postponed injection in order to connect them. + # First, we need to create instances of postponed injection + # in order to connect them. self.inject(postponed) for cls, instance in self._deps.iter_instances(): if is_injectable(cls): - self.logger.debug(f"Connecting {cls}...") + self.logger.debug("Connecting %s...", cls.__name__) await instance.__connect__() async def disconnect(self): @@ -146,7 +162,7 @@ async def disconnect(self): try: await instance.__disconnect__() except Exception: - self.logger.exception(f"Failed to disconnect {cls}") + self.logger.exception("Failed to disconnect %s", cls.__name__) async def __aenter__(self): await self.connect() @@ -170,7 +186,8 @@ def lazy_inject(self, obj: Callable[..., T]) -> Callable[..., T]: Returns: Callable[..., T]: A function that, when called, - will inject the dependencies and return the injected class instance or function. + will inject the dependencies and + return the injected class instance or function. """ with self._lock: @@ -188,31 +205,36 @@ def inject(): injected = self.inject(obj)() return injected - return cast(Type[T], inject) + return cast(type[T], inject) - def bind(self, bindings: dict[Type, Type]): + def bind(self, bindings: dict[type, type]): """ Bind new bindings to the injector. - This method is used to add new bindings to the injector. Bindings are a dictionary where the keys are the - classes used in dependencies type hints, and the values are the classes that should replace them. + This method is used to add new bindings to the injector. + Bindings are a dictionary where the keys are the + classes used in dependencies type hints, + and the values are the classes that should replace them. - For example, if you have a class `Foo` that depends on an interface `Bar`, and you have a class `BarImpl` - that implements `Bar`, you would add a binding like this: `injector.bind({Bar: BarImpl})`. Then, whenever - `Foo` is injected, it will receive an instance of `BarImpl` instead of `Bar`. + For example, if you have a class `Foo` that depends on an interface `Bar`, + and you have a class `BarImpl` that implements `Bar`, + you would add a binding like this: `injector.bind({Bar: BarImpl})`. + Then, whenever `Foo` is injected, + it will receive an instance of `BarImpl` instead of `Bar`. - If a binding for a particular class or type already exists, this method will update that binding with the - new value + If a binding for a particular class or type already exists, + this method will update that binding with the new value Args: - bindings (dict[Type, Type]): The bindings to add. This should be a dictionary where the keys are + bindings (dict[type, type]): The bindings to add. + This should be a dictionary where the keys are classes and the values are the classes that should replace them. """ with self._lock: self.bindings = self.bindings | bindings @contextmanager - def override(self, bindings: dict[Type, Type]): + def override(self, bindings: dict[type, type]): """ Temporarily override the bindings and dependencies of the injector. @@ -233,10 +255,9 @@ def override(self, bindings: dict[Type, Type]): self._deps = actual_deps self.bindings = actual_bindings - def _unwrap_type_hint(self, obj: Type[AnyObject]) -> Type[AnyObject]: + def _unwrap_type_hint(self, obj: type[AnyObject]) -> type[AnyObject]: obj = get_cls_from_optional(obj) - obj = self.bindings.get(obj, obj) - return obj + return self.bindings.get(obj, obj) def __hash__(self) -> int: """Injector is always unique""" diff --git a/src/magic_di/_signature.py b/src/magic_di/_signature.py index 48193f6..50de05e 100644 --- a/src/magic_di/_signature.py +++ b/src/magic_di/_signature.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from typing import Generic, TypeVar diff --git a/src/magic_di/_utils.py b/src/magic_di/_utils.py index 3d8c2a6..d0e6cfd 100644 --- a/src/magic_di/_utils.py +++ b/src/magic_di/_utils.py @@ -1,4 +1,6 @@ -from typing import Any, Type, TypeVar, Union, cast, get_args +from __future__ import annotations + +from typing import Any, TypeVar, Union, cast, get_args from typing import get_type_hints as _get_type_hints from magic_di import ConnectableProtocol @@ -6,14 +8,14 @@ LegacyUnionType = type(Union[object, None]) try: - from types import UnionType # type: ignore + from types import UnionType # type: ignore[import-error] except ImportError: UnionType = LegacyUnionType # type: ignore[misc, assignment] T = TypeVar("T") -NONE_TYPE: Type = type(None) +NONE_TYPE: type = type(None) def is_class(obj: Any) -> bool: @@ -42,23 +44,22 @@ def get_cls_from_optional(cls: T) -> T: return cls args = get_args(cls) - if len(args) != 2: + + optional_type_hint_args_len = 2 + if len(args) != optional_type_hint_args_len: return cast(T, cls) if NONE_TYPE not in args: return cast(T, cls) for typo in args: - try: - if not issubclass(typo, NONE_TYPE): - return cast(T, typo) - except TypeError: - ... + if not safe_is_subclass(typo, NONE_TYPE): + return cast(T, typo) return cast(T, cls) -def safe_is_subclass(sub_cls: Any, cls: Type) -> bool: +def safe_is_subclass(sub_cls: Any, cls: type) -> bool: try: return issubclass(sub_cls, cls) except TypeError: @@ -82,7 +83,7 @@ def get_type_hints(obj: Any, *, include_extras=False) -> dict[str, type]: try: if is_class(obj): return _get_type_hints(obj.__init__, include_extras=include_extras) # type: ignore[misc] - else: - return _get_type_hints(obj, include_extras=include_extras) + + return _get_type_hints(obj, include_extras=include_extras) except TypeError: return {} diff --git a/src/magic_di/celery/__init__.py b/src/magic_di/celery/__init__.py index 4da2fc8..87b2045 100644 --- a/src/magic_di/celery/__init__.py +++ b/src/magic_di/celery/__init__.py @@ -1,6 +1,6 @@ from ._async_utils import EventLoopGetter from ._loader import get_celery_loader -from ._provide import Provide +from ._provide import PROVIDE from ._task import BaseCeleryConnectableDeps, InjectableCeleryTask __all__ = ( @@ -8,5 +8,5 @@ "BaseCeleryConnectableDeps", "get_celery_loader", "EventLoopGetter", - "Provide", + "PROVIDE", ) diff --git a/src/magic_di/celery/_async_utils.py b/src/magic_di/celery/_async_utils.py index 575348a..80f5f0e 100644 --- a/src/magic_di/celery/_async_utils.py +++ b/src/magic_di/celery/_async_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import threading from dataclasses import dataclass @@ -25,7 +27,8 @@ def start(self) -> asyncio.AbstractEventLoop: self._event_loop = asyncio.get_event_loop() self._loop_thread = threading.Thread( - target=self._event_loop.run_forever, daemon=True + target=self._event_loop.run_forever, + daemon=True, ) self._loop_thread.start() return self._event_loop diff --git a/src/magic_di/celery/_loader.py b/src/magic_di/celery/_loader.py index 6c36bff..28fab6c 100644 --- a/src/magic_di/celery/_loader.py +++ b/src/magic_di/celery/_loader.py @@ -1,25 +1,27 @@ +from __future__ import annotations + import asyncio import os import threading -from typing import Any, Callable, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable from celery import signals from celery.loaders.app import AppLoader # type: ignore[import] -from celery.loaders.base import BaseLoader from magic_di import DependencyInjector from magic_di.celery._async_utils import EventLoop, EventLoopGetter, run_in_event_loop +if TYPE_CHECKING: + from celery.loaders.base import BaseLoader + @runtime_checkable class InjectedCeleryLoaderProtocol(Protocol): injector: DependencyInjector loaded: bool - def on_worker_process_init(self) -> None: - ... + def on_worker_process_init(self) -> None: ... - def get_event_loop(self) -> EventLoop: - ... + def get_event_loop(self) -> EventLoop: ... def get_celery_loader( @@ -73,14 +75,13 @@ def get_event_loop(self) -> EventLoop: with self._event_loop_lock: if self._event_loop is None: self._event_loop = EventLoop( - _event_loop_getter(), running_outside=False + _event_loop_getter(), + running_outside=False, ) return self._event_loop - def _disconnect( - self, *args: Any, pid: str, **kwargs: Any - ) -> None: # noqa: ARG004, ANN401 + def _disconnect(self, *_: Any, pid: str, **__: Any) -> None: log_fn(f"Trying to disconnect celery deps for pid {pid}...") run_in_event_loop(self.injector.disconnect(), self.get_event_loop()) log_fn(f"Celery deps are disconnected for pid {pid}") diff --git a/src/magic_di/celery/_provide.py b/src/magic_di/celery/_provide.py index 67bd8a1..eb268a6 100644 --- a/src/magic_di/celery/_provide.py +++ b/src/magic_di/celery/_provide.py @@ -1,10 +1,5 @@ -from typing import Any - - -def Provide() -> Any: # noqa: ANN401, N802 - """ - An empty function that returns an ellipsis - It's used only to pass celery’s strict_typing check. - It is not necessary to use Provide() if you disable the `strict_typing` option in Celery’s configuration. - """ - return ... +# An empty object that returns an ellipsis +# It's used only to pass celery's strict_typing check. +# It is not necessary to use Provide() +# if you disable the `strict_typing` option in Celery's configuration. +PROVIDE = type("DependencyProvider", (), {})() diff --git a/src/magic_di/celery/_task.py b/src/magic_di/celery/_task.py index 2004d81..752fbe1 100644 --- a/src/magic_di/celery/_task.py +++ b/src/magic_di/celery/_task.py @@ -10,8 +10,7 @@ from magic_di.celery._loader import InjectedCeleryLoaderProtocol -class BaseCeleryConnectableDeps(Connectable): - ... +class BaseCeleryConnectableDeps(Connectable): ... class InjectableCeleryTaskMetaclass(type): @@ -49,10 +48,11 @@ def __init__( return if not self.injector: - raise ValueError( + error_msg = ( "Celery is not injected by magic_di.celery.get_celery_loader " "and injector or deps_instance is not passed in task args" ) + raise ValueError(error_msg) for dep in self.injector.inspect(self.run).deps.values(): self.injector.lazy_inject(dep) @@ -69,7 +69,8 @@ def injector(self) -> DependencyInjector | None: if isinstance(self.app.loader, InjectedCeleryLoaderProtocol): loader: InjectedCeleryLoaderProtocol = cast( - InjectedCeleryLoaderProtocol, self.app.loader + InjectedCeleryLoaderProtocol, + self.app.loader, ) return loader.injector @@ -79,7 +80,8 @@ def injector(self) -> DependencyInjector | None: def load(self) -> None: if isinstance(self.app.loader, InjectedCeleryLoaderProtocol): loader: InjectedCeleryLoaderProtocol = cast( - InjectedCeleryLoaderProtocol, self.app.loader + InjectedCeleryLoaderProtocol, + self.app.loader, ) return loader.on_worker_process_init() @@ -91,14 +93,16 @@ def loaded(self) -> bool: return True loader: InjectedCeleryLoaderProtocol = cast( - InjectedCeleryLoaderProtocol, self.app.loader + InjectedCeleryLoaderProtocol, + self.app.loader, ) return loader.loaded def get_event_loop(self) -> EventLoop | None: if isinstance(self.app.loader, InjectedCeleryLoaderProtocol): loader: InjectedCeleryLoaderProtocol = cast( - InjectedCeleryLoaderProtocol, self.app.loader + InjectedCeleryLoaderProtocol, + self.app.loader, ) return loader.get_event_loop() @@ -108,14 +112,11 @@ def get_event_loop(self) -> EventLoop | None: @staticmethod def run_wrapper(orig_run: Callable) -> Callable: @wraps(orig_run) - def runner(self: "InjectableCeleryTask", *args, **kwargs) -> Any: + def runner(self: InjectableCeleryTask, *args, **kwargs) -> Any: if not self.loaded: self.load() - if self.injector: - run = self.injector.inject(orig_run) - else: - run = orig_run + run = self.injector.inject(orig_run) if self.injector else orig_run if "self" in inspect.signature(run).parameters: args = (self, *args) @@ -125,9 +126,11 @@ def runner(self: "InjectableCeleryTask", *args, **kwargs) -> Any: if inspect.isawaitable(result): event_loop = self.get_event_loop() if not event_loop: - raise ValueError( - "Cannot run async tasks. Celery is not injected by magic_di.celery.get_celery_loader" + error_msg = ( + "Cannot run async tasks. " + "Celery is not injected by magic_di.celery.get_celery_loader" ) + raise ValueError(error_msg) return run_in_event_loop( result, # type: ignore[arg-type] diff --git a/src/magic_di/exceptions.py b/src/magic_di/exceptions.py index b3c09a9..00263f3 100644 --- a/src/magic_di/exceptions.py +++ b/src/magic_di/exceptions.py @@ -5,8 +5,7 @@ from magic_di._utils import get_type_hints -class InjectorError(Exception): - ... +class InjectorError(Exception): ... class InjectionError(InjectorError): @@ -25,31 +24,32 @@ def _build_error_message(self) -> str: missing_kwargs = { param.name: type_hints.get(param.annotation) or param.annotation for param in full_signature.parameters.values() - if param.name not in self.signature.deps - and param.default is full_signature.empty + if param.name not in self.signature.deps and param.default is full_signature.empty } missing_kwargs_str = "\n".join( - [f"{name}: {repr(hint)}" for name, hint in missing_kwargs.items()] + [f"{name}: {hint!r}" for name, hint in missing_kwargs.items()], ) object_location = _get_object_source_location(self.obj) - - return ( - f"Failed to inject {self.obj}. \n" - f"{object_location}\n\n" - f"Missing arguments:\n" - f"{missing_kwargs_str}\n\n" - f"Hint: Did you forget to make these dependencies connectable? " - f"(Inherit from Connectable or implement __connect__ and __disconnect__ methods)" - ) - except Exception: + except Exception: # noqa: BLE001 return ( f"Failed to inject {self.obj}. Missing arguments\n" f"Hint: Did you forget to make these dependencies connectable? " - f"(Inherit from Connectable or implement __connect__ and __disconnect__ methods)" + f"(Inherit from Connectable or " + f"implement __connect__ and __disconnect__ methods)" ) + return ( + f"Failed to inject {self.obj}. \n" + f"{object_location}\n\n" + f"Missing arguments:\n" + f"{missing_kwargs_str}\n\n" + f"Hint: Did you forget to make these dependencies connectable? " + f"(Inherit from Connectable or " + f"implement __connect__ and __disconnect__ methods)" + ) + class InspectionError(InjectorError): def __init__(self, obj: Any): @@ -59,22 +59,16 @@ def __str__(self) -> str: return self._build_error_message() def _build_error_message(self) -> str: - try: - object_location = _get_object_source_location(self.obj) + object_location = _get_object_source_location(self.obj) - return ( - f"Failed to inspect {self.obj}. \n" - f"{object_location}\n" - "See the exception above" - ) - except Exception: - return f"Failed to inspect {self.obj}. \n" "See the exception above" + return f"Failed to inspect {self.obj}. \n{object_location}\nSee the exception above" def _get_object_source_location(obj: Callable) -> str: try: _, obj_line_number = inspect.getsourcelines(obj) source_file = inspect.getsourcefile(obj) - return f"{source_file}:{obj_line_number}" - except Exception: + except Exception: # noqa: BLE001 return "" + + return f"{source_file}:{obj_line_number}" diff --git a/src/magic_di/fastapi/__init__.py b/src/magic_di/fastapi/__init__.py index 61c5d88..e39a413 100644 --- a/src/magic_di/fastapi/__init__.py +++ b/src/magic_di/fastapi/__init__.py @@ -1,6 +1,7 @@ """ Package containing tools for integrating Dependency Injector with FastAPI framework. """ + from ._app import inject_app from ._provide import Provide, Provider diff --git a/src/magic_di/fastapi/_app.py b/src/magic_di/fastapi/_app.py index 36f2cd5..f81bbc1 100644 --- a/src/magic_di/fastapi/_app.py +++ b/src/magic_di/fastapi/_app.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import inspect from contextlib import asynccontextmanager from typing import ( + TYPE_CHECKING, Annotated, Any, Callable, @@ -10,10 +13,12 @@ runtime_checkable, ) -from fastapi import FastAPI, routing from fastapi.params import Depends from magic_di._injector import DependencyInjector +if TYPE_CHECKING: + from fastapi import FastAPI, routing + @runtime_checkable class RouterProtocol(Protocol): @@ -35,7 +40,9 @@ def inject_app( Inject dependencies into a FastAPI application using the provided injector. This function sets up the FastAPI application to connect and disconnect the injector - at startup and shutdown, respectively. This ensures that all dependencies are properly + at startup and shutdown, respectively. + + This ensures that all dependencies are properly connected when the application starts, and properly disconnected when the application shuts down. @@ -50,8 +57,10 @@ def inject_app( use_deprecated_events (bool, optional): Indicate whether the app should be injected with starlette events (which are deprecated) or use app lifespans. - Important: Use this flag only if you still use app events, - because if lifespans are defined, events will be ignored by Starlette. + Important: Use this flag only + if you still use app events, + because if lifespans are defined, + events will be ignored by Starlette. Returns: FastAPI: The FastAPI application with dependencies injected. @@ -97,7 +106,8 @@ def _inject_app_with_events(app: FastAPI, collect_deps_fn: Callable) -> None: def _collect_dependencies( - injector: DependencyInjector, app_router: routing.APIRouter + injector: DependencyInjector, + app_router: routing.APIRouter, ) -> None: """ It walks through all routers and collects all app dependencies to connect them later @@ -111,10 +121,11 @@ def _collect_dependencies( for route in app_router.routes: if not isinstance(route, RouterProtocol): - raise ValueError( + error_msg = ( "Unexpected router class. " f"Router must contain .endpoint property with handler function: {route}" ) + raise TypeError(error_msg) if isinstance(route, FastAPIRouterProtocol): for dependencies in route.dependencies: @@ -133,17 +144,16 @@ def _inspect_and_lazy_inject(obj: object, injector: DependencyInjector) -> None: def _find_fastapi_dependencies(dependency: Callable) -> Iterator[Callable]: """ Recursively finds all FastAPI dependencies. - It looks for FastAPI’s Depends() in default arguments and in type annotations. + It looks for FastAPI's Depends() in default arguments and in type annotations. """ signature = inspect.signature(dependency) for param in signature.parameters.values(): - if isinstance(param.default, Depends): - if param.default.dependency: - yield from _find_fastapi_dependencies(param.default.dependency) + if isinstance(param.default, Depends) and param.default.dependency: + yield from _find_fastapi_dependencies(param.default.dependency) type_hint = param.annotation - if not get_origin(type_hint) is Annotated: + if get_origin(type_hint) is not Annotated: continue for annotation in type_hint.__metadata__: diff --git a/src/magic_di/fastapi/_provide.py b/src/magic_di/fastapi/_provide.py index 3e890c3..d38b39e 100644 --- a/src/magic_di/fastapi/_provide.py +++ b/src/magic_di/fastapi/_provide.py @@ -1,28 +1,32 @@ +from __future__ import annotations + from functools import lru_cache -from typing import TYPE_CHECKING, Annotated, Type, TypeVar +from typing import TYPE_CHECKING, Annotated, TypeVar from fastapi import Depends, Request -from magic_di import DependencyInjector from magic_di._injector import Injectable +if TYPE_CHECKING: + from magic_di import DependencyInjector + T = TypeVar("T") -class FastAPIInjectionError(Exception): - ... +class FastAPIInjectionError(Exception): ... class Provider: - def __getitem__(self, obj: Type[T]) -> Type[T]: + def __getitem__(self, obj: type[T]) -> type[T]: @lru_cache(maxsize=1) def get_dependency(injector: DependencyInjector) -> T: return injector.inject(obj)() def inject(request: Request) -> T: if not hasattr(request.app.state, "dependency_injector"): - raise FastAPIInjectionError( + error_msg = ( "FastAPI application is not injected. Did you forget to add `inject_app(app)`?" ) + raise FastAPIInjectionError(error_msg) injector = request.app.state.dependency_injector return get_dependency(injector) @@ -31,6 +35,6 @@ def inject(request: Request) -> T: if TYPE_CHECKING: - from typing import Union as Provide # hack for mypy + from typing import Union as Provide # hack for mypy # noqa: FIX004 else: Provide = Provider() diff --git a/src/magic_di/testing.py b/src/magic_di/testing.py index 8b5648a..3fb271b 100644 --- a/src/magic_di/testing.py +++ b/src/magic_di/testing.py @@ -7,7 +7,7 @@ def get_injectable_mock_cls(return_value): class ClientMetaclassMock(ConnectableProtocol): __annotations__ = {} - def __new__(cls, *args, **kwargs): + def __new__(cls, *_, **__): return return_value return ClientMetaclassMock @@ -15,7 +15,8 @@ def __new__(cls, *args, **kwargs): class InjectableMock(AsyncMock): """ - You can use this mock to override dependencies in tests and use AsyncMock instead of a real class instance + You can use this mock to override dependencies in tests + and use AsyncMock instead of a real class instance Example: @pytest.fixture() @@ -36,11 +37,9 @@ def test_http_handler(client): def mock_cls(self): return get_injectable_mock_cls(self) - async def __connect__(self): - ... + async def __connect__(self): ... - async def __disconnect__(self): - ... + async def __disconnect__(self): ... def __call__(self, *args, **kwargs): return self.__class__(*args, **kwargs) diff --git a/src/magic_di/utils.py b/src/magic_di/utils.py index 5da6c6a..1638592 100644 --- a/src/magic_di/utils.py +++ b/src/magic_di/utils.py @@ -1,10 +1,14 @@ +from __future__ import annotations + import asyncio import inspect -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any from magic_di import DependencyInjector +if TYPE_CHECKING: + from collections.abc import Callable + def inject_and_run(fn: Callable, injector: DependencyInjector | None = None) -> Any: """ @@ -35,7 +39,7 @@ async def run(): async with injector: if inspect.iscoroutinefunction(fn): return await injected() - else: - return injected() + + return injected() return asyncio.run(run()) diff --git a/tests/conftest.py b/tests/conftest.py index 3e20372..7ee44fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import asyncio from dataclasses import dataclass -from typing import Protocol, Union +from typing import Protocol import pytest - from magic_di import Connectable, DependencyInjector @@ -17,12 +18,10 @@ async def __disconnect__(self): self.connected = False -class Database(ConnectableClient): - ... +class Database(ConnectableClient): ... -class AnotherDatabase(ConnectableClient): - ... +class AnotherDatabase(ConnectableClient): ... class Repository(ConnectableClient): @@ -35,21 +34,19 @@ async def do_something(self): return self.connected and self.db.connected -class AsyncWorkers(ConnectableClient): - ... +class AsyncWorkers(ConnectableClient): ... @dataclass class Service(ConnectableClient): repo: Repository - workers: Union[AsyncWorkers, None] + workers: AsyncWorkers | None def is_alive(self): return self.repo.connected and self.workers.connected -class NonConnectableDatabase: - ... +class NonConnectableDatabase: ... @dataclass @@ -62,8 +59,7 @@ class BrokenService: repo: BrokenRepo -class RepoInterface(Protocol): - ... +class RepoInterface(Protocol): ... @dataclass diff --git a/tests/test_app.py b/tests/test_app.py index 367c55a..5174474 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,12 +1,12 @@ from typing import Annotated import pytest -from starlette.testclient import TestClient - from fastapi import APIRouter, Depends, FastAPI from magic_di import DependencyInjector from magic_di.fastapi import Provide, inject_app from magic_di.fastapi._provide import FastAPIInjectionError +from starlette.testclient import TestClient + from tests.conftest import Database, Service @@ -26,8 +26,7 @@ def hello_world(service: Provide[Service], some_query: str) -> dict: def test_app_injection_with_depends(injector): connected_global_dependency = False - class GlobalConnect(Database): - ... + class GlobalConnect(Database): ... def global_dependency(dep: Provide[GlobalConnect]): nonlocal connected_global_dependency @@ -43,11 +42,10 @@ def global_dependency(dep: Provide[GlobalConnect]): class MiddlewareNonConnectable: creds: str = "secret_creds" - def get_creds(self, value: str | None = None) -> str: + def get_creds(self, value: str | None = None) -> str: # noqa: FA102 return value or self.creds - class AnotherDatabase(Database): - ... + class AnotherDatabase(Database): ... def assert_db_connected(db: Provide[AnotherDatabase]) -> bool: assert db.connected @@ -55,6 +53,7 @@ def assert_db_connected(db: Provide[AnotherDatabase]) -> bool: def get_creds( mw: Provide[MiddlewareNonConnectable], + *, db_connect: bool = Depends(assert_db_connected), ): if db_connect: @@ -64,7 +63,8 @@ def get_creds( @app.get(path="/hello-world", dependencies=[]) def hello_world( - service: Provide[Service], creds: Annotated[str, Depends(get_creds)] + service: Provide[Service], + creds: Annotated[str, Depends(get_creds)], ) -> dict: assert isinstance(service, Service) return {"creds": creds, "is_alive": service.is_alive()} # type: ignore[attr-defined] @@ -77,7 +77,9 @@ def hello_world( @pytest.mark.parametrize("use_deprecated_events", [False, True]) def test_app_injection_clients_connect( - injector: DependencyInjector, use_deprecated_events: bool + injector: DependencyInjector, + *, + use_deprecated_events: bool, ): app = inject_app( FastAPI(), @@ -123,6 +125,5 @@ def test_app_injection_without_registered_injector(injector: DependencyInjector) def hello_world(service: Provide[Service]) -> str: return "OK" - with TestClient(app) as client: - with pytest.raises(FastAPIInjectionError): - client.get("/hello-world") + with TestClient(app) as client, pytest.raises(FastAPIInjectionError): + client.get("/hello-world") diff --git a/tests/test_celery.py b/tests/test_celery.py index b30f346..592ea46 100644 --- a/tests/test_celery.py +++ b/tests/test_celery.py @@ -1,3 +1,4 @@ +import contextlib import tempfile from contextlib import contextmanager from dataclasses import dataclass @@ -6,20 +7,22 @@ from unittest.mock import MagicMock, call import pytest -from starlette.testclient import TestClient - from celery import Celery from celery.bin.celery import celery # type: ignore[import] from fastapi import FastAPI from magic_di import DependencyInjector from magic_di.celery import ( + PROVIDE, BaseCeleryConnectableDeps, InjectableCeleryTask, - Provide, get_celery_loader, ) +from starlette.testclient import TestClient + from tests.conftest import AnotherDatabase, Service +TEST_ERR_MSG = "TEST_ERR_MSG" + @pytest.fixture(scope="module") def injector() -> DependencyInjector: @@ -29,6 +32,7 @@ def injector() -> DependencyInjector: @contextmanager def create_celery( injector: DependencyInjector, + *, use_broker_and_backend: bool, ) -> Iterator[Celery]: with tempfile.NamedTemporaryFile(delete=True) as temp_file: @@ -36,9 +40,7 @@ def create_celery( loader=get_celery_loader(injector), task_cls=InjectableCeleryTask, broker="memory://" if use_broker_and_backend else None, - backend=f"db+sqlite:///{temp_file.name}" - if use_broker_and_backend - else None, + backend=f"db+sqlite:///{temp_file.name}" if use_broker_and_backend else None, ) @@ -51,7 +53,7 @@ def celery_app(injector: DependencyInjector) -> Iterator[Celery]: @pytest.fixture(scope="module") def service_ping_task(celery_app: Celery) -> InjectableCeleryTask: @celery_app.task - async def service_ping(arg1: int, arg2: str, service: Service = Provide()) -> tuple: + async def service_ping(arg1: int, arg2: str, service: Service = PROVIDE) -> tuple: return arg1, arg2, service.is_alive() return cast(InjectableCeleryTask, service_ping) @@ -60,7 +62,7 @@ async def service_ping(arg1: int, arg2: str, service: Service = Provide()) -> tu @pytest.fixture(scope="module") def service_ping_task_sync(celery_app: Celery) -> InjectableCeleryTask: @celery_app.task - def service_ping_sync(arg1: int, arg2: str, service: Service = Provide()) -> tuple: + def service_ping_sync(arg1: int, arg2: str, service: Service = PROVIDE) -> tuple: return arg1, arg2, service.is_alive() return cast(InjectableCeleryTask, service_ping_sync) @@ -116,87 +118,89 @@ def run_celery( yield celery_app - try: - celery.main(args=["control", "shutdown"]) - except SystemExit: - ... + with contextlib.suppress(SystemExit): + (celery.main(args=["control", "shutdown"])) thread.join() def test_async_function_based_tasks( - run_celery: Celery, service_ping_task: InjectableCeleryTask + run_celery: Celery, + service_ping_task: InjectableCeleryTask, ): result = service_ping_task.apply_async(args=(1337, "leet")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1337, "leet", True] result = service_ping_task.apply_async(args=(1010, "wowowo")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "wowowo", True] result = service_ping_task.apply(args=(1010, "123")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "123", True] -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_sync_function_based_tasks( - run_celery: Celery, service_ping_task_sync: InjectableCeleryTask + run_celery: Celery, + service_ping_task_sync: InjectableCeleryTask, ): result = service_ping_task_sync.apply_async(args=(1337, "leet")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1337, "leet", True] result = service_ping_task_sync.apply_async(args=(1010, "wowowo")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "wowowo", True] result = service_ping_task_sync.apply(args=(1010, "123")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "123", True] def test_async_class_based_tasks( - run_celery: Celery, service_ping_class_based_task: InjectableCeleryTask + run_celery: Celery, + service_ping_class_based_task: InjectableCeleryTask, ): result = service_ping_class_based_task.apply_async(args=(1337, "leet")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1337, "leet", True] result = service_ping_class_based_task.apply_async(args=(1010, "wowowo")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "wowowo", True] result = service_ping_class_based_task.apply(args=(1010, "123")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "123", True] def test_sync_class_based_tasks( - run_celery: Celery, service_ping_class_based_task_sync: InjectableCeleryTask + run_celery: Celery, + service_ping_class_based_task_sync: InjectableCeleryTask, ): result = service_ping_class_based_task_sync.apply_async(args=(1337, "leet")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1337, "leet", True] result = service_ping_class_based_task_sync.apply_async(args=(1010, "wowowo")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "wowowo", True] result = service_ping_class_based_task_sync.apply(args=(1010, "123")).get( - disable_sync_subtasks=False + disable_sync_subtasks=False, ) assert list(result) == [1010, "123", True] @@ -207,12 +211,12 @@ def test_retries_func_based_task(): mock = MagicMock() @app.task(max_retries=3, autoretry_for=(ValueError,), retry_backoff=0) - async def ping_task(service: Service = Provide()) -> None: + async def ping_task(service: Service = PROVIDE) -> None: assert service.is_alive() mock() - raise ValueError + raise ValueError(TEST_ERR_MSG) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=TEST_ERR_MSG): _ = ping_task.apply_async().get(disable_sync_subtasks=False) @@ -227,30 +231,31 @@ class PingTask(InjectableCeleryTask): autoretry_for = (ValueError,) retry_backoff = 0 - async def run(self, service: Service = Provide()): + async def run(self, service: Service = PROVIDE): assert service.is_alive() mock() - raise ValueError + raise ValueError(TEST_ERR_MSG) ping_class_task = app.register_task(PingTask) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=TEST_ERR_MSG): ping_class_task.apply_async().get(disable_sync_subtasks=False) - assert len(mock.call_args_list) == 4 + assert len(mock.call_args_list) == PingTask.max_retries + 1 @pytest.mark.parametrize( - "task_always_eager, use_broker_and_backend, expected_mock_calls", + ("task_always_eager", "use_broker_and_backend", "expected_mock_calls"), [ (False, True, [call()]), (True, False, [call(), call()]), (True, True, [call(), call()]), ], ) -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_async_function_based_tasks_inside_event_loop( service_ping_task: InjectableCeleryTask, + *, task_always_eager: bool, use_broker_and_backend: bool, expected_mock_calls: list, @@ -264,7 +269,9 @@ async def test_async_function_based_tasks_inside_event_loop( @app.task async def ping_task( - arg1: int, arg2: str, service: Service = Provide() + arg1: int, + arg2: str, + service: Service = PROVIDE, ) -> tuple: mock() return arg1, arg2, service.is_alive() diff --git a/tests/test_injector.py b/tests/test_injector.py index 94e6bff..66c03cb 100644 --- a/tests/test_injector.py +++ b/tests/test_injector.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from dataclasses import dataclass from typing import Annotated, Generic, TypeVar import pytest - from magic_di import Connectable, DependencyInjector, Injectable from magic_di.exceptions import InjectionError + from tests.conftest import ( AnotherDatabase, AsyncWorkers, @@ -18,7 +20,7 @@ ) -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_class_injection_success(injector): injected_service = injector.inject(Service)() assert not injected_service.is_alive() @@ -45,7 +47,7 @@ async def test_class_injection_success(injector): assert not injected_service.workers.connected -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_function_injection_success(injector): def run_service(service: Service): return service @@ -64,7 +66,7 @@ def test_class_injection_missing_class(injector): injector.inject(BrokenService) -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_class_injection_with_bindings(injector): injector.bind({RepoInterface: Repository}) @@ -126,7 +128,7 @@ def test_injector_iter_deps(injector): def test_injector_with_metaclass(injector): - class _GripServiceMetaClass(type): + class _ServiceMetaClass(type): def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict) -> type: for _ in attrs["__orig_bases__"]: ... @@ -136,8 +138,7 @@ def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict) -> type: tv = TypeVar("tv") - class ServiceGeneric(Generic[tv], metaclass=_GripServiceMetaClass): - ... + class ServiceGeneric(Generic[tv], metaclass=_ServiceMetaClass): ... @dataclass(frozen=True) class WrappedService(ServiceGeneric[Service]): diff --git a/tests/test_utils.py b/tests/test_utils.py index 470a6a4..756f75a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ from magic_di.utils import inject_and_run + from tests.conftest import Repository