Skip to content

Commit

Permalink
Improve typing in the module
Browse files Browse the repository at this point in the history
Fixes bauerji#65

- make it pass mypy tests (i.e., improve type hinting)
- add mypy test to CI
- add py.typed to setup so that other packages recognize it's typed
- enable async view function decoration (via ensure_sync() as noted on
  https://flask.palletsprojects.com/en/latest/async-await/#extensions)

Signed-off-by: Marek Pikuła <[email protected]>
  • Loading branch information
MarekPikula committed Apr 18, 2023
1 parent 064976a commit 24b321e
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 115 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@ name: Tests
on: [push, pull_request]

jobs:
mypy:
runs-on: ubuntu-latest
strategy:
fail-fast: false

steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: "3.7"
- name: Install dependencies
if: steps.cache-pip.outputs.cache-hit != 'true'
run: |
python -m pip install --upgrade pip
pip install -r requirements/test.pip
- name: Run mypy satic check
run: |
python3 -m mypy flask_pydantic/
build:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
8 changes: 4 additions & 4 deletions flask_pydantic/converters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Type
from typing import Dict, List, Type, Union

from pydantic import BaseModel
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.datastructures import MultiDict


def convert_query_params(
query_params: ImmutableMultiDict, model: Type[BaseModel]
) -> dict:
query_params: "MultiDict[str, str]", model: Type[BaseModel]
) -> Dict[str, Union[str, List[str]]]:
"""
group query parameters into lists if model defines them
Expand Down
243 changes: 146 additions & 97 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)

from flask import Response, current_app, jsonify, request
from flask.typing import ResponseReturnValue, RouteCallable

try:
from flask_restful import ( # type: ignore
original_flask_make_response as make_response,
)
except ImportError:
from flask import make_response

from flask import Response, current_app, jsonify, make_response, request
from pydantic import BaseModel, ValidationError
from pydantic.tools import parse_obj_as

Expand All @@ -13,22 +34,33 @@
)
from .exceptions import ValidationError as FailedValidation

try:
from flask_restful import original_flask_make_response as make_response
except ImportError:
pass
if TYPE_CHECKING:
from pydantic.error_wrappers import ErrorDict


ModelResponseReturnValue = Union[ResponseReturnValue, BaseModel]
ModelRouteCallable = Union[
Callable[..., ModelResponseReturnValue],
Callable[..., Awaitable[ModelResponseReturnValue]],
]


def make_json_response(
content: Union[BaseModel, Iterable[BaseModel]],
status_code: int,
by_alias: bool,
exclude_none: bool = False,
many: bool = False,
) -> Response:
"""serializes model, creates JSON response with given status code"""
if many:
js = f"[{', '.join([model.json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]"
if not isinstance(content, BaseModel):
js = "["
js += ", ".join(
[
model.json(exclude_none=exclude_none, by_alias=by_alias)
for model in content
]
)
js += "]"
else:
js = content.json(exclude_none=exclude_none, by_alias=by_alias)
response = make_response(js, status_code)
Expand Down Expand Up @@ -56,9 +88,9 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
return [model(**fields) for fields in content]
except TypeError:
# iteration through `content` fails
err = [
err: List["ErrorDict"] = [
{
"loc": ["root"],
"loc": ("root",),
"msg": "is not an array of objects",
"type": "type_error.array",
}
Expand All @@ -68,30 +100,53 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
raise ManyModelValidationError(ve.errors())


def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
errors = []
def validate_path_params(
func: ModelRouteCallable, kwargs: Dict[str, Any]
) -> Tuple[Dict[str, Any], List["ErrorDict"]]:
errors: List["ErrorDict"] = []
validated = {}
for name, type_ in func.__annotations__.items():
if name in {"query", "body", "form", "return"}:
continue
try:
value = parse_obj_as(type_, kwargs.get(name))
validated[name] = value
except ValidationError as e:
err = e.errors()[0]
err["loc"] = [name]
except ValidationError as error:
err = error.errors()[0]
err["loc"] = (name,)
errors.append(err)
kwargs = {**kwargs, **validated}
return kwargs, errors


def get_body_dict(**params):
data = request.get_json(**params)
def get_body_dict(**params: Dict[str, Any]) -> Any:
data = request.get_json(**params) # type: ignore
if data is None and params.get("silent"):
return {}
return data


def _ensure_model_kwarg(
kwarg_name: str,
from_validate: Optional[Type[BaseModel]],
func: ModelRouteCallable,
) -> Tuple[Optional[Type[BaseModel]], bool]:
"""Get model information either from wrapped function or validate kwargs."""
in_func_kwargs = func.__annotations__.get(kwarg_name)
if in_func_kwargs is None:
return from_validate, False
assert isinstance(in_func_kwargs, type) and issubclass(
in_func_kwargs, BaseModel
), "Model in function arguments needs to be a BaseModel."

# Ensure that the most "detailed" model is used.
if from_validate is None:
return in_func_kwargs, True
if issubclass(in_func_kwargs, from_validate):
return in_func_kwargs, True
return from_validate, True


def validate(
body: Optional[Type[BaseModel]] = None,
query: Optional[Type[BaseModel]] = None,
Expand All @@ -100,7 +155,7 @@ def validate(
response_many: bool = False,
request_body_many: bool = False,
response_by_alias: bool = False,
get_json_params: Optional[dict] = None,
get_json_params: Optional[Dict[str, Any]] = None,
form: Optional[Type[BaseModel]] = None,
):
"""
Expand Down Expand Up @@ -163,105 +218,93 @@ def test_route_kwargs(query:Query, body:Body, form:Form):
-> that will render JSON response with serialized MyModel instance
"""

def decorate(func: Callable) -> Callable:
def decorate(func: ModelRouteCallable) -> RouteCallable:
@wraps(func)
def wrapper(*args, **kwargs):
q, b, f, err = None, None, None, {}
kwargs, path_err = validate_path_params(func, kwargs)
if path_err:
err["path_params"] = path_err
query_in_kwargs = func.__annotations__.get("query")
query_model = query_in_kwargs or query
if query_model:
def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> ResponseReturnValue:
q, b, f, err = None, None, None, FailedValidation()
func_kwargs, path_err = validate_path_params(func, kwargs)
if len(path_err) > 0:
err.path_params = path_err
query_model, query_in_kwargs = _ensure_model_kwarg("query", query, func)
if query_model is not None:
query_params = convert_query_params(request.args, query_model)
try:
q = query_model(**query_params)
except ValidationError as ve:
err["query_params"] = ve.errors()
body_in_kwargs = func.__annotations__.get("body")
body_model = body_in_kwargs or body
if body_model:
err.query_params = ve.errors()
body_model, body_in_kwargs = _ensure_model_kwarg("body", body, func)
if body_model is not None:
body_params = get_body_dict(**(get_json_params or {}))
if "__root__" in body_model.__fields__:
try:
b = body_model(__root__=body_params).__root__
except ValidationError as ve:
err["body_params"] = ve.errors()
elif request_body_many:
try:
try:
if "__root__" in body_model.__fields__:
b = body_model(__root__=body_params).__root__ # type: ignore
elif request_body_many:
b = validate_many_models(body_model, body_params)
except ManyModelValidationError as e:
err["body_params"] = e.errors()
else:
try:
else:
b = body_model(**body_params)
except TypeError:
content_type = request.headers.get("Content-Type", "").lower()
media_type = content_type.split(";")[0]
if media_type != "application/json":
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError()
except ValidationError as ve:
err["body_params"] = ve.errors()
form_in_kwargs = func.__annotations__.get("form")
form_model = form_in_kwargs or form
if form_model:
except (ValidationError, ManyModelValidationError) as error:
err.body_params = error.errors()
except TypeError as error:
content_type = request.headers.get("Content-Type", "").lower()
media_type = content_type.split(";")[0]
if media_type != "application/json":
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError() from error
form_model, form_in_kwargs = _ensure_model_kwarg("form", form, func)
if form_model is not None:
form_params = request.form
if "__root__" in form_model.__fields__:
try:
f = form_model(__root__=form_params).__root__
except ValidationError as ve:
err["form_params"] = ve.errors()
else:
try:
try:
if "__root__" in form_model.__fields__:
f = form_model(__root__=form_params).__root__ # type: ignore
else:
f = form_model(**form_params)
except TypeError:
content_type = request.headers.get("Content-Type", "").lower()
media_type = content_type.split(";")[0]
if media_type != "multipart/form-data":
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError
except ValidationError as ve:
err["form_params"] = ve.errors()
request.query_params = q
request.body_params = b
request.form_params = f
except TypeError as error:
content_type = request.headers.get("Content-Type", "").lower()
media_type = content_type.split(";")[0]
if media_type != "multipart/form-data":
return unsupported_media_type_response(content_type)
else:
raise JsonBodyParsingError() from error
except ValidationError as ve:
err.form_params = ve.errors()
request.query_params = q # type: ignore
request.body_params = b # type: ignore
request.form_params = f # type: ignore
if query_in_kwargs:
kwargs["query"] = q
func_kwargs["query"] = q
if body_in_kwargs:
kwargs["body"] = b
func_kwargs["body"] = b
if form_in_kwargs:
kwargs["form"] = f
func_kwargs["form"] = f

if err:
if err.check():
if current_app.config.get(
"FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False
):
raise FailedValidation(**err)
raise err
else:
status_code = current_app.config.get(
"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400
)
return make_response(
jsonify({"validation_error": err}),
status_code
jsonify({"validation_error": err.to_dict()}), status_code
)
res = func(*args, **kwargs)
res: ModelResponseReturnValue = current_app.ensure_sync(func)(
*args, **func_kwargs
)

if response_many:
if is_iterable_of_models(res):
return make_json_response(
res,
on_success_status,
by_alias=response_by_alias,
exclude_none=exclude_none,
many=True,
)
else:
if not is_iterable_of_models(res):
raise InvalidIterableOfModelsException(res)

return make_json_response(
res, # type: ignore # Iterability and type is ensured above.
on_success_status,
by_alias=response_by_alias,
exclude_none=exclude_none,
)

if isinstance(res, BaseModel):
return make_json_response(
res,
Expand All @@ -275,23 +318,29 @@ def wrapper(*args, **kwargs):
and len(res) in [2, 3]
and isinstance(res[0], BaseModel)
):
headers = None
headers: Optional[
Union[Dict[str, Any], Tuple[Any, ...], List[Any]]
] = None
status = on_success_status
if isinstance(res[1], (dict, tuple, list)):
headers = res[1]
elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)):
status = res[1]
headers = res[2]
else:
elif isinstance(res[1], int):
status = res[1]

# Following type ignores should be fixed once
# https://github.com/python/mypy/issues/1178 is fixed.
if len(res) == 3 and isinstance(
res[2], (dict, tuple, list) # type: ignore[misc]
):
headers = res[2] # type: ignore[misc]

ret = make_json_response(
res[0],
status,
exclude_none=exclude_none,
by_alias=response_by_alias,
)
if headers:
if headers is not None:
ret.headers.update(headers)
return ret

Expand Down
Loading

0 comments on commit 24b321e

Please sign in to comment.