Skip to content

Commit

Permalink
Update instances
Browse files Browse the repository at this point in the history
  • Loading branch information
jamakase committed Sep 13, 2024
1 parent 017d5a0 commit 9ec8ac6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion services/ml/app/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
class Config:
PORT: int = int(os.environ.get("PORT", 8000))
QDRANT_HOST: str = os.environ.get("QDRANT_HOST", "http://localhost:6333")
QDRANT_COLLECTION_NAME: str = os.environ.get("QDRANT_COLLECTION_NAME", "test")
QDRANT_COLLECTION_NAME: str = os.environ.get("QDRANT_COLLECTION_NAME", "docs_768")
QDRANT_API_KEY: str = os.environ.get("QDRANT_API_KEY")
LLM_SOURCE: str = os.environ.get("LLM_SOURCE", "openai")

Expand Down
17 changes: 15 additions & 2 deletions services/ml/packages/llm/llm_instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
from app.config import Config

from langchain.embeddings.base import Embeddings
from sentence_transformers import SentenceTransformer

# 1. Обёртка для модели SentenceTransformer
class SentenceTransformerEmbeddings(Embeddings):
def __init__(self, model_name: str = 'all-mpnet-base-v2'):
self.model = SentenceTransformer(model_name)

def embed_documents(self, texts):
return self.model.encode(texts)

def embed_query(self, text):
return self.model.encode([text])[0]

class LLMInstance:
def __init__(self, config: Config):
Expand All @@ -9,15 +22,15 @@ def __init__(self, config: Config):

if config.LLM_SOURCE == "openai":
from langchain_openai import ChatOpenAI
from langchain_community.embeddings import FastEmbedEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings

self.llm = ChatOpenAI(
model=config.MODEL,
openai_api_key=config.OPENAI_API_KEY,
openai_api_base=config.OPENAI_BASE_URL,
max_tokens=1000,
)
self.embeddings = FastEmbedEmbeddings()
self.embeddings = SentenceTransformerEmbeddings()
elif config.LLM_SOURCE == "ygpt":
from langchain_community.chat_models import ChatYandexGPT
from langchain_community.embeddings import YandexGPTEmbeddings
Expand Down
2 changes: 1 addition & 1 deletion services/ml/packages/retriever/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(
client.create_collection(
collection_name=collection_name,
vectors_config={
"size": 384,
"size": 768,
"distance": "Cosine",
},
)
Expand Down

0 comments on commit 9ec8ac6

Please sign in to comment.