Skip to content

Commit

Permalink
fix: singleton reinstantiated if registered at last
Browse files Browse the repository at this point in the history
changed service provider build logic in order to
avoid singletons to be instantiated n times when
they're registered after its dependant
  • Loading branch information
lucas-labs committed Nov 11, 2023
1 parent de16370 commit 0633a7d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
26 changes: 20 additions & 6 deletions rodi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,12 @@ def _get_resolver(self, desired_type, context: ResolutionContext):
assert (
reg is not None
), f"A resolver for type {class_name(desired_type)} is not configured"
return reg(context)
resolver = reg(context)

# add the resolver to the context, so we can find it
# next time we need it
context.resolved[desired_type] = resolver
return resolver

def _get_resolvers_for_parameters(
self,
Expand Down Expand Up @@ -1146,14 +1151,23 @@ def build_provider(self) -> Services:
_map: Dict[Union[str, Type], Type] = {}

for _type, resolver in self._map.items():
# NB: do not call resolver if one was already prepared for the type
assert _type not in context.resolved, "_map keys must be unique"

if isinstance(resolver, DynamicResolver):
context.dynamic_chain.clear()

_map[_type] = resolver(context)
context.resolved[_type] = _map[_type]
if _type in context.resolved:
# assert _type not in context.resolved, "_map keys must be unique"
# check if its in the map
if _type in _map:
# NB: do not call resolver if one was already prepared for the type
raise OverridingServiceException(_type, resolver)
else:
resolved = context.resolved[_type]
else:
# add to context so that we don't repeat operations
resolved = resolver(context)
context.resolved[_type] = resolved

_map[_type] = resolved

type_name = class_name(_type)
if "." not in type_name:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

@pytest.mark.parametrize("file_path", examples)
def test_example(file_path: str):
module_name = file_path.replace("./examples/", "").replace(".py", "")
module_name = (
# Windows
file_path.replace("./examples\\", "")
# Unix
.replace("./examples/", "").replace(".py", "")
)
# assertions are in imported modules
importlib.import_module(module_name)
52 changes: 52 additions & 0 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,3 +2539,55 @@ class A:
a = container.resolve(A)

assert a.foo == "foo"


def test_singleton_register_order_last():
"""
The registration order of singletons should not matter.
Check that singletons are not registered twice when they are registered
after their dependents.
"""

class Bar:
foo: Foo

class Bar2:
foo: Foo

container = Container()
container.register(Bar)
container.register(Bar2)
container._add_exact_singleton(Foo)

bar = container.resolve(Bar)
bar2 = container.resolve(Bar2)
foo = container.resolve(Foo)

# check that singletons are always the same instance
assert bar.foo is bar2.foo is foo


def test_singleton_register_order_first():
"""
The registration order of singletons should not matter.
Check that singletons are not registered twice when they are registered
before their dependents.
"""

class Bar:
foo: Foo

class Bar2:
foo: Foo

container = Container()
container._add_exact_singleton(Foo)
container.register(Bar)
container.register(Bar2)

bar = container.resolve(Bar)
bar2 = container.resolve(Bar2)
foo = container.resolve(Foo)

# check that singletons are always the same instance
assert bar.foo is bar2.foo is foo

0 comments on commit 0633a7d

Please sign in to comment.