Skip to content

Commit

Permalink
Merge branch 'container-validators' into 'main'
Browse files Browse the repository at this point in the history
Add container validators

See merge request ThirVondukr/aioinject!7
  • Loading branch information
ThirVondukr committed Jan 9, 2024
2 parents 828832d + 8afa7b4 commit ca74b81
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dist/

# Coverage
.coverage
coverage.xml

# Cache
.pytest_cache
Expand Down
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ tasks:
cmds:
- task: format
- task: typecheck
- task: test
- task: testcov
5 changes: 5 additions & 0 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Protocol,
TypeAlias,
TypeVar,
runtime_checkable,
)

from aioinject.markers import Inject
Expand Down Expand Up @@ -154,6 +155,7 @@ def _guess_impl(factory: _FactoryType[_T]) -> type[_T]:
return return_type


@runtime_checkable
class Provider(Protocol[_T]):
type_: type[_T]
impl: Any
Expand All @@ -176,6 +178,9 @@ def is_async(self) -> bool:
def dependencies(self) -> tuple[Dependency[object], ...]:
return tuple(collect_dependencies(self.type_hints))

def __repr__(self) -> str:
return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})"


class Callable(Provider[_T]):
def __init__(
Expand Down
27 changes: 27 additions & 0 deletions aioinject/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from collections.abc import Sequence

from aioinject import Callable, Singleton
from aioinject.validation._builtin import (
ForbidDependency,
all_dependencies_are_present,
)
from aioinject.validation._validate import validate_container
from aioinject.validation.abc import ContainerValidator


DEFAULT_VALIDATORS: Sequence[ContainerValidator] = [
all_dependencies_are_present,
ForbidDependency(
dependant=lambda p: isinstance(p, Singleton),
dependency=lambda p: isinstance(p, Callable)
and not isinstance(p, Singleton),
),
]

__all__ = [
"DEFAULT_VALIDATORS",
"ContainerValidator",
"all_dependencies_are_present",
"ForbidDependency",
"validate_container",
]
57 changes: 57 additions & 0 deletions aioinject/validation/_builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from collections.abc import Callable, Sequence
from typing import Any

import aioinject
from aioinject import Provider
from aioinject.validation.abc import ContainerValidator
from aioinject.validation.error import (
ContainerValidationError,
DependencyNotFoundError,
)


def all_dependencies_are_present(
container: aioinject.Container,
) -> Sequence[ContainerValidationError]:
errors = []
for providers in container.providers.values():
for provider in providers:
for dependency in provider.dependencies:
if dependency.type_ not in container.providers:
error = DependencyNotFoundError(
message=f"Provider for type {dependency.type_} not found",
dependency=dependency.type_,
)
errors.append(error)

return errors


class ForbidDependency(ContainerValidator):
def __init__(
self,
dependant: Callable[[Provider[Any]], bool],
dependency: Callable[[Provider[Any]], bool],
) -> None:
self.dependant = dependant
self.dependency = dependency

def __call__(
self,
container: aioinject.Container,
) -> Sequence[ContainerValidationError]:
errors = []
for providers in container.providers.values():
for provider in providers:
if not self.dependant(provider):
continue

for dependency in provider.dependencies:
dependency_provider = container.get_provider(
type_=dependency.type_,
impl=dependency.implementation,
)
if self.dependency(dependency_provider):
msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}"
errors.append(ContainerValidationError(msg))
return errors
20 changes: 20 additions & 0 deletions aioinject/validation/_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from collections.abc import Iterable

import aioinject
from aioinject.validation.abc import ContainerValidator
from aioinject.validation.error import (
ContainerValidationError,
ContainerValidationErrorGroup,
)


def validate_container(
container: aioinject.Container,
validators: Iterable[ContainerValidator],
) -> None:
errors: list[ContainerValidationError] = []
for validator in validators:
errors.extend(validator.__call__(container))

if errors:
raise ContainerValidationErrorGroup(errors=errors)
13 changes: 13 additions & 0 deletions aioinject/validation/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from collections.abc import Sequence
from typing import Protocol

import aioinject
from aioinject.validation.error import ContainerValidationError


class ContainerValidator(Protocol):
def __call__(
self,
container: aioinject.Container,
) -> Sequence[ContainerValidationError]:
...
22 changes: 22 additions & 0 deletions aioinject/validation/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import dataclasses
from collections.abc import Sequence
from typing import Any


@dataclasses.dataclass(slots=True, frozen=True)
class ContainerValidationError:
message: str


@dataclasses.dataclass(slots=True, frozen=True)
class DependencyNotFoundError(ContainerValidationError):
dependency: type[Any]


@dataclasses.dataclass
class ContainerValidationErrorGroup(Exception): # noqa: N818
# Exception group could be used instead, but it's added in python 3.11
errors: Sequence[ContainerValidationError]

def __str__(self) -> str:
return repr(self) # pragma: no cover
2 changes: 1 addition & 1 deletion tests/ext/strawberry/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def hello_world_sync(
return f"{argument}-{provided_value}"

@strawberry.field
async def dataloader(self, info: Info[Context, None]) -> Sequence[int]:
async def dataloader(self, info: Info[Any, None]) -> Sequence[int]:
return await info.context.numbers.load_many(list(range(100)))


Expand Down
Empty file added tests/validation/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/validation/test_dependency_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from collections.abc import Sequence
from typing import Any

import pytest

from aioinject import Callable, Container, Provider
from aioinject.validation import (
all_dependencies_are_present,
validate_container,
)
from aioinject.validation.error import (
ContainerValidationErrorGroup,
DependencyNotFoundError,
)


_VALIDATORS = [all_dependencies_are_present]


def _str_dependency(number: int) -> str:
return str(number)


@pytest.mark.parametrize(
"providers",
[
[Callable(int)],
[Callable(_str_dependency), Callable(int)],
],
)
def test_ok(providers: Sequence[Provider[Any]]) -> None:
container = Container()
for provider in providers:
container.register(provider)

validate_container(container, _VALIDATORS)


def test_err() -> None:
container = Container()
container.register(Callable(_str_dependency))

with pytest.raises(ContainerValidationErrorGroup) as exc_info:
validate_container(container, _VALIDATORS)

assert len(exc_info.value.errors) == 1
err = exc_info.value.errors[0]
assert isinstance(err, DependencyNotFoundError)
assert err.dependency == int
55 changes: 55 additions & 0 deletions tests/validation/test_provider_dependency_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections.abc import Sequence
from typing import Any

import pytest

from aioinject import Callable, Container, Provider, Singleton
from aioinject.validation import ForbidDependency, validate_container
from aioinject.validation.error import (
ContainerValidationErrorGroup,
)


_VALIDATORS = [
ForbidDependency(
dependant=lambda p: isinstance(p, Singleton),
dependency=lambda p: isinstance(p, Callable)
and not isinstance(p, Singleton),
),
]


def _str_dependency(number: int) -> str:
return str(number)


@pytest.mark.parametrize(
"providers",
[
[Callable(int)],
[Callable(_str_dependency), Callable(int)],
[Callable(_str_dependency), Singleton(int)],
[Singleton(_str_dependency), Singleton(int)],
],
)
def test_ok(providers: Sequence[Provider[Any]]) -> None:
container = Container()
for provider in providers:
container.register(provider)

validate_container(container, _VALIDATORS)


@pytest.mark.parametrize(
"providers",
[
[Singleton(_str_dependency), Callable(int)],
],
)
def test_err(providers: Sequence[Provider[Any]]) -> None:
container = Container()
for provider in providers:
container.register(provider)

with pytest.raises(ContainerValidationErrorGroup):
validate_container(container, _VALIDATORS)

0 comments on commit ca74b81

Please sign in to comment.