Skip to content

Commit

Permalink
Catch and expose errors from ML model (#8)
Browse files Browse the repository at this point in the history
* Catch, format, and return errors from ML Model

Catch errors raised by ML model's predict method, format them as JSON,
and return as the body of the `500` status response.

* Add GitHub action for static analysis to CI

Extend CI workflow to include static analysis with flake8 and mypy.
  • Loading branch information
bloomonkey authored Jan 31, 2023
1 parent 52e4f6c commit da036be
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 16 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@ on:
branches: [ main ]

jobs:
static-analysis:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install poetry
poetry install
- name: Run flake8
run: |
poetry run flake8 fastapi_mlflow tests
- name: Run mypy
run: |
poetry run mypy fastapi_mlflow tests
build:

runs-on: ubuntu-latest
Expand Down
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- Catch errors raised by ML model's predict method, format them as JSON, and return as the body of the 500 status response.

## [0.3.2] - 2022-01-13
### Fixed
Expand All @@ -13,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.3.1] - 2022-11-24
### Fixed
- Assume that prediction output field(s) may be nullable. It's not uncommon to want ML models to return null or nan values at inference time (e.g. extrapolation outside of training data range).
- Coerce nan values in output to null in JSON
- Coerce nan values in output to null in JSON.

## [0.3.0] - 2022-10-07
### Added
Expand Down
2 changes: 0 additions & 2 deletions fastapi_mlflow/_mlflow_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from datetime import date, datetime
from typing import Dict, Optional, Union

from mlflow.types import Schema # type: ignore

MLFLOW_SIGNATURE_TO_PYTHON_TYPE_MAP = {
"boolean": bool,
"integer": int,
Expand Down
18 changes: 15 additions & 3 deletions fastapi_mlflow/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
"""
from inspect import signature

from fastapi import FastAPI
from mlflow.pyfunc import PyFuncModel # type: ignore
from fastapi import (
FastAPI,
Request,
)
from fastapi.responses import JSONResponse

from fastapi_mlflow.predictors import build_predictor
from mlflow.pyfunc import PyFuncModel # type: ignore
from fastapi_mlflow.predictors import build_predictor, PyFuncModelPredictError


def build_app(pyfunc_model: PyFuncModel) -> FastAPI:
Expand All @@ -23,4 +27,12 @@ def build_app(pyfunc_model: PyFuncModel) -> FastAPI:
response_model=response_model,
methods=["POST"],
)

@app.exception_handler(Exception)
def handle_exception(_: Request, exc: PyFuncModelPredictError):
return JSONResponse(
status_code=500,
content=exc.to_dict(),
)

return app
19 changes: 16 additions & 3 deletions fastapi_mlflow/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
Copyright (C) 2022, Auto Trader UK
"""
from typing import Any, Callable, List, no_type_check, Union, Dict
from typing import Any, Callable, List, no_type_check, Dict

import numpy as np
import numpy.typing as npt
import pandas as pd
from mlflow.pyfunc import PyFuncModel # type: ignore
from pydantic import BaseModel, create_model
Expand All @@ -24,6 +23,16 @@
)


class PyFuncModelPredictError(Exception):
def __init__(self, exc: Exception):
super().__init__()
self.error_type_name = exc.__class__.__name__
self.message = str(exc)

def to_dict(self):
return {"name": self.error_type_name, "message": self.message}


@no_type_check # Some types here are too dynamic for type checking
def build_predictor(model: PyFuncModel) -> Callable[[BaseModel], Any]:
"""Build and return a function that wraps the mlflow model.
Expand Down Expand Up @@ -69,7 +78,11 @@ def request_to_dataframe(request: Request) -> pd.DataFrame:
return df

def predictor(request: Request) -> Response:
predictions = model.predict(request_to_dataframe(request))
try:
predictions = model.predict(request_to_dataframe(request))
except Exception as exc:
raise PyFuncModelPredictError(exc) from exc

response_data = convert_predictions_to_python(predictions)
return Response(data=response_data)

Expand Down
35 changes: 34 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pandas as pd
import pytest
from mlflow.models import infer_signature # type: ignore
from mlflow.pyfunc import PyFuncModel, PythonModel, PythonModelContext
from mlflow.pyfunc import PyFuncModel, PythonModel, PythonModelContext # type: ignore
from mlflow.pyfunc import load_model as pyfunc_load_model # type: ignore
from mlflow.pyfunc import save_model as pyfunc_save_model

Expand Down Expand Up @@ -93,6 +93,13 @@ def predict(
)


class ExceptionRaiser(PythonModel):
def predict(
self, context: PythonModelContext, model_input: pd.DataFrame
) -> pd.DataFrame:
raise ValueError("I always raise an error!")


@pytest.fixture(scope="session")
def model_input() -> pd.DataFrame:
int32_ = [np.int32(i) for i in range(5)]
Expand Down Expand Up @@ -319,6 +326,32 @@ def pyfunc_model_nan_dataframe(
yield pyfunc_load_model(model_path)


# Model that always raises exceptions
@pytest.fixture(scope="session")
def python_model_value_error() -> PythonModel:
return ExceptionRaiser()


@pytest.fixture(scope="session")
def pyfunc_model_value_error(
python_model_value_error,
model_input: pd.DataFrame,
model_output_series: pd.Series, # Use to infer correct
) -> PyFuncModel:
signature = infer_signature(
model_input=model_input,
model_output=model_output_series,
)
with TemporaryDirectory() as temp_dir:
model_path = os.path.join(temp_dir, "model_exceptions")
pyfunc_save_model(
model_path,
python_model=python_model_value_error,
signature=signature,
)
yield pyfunc_load_model(model_path)


@pytest.fixture(
params=[
"ndarray",
Expand Down
22 changes: 20 additions & 2 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def test_build_app_returns_good_predictions(
)

app = build_app(pyfunc_model)

client = TestClient(app)
df_str = model_input.to_json(orient="records")
request_data = f'{{"data": {df_str}}}'
response = client.post("/predictions", data=request_data)

response = client.post("/predictions", content=request_data)

assert response.status_code == 200
results = response.json()["data"]
try:
Expand All @@ -69,3 +70,20 @@ def test_build_app_returns_good_predictions(
assert [
{"prediction": v} for v in np.nditer(model_output)
] == results # type: ignore


def test_built_application_handles_model_exceptions(
model_input: pd.DataFrame, pyfunc_model_value_error: PyFuncModel
):
app = build_app(pyfunc_model_value_error)
client = TestClient(app, raise_server_exceptions=False)
df_str = model_input.to_json(orient="records")
request_data = f'{{"data": {df_str}}}'

response = client.post("/predictions", content=request_data)

assert response.status_code == 500
assert {
"name": "ValueError",
"message": "I always raise an error!",
} == response.json()
2 changes: 1 addition & 1 deletion tests/test_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_pyfunc_model_nan_ndarray_instance(self, pyfunc_model_nan_ndarray):

def test_pyfunc_model_nan_ndarray_predict(
self,
pyfunc_model_nan_ndarray,
pyfunc_model_nan_ndarray,
model_input: pd.DataFrame,
):
"""PyFunc model with ndarray return type should predict correct values."""
Expand Down
3 changes: 1 addition & 2 deletions tests/test_mlflow_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from typing import Optional

import pytest
from mlflow.pyfunc import PyFuncModel
from mlflow.types.schema import Schema, ColSpec, DataType
from mlflow.types.schema import Schema, ColSpec, DataType # type: ignore
from fastapi_mlflow._mlflow_types import build_model_fields


Expand Down
18 changes: 17 additions & 1 deletion tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import pytest
from mlflow.pyfunc import PyFuncModel # type: ignore

from fastapi_mlflow.predictors import build_predictor, convert_predictions_to_python
from fastapi_mlflow.predictors import (
build_predictor,
convert_predictions_to_python,
PyFuncModelPredictError,
)


def test_build_predictor_returns_callable(
Expand Down Expand Up @@ -148,6 +152,18 @@ def test_predictor_handles_model_returning_nan(
assert item.b is None


def test_predictor_exception_from_model_raised_from(
model_input: pd.DataFrame, pyfunc_model_value_error: PyFuncModel
):
predictor = build_predictor(pyfunc_model_value_error)

request_type = signature(predictor).parameters["request"].annotation
request_obj = request_type(data=model_input.to_dict(orient="records"))

with pytest.raises(PyFuncModelPredictError):
predictor(request_obj)


def test_convert_predictions_to_python_ndarray():
predictions = np.array([1, 2, 3, np.nan])
response_data = convert_predictions_to_python(predictions)
Expand Down

0 comments on commit da036be

Please sign in to comment.