diff --git a/app/requirements-dev.txt b/app/requirements-dev.txt index 5ccc0aa..bb0edeb 100644 --- a/app/requirements-dev.txt +++ b/app/requirements-dev.txt @@ -7,3 +7,4 @@ scikit-learn==1.3.2 accelerate==0.25.0 datasets==2.16.1 wandb==0.16.1 +httpx==0.26.0 diff --git a/app/tests/serving/test_fastapi.py b/app/tests/serving/test_fastapi.py new file mode 100644 index 0000000..cecd4d8 --- /dev/null +++ b/app/tests/serving/test_fastapi.py @@ -0,0 +1,27 @@ +from fastapi.testclient import TestClient + +from unittest.mock import Mock +from src.serving.model import BertPredictor + +def mock_predict(text): + return ['positive' for _ in text] + +BertPredictor.from_model_registry = Mock(return_value=Mock(predict=mock_predict)) + + + +def test_predict(): + from src.serving.fastapi import app + client = TestClient(app) + + payload = { + "text": ["This is a test sentence.", "Here's another one!"] + } + + response = client.post("/predict", json=payload) + + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, list) + assert len(data) == len(payload['text'])