-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #24 from yuriihavrylko/feature/l9t1-streamlit
L9T1 Streamlit
- Loading branch information
Showing
15 changed files
with
168 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,8 @@ jobs: | |
run: | | ||
cd app/ | ||
pytest tests/ | ||
env: | ||
PYTHONPATH: '.' | ||
|
||
push-image: | ||
needs: tests | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,4 +7,5 @@ scikit-learn==1.3.2 | |
accelerate==0.25.0 | ||
datasets==2.16.1 | ||
wandb==0.16.1 | ||
httpx==0.26.0 | ||
ipykernel==6.28.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
flask==3.0.0 | ||
gunicorn==21.2.0 | ||
transformers==4.36.2 | ||
streamlit==1.29.0 | ||
fastapi==0.108.0 | ||
uvicorn==0.25.0 |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,15 @@ | ||
from pathlib import Path | ||
import wandb | ||
|
||
def publish_model(model_path, project, name, model_type="model"): | ||
run = wandb.init(project=project, job_type="model-publishing") | ||
artifact = wandb.Artifact(name, type=model_type) | ||
artifact.add_dir(model_path) | ||
run.log_artifact(artifact) | ||
run.finish() | ||
def publish_model(model_path: str, project: str, name: str, model_type: str = "model"): | ||
with wandb.init(project=project, job_type="model-publishing") as run: | ||
artifact = wandb.Artifact(name, type=model_type) | ||
artifact.add_dir(model_path) | ||
run.log_artifact(artifact) | ||
print(f"Published {name} to W&B") | ||
|
||
def download_model(model_name: str, project: str, download_path: Path, model_type: str = "model"): | ||
with wandb.init(project=project) as run: | ||
artifact = run.use_artifact(model_name, type=model_type) | ||
artifact_dir = artifact.download(root=download_path) | ||
print(f"Downloaded {model_name} to {artifact_dir}") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from typing import List | ||
|
||
from fastapi import FastAPI | ||
from pydantic import BaseModel | ||
|
||
from src.serving.model import BertPredictor | ||
|
||
|
||
class Payload(BaseModel): | ||
text: List[str] | ||
|
||
|
||
app = FastAPI() | ||
predictor = BertPredictor.from_model_registry() | ||
|
||
|
||
@app.post("/predict") | ||
def predict(payload: Payload): | ||
prediction = predictor.predict(text=payload.text) | ||
return prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import List | ||
import torch | ||
from pathlib import Path | ||
from filelock import FileLock | ||
from transformers import (BertForSequenceClassification, BertTokenizer) | ||
from torch.nn.functional import softmax | ||
from src.helpers.wandb_registry import download_model | ||
|
||
|
||
MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0" | ||
MODEL_PATH = "/tmp/model" | ||
MODEL_LOCK = ".lock-file" | ||
PROJECT = "huggingface" | ||
|
||
class BertPredictor: | ||
def __init__(self, model_load_path: str): | ||
self.tokenizer = BertTokenizer.from_pretrained(model_load_path) | ||
self.model = BertForSequenceClassification.from_pretrained(model_load_path) | ||
self.model.eval() | ||
self.labels = ['LABEL_0', 'LABEL_1'] | ||
|
||
@torch.no_grad() | ||
def predict(self, text: List[str]): | ||
text_encoded = self.tokenizer.batch_encode_plus(list(text), return_tensors="pt", padding=True) | ||
bert_outputs = self.model(**text_encoded).logits | ||
probabilities = softmax(bert_outputs, dim=1) | ||
results = [] | ||
for prob in probabilities: | ||
result = [{"label": self.labels[i], "score": float(score)} for i, score in enumerate(prob)] | ||
results.append(result) | ||
return results | ||
|
||
@classmethod | ||
def from_model_registry(cls) -> "BertPredictor": | ||
with FileLock(MODEL_LOCK): | ||
if not (Path(MODEL_PATH) / "model.safetensors").exists(): | ||
download_model(model_name=MODEL_ID, download_path=MODEL_PATH, project=PROJECT) | ||
|
||
return cls(model_load_path=MODEL_PATH) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import pandas as pd | ||
import streamlit as st | ||
|
||
from src.serving.model import BertPredictor | ||
|
||
def init_state(): | ||
if 'history' not in st.session_state: | ||
st.session_state['history'] = [] | ||
|
||
@st.cache_resource() | ||
def get_model() -> BertPredictor: | ||
return BertPredictor.from_model_registry() | ||
|
||
predictor = get_model() | ||
|
||
def prediction(): | ||
st.subheader("Fake news prediction") | ||
input_sent = st.text_area("Type an English news here", value="This is example input", height=150) | ||
if st.button("Run Inference"): | ||
pred = predictor.predict([input_sent]) | ||
st.session_state['history'].append({'Input': input_sent, 'Prediction': pred}) | ||
st.write("**Prediction:**", pred[0]) | ||
st.write("**Prediction History**") | ||
for item in reversed(st.session_state['history']): | ||
st.text(f"Input: {item['Input']}\nPrediction: {item['Prediction'][0]}\n") | ||
|
||
def batch_pred(): | ||
st.subheader("Batch Prediction from CSV") | ||
uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"]) | ||
if uploaded_file: | ||
dataframe = pd.read_csv(uploaded_file) | ||
st.markdown("#### Input dataframe") | ||
st.dataframe(dataframe) | ||
|
||
dataframe_with_pred = predictor.run_inference_on_dataframe(dataframe) | ||
st.markdown("#### Result dataframe") | ||
st.dataframe(dataframe_with_pred) | ||
|
||
def main(): | ||
st.title("BERT Model Prediction Service") | ||
init_state() | ||
prediction() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from fastapi.testclient import TestClient | ||
|
||
from unittest.mock import Mock | ||
from src.serving.model import BertPredictor | ||
|
||
def mock_predict(text): | ||
return ['positive' for _ in text] | ||
|
||
BertPredictor.from_model_registry = Mock(return_value=Mock(predict=mock_predict)) | ||
|
||
|
||
|
||
def test_predict(): | ||
from src.serving.fastapi import app | ||
client = TestClient(app) | ||
|
||
payload = { | ||
"text": ["This is a test sentence.", "Here's another one!"] | ||
} | ||
|
||
response = client.post("/predict", json=payload) | ||
|
||
assert response.status_code == 200 | ||
|
||
data = response.json() | ||
assert isinstance(data, list) | ||
assert len(data) == len(payload['text']) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.