From 1de2a10f27a16250b0c9453372e4827c4ba25abc Mon Sep 17 00:00:00 2001 From: Roberto Prevato Date: Sun, 31 Jan 2021 12:24:31 +0100 Subject: [PATCH] v1.1.0 :grapes: --- CHANGELOG.md | 5 + README.md | 62 +++- requirements.txt | 1 + rodi/__init__.py | 237 +++++++++----- setup.py | 4 +- tests/examples.py | 36 +-- tests/test_services.py | 720 ++++++++++++++++++++++++++++------------- 7 files changed, 719 insertions(+), 346 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e75c3cf..ddcebbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [1.0.9] - 2021-01-31 :grapes: +- Adds support to resolve class attributes annotations +- Changes how classes without an `__init__` method are handled +- Updates links to the GitHub organization, [Neoteroi](https://github.com/Neoteroi) + ## [1.0.9] - 2020-11-08 :octocat: - Completely migrates to GitHub Workflows - Improves build to test Python 3.6 and 3.9 diff --git a/README.md b/README.md index 95e83e0..1fb9d80 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,28 @@ -![Build](https://github.com/RobertoPrevato/rodi/workflows/Build/badge.svg) +![Build](https://github.com/Neoteroi/rodi/workflows/Build/badge.svg) [![pypi](https://img.shields.io/pypi/v/rodi.svg)](https://pypi.python.org/pypi/rodi) -[![versions](https://img.shields.io/pypi/pyversions/rodi.svg)](https://github.com/RobertoPrevato/rodi) -[![codecov](https://codecov.io/gh/RobertoPrevato/rodi/branch/master/graph/badge.svg?token=VzAnusWIZt)](https://codecov.io/gh/RobertoPrevato/rodi) -[![license](https://img.shields.io/github/license/RobertoPrevato/rodi.svg)](https://github.com/RobertoPrevato/rodi/blob/master/LICENSE) +[![versions](https://img.shields.io/pypi/pyversions/rodi.svg)](https://github.com/Neoteroi/rodi) +[![codecov](https://codecov.io/gh/Neoteroi/rodi/branch/master/graph/badge.svg?token=VzAnusWIZt)](https://codecov.io/gh/Neoteroi/rodi) +[![license](https://img.shields.io/github/license/Neoteroi/rodi.svg)](https://github.com/Neoteroi/rodi/blob/master/LICENSE) # Implementation of dependency injection for Python 3 **Features:** * types resolution by signature types annotations (_type hints_) -* types resolution by constructor parameter names and aliases (_convention over configuration_) -* unintrusive: builds objects graph **without** the need to change classes source code (unlike some other Python implementations of dependency injection, like _[inject](https://pypi.org/project/Inject/)_) +* types resolution by class annotations (_type hints_) (new in version 1.1.0 :star:) +* types resolution by constructor parameter names and aliases (_convention over + configuration_) +* unintrusive: builds objects graph **without** the need to change classes + source code (unlike some other Python implementations of dependency + injection, like _[inject](https://pypi.org/project/Inject/)_) * minimum overhead to obtain services, once the objects graph is built -* support for singletons, transient and scoped services +* support for singletons, transient, and scoped services -This library is freely inspired by .NET Standard `Microsoft.Extensions.DependencyInjection` implementation (_ref. [MSDN, Dependency injection in ASP.NET Core](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/dependency-injection?view=aspnetcore-2.1), [Using dependency injection in a .Net Core console application](https://andrewlock.net/using-dependency-injection-in-a-net-core-console-application/)_). +This library is freely inspired by .NET Standard +`Microsoft.Extensions.DependencyInjection` implementation (_ref. [MSDN, +Dependency injection in ASP.NET +Core](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/dependency-injection?view=aspnetcore-2.1), +[Using dependency injection in a .Net Core console +application](https://andrewlock.net/using-dependency-injection-in-a-net-core-console-application/)_). ## Installation @@ -22,16 +31,30 @@ pip install rodi ``` ## De gustibus non disputandum est -This library is designed to work both by type hints in constructor signatures, and by constructor parameter names (convention over configuration), like described in below diagram. It can be useful for those who like type hinting, those who don't, and those who like having both options. +This library is designed to work both by type hints in constructor signatures, +and by constructor parameter names (convention over configuration), like +described in below diagram. It can be useful for those who like type hinting, +those who don't, and those who like having both options. -![Usage option](https://raw.githubusercontent.com/RobertoPrevato/rodi/master/documentation/rodi-design-taste.png "Usage option") +![Usage +option](https://raw.githubusercontent.com/Neoteroi/rodi/master/documentation/rodi-design-taste.png +"Usage option") ## Minimum overhead -`rodi` works by inspecting ____init____ methods **once** at runtime, to generate functions that return instances of desired types. Validation steps, for example to detect circular dependencies or missing services, are done when building these functions, so additional validation is not needed when activating services. +`rodi` works by inspecting ____init____ methods **once** at +runtime, to generate functions that return instances of desired types. +Validation steps, for example to detect circular dependencies or missing +services, are done when building these functions, so additional validation is +not needed when activating services. -For this reason, services are first registered inside an instance of _Container_ class, which implements a method _build_provider()_ that returns an instance of _Services_. The service provider is then used to obtain desired services by type or name. Inspection and validation steps are done only when creating an instance of service provider. +For this reason, services are first registered inside an instance of +_Container_ class, which implements a method _build_provider()_ that +returns an instance of _Services_. The service provider is then used to obtain +desired services by type or name. Inspection and validation steps are done only +when creating an instance of service provider. -![Classes](https://raw.githubusercontent.com/RobertoPrevato/rodi/master/documentation/classes.png "Classes") +![Classes](https://raw.githubusercontent.com/Neoteroi/rodi/master/documentation/classes.png +"Classes") In the example below, a singleton is registered by exact instance. @@ -50,7 +73,16 @@ assert cat.name == "Celine" ## Service life style: * singleton - instantiated only once per service provider * transient - services are instantiated every time they are required -* scoped - instantiated only once per service resolution (root call, e.g. once per web request) +* scoped - instantiated only once per service resolution (root call, e.g. once + per web request) + +## Usage in BlackSheep +`rodi` is used in the [BlackSheep](https://www.neoteroi.dev/blacksheep/) web +framework to implement [dependency +injection](https://www.neoteroi.dev/blacksheep/dependency-injection/) for +request handlers. Refer to the framework's documentation for more information +on how dependency injection is # Documentation -For documentation and examples, please refer to the [wiki in GitHub, https://github.com/RobertoPrevato/rodi/wiki](https://github.com/RobertoPrevato/rodi/wiki). +For documentation and examples, please refer to the [wiki in GitHub, +https://github.com/Neoteroi/rodi/wiki](https://github.com/Neoteroi/rodi/wiki). diff --git a/requirements.txt b/requirements.txt index f531a63..4d4ef45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ pytest-cov pytest-asyncio flake8 mypy +dataclasses==0.8; python_version < '3.7' diff --git a/rodi/__init__.py b/rodi/__init__.py index 9d54119..cbf4c2f 100644 --- a/rodi/__init__.py +++ b/rodi/__init__.py @@ -2,7 +2,16 @@ from collections import defaultdict from enum import Enum from inspect import Signature, _empty, isabstract, isclass, iscoroutinefunction -from typing import Any, Callable, Dict, Optional, Type, Union +from typing import ( + Any, + Callable, + Dict, + Mapping, + Optional, + Type, + Union, + get_type_hints, +) AliasesTypeHint = Dict[str, Union[Type, str]] @@ -11,19 +20,9 @@ class DIException(Exception): """Base exception class for DI exceptions.""" -class ClassNotDefiningInitMethod(DIException): - """Exception risen when a registered class does - not implement a specific __init__ method""" - - def __init__(self, desired_type): - super().__init__( - f"Cannot activate an instance of '{desired_type.__name__}' " - f"because it does not implement a specific __init__ method." - ) - - class CannotResolveParameterException(DIException): - """Exception risen when it is not possible to resolve a parameter, + """ + Exception risen when it is not possible to resolve a parameter, necessary to instantiate a type.""" def __init__(self, param_name, desired_type): @@ -46,7 +45,8 @@ def __init__(self, param_name, desired_type): class OverridingServiceException(DIException): - """Exception risen when registering a service + """ + Exception risen when registering a service would override an existing one.""" def __init__(self, key, value): @@ -221,7 +221,8 @@ def __init__(self, _type, factory): self.factory = factory def __call__(self, context: GetServiceContext, parent_type): - return self.factory(context.provider, parent_type) + assert isinstance(context, GetServiceContext) + return self.factory(context, parent_type) class SingletonFactoryTypeProvider: @@ -234,7 +235,7 @@ def __init__(self, _type, factory): def __call__(self, context: GetServiceContext, parent_type): if self.instance is None: - self.instance = self.factory(context.provider, parent_type) + self.instance = self.factory(context, parent_type) return self.instance @@ -249,7 +250,7 @@ def __call__(self, context: GetServiceContext, parent_type): if self._type in context.scoped_services: return context.scoped_services[self._type] - instance = self.factory(context.provider, parent_type) + instance = self.factory(context, parent_type) context.scoped_services[self._type] = instance return instance @@ -289,6 +290,28 @@ def __call__(self, context, parent_type): return self._instance +def get_annotations_type_provider( + concrete_type: Type, + resolvers: Mapping[str, Callable], + life_style: ServiceLifeStyle, + resolver_context: ResolveContext, +): + def factory(context, parent_type): + instance = concrete_type() + for name, resolver in resolvers.items(): + setattr(instance, name, resolver(context, parent_type)) + return instance + + return FactoryResolver(concrete_type, factory, life_style)(resolver_context) + + +def _get_plain_class_factory(concrete_type: Type): + def factory(*args): + return concrete_type() + + return factory + + class InstanceResolver: __slots__ = ("instance",) @@ -302,18 +325,28 @@ def __call__(self, context: ResolveContext): return InstanceProvider(self.instance) +class Dependency: + __slots__ = ("name", "annotation") + + def __init__(self, name, annotation): + self.name = name + self.annotation = annotation + + class DynamicResolver: - __slots__ = ("concrete_type", "params", "services", "lifestyle") + __slots__ = ("_concrete_type", "services", "life_style") - def __init__(self, concrete_type, services, lifestyle): + def __init__(self, concrete_type, services, life_style): assert isclass(concrete_type) assert not isabstract(concrete_type) - self.concrete_type = concrete_type + self._concrete_type = concrete_type self.services = services - sig = Signature.from_callable(concrete_type.__init__) - self.params = sig.parameters - self.lifestyle = lifestyle + self.life_style = life_style + + @property + def concrete_type(self): + return self._concrete_type def _get_resolver(self, desired_type, context: ResolveContext): # NB: the following two lines are important to ensure that singletons @@ -328,29 +361,13 @@ def _get_resolver(self, desired_type, context: ResolveContext): ) return reg(context) - def __call__(self, context: ResolveContext): - concrete_type = self.concrete_type - + def _get_resolvers_for_parameters( + self, + concrete_type, + context: ResolveContext, + params: Mapping[str, Dependency], + ): chain = context.dynamic_chain - if concrete_type in chain: - raise CircularDependencyException(chain[0], concrete_type) - - chain.append(concrete_type) - - if getattr(concrete_type, "__init__") is object.__init__: - raise ClassNotDefiningInitMethod(concrete_type) - - params = self.params - - if len(params) == 1: - if self.lifestyle == ServiceLifeStyle.SINGLETON: - return SingletonTypeProvider(concrete_type, None) - - if self.lifestyle == ServiceLifeStyle.SCOPED: - return ScopedTypeProvider(concrete_type) - - return TypeProvider(concrete_type) - fns = [] services = self.services @@ -379,9 +396,9 @@ def __call__(self, context: ResolveContext): aliases = services._aliases[param_name] if aliases: - assert len(aliases) == 1, ( - "Configured aliases must " "not be ambiguous" - ) + assert ( + len(aliases) == 1 + ), "Configured aliases cannot be ambiguous" for param_type in aliases: break @@ -393,32 +410,91 @@ def __call__(self, context: ResolveContext): param_resolver = self._get_resolver(param_type, context) fns.append(param_resolver) + return fns + + def _resolve_by_init_method(self, context: ResolveContext): + sig = Signature.from_callable(self.concrete_type.__init__) + params = { + key: Dependency(key, value.annotation) + for key, value in sig.parameters.items() + } + concrete_type = self.concrete_type + + if len(params) == 1 and next(iter(params.keys())) == "self": + if self.life_style == ServiceLifeStyle.SINGLETON: + return SingletonTypeProvider(concrete_type, None) - if self.lifestyle == ServiceLifeStyle.SINGLETON: + if self.life_style == ServiceLifeStyle.SCOPED: + return ScopedTypeProvider(concrete_type) + + return TypeProvider(concrete_type) + + fns = self._get_resolvers_for_parameters(concrete_type, context, params) + + if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonTypeProvider(concrete_type, fns) - if self.lifestyle == ServiceLifeStyle.SCOPED: + if self.life_style == ServiceLifeStyle.SCOPED: return ScopedArgsTypeProvider(concrete_type, fns) return ArgsTypeProvider(concrete_type, fns) + def _resolve_by_annotations( + self, context: ResolveContext, annotations: Dict[str, Type] + ): + params = {key: Dependency(key, value) for key, value in annotations.items()} + concrete_type = self.concrete_type + + fns = self._get_resolvers_for_parameters(concrete_type, context, params) + resolvers = {} + + i = 0 + for name in annotations.keys(): + resolvers[name] = fns[i] + i += 1 + + return get_annotations_type_provider( + self.concrete_type, resolvers, self.life_style, context + ) + + def __call__(self, context: ResolveContext): + concrete_type = self.concrete_type + + chain = context.dynamic_chain + if concrete_type in chain: + raise CircularDependencyException(chain[0], concrete_type) + + chain.append(concrete_type) + + if getattr(concrete_type, "__init__") is object.__init__: + annotations = get_type_hints(concrete_type) + + if annotations: + return self._resolve_by_annotations(context, annotations) + + return FactoryResolver( + concrete_type, _get_plain_class_factory(concrete_type), self.life_style + )(context) + + return self._resolve_by_init_method(context) + class FactoryResolver: - __slots__ = ("concrete_type", "factory", "params", "lifestyle") + __slots__ = ("concrete_type", "factory", "params", "life_style") - def __init__(self, concrete_type, factory, lifestyle): + def __init__(self, concrete_type, factory, life_style): assert isclass(concrete_type) assert not isabstract(concrete_type) self.factory = factory self.concrete_type = concrete_type - self.lifestyle = lifestyle + self.life_style = life_style def __call__(self, context: ResolveContext): - if self.lifestyle == ServiceLifeStyle.SINGLETON: + if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonFactoryTypeProvider(self.concrete_type, self.factory) - if self.lifestyle == ServiceLifeStyle.SCOPED: + if self.life_style == ServiceLifeStyle.SCOPED: return ScopedFactoryTypeProvider(self.concrete_type, self.factory) return FactoryTypeProvider(self.concrete_type, self.factory) @@ -438,7 +514,7 @@ def to_standard_param_name(name): class Services: """ Provides methods to activate instances of classes, - by cached callback functions. + by cached activator functions. """ __slots__ = ("_map", "_executors") @@ -484,7 +560,8 @@ def get( desired_type: Union[Type, str], context: Optional[GetServiceContext] = None, ) -> Any: - """Gets a service of desired type, returning an activated instance. + """ + Gets a service of the desired type, returning an activated instance. :param desired_type: desired service type. :param context: optional context, used to handle scoped services. @@ -537,7 +614,11 @@ def executor(scoped: Optional[Dict[Type, Any]] = None): return executor - def exec(self, method: Callable, scoped: Optional[Dict[Type, Any]] = None,) -> Any: + def exec( + self, + method: Callable, + scoped: Optional[Dict[Type, Any]] = None, + ) -> Any: try: executor = self._executors[method] except KeyError: @@ -578,7 +659,8 @@ def __call__(self, context, activating_type): class Container: - """Configuration class for a collection of services.""" + """ + Configuration class for a collection of services.""" __slots__ = ("_map", "_aliases", "_exact_aliases", "strict") @@ -592,18 +674,19 @@ def __contains__(self, key): return key in self._map def register( - self, base_type: Type, concrete_type: Type, lifestyle: ServiceLifeStyle + self, base_type: Type, concrete_type: Type, life_style: ServiceLifeStyle ): assert issubclass(concrete_type, base_type), ( f"Cannot register {base_type.__name__} for abstract class " f"{concrete_type.__name__}" ) - self._bind(base_type, DynamicResolver(concrete_type, self, lifestyle)) + self._bind(base_type, DynamicResolver(concrete_type, self, life_style)) return self def add_alias(self, name: str, desired_type: Type): - """Adds an alias to the set of inferred aliases. + """ + Adds an alias to the set of inferred aliases. :param name: parameter name :param desired_type: desired type by parameter name @@ -617,7 +700,8 @@ def add_alias(self, name: str, desired_type: Type): return self def add_aliases(self, values: AliasesTypeHint): - """Adds aliases to the set of inferred aliases. + """ + Adds aliases to the set of inferred aliases. :param values: mapping object (parameter name: class) :return: self @@ -627,7 +711,8 @@ def add_aliases(self, values: AliasesTypeHint): return self def set_alias(self, name: str, desired_type: Type, override: bool = False): - """Sets an exact alias for a desired type. + """ + Sets an exact alias for a desired type. :param name: parameter name :param desired_type: desired type by parameter name @@ -667,7 +752,8 @@ def _bind(self, key: Type, value: Any): self._aliases[to_standard_param_name(key_name)].add(key) def add_instance(self, instance: Any, declared_class: Optional[Type] = None): - """Registers an exact instance, optionally by declared class. + """ + Registers an exact instance, optionally by declared class. :param instance: singleton to be registered :param declared_class: optionally, lets define the class used as reference of @@ -681,7 +767,8 @@ def add_instance(self, instance: Any, declared_class: Optional[Type] = None): return self def add_singleton(self, base_type: Type, concrete_type: Optional[Type] = None): - """Registers a type by base type, to be instantiated with singleton lifetime. + """ + Registers a type by base type, to be instantiated with singleton lifetime. If a single type is given, the method `add_exact_singleton` is used. :param base_type: registered type. If a concrete type is provided, it must @@ -695,7 +782,8 @@ def add_singleton(self, base_type: Type, concrete_type: Optional[Type] = None): return self.register(base_type, concrete_type, ServiceLifeStyle.SINGLETON) def add_scoped(self, base_type: Type, concrete_type: Optional[Type] = None): - """Registers a type by base type, to be instantiated with scoped lifetime. + """ + Registers a type by base type, to be instantiated with scoped lifetime. If a single type is given, the method `add_exact_scoped` is used. :param base_type: registered type. If a concrete type is provided, it must @@ -709,7 +797,8 @@ def add_scoped(self, base_type: Type, concrete_type: Optional[Type] = None): return self.register(base_type, concrete_type, ServiceLifeStyle.SCOPED) def add_transient(self, base_type: Type, concrete_type: Optional[Type] = None): - """Registers a type by base type, to be instantiated with transient lifetime. + """ + Registers a type by base type, to be instantiated with transient lifetime. If a single type is given, the method `add_exact_transient` is used. :param base_type: registered type. If a concrete type is provided, it must @@ -723,7 +812,8 @@ def add_transient(self, base_type: Type, concrete_type: Optional[Type] = None): return self.register(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) def add_exact_singleton(self, concrete_type: Type): - """Registers an exact type, to be instantiated with singleton lifetime. + """ + Registers an exact type, to be instantiated with singleton lifetime. :param concrete_type: concrete class :return: the service collection itself @@ -736,7 +826,8 @@ def add_exact_singleton(self, concrete_type: Type): return self def add_exact_scoped(self, concrete_type: Type): - """Registers an exact type, to be instantiated with scoped lifetime. + """ + Registers an exact type, to be instantiated with scoped lifetime. :param concrete_type: concrete class :return: the service collection itself @@ -748,7 +839,8 @@ def add_exact_scoped(self, concrete_type: Type): return self def add_exact_transient(self, concrete_type: Type): - """Registers an exact type, to be instantiated with transient lifetime. + """ + Registers an exact type, to be instantiated with transient lifetime. :param concrete_type: concrete class :return: the service collection itself @@ -816,7 +908,8 @@ def register_factory( ) def build_provider(self) -> Services: - """Builds and returns a service provider that can be used to activate and + """ + Builds and returns a service provider that can be used to activate and obtain services. The configuration of services is validated at this point, if any service cannot diff --git a/setup.py b/setup.py index e636e92..42cde96 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ def readme(): setup( name="rodi", - version="1.0.9", + version="1.1.0", description="Implementation of dependency injection for Python 3", long_description=readme(), long_description_content_type="text/markdown", @@ -22,7 +22,7 @@ def readme(): "Programming Language :: Python :: 3.9", "Operating System :: OS Independent", ], - url="https://github.com/RobertoPrevato/rodi", + url="https://github.com/Neoteroi/rodi", author="RobertoPrevato", author_email="roberto.prevato@gmail.com", keywords="dependency injection type hints typing convention", diff --git a/tests/examples.py b/tests/examples.py index 3c0c19a..a514961 100644 --- a/tests/examples.py +++ b/tests/examples.py @@ -5,14 +5,12 @@ # domain object: class Cat: - def __init__(self, name): self.name = name # abstract interface class ICatsRepository(ABC): - @abstractmethod def get_by_id(self, _id) -> Cat: pass @@ -20,7 +18,6 @@ def get_by_id(self, _id) -> Cat: # one of the possible implementations of ICatsRepository class InMemoryCatsRepository(ICatsRepository): - def __init__(self): self._cats = {} @@ -30,7 +27,6 @@ def get_by_id(self, _id) -> Cat: # NB: example of business layer class, using interface of repository class GetCatRequestHandler: - def __init__(self, cats_repository: ICatsRepository): self.repo = cats_repository @@ -41,13 +37,11 @@ def get_cat(self, _id): # NB: example of controller class; class CatsController: - def __init__(self, get_cat_request_handler: GetCatRequestHandler): self.cat_request_handler = get_cat_request_handler class IRequestContext(ABC): - @property @abstractmethod def id(self): @@ -60,7 +54,6 @@ def user(self): class RequestContext(IRequestContext): - def __init__(self): pass @@ -74,20 +67,17 @@ def user(self): class ServiceSettings: - def __init__(self, foo_db_connection_string): self.foo_db_connection_string = foo_db_connection_string class FooDBContext: - def __init__(self, service_settings: ServiceSettings): self.settings = service_settings self.connection_string = service_settings.foo_db_connection_string class FooDBCatsRepository(ICatsRepository): - def __init__(self, context: FooDBContext): self.context = context @@ -96,7 +86,6 @@ def get_by_id(self, _id) -> Cat: class IValueProvider: - @property @abstractmethod def value(self): @@ -105,7 +94,7 @@ def value(self): class ValueProvider(IValueProvider): - __slots__ = ('_value') + __slots__ = "_value" def __init__(self, value): self._value = value @@ -116,32 +105,28 @@ def value(self): class IdGetter: - def __init__(self): self.value = uuid.uuid4() def __repr__(self): - return f'' + return f"" def __str__(self): - return f'' + return f"" class A: - def __init__(self, id_getter: IdGetter): self.id_getter = id_getter class B: - def __init__(self, a: A, id_getter: IdGetter): self.a = a self.id_getter = id_getter class C: - def __init__(self, a: A, b: B, id_getter: IdGetter): self.a = a self.b = b @@ -153,69 +138,58 @@ class ICircle(ABC): class Circle(ICircle): - def __init__(self, circular: ICircle): # NB: this is not supported by DI self.circular = circular class Shape: - def __init__(self, circle: Circle): self.circle = circle class Foo: - def __init__(self): pass class UfoOne: - def __init__(self): pass class UfoTwo: - def __init__(self, one: UfoOne): self.one = one class UfoThree(UfoTwo): - def __init__(self, one: UfoOne, foo: Foo): super().__init__(one) self.foo = foo class UfoFour(UfoThree): - def __init__(self, one: UfoOne, foo: Foo): super().__init__(one, foo) class TypeWithOptional: - def __init__(self, foo: Optional[Foo]): self.foo = foo class SelfReferencingCircle: - - def __init__(self, circle: 'SelfReferencingCircle'): + def __init__(self, circle: "SelfReferencingCircle"): self.circular = circle class TrickyCircle: - def __init__(self, circle: ICircle): self.circular = circle class ResolveThisByParameterName: - def __init__(self, icats_repository): self.cats_repository = icats_repository @@ -225,7 +199,6 @@ class IByParamName: class FooByParamName(IByParamName): - def __init__(self, foo): self.foo = foo @@ -286,7 +259,6 @@ def __init__(self): class PrecedenceOfTypeHintsOverNames: - def __init__(self, foo: Q, ko: P): self.q = foo self.p = ko diff --git a/tests/test_services.py b/tests/test_services.py index d2dd53d..ff425f6 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -1,25 +1,68 @@ - from abc import ABC +from dataclasses import dataclass from typing import Type import pytest from pytest import raises - from rodi import ( - AliasAlreadyDefined, AliasConfigurationError, - CannotResolveParameterException, CircularDependencyException, - ClassNotDefiningInitMethod, Container, GetServiceContext, InstanceResolver, - InvalidFactory, InvalidOperationInStrictMode, MissingTypeException, - OverridingServiceException, ServiceLifeStyle, Services, - UnsupportedUnionTypeException, to_standard_param_name) + AliasAlreadyDefined, + AliasConfigurationError, + CannotResolveParameterException, + CircularDependencyException, + Container, + GetServiceContext, + InstanceResolver, + InvalidFactory, + InvalidOperationInStrictMode, + MissingTypeException, + OverridingServiceException, + ServiceLifeStyle, + Services, + UnsupportedUnionTypeException, + to_standard_param_name, +) + from tests.examples import ( - A, B, C, Cat, CatsController, Circle, Foo, FooByParamName, - FooDBCatsRepository, FooDBContext, GetCatRequestHandler, IByParamName, - ICatsRepository, ICircle, IdGetter, InMemoryCatsRepository, - IRequestContext, Jang, Jing, Ko, Ok, P, PrecedenceOfTypeHintsOverNames, Q, - R, RequestContext, ResolveThisByParameterName, ServiceSettings, Shape, - TrickyCircle, TypeWithOptional, UfoFour, UfoOne, UfoThree, UfoTwo, W, X, Y, - Z) + A, + B, + C, + Cat, + CatsController, + Circle, + Foo, + FooByParamName, + FooDBCatsRepository, + FooDBContext, + GetCatRequestHandler, + IByParamName, + ICatsRepository, + ICircle, + IdGetter, + InMemoryCatsRepository, + IRequestContext, + Jang, + Jing, + Ko, + Ok, + P, + PrecedenceOfTypeHintsOverNames, + Q, + R, + RequestContext, + ResolveThisByParameterName, + ServiceSettings, + Shape, + TrickyCircle, + TypeWithOptional, + UfoFour, + UfoOne, + UfoThree, + UfoTwo, + W, + X, + Y, + Z, +) def arrange_cats_example(): @@ -28,18 +71,21 @@ def arrange_cats_example(): container.add_scoped(IRequestContext, RequestContext) container.add_exact_transient(GetCatRequestHandler) container.add_exact_transient(CatsController) - container.add_instance(ServiceSettings('foodb:example;something;')) + container.add_instance(ServiceSettings("foodb:example;something;")) container.add_exact_transient(FooDBContext) return container -@pytest.mark.parametrize('value,expected_result', ( - ('CamelCase', 'camel_case'), - ('HTTPResponse', 'http_response'), - ('ICatsRepository', 'icats_repository'), - ('Cat', 'cat'), - ('UFO', 'ufo'), -)) +@pytest.mark.parametrize( + "value,expected_result", + ( + ("CamelCase", "camel_case"), + ("HTTPResponse", "http_response"), + ("ICatsRepository", "icats_repository"), + ("Cat", "cat"), + ("UFO", "ufo"), + ), +) def test_standard_param_name(value, expected_result): snaked = to_standard_param_name(value) assert snaked == expected_result @@ -47,7 +93,7 @@ def test_standard_param_name(value, expected_result): def test_singleton_by_instance(): container = Container() - container.add_instance(Cat('Celine')) + container.add_instance(Cat("Celine")) provider = container.build_provider() cat = provider.get(Cat) @@ -74,7 +120,7 @@ def test_transient_by_type_with_parameters(): container.add_transient(ICatsRepository, FooDBCatsRepository) # NB: - container.add_instance(ServiceSettings('foodb:example;something;')) + container.add_instance(ServiceSettings("foodb:example;something;")) container.add_exact_transient(FooDBContext) provider = container.build_provider() @@ -83,7 +129,7 @@ def test_transient_by_type_with_parameters(): assert isinstance(cats_repo, FooDBCatsRepository) assert isinstance(cats_repo.context, FooDBContext) assert isinstance(cats_repo.context.settings, ServiceSettings) - assert cats_repo.context.connection_string == 'foodb:example;something;' + assert cats_repo.context.connection_string == "foodb:example;something;" def test_add_transient_shortcut(): @@ -91,7 +137,7 @@ def test_add_transient_shortcut(): container.add_transient(ICatsRepository, FooDBCatsRepository) # NB: - container.add_instance(ServiceSettings('foodb:example;something;')) + container.add_instance(ServiceSettings("foodb:example;something;")) container.add_transient(FooDBContext) provider = container.build_provider() @@ -100,7 +146,7 @@ def test_add_transient_shortcut(): assert isinstance(cats_repo, FooDBCatsRepository) assert isinstance(cats_repo.context, FooDBContext) assert isinstance(cats_repo.context.settings, ServiceSettings) - assert cats_repo.context.connection_string == 'foodb:example;something;' + assert cats_repo.context.connection_string == "foodb:example;something;" def test_raises_for_overriding_service(): @@ -110,17 +156,17 @@ def test_raises_for_overriding_service(): with pytest.raises(OverridingServiceException) as context: container.add_singleton(ICircle, Circle) - assert 'ICircle' in str(context.value) + assert "ICircle" in str(context.value) with pytest.raises(OverridingServiceException) as context: container.add_transient(ICircle, Circle) - assert 'ICircle' in str(context.value) + assert "ICircle" in str(context.value) with pytest.raises(OverridingServiceException) as context: container.add_scoped(ICircle, Circle) - assert 'ICircle' in str(context.value) + assert "ICircle" in str(context.value) def test_raises_for_circular_dependency(): @@ -130,7 +176,7 @@ def test_raises_for_circular_dependency(): with pytest.raises(CircularDependencyException) as context: container.build_provider() - assert 'Circle' in str(context.value) + assert "Circle" in str(context.value) def test_raises_for_circular_dependency_with_dynamic_resolver(): @@ -225,7 +271,7 @@ def test_raises_for_optional_parameter(): with pytest.raises(UnsupportedUnionTypeException) as context: container.build_provider() - assert 'foo' in str(context.value) + assert "foo" in str(context.value) def test_raises_for_nested_circular_dependency(): @@ -236,7 +282,7 @@ def test_raises_for_nested_circular_dependency(): with pytest.raises(CircularDependencyException) as context: container.build_provider() - assert 'Circle' in str(context.value) + assert "Circle" in str(context.value) def test_interdependencies(): @@ -292,7 +338,6 @@ def __init__(self): pass class B2: - def __init__(self, c: C): self.c = c @@ -301,7 +346,6 @@ def __init__(self, c: C): self.c = c class A: - def __init__(self, b1: B1, b2: B2): self.b1 = b1 self.b2 = b2 @@ -335,7 +379,7 @@ def __init__(self, b1: B1, b2: B2): def test_scoped_services_context_used_more_than_once_manual_dispose(): container = Container() - container.add_instance('value') + container.add_instance("value") provider = container.build_provider() context = GetServiceContext(provider) @@ -449,14 +493,13 @@ def test_exact_alias(): container = arrange_cats_example() class UsingAlias: - def __init__(self, example): self.cats_controller = example container.add_exact_transient(UsingAlias) # arrange an exact alias for UsingAlias class init parameter: - container.set_alias('example', CatsController) + container.set_alias("example", CatsController) provider = container.build_provider() u = provider.get(UsingAlias) @@ -470,13 +513,11 @@ def test_additional_alias(): container = arrange_cats_example() class UsingAlias: - def __init__(self, example, settings): self.cats_controller = example self.settings = settings class AnotherUsingAlias: - def __init__(self, cats_controller, service_settings): self.cats_controller = cats_controller self.settings = service_settings @@ -485,8 +526,8 @@ def __init__(self, cats_controller, service_settings): container.add_exact_transient(AnotherUsingAlias) # arrange an exact alias for UsingAlias class init parameter: - container.add_alias('example', CatsController) - container.add_alias('settings', ServiceSettings) + container.add_alias("example", CatsController) + container.add_alias("settings", ServiceSettings) provider = container.build_provider() u = provider.get(UsingAlias) @@ -506,11 +547,11 @@ def __init__(self, cats_controller, service_settings): def test_get_service_by_name_or_alias(): container = arrange_cats_example() - container.add_alias('k', CatsController) + container.add_alias("k", CatsController) provider = container.build_provider() - for name in {'CatsController', 'cats_controller', 'k'}: + for name in {"CatsController", "cats_controller", "k"}: service = provider.get(name) assert isinstance(service, CatsController) @@ -521,20 +562,19 @@ def test_get_service_by_name_or_alias(): def test_missing_service_returns_none(): container = Container() provider = container.build_provider() - service = provider.get('not_existing') + service = provider.get("not_existing") assert service is None -@pytest.mark.parametrize('method_name', [ - 'add_singleton_by_factory', - 'add_transient_by_factory', - 'add_scoped_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], +) def test_by_factory_type_annotation(method_name): container = Container() def factory(_) -> Cat: - return Cat('Celine') + return Cat("Celine") method = getattr(container, method_name) @@ -545,18 +585,18 @@ def factory(_) -> Cat: cat = provider.get(Cat) assert cat is not None - assert cat.name == 'Celine' + assert cat.name == "Celine" - if method_name == 'add_singleton_by_factory': + if method_name == "add_singleton_by_factory": cat_2 = provider.get(Cat) assert cat_2 is cat - if method_name == 'add_transient_by_factory': + if method_name == "add_transient_by_factory": assert provider.get(Cat) is not cat assert provider.get(Cat) is not cat assert provider.get(Cat) is not cat - if method_name == 'add_scoped_by_factory': + if method_name == "add_scoped_by_factory": with GetServiceContext() as context: cat_2 = provider.get(Cat, context) assert cat_2 is not cat @@ -566,44 +606,42 @@ def factory(_) -> Cat: assert provider.get(Cat, context) is cat_2 -@pytest.mark.parametrize('method_name', [ - 'add_singleton_by_factory', - 'add_transient_by_factory', - 'add_scoped_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], +) def test_invalid_factory_too_many_arguments_throws(method_name): container = Container() method = getattr(container, method_name) def factory(context, activating_type, extra_argument_mistake): - return Cat('Celine') + return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) def factory(context, activating_type, extra_argument_mistake, two): - return Cat('Celine') + return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) def factory(context, activating_type, extra_argument_mistake, two, three): - return Cat('Celine') + return Cat("Celine") with raises(InvalidFactory): method(factory, Cat) -@pytest.mark.parametrize('method_name', [ - 'add_singleton_by_factory', - 'add_transient_by_factory', - 'add_scoped_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], +) def test_add_singleton_by_factory_given_type(method_name): container = Container() def factory(a): - return Cat('Celine') + return Cat("Celine") method = getattr(container, method_name) @@ -614,18 +652,18 @@ def factory(a): cat = provider.get(Cat) assert cat is not None - assert cat.name == 'Celine' + assert cat.name == "Celine" - if method_name == 'add_singleton_by_factory': + if method_name == "add_singleton_by_factory": cat_2 = provider.get(Cat) assert cat_2 is cat - if method_name == 'add_transient_by_factory': + if method_name == "add_transient_by_factory": assert provider.get(Cat) is not cat assert provider.get(Cat) is not cat assert provider.get(Cat) is not cat - if method_name == 'add_scoped_by_factory': + if method_name == "add_scoped_by_factory": with GetServiceContext() as context: cat_2 = provider.get(Cat, context) assert cat_2 is not cat @@ -635,16 +673,15 @@ def factory(a): assert provider.get(Cat, context) is cat_2 -@pytest.mark.parametrize('method_name', [ - 'add_singleton_by_factory', - 'add_transient_by_factory', - 'add_scoped_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_singleton_by_factory", "add_transient_by_factory", "add_scoped_by_factory"], +) def test_add_singleton_by_factory_raises_for_missing_type(method_name): container = Container() def factory(_): - return Cat('Celine') + return Cat("Celine") method = getattr(container, method_name) @@ -758,29 +795,36 @@ def test_proper_handling_of_inheritance(): def cat_factory_no_args() -> Cat: - return Cat('Celine') + return Cat("Celine") def cat_factory_with_context(context) -> Cat: - assert isinstance(context, Services) - return Cat('Celine') + assert isinstance(context, GetServiceContext) + return Cat("Celine") def cat_factory_with_context_and_activating_type(context, activating_type) -> Cat: - assert isinstance(context, Services) + assert isinstance(context, GetServiceContext) assert activating_type is Cat - return Cat('Celine') + return Cat("Celine") @pytest.mark.parametrize( - 'method_name,factory', - [(name, method) for name in - ['add_singleton_by_factory', - 'add_transient_by_factory', - 'add_scoped_by_factory'] for method in - [cat_factory_no_args, - cat_factory_with_context, - cat_factory_with_context_and_activating_type]]) + "method_name,factory", + [ + (name, method) + for name in [ + "add_singleton_by_factory", + "add_transient_by_factory", + "add_scoped_by_factory", + ] + for method in [ + cat_factory_no_args, + cat_factory_with_context, + cat_factory_with_context_and_activating_type, + ] + ], +) def test_by_factory_with_different_parameters(method_name, factory): container = Container() @@ -789,19 +833,16 @@ def test_by_factory_with_different_parameters(method_name, factory): provider = container.build_provider() - # ??? why not? cat = provider.get(Cat) assert cat is not None - assert cat.name == 'Celine' + assert cat.name == "Celine" -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory' -]) +@pytest.mark.parametrize( + "method_name", ["add_transient_by_factory", "add_scoped_by_factory"] +) def test_factory_can_receive_activating_type_as_parameter(method_name): - class Logger: def __init__(self, name): self.name = name @@ -823,15 +864,14 @@ def __init__(self, foo: Foo, logger: Logger): container.add_exact_transient(Foo) def factory(_, activating_type) -> Logger: - return Logger(activating_type.__module__ + '.' + activating_type.__name__) + return Logger(activating_type.__module__ + "." + activating_type.__name__) method = getattr(container, method_name) method(factory) - container\ - .add_exact_transient(HelpController)\ - .add_exact_transient(HomeController)\ - .add_exact_transient(FooController) + container.add_exact_transient(HelpController).add_exact_transient( + HomeController + ).add_exact_transient(FooController) provider = container.build_provider() @@ -839,19 +879,19 @@ def factory(_, activating_type) -> Logger: assert help_controller is not None assert help_controller.logger is not None - assert help_controller.logger.name == 'tests.test_services.HelpController' + assert help_controller.logger.name == "tests.test_services.HelpController" home_controller = provider.get(HomeController) assert home_controller is not None assert home_controller.logger is not None - assert home_controller.logger.name == 'tests.test_services.HomeController' + assert home_controller.logger.name == "tests.test_services.HomeController" foo_controller = provider.get(FooController) assert foo_controller is not None assert foo_controller.logger is not None - assert foo_controller.logger.name == 'tests.test_services.FooController' + assert foo_controller.logger.name == "tests.test_services.FooController" def test_factory_can_receive_activating_type_as_parameter_nested_resolution(): @@ -878,7 +918,7 @@ def __init__(self, logger: Logger, handler: HelpHandler): def factory(_, activating_type) -> Logger: # NB: this scenario is tested for rolog library - return Logger(activating_type.__module__ + '.' + activating_type.__name__) + return Logger(activating_type.__module__ + "." + activating_type.__name__) container.add_transient_by_factory(factory) @@ -891,8 +931,8 @@ def factory(_, activating_type) -> Logger: assert help_controller is not None assert help_controller.logger is not None - assert help_controller.logger.name == 'tests.test_services.HelpController' - assert help_controller.handler.repo.logger.name == 'tests.test_services.HelpRepo' + assert help_controller.logger.name == "tests.test_services.HelpController" + assert help_controller.handler.repo.logger.name == "tests.test_services.HelpRepo" def test_factory_can_receive_activating_type_as_parameter_nested_resolution_many(): @@ -920,10 +960,9 @@ def __init__(self, another_path_2: AnotherPathTwo): self.child = another_path_2 class HelpController: - def __init__(self, - handler: HelpHandler, - another_path: AnotherPath, - logger: Logger): + def __init__( + self, handler: HelpHandler, another_path: AnotherPath, logger: Logger + ): self.logger = logger self.handler = handler self.other = another_path @@ -932,19 +971,20 @@ def __init__(self, def factory(_, activating_type) -> Logger: # NB: this scenario is tested for rolog library - return Logger(activating_type.__module__ + '.' + activating_type.__name__) + return Logger(activating_type.__module__ + "." + activating_type.__name__) container.add_transient_by_factory(factory) - - container.add_instance(ServiceSettings('foo:foo')) - - for service_type in {HelpRepo, - HelpHandler, - HelpController, - AnotherPath, - AnotherPathTwo, - Foo, - FooDBContext}: + container.add_instance(ServiceSettings("foo:foo")) + + for service_type in { + HelpRepo, + HelpHandler, + HelpController, + AnotherPath, + AnotherPathTwo, + Foo, + FooDBContext, + }: container.add_exact_transient(service_type) provider = container.build_provider() @@ -953,16 +993,18 @@ def factory(_, activating_type) -> Logger: assert help_controller is not None assert help_controller.logger is not None - assert help_controller.logger.name == 'tests.test_services.HelpController' - assert help_controller.handler.repo.logger.name == 'tests.test_services.HelpRepo' - assert help_controller.other.child.logger.name == 'tests.test_services.' \ - 'AnotherPathTwo' + assert help_controller.logger.name == "tests.test_services.HelpController" + assert help_controller.handler.repo.logger.name == "tests.test_services.HelpRepo" + assert ( + help_controller.other.child.logger.name == "tests.test_services." + "AnotherPathTwo" + ) def test_service_provider_supports_set_by_class(): provider = Services() - singleton_cat = Cat('Celine') + singleton_cat = Cat("Celine") provider.set(Cat, singleton_cat) @@ -971,7 +1013,7 @@ def test_service_provider_supports_set_by_class(): assert cat is not None assert cat.name == "Celine" - cat = provider.get('Cat') + cat = provider.get("Cat") assert cat is not None assert cat.name == "Celine" @@ -980,11 +1022,11 @@ def test_service_provider_supports_set_by_class(): def test_service_provider_supports_set_by_name(): provider = Services() - singleton_cat = Cat('Celine') + singleton_cat = Cat("Celine") - provider.set('my_cat', singleton_cat) + provider.set("my_cat", singleton_cat) - cat = provider.get('my_cat') + cat = provider.get("my_cat") assert cat is not None assert cat.name == "Celine" @@ -993,7 +1035,7 @@ def test_service_provider_supports_set_by_name(): def test_service_provider_supports_set_and_get_item_by_class(): provider = Services() - singleton_cat = Cat('Celine') + singleton_cat = Cat("Celine") provider[Cat] = singleton_cat @@ -1002,7 +1044,7 @@ def test_service_provider_supports_set_and_get_item_by_class(): assert cat is not None assert cat.name == "Celine" - cat = provider['Cat'] + cat = provider["Cat"] assert cat is not None assert cat.name == "Celine" @@ -1011,11 +1053,11 @@ def test_service_provider_supports_set_and_get_item_by_class(): def test_service_provider_supports_set_and_get_item_by_name(): provider = Services() - singleton_cat = Cat('Celine') + singleton_cat = Cat("Celine") - provider['my_cat'] = singleton_cat + provider["my_cat"] = singleton_cat - cat = provider['my_cat'] + cat = provider["my_cat"] assert cat is not None assert cat.name == "Celine" @@ -1024,54 +1066,55 @@ def test_service_provider_supports_set_and_get_item_by_name(): def test_service_provider_supports_set_simple_values(): provider = Services() - provider['one'] = 10 - provider['two'] = 12 - provider['three'] = 16 + provider["one"] = 10 + provider["two"] = 12 + provider["three"] = 16 - assert provider['one'] == 10 - assert provider['two'] == 12 - assert provider['three'] == 16 + assert provider["one"] == 10 + assert provider["two"] == 12 + assert provider["three"] == 16 -def test_container_raises_for_class_without_init(): +def test_container_handles_class_without_init(): container = Container() class WithoutInit: pass container.add_exact_singleton(WithoutInit) + provider = container.build_provider() - with raises(ClassNotDefiningInitMethod): - container.build_provider() + instance = provider.get(WithoutInit) + assert isinstance(instance, WithoutInit) def test_raises_invalid_factory_for_non_callable(): container = Container() with raises(InvalidFactory): - container.register_factory('Not a factory', Cat, ServiceLifeStyle.SINGLETON) + container.register_factory("Not a factory", Cat, ServiceLifeStyle.SINGLETON) def test_set_alias_raises_in_strict_mode(): container = Container(strict=True) with raises(InvalidOperationInStrictMode): - container.set_alias('something', Cat) + container.set_alias("something", Cat) def test_set_alias_raises_if_alias_is_defined(): container = Container() - container.set_alias('something', Cat) + container.set_alias("something", Cat) with raises(AliasAlreadyDefined): - container.set_alias('something', Foo) + container.set_alias("something", Foo) def test_set_alias_requires_configured_type(): container = Container() - container.set_alias('something', Cat) + container.set_alias("something", Cat) with raises(AliasConfigurationError): container.build_provider() @@ -1080,72 +1123,63 @@ def test_set_alias_requires_configured_type(): def test_set_aliases(): container = Container() - container.add_instance(Cat('Celine')) + container.add_instance(Cat("Celine")) container.add_instance(Foo()) - container.set_aliases({ - 'a': Cat, - 'b': Foo - }) + container.set_aliases({"a": Cat, "b": Foo}) provider = container.build_provider() - x = provider.get('a') + x = provider.get("a") assert isinstance(x, Cat) - assert x.name == 'Celine' + assert x.name == "Celine" - assert isinstance(provider.get('b'), Foo) + assert isinstance(provider.get("b"), Foo) def test_add_alias_raises_in_strict_mode(): container = Container(strict=True) with raises(InvalidOperationInStrictMode): - container.add_alias('something', Cat) + container.add_alias("something", Cat) def test_add_alias_raises_if_alias_is_defined(): container = Container() - container.add_alias('something', Cat) + container.add_alias("something", Cat) with raises(AliasAlreadyDefined): - container.add_alias('something', Foo) + container.add_alias("something", Foo) def test_add_aliases(): container = Container() - container.add_instance(Cat('Celine')) + container.add_instance(Cat("Celine")) container.add_instance(Foo()) - container.add_aliases({ - 'a': Cat, - 'b': Foo - }) + container.add_aliases({"a": Cat, "b": Foo}) - container.add_aliases({ - 'c': Cat, - 'd': Foo - }) + container.add_aliases({"c": Cat, "d": Foo}) provider = container.build_provider() - for alias in {'a', 'c'}: + for alias in {"a", "c"}: x = provider.get(alias) assert isinstance(x, Cat) - assert x.name == 'Celine' + assert x.name == "Celine" - for alias in {'b', 'd'}: + for alias in {"b", "d"}: assert isinstance(provider.get(alias), Foo) def test_add_alias_requires_configured_type(): container = Container() - container.add_alias('something', Cat) + container.add_alias("something", Cat) with raises(AliasConfigurationError): container.build_provider() @@ -1198,7 +1232,7 @@ def __init__(self): container.add_exact_transient(C) with raises(AliasAlreadyDefined): - container.add_alias('b', C) # <-- ambiguity + container.add_alias("b", C) # <-- ambiguity def test_cannot_resolve_parameter_in_strict_mode_throws(): @@ -1222,10 +1256,10 @@ def __init__(self, c): def test_services_set_throws_if_service_is_already_defined(): services = Services() - services.set('example', {}) + services.set("example", {}) with raises(OverridingServiceException): - services.set('example', []) + services.set("example", []) def test_scoped_services_exact(): @@ -1308,15 +1342,14 @@ def test_instance_resolver_representation(): resolver = InstanceResolver(singleton) representation = repr(resolver) - assert representation.startswith(' BBase: - assert isinstance(context, Services) + def bbase_factory(context: GetServiceContext, activating_type: Type) -> BBase: + assert isinstance(context, GetServiceContext) assert activating_type is A return B() @@ -1352,11 +1385,10 @@ def bbase_factory(context: Services, activating_type: Type) -> BBase: assert isinstance(a.b, B) -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory', - 'add_singleton_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], +) def test_factories_activating_scoped_type_consistency(method_name): container = Container() @@ -1374,8 +1406,8 @@ class B(BBase): def __init__(self): pass - def bbase_factory(context: Services, activating_type: Type) -> BBase: - assert isinstance(context, Services) + def bbase_factory(context: GetServiceContext, activating_type: Type) -> BBase: + assert isinstance(context, GetServiceContext) assert activating_type is A return B() @@ -1392,11 +1424,10 @@ def bbase_factory(context: Services, activating_type: Type) -> BBase: assert isinstance(a.b, B) -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory', - 'add_singleton_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], +) def test_factories_activating_singleton_type_consistency(method_name): container = Container() @@ -1414,8 +1445,8 @@ class B(BBase): def __init__(self): pass - def bbase_factory(context: Services, activating_type: Type) -> BBase: - assert isinstance(context, Services) + def bbase_factory(context: GetServiceContext, activating_type: Type) -> BBase: + assert isinstance(context, GetServiceContext) assert activating_type is A return B() @@ -1432,11 +1463,10 @@ def bbase_factory(context: Services, activating_type: Type) -> BBase: assert isinstance(a.b, B) -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory', - 'add_singleton_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], +) def test_factories_type_transient_consistency_nested(method_name): container = Container() @@ -1461,8 +1491,8 @@ class C(CBase): def __init__(self): pass - def cbase_factory(context: Services, activating_type: Type) -> CBase: - assert isinstance(context, Services) + def cbase_factory(context: GetServiceContext, activating_type: Type) -> CBase: + assert isinstance(context, GetServiceContext) assert activating_type is B return C() @@ -1481,11 +1511,10 @@ def cbase_factory(context: Services, activating_type: Type) -> CBase: assert isinstance(a.b.c, C) -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory', - 'add_singleton_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], +) def test_factories_type_scoped_consistency_nested(method_name): container = Container() @@ -1510,8 +1539,8 @@ class C(CBase): def __init__(self): pass - def cbase_factory(context: Services, activating_type: Type) -> CBase: - assert isinstance(context, Services) + def cbase_factory(context: GetServiceContext, activating_type: Type) -> CBase: + assert isinstance(context, GetServiceContext) assert activating_type is B return C() @@ -1530,11 +1559,10 @@ def cbase_factory(context: Services, activating_type: Type) -> CBase: assert isinstance(a.b.c, C) -@pytest.mark.parametrize('method_name', [ - 'add_transient_by_factory', - 'add_scoped_by_factory', - 'add_singleton_by_factory' -]) +@pytest.mark.parametrize( + "method_name", + ["add_transient_by_factory", "add_scoped_by_factory", "add_singleton_by_factory"], +) def test_factories_type_singleton_consistency_nested(method_name): container = Container() @@ -1559,8 +1587,8 @@ class C(CBase): def __init__(self): pass - def cbase_factory(context: Services, activating_type: Type) -> CBase: - assert isinstance(context, Services) + def cbase_factory(context: GetServiceContext, activating_type: Type) -> CBase: + assert isinstance(context, GetServiceContext) assert activating_type is B return C() @@ -1577,3 +1605,245 @@ def cbase_factory(context: Services, activating_type: Type) -> CBase: assert isinstance(a, A) assert isinstance(a.b, B) assert isinstance(a.b.c, C) + + +def test_annotation_resolution(): + class B: + pass + + class A: + dep: B + + container = Container() + + b_singleton = B() + container.add_instance(b_singleton) + container.add_exact_scoped(A) + + provider = container.build_provider() + + instance = provider.get(A) + + assert isinstance(instance, A) + assert instance.dep is not None + assert instance.dep is b_singleton + + +def test_annotation_resolution_scoped(): + class B: + pass + + class A: + dep: B + + container = Container() + + b_singleton = B() + container.add_instance(b_singleton) + container.add_exact_scoped(A) + + provider = container.build_provider() + + with GetServiceContext() as context: + instance = provider.get(A, context) + + assert isinstance(instance, A) + assert instance.dep is not None + assert instance.dep is b_singleton + + second = provider.get(A, context) + assert instance is second + + third = provider.get(A) + assert third is not instance + + +def test_annotation_nested_resolution_1(): + class D: + pass + + class C: + pass + + class B: + dep_1: C + dep_2: D + + class A: + dep: B + + container = Container() + + container.add_instance(C()) + container.add_instance(D()) + container.add_exact_transient(B) + container.add_exact_scoped(A) + + provider = container.build_provider() + + with GetServiceContext(provider) as context: + instance = provider.get(A, context) + + assert isinstance(instance, A) + assert isinstance(instance.dep, B) + assert isinstance(instance.dep.dep_1, C) + assert isinstance(instance.dep.dep_2, D) + + second = provider.get(A, context) + assert instance is second + + third = provider.get(A) + assert third is not instance + + +def test_annotation_nested_resolution_2(): + class E: + pass + + class D: + dep: E + + class C: + dep: E + + class B: + dep_1: C + dep_2: D + + class A: + dep: B + + container = Container() + + container.add_scoped_by_factory(E, E) + container.add_exact_scoped(C) + container.add_exact_scoped(D) + container.add_exact_transient(B) + container.add_exact_scoped(A) + + provider = container.build_provider() + + with GetServiceContext(provider) as context: + instance = provider.get(A, context) + + assert isinstance(instance, A) + assert isinstance(instance.dep, B) + assert isinstance(instance.dep.dep_1, C) + assert isinstance(instance.dep.dep_2, D) + assert isinstance(instance.dep.dep_1.dep, E) + assert isinstance(instance.dep.dep_2.dep, E) + assert instance.dep.dep_1.dep is instance.dep.dep_2.dep + + second = provider.get(A, context) + assert instance is second + assert instance.dep.dep_1.dep is second.dep.dep_1.dep + + third = provider.get(A) + assert third is not instance + + +def test_annotation_resolution_singleton(): + class B: + pass + + class A: + dep: B + + container = Container() + + b_singleton = B() + container.add_instance(b_singleton) + container.add_exact_singleton(A) + + provider = container.build_provider() + + instance = provider.get(A) + + assert isinstance(instance, A) + assert instance.dep is not None + assert instance.dep is b_singleton + + second = provider.get(A) + assert instance is second + + +def test_annotation_resolution_transient(): + class B: + pass + + class A: + dep: B + + container = Container() + + b_singleton = B() + container.add_instance(b_singleton) + container.add_exact_transient(A) + + provider = container.build_provider() + + with GetServiceContext() as context: + instance = provider.get(A, context) + + assert isinstance(instance, A) + assert instance.dep is not None + assert instance.dep is b_singleton + + second = provider.get(A, context) + assert instance is not second + + assert isinstance(second, A) + assert second.dep is not None + assert second.dep is b_singleton + + +def test_annotations_abstract_type_transient_service(): + class FooCatsRepository(ICatsRepository): + def get_by_id(self, _id) -> Cat: + return Cat("foo") + + class GetCatRequestHandler: + cats_repository: ICatsRepository + + def get_cat(self, _id): + cat = self.cats_repository.get_by_id(_id) + return cat + + container = Container() + container.add_transient(ICatsRepository, FooCatsRepository) + container.add_exact_transient(GetCatRequestHandler) + provider = container.build_provider() + + cats_repo = provider.get(ICatsRepository) + assert isinstance(cats_repo, FooCatsRepository) + + other_cats_repo = provider.get(ICatsRepository) + assert cats_repo is not other_cats_repo + + get_cat_handler = provider.get(GetCatRequestHandler) + assert isinstance(get_cat_handler, GetCatRequestHandler) + assert isinstance(get_cat_handler.cats_repository, FooCatsRepository) + + +def test_support_for_dataclasses(): + @dataclass + class Settings: + region: str + + @dataclass + class GetInfoHandler: + service_settings: Settings + + def handle_request(self): + return {"service_region": self.service_settings.region} + + container = Container() + container.add_instance(Settings(region="Western Europe")) + container.add_exact_scoped(GetInfoHandler) + + provider = container.build_provider() + + info_handler = provider.get(GetInfoHandler) + + assert isinstance(info_handler, GetInfoHandler) + assert isinstance(info_handler.service_settings, Settings)