From 75b78241c8e7da16dd0c22f94c67c403c79c0603 Mon Sep 17 00:00:00 2001 From: Piotr Cybulski Date: Sun, 5 Feb 2023 22:52:54 +0100 Subject: [PATCH] Fix handling non-pydantic arguments in view function --- flask_pydantic/core.py | 12 ++++++++---- tests/func/test_app.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 912415a..e5765ac 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union from flask import Response, current_app, jsonify, make_response, request +from werkzeug.routing import Rule from pydantic import BaseModel, ValidationError from pydantic.tools import parse_obj_as @@ -68,12 +69,16 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel raise ManyModelValidationError(ve.errors()) -def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: +def validate_path_params( + func: Callable, kwargs: dict, url_rule: Rule +) -> Tuple[dict, list]: errors = [] validated = {} for name, type_ in func.__annotations__.items(): if name in {"query", "body", "form", "return"}: continue + if not url_rule or not name in url_rule.arguments: + continue try: value = parse_obj_as(type_, kwargs.get(name)) validated[name] = value @@ -167,7 +172,7 @@ def decorate(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): q, b, f, err = None, None, None, {} - kwargs, path_err = validate_path_params(func, kwargs) + kwargs, path_err = validate_path_params(func, kwargs, request.url_rule) if path_err: err["path_params"] = path_err query_in_kwargs = func.__annotations__.get("query") @@ -245,8 +250,7 @@ def wrapper(*args, **kwargs): "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 ) return make_response( - jsonify({"validation_error": err}), - status_code + jsonify({"validation_error": err}), status_code ) res = func(*args, **kwargs) diff --git a/tests/func/test_app.py b/tests/func/test_app.py index a01edd7..7a720b7 100644 --- a/tests/func/test_app.py +++ b/tests/func/test_app.py @@ -85,6 +85,17 @@ def int_path_param(obj_id): return IdObj(id=obj_id) +@pytest.fixture +def app_with_non_pydantic_argument(app): + class IdObj(BaseModel): + id: str + + @app.route("/path_param//", methods=["GET"]) + @validate() + def int_path_param(obj_id: str, non_pydantic: ValidationError = None): + return IdObj(id=obj_id) + + @pytest.fixture def app_with_custom_root_type(app): class Person(BaseModel): @@ -386,6 +397,23 @@ def test_str_param_passes(self, client): assert response.json == expected_response +@pytest.mark.usefixtures("app_with_non_pydantic_argument") +class TestFunctionArgument: + def test_int_str_param_passes(self, client): + id_ = 12 + expected_response = {"id": str(id_)} + response = client.get(f"/path_param/{id_}/") + + assert response.json == expected_response + + def test_str_param_passes(self, client): + id_ = "twelve" + expected_response = {"id": id_} + response = client.get(f"/path_param/{id_}/") + + assert response.json == expected_response + + @pytest.mark.usefixtures("app_with_optional_body") class TestGetJsonParams: def test_empty_body_fails(self, client):