Skip to content

Commit

Permalink
L9T2 Fastapi (#15)
Browse files Browse the repository at this point in the history
* add requirements

* implement serving

* tests for fast api

* update readme
  • Loading branch information
yuriihavrylko authored Feb 11, 2024
1 parent 2ffe427 commit 9e483c9
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 1 deletion.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,10 @@ streamlit run src/serving/streamlit.py
```

![Alt text](assets/streamlit.png)


### Fast API

Postman

![Alt text](assets/fastapi.png)
1 change: 1 addition & 0 deletions app/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ scikit-learn==1.3.2
accelerate==0.25.0
datasets==2.16.1
wandb==0.16.1
httpx==0.26.0
3 changes: 2 additions & 1 deletion app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ flask==3.0.0
gunicorn==21.2.0
transformers==4.36.2
streamlit==1.29.0

fastapi==0.108.0
uvicorn==0.25.0
20 changes: 20 additions & 0 deletions app/src/serving/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List

from fastapi import FastAPI
from pydantic import BaseModel

from src.serving.model import BertPredictor


class Payload(BaseModel):
text: List[str]


app = FastAPI()
predictor = BertPredictor.from_model_registry()


@app.post("/predict")
def predict(payload: Payload):
prediction = predictor.predict(text=payload.text)
return prediction
27 changes: 27 additions & 0 deletions app/tests/serving/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi.testclient import TestClient

from unittest.mock import Mock
from src.serving.model import BertPredictor

def mock_predict(text):
return ['positive' for _ in text]

BertPredictor.from_model_registry = Mock(return_value=Mock(predict=mock_predict))



def test_predict():
from src.serving.fastapi import app
client = TestClient(app)

payload = {
"text": ["This is a test sentence.", "Here's another one!"]
}

response = client.post("/predict", json=payload)

assert response.status_code == 200

data = response.json()
assert isinstance(data, list)
assert len(data) == len(payload['text'])
Binary file added assets/fastapi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 9e483c9

Please sign in to comment.