From 0633a7df7f05c2f6d4f0e11b2299507668563806 Mon Sep 17 00:00:00 2001 From: Lucas Colombo Date: Sat, 11 Nov 2023 03:00:04 -0300 Subject: [PATCH] fix: singleton reinstantiated if registered at last changed service provider build logic in order to avoid singletons to be instantiated n times when they're registered after its dependant --- rodi/__init__.py | 26 ++++++++++++++++----- tests/test_examples.py | 7 +++++- tests/test_services.py | 52 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 7 deletions(-) diff --git a/rodi/__init__.py b/rodi/__init__.py index 21f4f0e..11c35b7 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -475,7 +475,12 @@ def _get_resolver(self, desired_type, context: ResolutionContext): assert ( reg is not None ), f"A resolver for type {class_name(desired_type)} is not configured" - return reg(context) + resolver = reg(context) + + # add the resolver to the context, so we can find it + # next time we need it + context.resolved[desired_type] = resolver + return resolver def _get_resolvers_for_parameters( self, @@ -1146,14 +1151,23 @@ def build_provider(self) -> Services: _map: Dict[Union[str, Type], Type] = {} for _type, resolver in self._map.items(): - # NB: do not call resolver if one was already prepared for the type - assert _type not in context.resolved, "_map keys must be unique" - if isinstance(resolver, DynamicResolver): context.dynamic_chain.clear() - _map[_type] = resolver(context) - context.resolved[_type] = _map[_type] + if _type in context.resolved: + # assert _type not in context.resolved, "_map keys must be unique" + # check if its in the map + if _type in _map: + # NB: do not call resolver if one was already prepared for the type + raise OverridingServiceException(_type, resolver) + else: + resolved = context.resolved[_type] + else: + # add to context so that we don't repeat operations + resolved = resolver(context) + context.resolved[_type] = resolved + + _map[_type] = resolved type_name = class_name(_type) if "." not in type_name: diff --git a/tests/test_examples.py b/tests/test_examples.py index acaae27..0a7b3c6 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -12,6 +12,11 @@ @pytest.mark.parametrize("file_path", examples) def test_example(file_path: str): - module_name = file_path.replace("./examples/", "").replace(".py", "") + module_name = ( + # Windows + file_path.replace("./examples\\", "") + # Unix + .replace("./examples/", "").replace(".py", "") + ) # assertions are in imported modules importlib.import_module(module_name) diff --git a/tests/test_services.py b/tests/test_services.py index f93fc69..b9bc1aa 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -2539,3 +2539,55 @@ class A: a = container.resolve(A) assert a.foo == "foo" + + +def test_singleton_register_order_last(): + """ + The registration order of singletons should not matter. + Check that singletons are not registered twice when they are registered + after their dependents. + """ + + class Bar: + foo: Foo + + class Bar2: + foo: Foo + + container = Container() + container.register(Bar) + container.register(Bar2) + container._add_exact_singleton(Foo) + + bar = container.resolve(Bar) + bar2 = container.resolve(Bar2) + foo = container.resolve(Foo) + + # check that singletons are always the same instance + assert bar.foo is bar2.foo is foo + + +def test_singleton_register_order_first(): + """ + The registration order of singletons should not matter. + Check that singletons are not registered twice when they are registered + before their dependents. + """ + + class Bar: + foo: Foo + + class Bar2: + foo: Foo + + container = Container() + container._add_exact_singleton(Foo) + container.register(Bar) + container.register(Bar2) + + bar = container.resolve(Bar) + bar2 = container.resolve(Bar2) + foo = container.resolve(Foo) + + # check that singletons are always the same instance + assert bar.foo is bar2.foo is foo