From 064976a9a897be6e557f15f24b07b8dd8aa45f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 16:09:22 +0200 Subject: [PATCH 1/7] Set isort profile to black MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- setup.cfg | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3acda29..c828549 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,4 +19,7 @@ ignore = E121 E122 E123 E124 E125 E126 E127 E128 E711 E712 F811 F841 H803 E501 E exclude = .circleci, .github, - venv \ No newline at end of file + venv + +[isort] +profile = black From f6c2afc96dd43b36cc86cd594221eef84aafacfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 18:36:07 +0200 Subject: [PATCH 2/7] Improve typing in the module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #65 - make it pass mypy tests (i.e., improve type hinting) - add mypy test to CI - add py.typed to setup so that other packages recognize it's typed - enable async view function decoration (via ensure_sync() as noted on https://flask.palletsprojects.com/en/latest/async-await/#extensions) Signed-off-by: Marek Pikuła --- .github/workflows/tests.yml | 20 +++ flask_pydantic/converters.py | 8 +- flask_pydantic/core.py | 245 +++++++++++++++++++++-------------- flask_pydantic/exceptions.py | 32 +++-- flask_pydantic/py.typed | 0 requirements/test.pip | 2 + setup.py | 1 + 7 files changed, 192 insertions(+), 116 deletions(-) create mode 100644 flask_pydantic/py.typed diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aea03de..c165eee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,6 +3,26 @@ name: Tests on: [push, pull_request] jobs: + mypy: + runs-on: ubuntu-latest + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: "3.7" + - name: Install dependencies + if: steps.cache-pip.outputs.cache-hit != 'true' + run: | + python -m pip install --upgrade pip + pip install -r requirements/test.pip + - name: Run mypy satic check + run: | + python3 -m mypy flask_pydantic/ + build: runs-on: ${{ matrix.os }} strategy: diff --git a/flask_pydantic/converters.py b/flask_pydantic/converters.py index 0fed087..f128631 100644 --- a/flask_pydantic/converters.py +++ b/flask_pydantic/converters.py @@ -1,12 +1,12 @@ -from typing import Type +from typing import Dict, List, Type, Union from pydantic import BaseModel -from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.datastructures import MultiDict def convert_query_params( - query_params: ImmutableMultiDict, model: Type[BaseModel] -) -> dict: + query_params: "MultiDict[str, str]", model: Type[BaseModel] +) -> Dict[str, Union[str, List[str]]]: """ group query parameters into lists if model defines them diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 912415a..294f36a 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -1,7 +1,28 @@ +from collections.abc import Iterable from functools import wraps -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +from flask import Response, current_app, jsonify, request +from flask.typing import ResponseReturnValue, RouteCallable + +try: + from flask_restful import ( # type: ignore + original_flask_make_response as make_response, + ) +except ImportError: + from flask import make_response -from flask import Response, current_app, jsonify, make_response, request from pydantic import BaseModel, ValidationError from pydantic.tools import parse_obj_as @@ -13,22 +34,33 @@ ) from .exceptions import ValidationError as FailedValidation -try: - from flask_restful import original_flask_make_response as make_response -except ImportError: - pass +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict + + +ModelResponseReturnValue = Union[ResponseReturnValue, BaseModel] +ModelRouteCallable = Union[ + Callable[..., ModelResponseReturnValue], + Callable[..., Awaitable[ModelResponseReturnValue]], +] def make_json_response( - content: Union[BaseModel, Iterable[BaseModel]], + content: "Union[BaseModel, Iterable[BaseModel]]", status_code: int, by_alias: bool, exclude_none: bool = False, - many: bool = False, ) -> Response: """serializes model, creates JSON response with given status code""" - if many: - js = f"[{', '.join([model.json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" + if not isinstance(content, BaseModel): + js = "[" + js += ", ".join( + [ + model.json(exclude_none=exclude_none, by_alias=by_alias) + for model in content + ] + ) + js += "]" else: js = content.json(exclude_none=exclude_none, by_alias=by_alias) response = make_response(js, status_code) @@ -56,9 +88,9 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel return [model(**fields) for fields in content] except TypeError: # iteration through `content` fails - err = [ + err: List["ErrorDict"] = [ { - "loc": ["root"], + "loc": ("root",), "msg": "is not an array of objects", "type": "type_error.array", } @@ -68,8 +100,10 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel raise ManyModelValidationError(ve.errors()) -def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: - errors = [] +def validate_path_params( + func: ModelRouteCallable, kwargs: Dict[str, Any] +) -> Tuple[Dict[str, Any], List["ErrorDict"]]: + errors: List["ErrorDict"] = [] validated = {} for name, type_ in func.__annotations__.items(): if name in {"query", "body", "form", "return"}: @@ -77,21 +111,42 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: try: value = parse_obj_as(type_, kwargs.get(name)) validated[name] = value - except ValidationError as e: - err = e.errors()[0] - err["loc"] = [name] + except ValidationError as error: + err = error.errors()[0] + err["loc"] = (name,) errors.append(err) kwargs = {**kwargs, **validated} return kwargs, errors -def get_body_dict(**params): - data = request.get_json(**params) +def get_body_dict(**params: Dict[str, Any]) -> Any: + data = request.get_json(**params) # type: ignore if data is None and params.get("silent"): return {} return data +def _ensure_model_kwarg( + kwarg_name: str, + from_validate: Optional[Type[BaseModel]], + func: ModelRouteCallable, +) -> Tuple[Optional[Type[BaseModel]], bool]: + """Get model information either from wrapped function or validate kwargs.""" + in_func_kwargs = func.__annotations__.get(kwarg_name) + if in_func_kwargs is None: + return from_validate, False + assert isinstance(in_func_kwargs, type) and issubclass( + in_func_kwargs, BaseModel + ), "Model in function arguments needs to be a BaseModel." + + # Ensure that the most "detailed" model is used. + if from_validate is None: + return in_func_kwargs, True + if issubclass(in_func_kwargs, from_validate): + return in_func_kwargs, True + return from_validate, True + + def validate( body: Optional[Type[BaseModel]] = None, query: Optional[Type[BaseModel]] = None, @@ -100,7 +155,7 @@ def validate( response_many: bool = False, request_body_many: bool = False, response_by_alias: bool = False, - get_json_params: Optional[dict] = None, + get_json_params: Optional[Dict[str, Any]] = None, form: Optional[Type[BaseModel]] = None, ): """ @@ -163,105 +218,93 @@ def test_route_kwargs(query:Query, body:Body, form:Form): -> that will render JSON response with serialized MyModel instance """ - def decorate(func: Callable) -> Callable: + def decorate(func: ModelRouteCallable) -> RouteCallable: @wraps(func) - def wrapper(*args, **kwargs): - q, b, f, err = None, None, None, {} - kwargs, path_err = validate_path_params(func, kwargs) - if path_err: - err["path_params"] = path_err - query_in_kwargs = func.__annotations__.get("query") - query_model = query_in_kwargs or query - if query_model: + def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> ResponseReturnValue: + q, b, f, err = None, None, None, FailedValidation() + func_kwargs, path_err = validate_path_params(func, kwargs) + if len(path_err) > 0: + err.path_params = path_err + query_model, query_in_kwargs = _ensure_model_kwarg("query", query, func) + if query_model is not None: query_params = convert_query_params(request.args, query_model) try: q = query_model(**query_params) except ValidationError as ve: - err["query_params"] = ve.errors() - body_in_kwargs = func.__annotations__.get("body") - body_model = body_in_kwargs or body - if body_model: + err.query_params = ve.errors() + body_model, body_in_kwargs = _ensure_model_kwarg("body", body, func) + if body_model is not None: body_params = get_body_dict(**(get_json_params or {})) - if "__root__" in body_model.__fields__: - try: - b = body_model(__root__=body_params).__root__ - except ValidationError as ve: - err["body_params"] = ve.errors() - elif request_body_many: - try: + try: + if "__root__" in body_model.__fields__: + b = body_model(__root__=body_params).__root__ # type: ignore + elif request_body_many: b = validate_many_models(body_model, body_params) - except ManyModelValidationError as e: - err["body_params"] = e.errors() - else: - try: + else: b = body_model(**body_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "application/json": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError() - except ValidationError as ve: - err["body_params"] = ve.errors() - form_in_kwargs = func.__annotations__.get("form") - form_model = form_in_kwargs or form - if form_model: + except (ValidationError, ManyModelValidationError) as error: + err.body_params = error.errors() + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "application/json": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + form_model, form_in_kwargs = _ensure_model_kwarg("form", form, func) + if form_model is not None: form_params = request.form - if "__root__" in form_model.__fields__: - try: - f = form_model(__root__=form_params).__root__ - except ValidationError as ve: - err["form_params"] = ve.errors() - else: - try: + try: + if "__root__" in form_model.__fields__: + f = form_model(__root__=form_params).__root__ # type: ignore + else: f = form_model(**form_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "multipart/form-data": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError - except ValidationError as ve: - err["form_params"] = ve.errors() - request.query_params = q - request.body_params = b - request.form_params = f + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "multipart/form-data": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + except ValidationError as ve: + err.form_params = ve.errors() + request.query_params = q # type: ignore + request.body_params = b # type: ignore + request.form_params = f # type: ignore if query_in_kwargs: - kwargs["query"] = q + func_kwargs["query"] = q if body_in_kwargs: - kwargs["body"] = b + func_kwargs["body"] = b if form_in_kwargs: - kwargs["form"] = f + func_kwargs["form"] = f - if err: + if err.check(): if current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False ): - raise FailedValidation(**err) + raise err else: status_code = current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 ) return make_response( - jsonify({"validation_error": err}), - status_code + jsonify({"validation_error": err.to_dict()}), status_code ) - res = func(*args, **kwargs) + res: ModelResponseReturnValue = current_app.ensure_sync(func)( + *args, **func_kwargs + ) if response_many: - if is_iterable_of_models(res): - return make_json_response( - res, - on_success_status, - by_alias=response_by_alias, - exclude_none=exclude_none, - many=True, - ) - else: + if not is_iterable_of_models(res): raise InvalidIterableOfModelsException(res) + return make_json_response( + res, # type: ignore # Iterability and type is ensured above. + on_success_status, + by_alias=response_by_alias, + exclude_none=exclude_none, + ) + if isinstance(res, BaseModel): return make_json_response( res, @@ -275,23 +318,29 @@ def wrapper(*args, **kwargs): and len(res) in [2, 3] and isinstance(res[0], BaseModel) ): - headers = None + headers: Optional[ + Union[Dict[str, Any], Tuple[Any, ...], List[Any]] + ] = None status = on_success_status if isinstance(res[1], (dict, tuple, list)): headers = res[1] - elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)): - status = res[1] - headers = res[2] - else: + elif isinstance(res[1], int): status = res[1] + # Following type ignores should be fixed once + # https://github.com/python/mypy/issues/1178 is fixed. + if len(res) == 3 and isinstance( + res[2], (dict, tuple, list) # type: ignore[misc] + ): + headers = res[2] # type: ignore[misc] + ret = make_json_response( res[0], status, exclude_none=exclude_none, by_alias=response_by_alias, ) - if headers: + if headers is not None: ret.headers.update(headers) return ret diff --git a/flask_pydantic/exceptions.py b/flask_pydantic/exceptions.py index e214cc1..14eb7ba 100644 --- a/flask_pydantic/exceptions.py +++ b/flask_pydantic/exceptions.py @@ -1,4 +1,8 @@ -from typing import List, Optional +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict class BaseFlaskPydanticException(Exception): @@ -24,7 +28,7 @@ class ManyModelValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation of many models in an iterable""" - def __init__(self, errors: List[dict], *args): + def __init__(self, errors: List["ErrorDict"], *args: Any): self._errors = errors super().__init__(*args) @@ -32,19 +36,19 @@ def errors(self): return self._errors +@dataclass class ValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation if the user has configured an exception to be raised instead of a response""" - def __init__( - self, - body_params: Optional[List[dict]] = None, - form_params: Optional[List[dict]] = None, - path_params: Optional[List[dict]] = None, - query_params: Optional[List[dict]] = None, - ): - super().__init__() - self.body_params = body_params - self.form_params = form_params - self.path_params = path_params - self.query_params = query_params + body_params: Optional[List["ErrorDict"]] = None + form_params: Optional[List["ErrorDict"]] = None + path_params: Optional[List["ErrorDict"]] = None + query_params: Optional[List["ErrorDict"]] = None + + def check(self) -> bool: + """Check if any param resulted in error.""" + return any(value is not None for _, value in asdict(self).items()) + + def to_dict(self): + return {key: value for key, value in asdict(self).items() if value is not None} diff --git a/flask_pydantic/py.typed b/flask_pydantic/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/requirements/test.pip b/requirements/test.pip index 14530b2..a959c83 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -6,3 +6,5 @@ pytest-flake8 pytest-coverage pytest-black pytest-mock + +mypy diff --git a/setup.py b/setup.py index 09be0eb..1a4c595 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def find_version(file_path: Path = VERSION_FILE_PATH) -> str: long_description=README, long_description_content_type="text/markdown", packages=["flask_pydantic"], + package_data={"flask_pydantic": ["py.typed"]}, install_requires=list(get_install_requires()), python_requires=">=3.6", classifiers=[ From 9fe323eb388b2ce8c433a12f408cf42ea94019ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 16:56:50 +0000 Subject: [PATCH 3/7] Update GitHub actions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- .github/workflows/tests.yml | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c165eee..26b9b8a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -9,13 +9,16 @@ jobs: fail-fast: false steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python 3.7 - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: "3.7" + cache: pip + cache-dependency-path: | + requirements/base.pip + requirements/test.pip - name: Install dependencies - if: steps.cache-pip.outputs.cache-hit != 'true' run: | python -m pip install --upgrade pip pip install -r requirements/test.pip @@ -32,13 +35,16 @@ jobs: os: [ubuntu-latest, macOS-latest] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: | + requirements/base.pip + requirements/test.pip - name: Install dependencies - if: steps.cache-pip.outputs.cache-hit != 'true' run: | python -m pip install --upgrade pip pip install -r requirements/test.pip From 8d70d0f88a4cdac9a05d4f49906524c701902cd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 20:52:46 +0200 Subject: [PATCH 4/7] Run linters in a separate matrix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- .github/workflows/tests.yml | 14 ++++++++------ requirements/test.pip | 2 ++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 26b9b8a..ffb0338 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,10 +3,15 @@ name: Tests on: [push, pull_request] jobs: - mypy: + lint: runs-on: ubuntu-latest strategy: fail-fast: false + matrix: + tool: + - flake8 + - mypy + - black --diff --check steps: - uses: actions/checkout@v3 @@ -22,9 +27,9 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements/test.pip - - name: Run mypy satic check + - name: Run ${{ matrix.tool }} static check run: | - python3 -m mypy flask_pydantic/ + python3 -m ${{ matrix.tool }} flask_pydantic/ tests/ build: runs-on: ${{ matrix.os }} @@ -51,6 +56,3 @@ jobs: - name: Test with pytest run: | python3 -m pytest - - name: Lint with flake8 - run: | - flake8 . diff --git a/requirements/test.pip b/requirements/test.pip index a959c83..d5f68a5 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -7,4 +7,6 @@ pytest-coverage pytest-black pytest-mock +black +flake8 mypy From 163dff9c9d128064166e401491c73bd63f1f1b36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 20:56:14 +0200 Subject: [PATCH 5/7] Make it possible to use TypeVars for model hinting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- flask_pydantic/core.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 294f36a..d7690c9 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -1,3 +1,4 @@ +import inspect from collections.abc import Iterable from functools import wraps from typing import ( @@ -10,7 +11,9 @@ Optional, Tuple, Type, + TypeVar, Union, + get_args, ) from flask import Response, current_app, jsonify, request @@ -126,25 +129,42 @@ def get_body_dict(**params: Dict[str, Any]) -> Any: return data +def _get_type_generic(hint: Any): + """Extract type information from bound TypeVar or Type[TypeVar].""" + if isinstance(hint, TypeVar): + assert ( + getattr(hint, "__bound__", None) is not None + ), "If using TypeVar, you need to specify bound model." + return getattr(hint, "__bound__") + + args = get_args(hint) + if len(args) > 0 and isinstance(args[0], TypeVar): + assert ( + getattr(args[0], "__bound__", None) is not None + ), "If using TypeVar, you need to specify bound model." + return getattr(args[0], "__bound__") + + return hint + + def _ensure_model_kwarg( kwarg_name: str, from_validate: Optional[Type[BaseModel]], func: ModelRouteCallable, ) -> Tuple[Optional[Type[BaseModel]], bool]: """Get model information either from wrapped function or validate kwargs.""" - in_func_kwargs = func.__annotations__.get(kwarg_name) - if in_func_kwargs is None: - return from_validate, False - assert isinstance(in_func_kwargs, type) and issubclass( - in_func_kwargs, BaseModel - ), "Model in function arguments needs to be a BaseModel." + func_spec = inspect.getfullargspec(func) + in_func_arg = kwarg_name in func_spec.args or kwarg_name in func_spec.kwonlyargs + from_func = _get_type_generic(func_spec.annotations.get(kwarg_name)) + if from_func is None or not isinstance(from_func, type): + return _get_type_generic(from_validate), in_func_arg # Ensure that the most "detailed" model is used. if from_validate is None: - return in_func_kwargs, True - if issubclass(in_func_kwargs, from_validate): - return in_func_kwargs, True - return from_validate, True + return from_func, in_func_arg + if issubclass(from_func, from_validate): + return from_func, in_func_arg + return from_validate, in_func_arg def validate( From 3e350941ada7cf8db923ac8c26d2dd0ec55779c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 20:58:05 +0200 Subject: [PATCH 6/7] Fix typing in tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- tests/conftest.py | 78 +++++++++++++++++++++++++++-------------- tests/func/test_app.py | 4 +-- tests/unit/test_core.py | 12 ++++--- 3 files changed, 60 insertions(+), 34 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e2b8ed1..6b94bb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,15 @@ -from typing import List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import pytest from flask import Flask, request -from flask_pydantic import validate from pydantic import BaseModel +from pydantic.generics import GenericModel + +from flask_pydantic import validate @pytest.fixture -def posts() -> List[dict]: +def posts() -> List[Dict[str, Any]]: return [ {"title": "title 1", "text": "random text", "views": 1}, {"title": "2", "text": "another text", "views": 2}, @@ -16,30 +18,33 @@ def posts() -> List[dict]: ] +class Query(BaseModel): + limit: int = 2 + min_views: Optional[int] + + @pytest.fixture def query_model() -> Type[BaseModel]: - class Query(BaseModel): - limit: int = 2 - min_views: Optional[int] - return Query +class Body(BaseModel): + search_term: str + exclude: Optional[str] + + @pytest.fixture def body_model() -> Type[BaseModel]: - class Body(BaseModel): - search_term: str - exclude: Optional[str] - return Body +class Form(BaseModel): + search_term: str + exclude: Optional[str] + + @pytest.fixture def form_model() -> Type[BaseModel]: - class Form(BaseModel): - search_term: str - exclude: Optional[str] - return Form @@ -53,23 +58,30 @@ class Post(BaseModel): return Post -@pytest.fixture -def response_model(post_model: BaseModel) -> Type[BaseModel]: - class Response(BaseModel): - results: List[post_model] - count: int +PostModelT = TypeVar("PostModelT", bound=BaseModel) - return Response +class Response(GenericModel, Generic[PostModelT]): + results: List[PostModelT] + count: int -def is_excluded(post: dict, exclude: Optional[str]) -> bool: + +@pytest.fixture +def response_model(post_model: PostModelT) -> Type[Response[PostModelT]]: + return Response[PostModelT] + + +def is_excluded(post: Dict[str, Any], exclude: Optional[str]) -> bool: if exclude is None: return False return exclude in post["title"] or exclude in post["text"] def pass_search( - post: dict, search_term: str, exclude: Optional[str], min_views: Optional[int] + post: Dict[str, Any], + search_term: str, + exclude: Optional[str], + min_views: Optional[int], ) -> bool: return ( (search_term in post["title"] or search_term in post["text"]) @@ -78,8 +90,20 @@ def pass_search( ) +QueryModelT = TypeVar("QueryModelT", bound=Query) +BodyModelT = TypeVar("BodyModelT", bound=Body) +FormModelT = TypeVar("FormModelT", bound=Form) + + @pytest.fixture -def app(posts, response_model, query_model, body_model, post_model, form_model): +def app( + posts: List[Dict[str, Any]], + response_model: Type[Response[PostModelT]], + query_model: Type[QueryModelT], + body_model: Type[BodyModelT], + post_model: Type[PostModelT], + form_model: Type[FormModelT], +) -> Flask: app = Flask("test_app") app.config["DEBUG"] = True app.config["TESTING"] = True @@ -88,7 +112,7 @@ def app(posts, response_model, query_model, body_model, post_model, form_model): @validate(query=query_model, body=body_model) def post(): query_params = request.query_params - body = request.body_params + body: BodyModelT = request.body_params results = [ post_model(**p) for p in posts @@ -98,7 +122,7 @@ def post(): @app.route("/search/kwargs", methods=["POST"]) @validate() - def post_kwargs(query: query_model, body: body_model): + def post_kwargs(query: Type[QueryModelT], body: Type[BodyModelT]): results = [ post_model(**p) for p in posts @@ -108,7 +132,7 @@ def post_kwargs(query: query_model, body: body_model): @app.route("/search/form/kwargs", methods=["POST"]) @validate() - def post_kwargs_form(query: query_model, form: form_model): + def post_kwargs_form(query: Type[QueryModelT], form: Type[BodyModelT]): results = [ post_model(**p) for p in posts diff --git a/tests/func/test_app.py b/tests/func/test_app.py index a01edd7..8660955 100644 --- a/tests/func/test_app.py +++ b/tests/func/test_app.py @@ -97,7 +97,7 @@ class PersonBulk(BaseModel): @app.route("/root_type", methods=["POST"]) @validate() def root_type(body: PersonBulk): - return {"number": len(body)} + return {"number": len(body)} # type: ignore @pytest.fixture @@ -293,7 +293,7 @@ def test_custom_headers(client): @pytest.mark.usefixtures("app_with_custom_headers_status") -def test_custom_headers(client): +def test_custom_headers_201_status(client): response = client.get("/custom_headers_status") assert response.json == {"test": 1} assert response.status_code == 201 diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 62b1baf..22eb2d9 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -16,7 +16,7 @@ class ValidateParams(NamedTuple): body_model: Optional[Type[BaseModel]] = None query_model: Optional[Type[BaseModel]] = None form_model: Optional[Type[BaseModel]] = None - response_model: Type[BaseModel] = None + response_model: Optional[Type[BaseModel]] = None on_success_status: int = 200 request_query: ImmutableMultiDict = ImmutableMultiDict({}) request_body: Union[dict, List[dict]] = {} @@ -47,7 +47,7 @@ class RequestBodyModel(BaseModel): class FormModel(BaseModel): f1: int - f2: str = None + f2: Optional[str] = None class RequestBodyModelRoot(BaseModel): @@ -228,10 +228,11 @@ def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): mock_request.form = parameters.request_form def f( - body: parameters.body_model, - query: parameters.query_model, - form: parameters.form_model, + body: parameters.body_model, # type: ignore + query: parameters.query_model, # type: ignore + form: parameters.form_model, # type: ignore ): + assert parameters.response_model is not None return parameters.response_model( **body.dict(), **query.dict(), **form.dict() ) @@ -408,6 +409,7 @@ def f() -> Any: body = mock_request.body_params.dict() if mock_request.query_params: query = mock_request.query_params.dict() + assert parameters.response_model is not None return parameters.response_model(**body, **query) response = validate( From 136fb94d29c099a32ffbc6c640e2342c2346cbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 21:00:57 +0200 Subject: [PATCH 7/7] Fix get_args for Python 3.7 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marek Pikuła --- flask_pydantic/core.py | 2 +- requirements/base.pip | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index d7690c9..1bb0ba2 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -13,11 +13,11 @@ Type, TypeVar, Union, - get_args, ) from flask import Response, current_app, jsonify, request from flask.typing import ResponseReturnValue, RouteCallable +from typing_extensions import get_args try: from flask_restful import ( # type: ignore diff --git a/requirements/base.pip b/requirements/base.pip index 6640194..0b1c0de 100644 --- a/requirements/base.pip +++ b/requirements/base.pip @@ -1,2 +1,3 @@ Flask pydantic>=1.7 +typing-extensions>=4.1.1