Skip to content

Commit

Permalink
test(fastapi): exception propagation into dependencies with yield
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirVondukr committed Nov 22, 2024
1 parent 5ee7e31 commit 886755a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 24 deletions.
15 changes: 8 additions & 7 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def override(self, *providers: Provider[Any]) -> Iterator[None]:
previous[provider.type_] = self.providers.get(provider.type_)
self.providers[provider.type_] = provider

yield

for provider in providers:
del self.providers[provider.type_]
prev = previous[provider.type_]
if prev is not None:
self.providers[provider.type_] = prev
try:
yield
finally:
for provider in providers:
del self.providers[provider.type_]
prev = previous[provider.type_]
if prev is not None:
self.providers[provider.type_] = prev

async def __aenter__(self) -> Self:
for extension in self.extensions:
Expand Down
11 changes: 11 additions & 0 deletions tests/ext/fastapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import aioinject
from aioinject import Inject
from aioinject.ext.fastapi import AioInjectMiddleware, inject
from tests.ext.utils import PropagatedError


@inject
Expand All @@ -35,6 +36,16 @@ async def route_with_depends(
) -> dict[str, str | int]:
return {"value": number}


@app_.get("/raise-exception")
@inject
async def raises_exception(
number: Annotated[int, Depends(dependency)],
) -> dict[str, str | int]:
if number == 0:
raise PropagatedError
return {"value": number}

return app_


Expand Down
33 changes: 33 additions & 0 deletions tests/ext/fastapi/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import contextlib
import uuid
from typing import Any

import httpx
import pytest
from _pytest.fixtures import SubRequest

import aioinject
from aioinject import Scoped, Transient, Singleton
from tests.ext.utils import ExceptionPropagation, PropagatedError


@pytest.fixture(params=["/function-route", "/depends"])
Expand Down Expand Up @@ -32,3 +36,32 @@ async def test_function_route_override(
response = await http_client.get(route)
assert response.status_code == httpx.codes.OK.value
assert response.json() == {"value": expected}

@pytest.mark.parametrize(
("provider_type", "should_propagate"),
[
(Singleton, False),
(Scoped, True),
(Transient, True),
],
)
async def test_propagation(
http_client: httpx.AsyncClient,
container: aioinject.Container,
route: str,
provider_type: Any,
should_propagate: bool,
) -> None:
propagation = ExceptionPropagation()

with (
container.override(provider_type(propagation.dependency, type_=int)), # type: ignore[call-arg]
contextlib.suppress(Exception),

):
await http_client.get("/raise-exception")

if should_propagate:
assert isinstance(propagation.exc, PropagatedError)
else:
assert propagation.exc is None
4 changes: 2 additions & 2 deletions tests/ext/litestar/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import aioinject
from aioinject import Inject
from aioinject.ext.litestar import AioInjectPlugin, inject
from tests.ext.utils import PropagatedError


@pytest.fixture
Expand All @@ -26,8 +27,7 @@ async def raise_exception(
provided: Annotated[int, Inject],
) -> dict[str, str | int]:
if provided == 0:
msg = "Raised Exception"
raise Exception(msg) # noqa: TRY002
raise PropagatedError
return {"value": provided}

return Litestar(
Expand Down
20 changes: 5 additions & 15 deletions tests/ext/litestar/test_litestar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import contextlib
import uuid
from collections.abc import AsyncIterator

import httpx
import pytest

import aioinject
from aioinject import Provider, Scoped, Singleton, Transient
from tests.ext.utils import ExceptionPropagation, PropagatedError


async def test_function_route(
Expand Down Expand Up @@ -43,25 +43,15 @@ async def test_should_propagate_exceptions(
provider_type: type[Provider[int]],
should_propagate: bool,
) -> None:
exc = None

@contextlib.asynccontextmanager
async def dependency() -> AsyncIterator[int]:
nonlocal exc
try:
yield 0
except Exception as e:
exc = e
raise
propagation = ExceptionPropagation()

with (
container.override(provider_type(dependency)), # type: ignore[call-arg]
container.override(provider_type(propagation.dependency)), # type: ignore[call-arg]
contextlib.suppress(Exception),
):
await http_client.get("/raise-exception")

if should_propagate:
assert type(exc) is Exception
assert str(exc) == "Raised Exception"
assert isinstance(propagation.exc, PropagatedError)
else:
assert exc is None
assert propagation.exc is None
19 changes: 19 additions & 0 deletions tests/ext/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import contextlib
from collections.abc import AsyncIterator


class PropagatedError(Exception):
pass


class ExceptionPropagation:
def __init__(self) -> None:
self.exc: BaseException | None = None

@contextlib.asynccontextmanager
async def dependency(self) -> AsyncIterator[int]:
try:
yield 0
except Exception as exc:
self.exc = exc
raise

0 comments on commit 886755a

Please sign in to comment.