Skip to content

Commit

Permalink
Add openai compatible API (#49)
Browse files Browse the repository at this point in the history
* Add openai compatible API

* Revert remove ygpt

* Fix input and output

* Minor fix for formatting
  • Loading branch information
jamakase authored Sep 12, 2024
1 parent 182026d commit adda5b0
Show file tree
Hide file tree
Showing 8 changed files with 940 additions and 518 deletions.
4 changes: 2 additions & 2 deletions docker-compose.ollama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ services:
depends_on:
- db
- redis
- backend_app
- ml

ml:
restart: unless-stopped
Expand All @@ -109,8 +109,8 @@ services:
context: ./services/ml
dockerfile: Dockerfile
environment:
FAISS_INDEX_PATH: /data/faiss_index
OLLAMA_HOST: http://ollama:11434
QDRANT_HOST: http://qdrant:6333
LLM_SOURCE: ollama
MODEL: llama3.1
volumes:
Expand Down
6 changes: 2 additions & 4 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,9 @@ services:
context: ./services/ml
dockerfile: Dockerfile
environment:
LLM_SOURCE: ollama
OLLAMA_HOST: http://localhost:11434
MODEL_NAME: llama3.1
LLM_SOURCE: openai
QDRANT_HOST: http://qdrant:6333
# EMBEDDING_MODEL: sentence-transformers/all-mpnet-base-v2
OPENAI_API_KEY: ${OPENAI_API_KEY}
volumes:
- ./services/ml:/app
- ./services/ml/data:/data
Expand Down
25 changes: 18 additions & 7 deletions services/ml/app/config/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import os

class Config:
FAISS_INDEX_PATH: str = os.environ.get("FAISS_INDEX_PATH", os.path.join(os.path.dirname(__file__),"../../" "data", "faiss_index"))
PORT: int = int(os.environ.get("PORT", 8000))
QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://localhost:6333")
QDRANT_COLLECTION_NAME: str = os.environ.get("QDRANT_COLLECTION_NAME", "test")
MODEL: str = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct:free")
PORT: int = int(os.environ.get("PORT", 8000))
LLM_SOURCE: str = os.environ.get("LLM_SOURCE", "openrouter")
OPENROUTER_API_KEY: str = os.environ.get("OPENROUTER_API_KEY", "placeholder")
EMBEDDING_MODEL: str = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")
OLLAMA_HOST: str = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
LLM_SOURCE: str = os.environ.get("LLM_SOURCE", "openai")

def __init__(self):
if self.LLM_SOURCE == "openai":
self.MODEL = os.environ.get("MODEL", "meta-llama/llama-3.1-8b-instruct:free")
self.OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://openrouter.ai/api/v1")
self.OPENAI_API_KEY = os.environ.get("")
self.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-mpnet-base-v2")

elif self.LLM_SOURCE == "ollama":
self.OLLAMA_HOST: str = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
self.MODEL = os.environ.get("MODEL", "llama3.1")
self.EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "llama3.1")
elif self.LLM_SOURCE == "ygpt":
pass
else:
raise ValueError(f"Unsupported LLM_SOURCE: {self.LLM_SOURCE}")
39 changes: 27 additions & 12 deletions services/ml/app/server.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
from fastapi import FastAPI
from langchain.chains import RetrievalQA
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langserve import add_routes

from .config import config

from packages.llm import LLMInstance
from packages.retriever import Retriever
from packages.prompts import prompt
from packages.retriever import Retriever
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.output_parsers import StrOutputParser

from .config import config

llm_instance = LLMInstance(config)
retriever_instance = Retriever(llm_instance.get_embeddings(), host=config.QDRANT_HOST, collection_name=config.QDRANT_COLLECTION_NAME)

# Create a RetrievalQA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm_instance.get_llm(),
chain_type="stuff",
retriever=retriever_instance.get_retriever(),
chain_type_kwargs={"prompt": prompt},
)

class InputChat(BaseModel):
"""Input for the chat endpoint."""
question: str = Field(..., description="The query to retrieve relevant documents.")

class OutputChat(BaseModel):
"""Output for the chat endpoint."""
result: str = Field(..., description="The output containing the result.")

# Modify the qa_chain definition
qa_chain = (
RunnableParallel(
{"context": retriever_instance.get_retriever(), "question": RunnablePassthrough()}
)
| prompt
| llm_instance.get_llm()
| StrOutputParser()
| (lambda x: {"result": x})
).with_types(input_type=InputChat, output_type=OutputChat)


app = FastAPI(
title="LangChain Server",
Expand All @@ -26,6 +40,7 @@
)

# Add routes for the QA chain instead of just the retriever
add_routes(app, path="/llm", runnable=llm_instance.get_llm())
add_routes(app, qa_chain)

if __name__ == "__main__":
Expand Down
17 changes: 6 additions & 11 deletions services/ml/packages/llm/llm_instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

from app.config import Config


Expand All @@ -9,20 +7,17 @@ def __init__(self, config: Config):
self.llm = None
self.embeddings = None

if config.LLM_SOURCE == "openrouter":
from langchain_community.chat_models import ChatOpenAI
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
if config.LLM_SOURCE == "openai":
from langchain_openai import ChatOpenAI
from langchain_community.embeddings import FastEmbedEmbeddings

self.llm = ChatOpenAI(
model=config.MODEL,
openai_api_key=config.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
openai_api_key=config.OPENAI_API_KEY,
openai_api_base=config.OPENAI_BASE_URL,
max_tokens=1000,
)
cache_folder = os.path.join(os.getcwd(), "model_cache")
self.embeddings = HuggingFaceEmbeddings(
model_name=config.EMBEDDING_MODEL, cache_folder=cache_folder
)
self.embeddings = FastEmbedEmbeddings()
elif config.LLM_SOURCE == "ygpt":
from langchain_community.chat_models import ChatYandexGPT
from langchain_community.embeddings import YandexGPTEmbeddings
Expand Down
2 changes: 1 addition & 1 deletion services/ml/packages/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, embeddings: Embeddings, host: str, collection_name: str):
client.create_collection(
collection_name=collection_name,
vectors_config={
"size": 768,
"size": 384,
"distance": "Cosine",
}
)
Expand Down
Loading

0 comments on commit adda5b0

Please sign in to comment.