Skip to content

Commit

Permalink
Merge pull request #24 from yuriihavrylko/feature/l9t1-streamlit
Browse files Browse the repository at this point in the history
L9T1 Streamlit
  • Loading branch information
yuriihavrylko authored Feb 11, 2024
2 parents f0cada9 + 7b5ec58 commit 3f435b9
Show file tree
Hide file tree
Showing 15 changed files with 168 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
run: |
cd app/
pytest tests/
env:
PYTHONPATH: '.'

push-image:
needs: tests
Expand Down
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ DH Images:
Works on push to master/feature*
![Alt text](assets/actions.png)


### Streamlit

Run:
```
streamlit run src/serving/streamlit.py
```

![Alt text](assets/streamlit.png)


### Fast API

Postman

![Alt text](assets/fastapi.png)

### DVC

Install DVC
Expand Down Expand Up @@ -100,3 +117,4 @@ Kubernetes
```
kubectl create -f deployment/minio.yml
```
Empty file added app/__init__.py
Empty file.
1 change: 1 addition & 0 deletions app/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ scikit-learn==1.3.2
accelerate==0.25.0
datasets==2.16.1
wandb==0.16.1
httpx==0.26.0
ipykernel==6.28.0
3 changes: 3 additions & 0 deletions app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
flask==3.0.0
gunicorn==21.2.0
transformers==4.36.2
streamlit==1.29.0
fastapi==0.108.0
uvicorn==0.25.0
Empty file added app/src/__init__.py
Empty file.
Empty file added app/src/helpers/__init__.py
Empty file.
18 changes: 12 additions & 6 deletions app/src/helpers/wandb_registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from pathlib import Path
import wandb

def publish_model(model_path, project, name, model_type="model"):
run = wandb.init(project=project, job_type="model-publishing")
artifact = wandb.Artifact(name, type=model_type)
artifact.add_dir(model_path)
run.log_artifact(artifact)
run.finish()
def publish_model(model_path: str, project: str, name: str, model_type: str = "model"):
with wandb.init(project=project, job_type="model-publishing") as run:
artifact = wandb.Artifact(name, type=model_type)
artifact.add_dir(model_path)
run.log_artifact(artifact)
print(f"Published {name} to W&B")

def download_model(model_name: str, project: str, download_path: Path, model_type: str = "model"):
with wandb.init(project=project) as run:
artifact = run.use_artifact(model_name, type=model_type)
artifact_dir = artifact.download(root=download_path)
print(f"Downloaded {model_name} to {artifact_dir}")
Empty file added app/src/serving/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions app/src/serving/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List

from fastapi import FastAPI
from pydantic import BaseModel

from src.serving.model import BertPredictor


class Payload(BaseModel):
text: List[str]


app = FastAPI()
predictor = BertPredictor.from_model_registry()


@app.post("/predict")
def predict(payload: Payload):
prediction = predictor.predict(text=payload.text)
return prediction
40 changes: 40 additions & 0 deletions app/src/serving/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List
import torch
from pathlib import Path
from filelock import FileLock
from transformers import (BertForSequenceClassification, BertTokenizer)
from torch.nn.functional import softmax
from src.helpers.wandb_registry import download_model


MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0"
MODEL_PATH = "/tmp/model"
MODEL_LOCK = ".lock-file"
PROJECT = "huggingface"

class BertPredictor:
def __init__(self, model_load_path: str):
self.tokenizer = BertTokenizer.from_pretrained(model_load_path)
self.model = BertForSequenceClassification.from_pretrained(model_load_path)
self.model.eval()
self.labels = ['LABEL_0', 'LABEL_1']

@torch.no_grad()
def predict(self, text: List[str]):
text_encoded = self.tokenizer.batch_encode_plus(list(text), return_tensors="pt", padding=True)
bert_outputs = self.model(**text_encoded).logits
probabilities = softmax(bert_outputs, dim=1)
results = []
for prob in probabilities:
result = [{"label": self.labels[i], "score": float(score)} for i, score in enumerate(prob)]
results.append(result)
return results

@classmethod
def from_model_registry(cls) -> "BertPredictor":
with FileLock(MODEL_LOCK):
if not (Path(MODEL_PATH) / "model.safetensors").exists():
download_model(model_name=MODEL_ID, download_path=MODEL_PATH, project=PROJECT)

return cls(model_load_path=MODEL_PATH)

45 changes: 45 additions & 0 deletions app/src/serving/streamlit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pandas as pd
import streamlit as st

from src.serving.model import BertPredictor

def init_state():
if 'history' not in st.session_state:
st.session_state['history'] = []

@st.cache_resource()
def get_model() -> BertPredictor:
return BertPredictor.from_model_registry()

predictor = get_model()

def prediction():
st.subheader("Fake news prediction")
input_sent = st.text_area("Type an English news here", value="This is example input", height=150)
if st.button("Run Inference"):
pred = predictor.predict([input_sent])
st.session_state['history'].append({'Input': input_sent, 'Prediction': pred})
st.write("**Prediction:**", pred[0])
st.write("**Prediction History**")
for item in reversed(st.session_state['history']):
st.text(f"Input: {item['Input']}\nPrediction: {item['Prediction'][0]}\n")

def batch_pred():
st.subheader("Batch Prediction from CSV")
uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
if uploaded_file:
dataframe = pd.read_csv(uploaded_file)
st.markdown("#### Input dataframe")
st.dataframe(dataframe)

dataframe_with_pred = predictor.run_inference_on_dataframe(dataframe)
st.markdown("#### Result dataframe")
st.dataframe(dataframe_with_pred)

def main():
st.title("BERT Model Prediction Service")
init_state()
prediction()

if __name__ == "__main__":
main()
27 changes: 27 additions & 0 deletions app/tests/serving/test_fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastapi.testclient import TestClient

from unittest.mock import Mock
from src.serving.model import BertPredictor

def mock_predict(text):
return ['positive' for _ in text]

BertPredictor.from_model_registry = Mock(return_value=Mock(predict=mock_predict))



def test_predict():
from src.serving.fastapi import app
client = TestClient(app)

payload = {
"text": ["This is a test sentence.", "Here's another one!"]
}

response = client.post("/predict", json=payload)

assert response.status_code == 200

data = response.json()
assert isinstance(data, list)
assert len(data) == len(payload['text'])
Binary file added assets/fastapi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/streamlit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 3f435b9

Please sign in to comment.