From 522e43a39e48fbd280a2568c198125c89bf9f508 Mon Sep 17 00:00:00 2001 From: Jamie Stumme <3059647+StummeJ@users.noreply.github.com> Date: Mon, 26 Feb 2024 16:23:45 -0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20allow=20factory=20method?= =?UTF-8?q?=20to=20get=20requested=20type=20as=20well=20as=20activating=20?= =?UTF-8?q?type?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rodi/__init__.py | 33 +++++++++++++++------ tests/test_services.py | 65 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 0fc539e..c784c7c 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -336,9 +336,9 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: type): assert isinstance(context, ActivationScope) - return self.factory(context, parent_type) + return self.factory(context, parent_type, self._type) class SingletonFactoryTypeProvider: @@ -349,9 +349,9 @@ def __init__(self, _type, factory): self.factory = factory self.instance = None - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self.instance is None: - self.instance = self.factory(context, parent_type) + self.instance = self.factory(context, parent_type, self._type) return self.instance @@ -362,11 +362,11 @@ def __init__(self, _type, factory): self._type = _type self.factory = factory - def __call__(self, context: ActivationScope, parent_type): + def __call__(self, context: ActivationScope, parent_type: Type): if self._type in context.scoped_services: return context.scoped_services[self._type] - instance = self.factory(context, parent_type) + instance = self.factory(context, parent_type, self._type) context.scoped_services[self._type] = instance return instance @@ -412,7 +412,7 @@ def get_annotations_type_provider( life_style: ServiceLifeStyle, resolver_context: ResolutionContext, ): - def factory(context, parent_type): + def factory(context, parent_type, requested_type): instance = concrete_type() for name, resolver in resolvers.items(): setattr(instance, name, resolver(context, parent_type)) @@ -784,10 +784,12 @@ def exec( FactoryCallableNoArguments = Callable[[], Any] FactoryCallableSingleArgument = Callable[[ActivationScope], Any] FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any] +FactoryCallableThreeArguments = Callable[[ActivationScope, Type, Type], Any] FactoryCallableType = Union[ FactoryCallableNoArguments, FactoryCallableSingleArgument, FactoryCallableTwoArguments, + FactoryCallableThreeArguments, ] @@ -797,7 +799,7 @@ class FactoryWrapperNoArgs: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, parent_type): return self.factory() @@ -807,10 +809,20 @@ class FactoryWrapperContextArg: def __init__(self, factory): self.factory = factory - def __call__(self, context, activating_type): + def __call__(self, context, activating_type, parent_type): return self.factory(context) +class FactoryWrapperPartentArg: + __slots__ = ("factory",) + + def __init__(self, factory): + self.factory = factory + + def __call__(self, context, activating_type, parent_type): + return self.factory(context, activating_type) + + class Container(ContainerProtocol): """ Configuration class for a collection of services. @@ -1097,6 +1109,9 @@ def _check_factory(factory, signature, handled_type) -> Callable: return FactoryWrapperContextArg(factory) if params_len == 2: + return FactoryWrapperPartentArg(factory) + + if params_len == 3: return factory raise InvalidFactory(handled_type) diff --git a/tests/test_services.py b/tests/test_services.py index 866eae0..ee9daed 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -754,19 +754,21 @@ def test_invalid_factory_too_many_arguments_throws(method_name): container = Container() method = getattr(container, method_name) - def factory(context, activating_type, extra_argument_mistake): + def factory(context, activating_type, requested_type, extra_argument_mistake): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two): + def factory(context, activating_type, requested_type, extra_argument_mistake, two): return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) - def factory(context, activating_type, extra_argument_mistake, two, three): + def factory( + context, activating_type, requested_type, extra_argument_mistake, two, three + ): return Cat("Celine") with raises(InvalidFactory): @@ -949,6 +951,15 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca return Cat("Celine") +def cat_factory_with_context_activating_type_and_requested_type( + context, activating_type, requested_type +) -> Cat: + assert isinstance(context, ActivationScope) + assert activating_type is Cat + assert requested_type is Cat + return Cat("Celine") + + @pytest.mark.parametrize( "method_name,factory", [ @@ -962,6 +973,7 @@ def cat_factory_with_context_and_activating_type(context, activating_type) -> Ca cat_factory_no_args, cat_factory_with_context, cat_factory_with_context_and_activating_type, + cat_factory_with_context_activating_type_and_requested_type, ] ], ) @@ -1156,6 +1168,53 @@ def factory(_, activating_type) -> Logger: ) +@pytest.mark.parametrize( + "method_name", ["add_transient_by_factory", "add_scoped_by_factory"] +) +def test_factory_can_receive_requested_type_as_parameter(method_name): + class Db: + def __init__(self, activating, requested): + self.activating = activating + self.requested = requested + + class Fetcher: + def __init__(self, db: Db): + self.db = db + + container = Container() + container._add_exact_transient(Foo) + + def factory(self, activating_type, requested_type) -> Db: + return Db( + activating_type.__module__ + "." + activating_type.__name__, + requested_type.__module__ + "." + requested_type.__name__, + ) + + method = getattr(container, method_name) + method(factory, Db) + + container._add_exact_transient(Fetcher) + + provider = container.build_provider() + + db = provider.get(Db) + + assert db is not None + assert db.activating is not None + assert db.activating == "tests.test_services.Db" + assert db.requested is not None + assert db.requested == "tests.test_services.Db" + + fetcher = provider.get(Fetcher) + + assert fetcher is not None + assert fetcher.db is not None + assert fetcher.db.activating is not None + assert fetcher.db.activating == "tests.test_services.Fetcher" + assert fetcher.db.requested is not None + assert fetcher.db.requested == "tests.test_services.Db" + + def test_service_provider_supports_set_by_class(): provider = Services()