Skip to content

Commit

Permalink
🐛 fix: allow factory method to get requested type as well as activati…
Browse files Browse the repository at this point in the history
…ng type
  • Loading branch information
StummeJ committed Feb 27, 2024
1 parent 95edd65 commit 522e43a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 12 deletions.
33 changes: 24 additions & 9 deletions rodi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def __init__(self, _type, factory):
self._type = _type
self.factory = factory

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: type):
assert isinstance(context, ActivationScope)
return self.factory(context, parent_type)
return self.factory(context, parent_type, self._type)


class SingletonFactoryTypeProvider:
Expand All @@ -349,9 +349,9 @@ def __init__(self, _type, factory):
self.factory = factory
self.instance = None

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: Type):
if self.instance is None:
self.instance = self.factory(context, parent_type)
self.instance = self.factory(context, parent_type, self._type)
return self.instance


Expand All @@ -362,11 +362,11 @@ def __init__(self, _type, factory):
self._type = _type
self.factory = factory

def __call__(self, context: ActivationScope, parent_type):
def __call__(self, context: ActivationScope, parent_type: Type):
if self._type in context.scoped_services:
return context.scoped_services[self._type]

instance = self.factory(context, parent_type)
instance = self.factory(context, parent_type, self._type)
context.scoped_services[self._type] = instance
return instance

Expand Down Expand Up @@ -412,7 +412,7 @@ def get_annotations_type_provider(
life_style: ServiceLifeStyle,
resolver_context: ResolutionContext,
):
def factory(context, parent_type):
def factory(context, parent_type, requested_type):
instance = concrete_type()
for name, resolver in resolvers.items():
setattr(instance, name, resolver(context, parent_type))
Expand Down Expand Up @@ -784,10 +784,12 @@ def exec(
FactoryCallableNoArguments = Callable[[], Any]
FactoryCallableSingleArgument = Callable[[ActivationScope], Any]
FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any]
FactoryCallableThreeArguments = Callable[[ActivationScope, Type, Type], Any]
FactoryCallableType = Union[
FactoryCallableNoArguments,
FactoryCallableSingleArgument,
FactoryCallableTwoArguments,
FactoryCallableThreeArguments,
]


Expand All @@ -797,7 +799,7 @@ class FactoryWrapperNoArgs:
def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type):
def __call__(self, context, activating_type, parent_type):
return self.factory()


Expand All @@ -807,10 +809,20 @@ class FactoryWrapperContextArg:
def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type):
def __call__(self, context, activating_type, parent_type):
return self.factory(context)


class FactoryWrapperPartentArg:
__slots__ = ("factory",)

def __init__(self, factory):
self.factory = factory

def __call__(self, context, activating_type, parent_type):
return self.factory(context, activating_type)


class Container(ContainerProtocol):
"""
Configuration class for a collection of services.
Expand Down Expand Up @@ -1097,6 +1109,9 @@ def _check_factory(factory, signature, handled_type) -> Callable:
return FactoryWrapperContextArg(factory)

if params_len == 2:
return FactoryWrapperPartentArg(factory)

if params_len == 3:
return factory

raise InvalidFactory(handled_type)
Expand Down
65 changes: 62 additions & 3 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,19 +754,21 @@ def test_invalid_factory_too_many_arguments_throws(method_name):
container = Container()
method = getattr(container, method_name)

def factory(context, activating_type, extra_argument_mistake):
def factory(context, activating_type, requested_type, extra_argument_mistake):
return Cat("Celine")

with raises(InvalidFactory):
method(factory, Cat)

def factory(context, activating_type, extra_argument_mistake, two):
def factory(context, activating_type, requested_type, extra_argument_mistake, two):
return Cat("Celine")

with raises(InvalidFactory):
method(factory, Cat)

def factory(context, activating_type, extra_argument_mistake, two, three):
def factory(
context, activating_type, requested_type, extra_argument_mistake, two, three
):
return Cat("Celine")

with raises(InvalidFactory):
Expand Down Expand Up @@ -949,6 +951,15 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca
return Cat("Celine")


def cat_factory_with_context_activating_type_and_requested_type(
context, activating_type, requested_type
) -> Cat:
assert isinstance(context, ActivationScope)
assert activating_type is Cat
assert requested_type is Cat
return Cat("Celine")


@pytest.mark.parametrize(
"method_name,factory",
[
Expand All @@ -962,6 +973,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca
cat_factory_no_args,
cat_factory_with_context,
cat_factory_with_context_and_activating_type,
cat_factory_with_context_activating_type_and_requested_type,
]
],
)
Expand Down Expand Up @@ -1156,6 +1168,53 @@ def factory(_, activating_type) -> Logger:
)


@pytest.mark.parametrize(
"method_name", ["add_transient_by_factory", "add_scoped_by_factory"]
)
def test_factory_can_receive_requested_type_as_parameter(method_name):
class Db:
def __init__(self, activating, requested):
self.activating = activating
self.requested = requested

class Fetcher:
def __init__(self, db: Db):
self.db = db

container = Container()
container._add_exact_transient(Foo)

def factory(self, activating_type, requested_type) -> Db:
return Db(
activating_type.__module__ + "." + activating_type.__name__,
requested_type.__module__ + "." + requested_type.__name__,
)

method = getattr(container, method_name)
method(factory, Db)

container._add_exact_transient(Fetcher)

provider = container.build_provider()

db = provider.get(Db)

assert db is not None
assert db.activating is not None
assert db.activating == "tests.test_services.Db"
assert db.requested is not None
assert db.requested == "tests.test_services.Db"

fetcher = provider.get(Fetcher)

assert fetcher is not None
assert fetcher.db is not None
assert fetcher.db.activating is not None
assert fetcher.db.activating == "tests.test_services.Fetcher"
assert fetcher.db.requested is not None
assert fetcher.db.requested == "tests.test_services.Db"


def test_service_provider_supports_set_by_class():
provider = Services()

Expand Down

0 comments on commit 522e43a

Please sign in to comment.