-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #17205 from jdavcs/dev_cbv
Vendorize fastapi-utls.cbv
- Loading branch information
Showing
6 changed files
with
218 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
""" | ||
Original implementation by David Montague (@dmontagu) | ||
https://github.com/dmontagu/fastapi-utils | ||
""" | ||
from __future__ import annotations | ||
|
||
import inspect | ||
from collections.abc import Callable | ||
from typing import ( | ||
Any, | ||
get_type_hints, | ||
TypeVar, | ||
) | ||
|
||
from fastapi import ( | ||
APIRouter, | ||
Depends, | ||
) | ||
from pydantic.typing import is_classvar | ||
from starlette.routing import ( | ||
Route, | ||
WebSocketRoute, | ||
) | ||
|
||
T = TypeVar("T") | ||
|
||
CBV_CLASS_KEY = "__cbv_class__" | ||
|
||
|
||
def cbv(router: APIRouter) -> Callable[[type[T]], type[T]]: | ||
""" | ||
This function returns a decorator that converts the decorated into a class-based view for the provided router. | ||
Any methods of the decorated class that are decorated as endpoints using the router provided to this function | ||
will become endpoints in the router. The first positional argument to the methods (typically `self`) | ||
will be populated with an instance created using FastAPI's dependency-injection. | ||
For more detail, review the documentation at | ||
https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/#the-cbv-decorator | ||
""" | ||
|
||
def decorator(cls: type[T]) -> type[T]: | ||
return _cbv(router, cls) | ||
|
||
return decorator | ||
|
||
|
||
def _cbv(router: APIRouter, cls: type[T]) -> type[T]: | ||
""" | ||
Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated | ||
function calls that will properly inject an instance of `cls`. | ||
""" | ||
_init_cbv(cls) | ||
cbv_router = APIRouter() | ||
function_members = inspect.getmembers(cls, inspect.isfunction) | ||
functions_set = {func for _, func in function_members} | ||
cbv_routes = [ | ||
route | ||
for route in router.routes | ||
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set | ||
] | ||
for route in cbv_routes: | ||
router.routes.remove(route) | ||
_update_cbv_route_endpoint_signature(cls, route) | ||
cbv_router.routes.append(route) | ||
router.include_router(cbv_router) | ||
return cls | ||
|
||
|
||
def _init_cbv(cls: type[Any]) -> None: | ||
""" | ||
Idempotently modifies the provided `cls`, performing the following modifications: | ||
* The `__init__` function is updated to set any class-annotated dependencies as instance attributes | ||
* The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer | ||
""" | ||
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover | ||
return # Already initialized | ||
old_init: Callable[..., Any] = cls.__init__ | ||
old_signature = inspect.signature(old_init) | ||
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter | ||
new_parameters = [ | ||
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) | ||
] | ||
dependency_names: list[str] = [] | ||
for name, hint in get_type_hints(cls).items(): | ||
if is_classvar(hint): | ||
continue | ||
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} | ||
dependency_names.append(name) | ||
new_parameters.append( | ||
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs) | ||
) | ||
new_signature = old_signature.replace(parameters=new_parameters) | ||
|
||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None: | ||
for dep_name in dependency_names: | ||
dep_value = kwargs.pop(dep_name) | ||
setattr(self, dep_name, dep_value) | ||
old_init(self, *args, **kwargs) | ||
|
||
setattr(cls, "__signature__", new_signature) # noqa: B010 | ||
setattr(cls, "__init__", new_init) # noqa: B010 | ||
setattr(cls, CBV_CLASS_KEY, True) | ||
|
||
|
||
def _update_cbv_route_endpoint_signature(cls: type[Any], route: Route | WebSocketRoute) -> None: | ||
""" | ||
Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly. | ||
""" | ||
old_endpoint = route.endpoint | ||
old_signature = inspect.signature(old_endpoint) | ||
old_parameters: list[inspect.Parameter] = list(old_signature.parameters.values()) | ||
old_first_parameter = old_parameters[0] | ||
new_first_parameter = old_first_parameter.replace(default=Depends(cls)) | ||
new_parameters = [new_first_parameter] + [ | ||
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:] | ||
] | ||
new_signature = old_signature.replace(parameters=new_parameters) | ||
setattr(route.endpoint, "__signature__", new_signature) # noqa: B010 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
Original implementation by David Montague (@dmontagu) | ||
https://github.com/dmontagu/fastapi-utils | ||
""" | ||
from __future__ import annotations | ||
|
||
from typing import ( | ||
Any, | ||
ClassVar, | ||
Optional, | ||
) | ||
|
||
from fastapi import ( | ||
APIRouter, | ||
Depends, | ||
FastAPI, | ||
) | ||
from starlette.testclient import TestClient | ||
|
||
from galaxy.webapps.galaxy.api.cbv import cbv | ||
|
||
|
||
def test_cbv() -> None: | ||
router = APIRouter() | ||
|
||
def dependency() -> int: | ||
return 1 | ||
|
||
@cbv(router) | ||
class CBV: | ||
x: int = Depends(dependency) | ||
cx: ClassVar[int] = 1 | ||
cy: ClassVar[int] | ||
|
||
def __init__(self, z: int = Depends(dependency)): | ||
self.y = 1 | ||
self.z = z | ||
|
||
@router.get("/", response_model=int) | ||
def f(self) -> int: | ||
return self.cx + self.x + self.y + self.z | ||
|
||
@router.get("/classvar", response_model=bool) | ||
def g(self) -> bool: | ||
return hasattr(self, "cy") | ||
|
||
app = FastAPI() | ||
app.include_router(router) | ||
client = TestClient(app) | ||
response_1 = client.get("/") | ||
assert response_1.status_code == 200 | ||
assert response_1.content == b"4" | ||
|
||
response_2 = client.get("/classvar") | ||
assert response_2.status_code == 200 | ||
assert response_2.content == b"false" | ||
|
||
|
||
def test_method_order_preserved() -> None: | ||
router = APIRouter() | ||
|
||
@cbv(router) | ||
class TestCBV: | ||
@router.get("/test") | ||
def get_test(self) -> int: | ||
return 1 | ||
|
||
@router.get("/{item_id}") | ||
def get_item(self) -> int: # Alphabetically before `get_test` | ||
return 2 | ||
|
||
app = FastAPI() | ||
app.include_router(router) | ||
|
||
assert TestClient(app).get("/test").json() == 1 | ||
assert TestClient(app).get("/other").json() == 2 | ||
|
||
|
||
def test_multiple_decorators() -> None: | ||
router = APIRouter() | ||
|
||
@cbv(router) | ||
class RootHandler: | ||
@router.get("/items/?") | ||
@router.get("/items/{item_path:path}") | ||
@router.get("/database/{item_path:path}") | ||
def root(self, item_path: Optional[str] = None, item_query: Optional[str] = None) -> Any: # noqa: UP007 | ||
if item_path: | ||
return {"item_path": item_path} | ||
if item_query: | ||
return {"item_query": item_query} | ||
return [] | ||
|
||
client = TestClient(router) | ||
|
||
assert client.get("/items").json() == [] | ||
assert client.get("/items/1").json() == {"item_path": "1"} | ||
assert client.get("/database/abc").json() == {"item_path": "abc"} |