From 6a041747fc742427f20187e8011c1ae56a0cc21b Mon Sep 17 00:00:00 2001 From: Nikita Zavadin Date: Wed, 27 Mar 2024 12:50:18 +0100 Subject: [PATCH] fix pydantic support --- poetry.lock | 8 ++++---- pyproject.toml | 1 + src/magic_di/_container.py | 15 +++------------ src/magic_di/_injector.py | 4 ++-- src/magic_di/_utils.py | 9 ++++++++- tests/test_injector.py | 11 +++++++++++ 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/poetry.lock b/poetry.lock index 706b7bc..ae010db 100644 --- a/poetry.lock +++ b/poetry.lock @@ -18,7 +18,7 @@ vine = ">=5.0.0,<6.0.0" name = "annotated-types" version = "0.6.0" description = "Reusable constraint types to use with typing.Annotated" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, @@ -1130,7 +1130,7 @@ wcwidth = "*" name = "pydantic" version = "2.6.4" description = "Data validation using Python type hints" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pydantic-2.6.4-py3-none-any.whl", hash = "sha256:cc46fce86607580867bdc3361ad462bab9c222ef042d3da86f2fb333e1d916c5"}, @@ -1149,7 +1149,7 @@ email = ["email-validator (>=2.0.0)"] name = "pydantic-core" version = "2.16.3" description = "" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "pydantic_core-2.16.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:75b81e678d1c1ede0785c7f46690621e4c6e63ccd9192af1f0bd9d504bbb6bf4"}, @@ -1888,4 +1888,4 @@ fastapi = ["fastapi"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "d3ba4bd941d4f05eeaebdfae12672ba2af7b56490e4d1f6e8c6042142dc82d0f" +content-hash = "f14d11227dff0ec4c8b5977390bb3011f7d4a15064891899d6757adf7ff283c6" diff --git a/pyproject.toml b/pyproject.toml index 277b637..69bd7cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ pytest-github-actions-annotate-failures = "*" pytest-cov = "*" python-kacl = "*" ruff = ">=0.2.0" +pydantic = "*" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/src/magic_di/_container.py b/src/magic_di/_container.py index e72f314..3cce013 100644 --- a/src/magic_di/_container.py +++ b/src/magic_di/_container.py @@ -95,13 +95,7 @@ def new(_: Any) -> T: # >>> injected_redis() # works # >>> injected_redis(timeout=10) # doesn't work # - # Here we manually create a new class using the `type` metaclass - # to prevent possible overrides of it. - # We copy all object attributes (__dict__) - # so that upon inspection, - # the class should look exactly like the wrapped class. - # However, we override the __new__ method - # to return an instance of the original class. + # Here we manually create a new singleton class factory using the `type` metaclass # Since the original class was not modified, it will use its own metaclass. return functools.wraps( obj, @@ -109,10 +103,7 @@ def new(_: Any) -> T: )( type( obj.__name__, - (obj,), - { - **obj.__dict__, - "__new__": new, - }, + (), + {"__new__": new}, ), ) diff --git a/src/magic_di/_injector.py b/src/magic_di/_injector.py index 4f2cda5..4dbe6dc 100644 --- a/src/magic_di/_injector.py +++ b/src/magic_di/_injector.py @@ -150,7 +150,7 @@ async def connect(self) -> None: self.inject(postponed) for cls, instance in self._deps.iter_instances(): - if is_injectable(cls): + if is_injectable(instance): self.logger.debug("Connecting %s...", cls.__name__) await instance.__connect__() @@ -159,7 +159,7 @@ async def disconnect(self) -> None: Disconnect all injected dependencies """ for cls, instance in self._deps.iter_instances(reverse=True): - if is_injectable(cls): + if is_injectable(instance): try: await instance.__disconnect__() except Exception: diff --git a/src/magic_di/_utils.py b/src/magic_di/_utils.py index 38cd427..a4b8e56 100644 --- a/src/magic_di/_utils.py +++ b/src/magic_di/_utils.py @@ -66,6 +66,13 @@ def safe_is_subclass(sub_cls: Any, cls: type) -> bool: return False +def safe_is_instance(sub_cls: Any, cls: type) -> bool: + try: + return isinstance(sub_cls, cls) + except TypeError: + return False + + def is_injectable(cls: Any) -> bool: """ Check if a class is a subclass of ConnectableProtocol. @@ -76,7 +83,7 @@ def is_injectable(cls: Any) -> bool: Returns: bool: True if the class is a subclass of ConnectableProtocol, False otherwise. """ - return safe_is_subclass(cls, ConnectableProtocol) + return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol) def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]: diff --git a/tests/test_injector.py b/tests/test_injector.py index 04559ca..2deef33 100644 --- a/tests/test_injector.py +++ b/tests/test_injector.py @@ -155,3 +155,14 @@ class ServiceWithNonConnectable: injected = injector.inject(ServiceWithNonConnectable)() assert isinstance(injected.db, NonConnectableDatabase) + + +def test_injector_pydantic_metaclass_doesnt_break(injector: DependencyInjector) -> None: + from pydantic import BaseModel + + class PydanticClass(BaseModel, Connectable): + var: int = 1 + + injected = injector.inject(PydanticClass)() + assert isinstance(injected, PydanticClass) + assert injected.var == 1