Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DependenciesHealthcheck dependency #17

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @RB387 @jerry-git
* @RB387 @jerry-git @erhosen
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Added `magic_di.healthcheck.DependenciesHealthcheck` class to make health checks of injected dependencies that implement `magic_di.healthcheck.PingableProtocol` interface
### Fixed
- Inject dependencies inside of event loop in `magic_di.utils.inject_and_run` to prevent wrong event loop usage inside of the injected dependencies
### Changed
- Cruft update to get changes from the cookiecutter template
- Renamed LICENCE -> LICENSE, now it's automatically included in the wheel created by poetry
Expand Down
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Dependency Injector with minimal boilerplate code, built-in support for FastAPI
* [Custom integrations](#custom-integrations)
* [Manual injection](#manual-injection)
* [Forced injections](#forced-injections)
* [Healthcheck](#healthcheck)
* [Testing](#testing)
* [Default simple mock](#default-simple-mock)
* [Custom mocks](#custom-mocks)
Expand Down Expand Up @@ -330,6 +331,38 @@ class Service(Connectable):
dependency: Annotated[NonConnectableDependency, Injectable]
```

## Healthchecks
You can implement `Pingable` protocol to define healthchecks for your clients. The `DependenciesHealthcheck` will call the `__ping__` method on all injected clients that implement this protocol.

```python
from magic_di.healthcheck import DependenciesHealthcheck


class Service(Connectable):
def __init__(self, db: Database):
self.db = db

def is_connected(self):
return self.db.connected

async def __ping__(self) -> None:
if not self.is_connected():
raise Exception("Service is not connected")


@app.get(path="/hello-world")
def hello_world(service: Provide[Service]) -> dict:
return {
"is_connected": service.is_connected()
}


@app.get(path="/healthcheck")
async def healthcheck_handler(healthcheck: Provide[DependenciesHealthcheck]) -> dict:
await healthcheck.ping_dependencies()
return {"alive": True}
```

## Testing
If you need to mock a dependency in tests, you can easily do so by using the `injector.override` context manager and still use this dependency injector.

Expand Down
9 changes: 4 additions & 5 deletions src/magic_di/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import inspect
from dataclasses import dataclass
from threading import Lock
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, TypeVar

if TYPE_CHECKING:
from collections.abc import Iterable

from magic_di import ConnectableProtocol

T = TypeVar("T")

Expand Down Expand Up @@ -53,7 +52,7 @@ def iter_instances(
self,
*,
reverse: bool = False,
) -> Iterable[tuple[type, ConnectableProtocol]]:
) -> Iterable[tuple[type, object]]:
with self._lock:
deps_iter: Iterable[Dependency[Any]] = list(
reversed(self._deps.values()) if reverse else self._deps.values(),
Expand All @@ -70,8 +69,8 @@ def _get(self, obj: type[T]) -> type[T] | None:

def _wrap(obj: type[T], *args: Any, **kwargs: Any) -> type[T]:
if not inspect.isclass(obj):
partial = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[var-annotated]
return cast(type[T], partial)
partial: type[T] = functools.wraps(obj)(functools.partial(obj, *args, **kwargs)) # type: ignore[assignment]
return partial

_instance: T | None = None

Expand Down
33 changes: 22 additions & 11 deletions src/magic_di/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from magic_di._utils import (
get_cls_from_optional,
get_type_hints,
is_injectable,
is_connectable,
safe_is_instance,
safe_is_subclass,
)
from magic_di.exceptions import InjectionError, InspectionError

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator

from magic_di._connectable import ConnectableProtocol

# flag to use in typing.Annotated
# to forcefully mark dependency as injectable
Expand Down Expand Up @@ -111,22 +111,22 @@ def inspect(self, obj: AnyObject) -> Signature[AnyObject]:
hints_with_extras = get_type_hints(obj, include_extras=True)

if not hints:
return Signature(obj, is_injectable=is_injectable(obj))
return Signature(obj, is_injectable=bool(is_connectable(obj)))

if inspect.ismethod(obj):
hints.pop("self", None)

hints.pop("return", None)

signature = Signature(obj, is_injectable=is_injectable(obj))
signature = Signature(obj, is_injectable=bool(is_connectable(obj)))

for name, hint_ in hints.items():
hint = self._unwrap_type_hint(hint_)
hint_with_extra = hints_with_extras[name]

if is_injector(hint):
signature.injector_arg = name
elif not is_injectable(hint) and not is_forcefully_marked_as_injectable(
elif not is_connectable(hint) and not is_forcefully_marked_as_injectable(
hint_with_extra,
):
signature.kwargs[name] = hint
Expand All @@ -149,30 +149,41 @@ async def connect(self) -> None:
self.inject(postponed)

for cls, instance in self._deps.iter_instances():
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
self.logger.debug("Connecting %s...", cls.__name__)
await instance.__connect__()
await connectable_instance.__connect__()

async def disconnect(self) -> None:
"""
Disconnect all injected dependencies
"""
for cls, instance in self._deps.iter_instances(reverse=True):
if is_injectable(instance):
if connectable_instance := is_connectable(instance):
try:
await instance.__disconnect__()
await connectable_instance.__disconnect__()
except Exception:
self.logger.exception("Failed to disconnect %s", cls.__name__)

def get_dependencies_by_interface(
self,
interface: Callable[..., AnyObject],
) -> Iterable[AnyObject]:
"""
Get all injected dependencies that implement a particular interface.
"""
for _, instance in self._deps.iter_instances():
if safe_is_instance(instance, interface): # type: ignore[arg-type]
yield instance # type: ignore[misc]

async def __aenter__(self) -> DependencyInjector: # noqa: PYI034
await self.connect()
return self

async def __aexit__(self, *args: object, **kwargs: Any) -> None:
await self.disconnect()

def iter_deps(self) -> Iterable[ConnectableProtocol]:
instance: ConnectableProtocol
def iter_deps(self) -> Iterable[object]:
instance: object

for _, instance in self._deps.iter_instances():
yield instance
Expand Down
13 changes: 10 additions & 3 deletions src/magic_di/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_cls_from_optional(cls: T) -> T:
Extract the actual class from a union that includes None.
If it is not a union type hint, it returns the same type hint.
Example:
```python
>>> get_cls_from_optional(Union[str, None])
str
>>> get_cls_from_optional(str | None)
Expand All @@ -35,6 +36,7 @@ def get_cls_from_optional(cls: T) -> T:
str
>>> get_cls_from_optional(int)
int
```
Args:
cls (T): Type hint for class
Returns:
Expand Down Expand Up @@ -73,17 +75,22 @@ def safe_is_instance(sub_cls: Any, cls: type) -> bool:
return False


def is_injectable(cls: Any) -> bool:
def is_connectable(cls: Any) -> ConnectableProtocol | None:
"""
Check if a class is a subclass of ConnectableProtocol.

Args:
cls (Any): The class to check.

Returns:
bool: True if the class is a subclass of ConnectableProtocol, False otherwise.
ConnectableProtocol | None: return instance if the class
is a subclass of ConnectableProtocol, None otherwise.
"""
return safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(cls, ConnectableProtocol)
connectable = safe_is_subclass(cls, ConnectableProtocol) or safe_is_instance(
cls,
ConnectableProtocol,
)
return cls if connectable else None


def get_type_hints(obj: Any, *, include_extras: bool = False) -> dict[str, type]:
Expand Down
73 changes: 73 additions & 0 deletions src/magic_di/healthcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
from asyncio import Future
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Protocol

from magic_di import Connectable, ConnectableProtocol, DependencyInjector


class PingableProtocol(ConnectableProtocol, Protocol):
async def __ping__(self) -> None: ...


@dataclass
class DependenciesHealthcheck(Connectable):
"""
Injectable Healthcheck component that pings all injected dependencies
that implement the PingableProtocol

Example usage:

```python
from app.components.services.health import DependenciesHealthcheck

async def main(redis: Redis, deps_healthcheck: DependenciesHealthcheck) -> None:
await deps_healthcheck.ping_dependencies() # redis will be pinged if it has method __ping__

inject_and_run(main)
```
"""

injector: DependencyInjector

async def ping_dependencies(self, max_concurrency: int = 1) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be beneficial if this would return information about which ones failed (and which one succeeded)? Then it would be possible to include that information in the response of the healthcheck endpoint. At minimum, I think there could be some logging inside this method about failures.

Copy link
Collaborator

@erhosen erhosen Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently one of our services uses this approach: each dependency (such as redis, mongo or kafka-consumer) implements a ping method, that might raise an exception. Each of those pings has they own set of exceptions (e.g., ConnectionFailure, NetworkTimeout, OperationFailure for MongoDB), and we’ve intentionally decided not to catch them.

This approach ensures that in case of fire, we get a complete stack trace in the logs, making it easier to debug what happened.

And it also leads to 500 error status code for a typical FastAPI/Starlette server. Since these health checks are primarily designed for Kubernetes - correct status code is basically the only thing that matters

What do you think about it?

Copy link
Collaborator Author

@RB387 RB387 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be beneficial if this would return information about which ones failed (and which one succeeded)?

Hmm. Do we have any case where we need to know which dependencies are healthy or not?
Usually we only want to know if everything is healthy.

At minimum, I think there could be some logging inside this method about failures.

That’s a good point. I added debug logging, but I decided not to log failures since they are raised to user and user can catch them and log if needed.

But, what I’m thinking about is, do we need the ignore_exceptions argument? To ignore some concrete exceptions that are raised by some dependencies

Copy link
Collaborator

@jerry-git jerry-git Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I guess it makes sense, at least in K8s (probes) use cases. I was just thinking that peeps in the wider community might have also non-K8s use cases in which it might be beneficial to know which ones failed. Anyway, I guess this will be sufficient for majority.

do we need the ignore_exceptions argument

At least I can't think of a real world use case in which one would like to check the healthiness of all dependencies but ignore some 🤔

"""
Ping all dependencies that implement the PingableProtocol

:param max_concurrency: Maximum number of concurrent pings
"""
tasks: set[Future[Any]] = set()

try:
for dependency in self.injector.get_dependencies_by_interface(PingableProtocol):
future = asyncio.ensure_future(self.ping(dependency))
tasks.add(future)

if len(tasks) >= max_concurrency:
tasks, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

if tasks:
await asyncio.gather(*tasks)
tasks = set()

finally:
for task in tasks:
task.cancel()

if tasks:
with suppress(asyncio.CancelledError):
await asyncio.gather(*tasks)

async def ping(self, dependency: PingableProtocol) -> None:
"""
Ping a single dependency

:param dependency: Dependency to ping
"""
dependency_name = dependency.__class__.__name__
self.injector.logger.debug("Pinging dependency %s...", dependency_name)

await dependency.__ping__()

self.injector.logger.debug("Dependency %s is healthy", dependency_name)
20 changes: 11 additions & 9 deletions src/magic_di/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@ class InjectableMock(AsyncMock):
and use AsyncMock instead of a real class instance

Example:
@pytest.fixture()
def client():
injector = DependencyInjector()
```python
@pytest.fixture()
def client():
injector = DependencyInjector()

with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client
with injector.override({Service: InjectableMock().mock_cls}):
with TestClient(app) as client:
yield client

def test_http_handler(client):
resp = client.post('/hello-world')
def test_http_handler(client):
resp = client.post('/hello-world')

assert resp.status_code == 200
assert resp.status_code == 200
```
"""

@property
Expand Down
4 changes: 2 additions & 2 deletions src/magic_di/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def inject_and_run(
"""
injector = injector or DependencyInjector()

injected = injector.inject(fn)

async def run() -> T:
injected = injector.inject(fn)

async with injector:
if inspect.iscoroutinefunction(fn):
return await injected() # type: ignore[misc,no-any-return]
Expand Down
Loading
Loading