-
Notifications
You must be signed in to change notification settings - Fork 1
/
app.py
33 lines (22 loc) · 845 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import mlflow
app = FastAPI()
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("RAG_MLFlow")
mlflow.langchain.autolog()
model_uri='runs:/9976886f819f4400816192bbf73fcb25/rag_chain_with_source' ##your mlflow model_uri
loaded_model = mlflow.langchain.load_model(model_uri)
class Query(BaseModel):
question: str
@app.post("/predict")
async def predict(query: Query):
try:
result = loaded_model.invoke(query.question)
answer = result.get('answer', 'No answer found')
return {"question": query.question, "answer": answer}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__== "__main__":
uvicorn.run(app, host="0.0.0.0",port=8005)