From ad697d9457c6044883c7f6f734f2e3e273e55d58 Mon Sep 17 00:00:00 2001 From: Yurii Havrylko Date: Tue, 9 Jan 2024 16:39:28 +0100 Subject: [PATCH] tests for fast api --- app/requirements-dev.txt | 1 + app/tests/serving/test_fastapi.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 app/tests/serving/test_fastapi.py diff --git a/app/requirements-dev.txt b/app/requirements-dev.txt index 5ccc0aa..bb0edeb 100644 --- a/app/requirements-dev.txt +++ b/app/requirements-dev.txt @@ -7,3 +7,4 @@ scikit-learn==1.3.2 accelerate==0.25.0 datasets==2.16.1 wandb==0.16.1 +httpx==0.26.0 diff --git a/app/tests/serving/test_fastapi.py b/app/tests/serving/test_fastapi.py new file mode 100644 index 0000000..cecd4d8 --- /dev/null +++ b/app/tests/serving/test_fastapi.py @@ -0,0 +1,27 @@ +from fastapi.testclient import TestClient + +from unittest.mock import Mock +from src.serving.model import BertPredictor + +def mock_predict(text): + return ['positive' for _ in text] + +BertPredictor.from_model_registry = Mock(return_value=Mock(predict=mock_predict)) + + + +def test_predict(): + from src.serving.fastapi import app + client = TestClient(app) + + payload = { + "text": ["This is a test sentence.", "Here's another one!"] + } + + response = client.post("/predict", json=payload) + + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, list) + assert len(data) == len(payload['text'])