Skip to content

Commit

Permalink
fix pydantic support
Browse files Browse the repository at this point in the history
  • Loading branch information
RB387 committed Mar 27, 2024
1 parent f2894cc commit 6a04174
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 19 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pytest-github-actions-annotate-failures = "*"
pytest-cov = "*"
python-kacl = "*"
ruff = ">=0.2.0"
pydantic = "*"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
15 changes: 3 additions & 12 deletions src/magic_di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,24 +95,15 @@ def new(_: Any) -> T:
# >>> injected_redis() # works
# >>> injected_redis(timeout=10) # doesn't work
#
# Here we manually create a new class using the `type` metaclass
# to prevent possible overrides of it.
# We copy all object attributes (__dict__)
# so that upon inspection,
# the class should look exactly like the wrapped class.
# However, we override the __new__ method
# to return an instance of the original class.
# Here we manually create a new singleton class factory using the `type` metaclass
# Since the original class was not modified, it will use its own metaclass.
return functools.wraps(
obj,
updated=(),
)(
type(
obj.__name__,
(obj,),
{
**obj.__dict__,
"__new__": new,
},
(),
{"__new__": new},
),
)
4 changes: 2 additions & 2 deletions src/magic_di/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def connect(self) -> None:
self.inject(postponed)

for cls, instance in self._deps.iter_instances():
if is_injectable(cls):
if is_injectable(instance):
self.logger.debug("Connecting %s...", cls.__name__)
await instance.__connect__()

Expand All @@ -159,7 +159,7 @@ async def disconnect(self) -> None:
Disconnect all injected dependencies
"""
for cls, instance in self._deps.iter_instances(reverse=True):
if is_injectable(cls):
if is_injectable(instance):
try:
await instance.__disconnect__()
except Exception:
Expand Down
9 changes: 8 additions & 1 deletion src/magic_di/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def safe_is_subclass(sub_cls: Any, cls: type) -> bool:
return False


def safe_is_instance(sub_cls: Any, cls: type) -> bool:
try:
return isinstance(sub_cls, cls)
except TypeError:
return False


def is_injectable(cls: Any) -> bool:
"""
Check if a class is a subclass of ConnectableProtocol.
Expand All @@ -76,7 +83,7 @@ def is_injectable(cls: Any) -> bool:
Returns:
bool: True if the class is a subclass of ConnectableProtocol, False otherwise.
"""
return safe_is_subclass(cls, ConnectableProtocol)
return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol)


def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]:
Expand Down
11 changes: 11 additions & 0 deletions tests/test_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,14 @@ class ServiceWithNonConnectable:

injected = injector.inject(ServiceWithNonConnectable)()
assert isinstance(injected.db, NonConnectableDatabase)


def test_injector_pydantic_metaclass_doesnt_break(injector: DependencyInjector) -> None:
from pydantic import BaseModel

class PydanticClass(BaseModel, Connectable):
var: int = 1

injected = injector.inject(PydanticClass)()
assert isinstance(injected, PydanticClass)
assert injected.var == 1

0 comments on commit 6a04174

Please sign in to comment.