Skip to content

Commit

Permalink
use globals of the function; don't support local ctx; add tests in se…
Browse files Browse the repository at this point in the history
…parated modules.
  • Loading branch information
nrbnlulu committed May 9, 2024
1 parent 2896b63 commit 692ffb0
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 62 deletions.
16 changes: 8 additions & 8 deletions aioinject/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import inspect
import sys
import typing
from collections.abc import Callable, Iterator
from contextlib import (
Expand Down Expand Up @@ -105,13 +106,12 @@ def _get_type_hints(
return typing.get_type_hints(obj, include_extras=True, localns=context)


def get_fn_ns(fn: Callable[..., Any]) -> dict[str, Any]:
return getattr(sys.modules.get(fn.__module__, None), "__dict__", {})


def get_return_annotation(
ret_annotation: str,
globals_: dict[str, Any],
locals_: dict[str, Any],
) -> Any:
return eval( # noqa: S307
ret_annotation,
globals_,
locals_,
)
context: dict[str, Any],
) -> type[Any]:
return eval(ret_annotation, context) # noqa: S307
11 changes: 4 additions & 7 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from aioinject._utils import (
_get_type_hints,
get_fn_ns,
get_return_annotation,
is_context_manager_function,
remove_annotation,
Expand Down Expand Up @@ -165,16 +166,12 @@ def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]: # noqa: C901
# functions might have dependecies in them
# and we don't have the container context here so
# we can't call _get_type_hints
init_frame = (
inspect.currentframe().f_back.f_back # type: ignore[union-attr] # pyright: ignore [reportArgumentType, reportOptionalMemberAccess]
)
assert init_frame # noqa: S101
ret_annotation = unwrapped.__annotations__["return"]

try:
ret_annotation = unwrapped.__annotations__["return"]
return_type = get_return_annotation(
ret_annotation,
globals_=init_frame.f_globals,
locals_=init_frame.f_locals,
context=get_fn_ns(unwrapped),
)
except NameError as e:
msg = f"Factory {factory.__qualname__} does not specify return type. Or it's type is not defined yet."
Expand Down
Empty file.
18 changes: 18 additions & 0 deletions tests/container/mod_tests/provider_fn_deffered_dep_missuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from __future__ import annotations

from aioinject.containers import Container
from aioinject.providers import Singleton


cont = Container()


# notice that C is not defined yet
def get_c() -> C:
return C()


cont.register(Singleton(get_c))


class C: ...
31 changes: 31 additions & 0 deletions tests/container/mod_tests/provider_fn_with_deffered_dep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from dataclasses import dataclass

from aioinject.containers import Container
from aioinject.providers import Singleton


@dataclass
class B:
a: A


cont = Container()


def get_b(a: A) -> B:
return B(a)


cont.register(Singleton(get_b))


class A: ...


cont.register(Singleton(A))
with cont.sync_context() as ctx:
a = ctx.resolve(A)
b = ctx.resolve(B)
assert b.a is a
52 changes: 5 additions & 47 deletions tests/container/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

import pytest

from aioinject.containers import Container
from aioinject.providers import Scoped, Singleton
from aioinject.providers import Scoped


async def test_deffered_dependecies() -> None:
Expand Down Expand Up @@ -38,54 +37,13 @@ def register_decimal_scoped() -> None:


def test_provider_fn_with_deffered_dep() -> None:
@dataclass
class B:
a: A

cont = Container()

def get_b(a: A) -> B:
return B(a)

cont.register(Singleton(get_b))

class A: ...

cont.register(Singleton(A))
with cont.sync_context() as ctx:
a = ctx.resolve(A)
b = ctx.resolve(B)
assert b.a is a


class C: ...


def test_provider_fn_deffered_dep_globals() -> None:

def get_c(_: D) -> C:
return C()

class D: ...

cont = Container()
cont.register(Singleton(D))
cont.register(Singleton(get_c))
with cont.sync_context() as ctx:
_ = ctx.resolve(D)
_ = ctx.resolve(C)
pass


def test_provider_fn_deffered_dep_missuse() -> None:
cont = Container()

def get_a() -> A:
return A()

# notice that B is not defined yet

with pytest.raises(ValueError) as exc_info: # noqa: PT011
cont.register(Singleton(get_a))
from tests.container.mod_tests import (
provider_fn_deffered_dep_missuse, # noqa: F401
)
assert exc_info.match("Or it's type is not defined yet.")

class A: ...

0 comments on commit 692ffb0

Please sign in to comment.