Skip to content

Commit

Permalink
Fix handling non-pydantic arguments in view function
Browse files Browse the repository at this point in the history
  • Loading branch information
cybski committed Feb 5, 2023
1 parent 4b9b7c6 commit 75b7824
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
12 changes: 8 additions & 4 deletions flask_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union

from flask import Response, current_app, jsonify, make_response, request
from werkzeug.routing import Rule
from pydantic import BaseModel, ValidationError
from pydantic.tools import parse_obj_as

Expand Down Expand Up @@ -68,12 +69,16 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel
raise ManyModelValidationError(ve.errors())


def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]:
def validate_path_params(
func: Callable, kwargs: dict, url_rule: Rule
) -> Tuple[dict, list]:
errors = []
validated = {}
for name, type_ in func.__annotations__.items():
if name in {"query", "body", "form", "return"}:
continue
if not url_rule or not name in url_rule.arguments:
continue
try:
value = parse_obj_as(type_, kwargs.get(name))
validated[name] = value
Expand Down Expand Up @@ -167,7 +172,7 @@ def decorate(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs):
q, b, f, err = None, None, None, {}
kwargs, path_err = validate_path_params(func, kwargs)
kwargs, path_err = validate_path_params(func, kwargs, request.url_rule)
if path_err:
err["path_params"] = path_err
query_in_kwargs = func.__annotations__.get("query")
Expand Down Expand Up @@ -245,8 +250,7 @@ def wrapper(*args, **kwargs):
"FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400
)
return make_response(
jsonify({"validation_error": err}),
status_code
jsonify({"validation_error": err}), status_code
)
res = func(*args, **kwargs)

Expand Down
28 changes: 28 additions & 0 deletions tests/func/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,17 @@ def int_path_param(obj_id):
return IdObj(id=obj_id)


@pytest.fixture
def app_with_non_pydantic_argument(app):
class IdObj(BaseModel):
id: str

@app.route("/path_param/<obj_id>/", methods=["GET"])
@validate()
def int_path_param(obj_id: str, non_pydantic: ValidationError = None):
return IdObj(id=obj_id)


@pytest.fixture
def app_with_custom_root_type(app):
class Person(BaseModel):
Expand Down Expand Up @@ -386,6 +397,23 @@ def test_str_param_passes(self, client):
assert response.json == expected_response


@pytest.mark.usefixtures("app_with_non_pydantic_argument")
class TestFunctionArgument:
def test_int_str_param_passes(self, client):
id_ = 12
expected_response = {"id": str(id_)}
response = client.get(f"/path_param/{id_}/")

assert response.json == expected_response

def test_str_param_passes(self, client):
id_ = "twelve"
expected_response = {"id": id_}
response = client.get(f"/path_param/{id_}/")

assert response.json == expected_response


@pytest.mark.usefixtures("app_with_optional_body")
class TestGetJsonParams:
def test_empty_body_fails(self, client):
Expand Down

0 comments on commit 75b7824

Please sign in to comment.