diff --git a/rag-langchain/README.md b/rag-langchain/README.md index a82b2c8f..ec3acecb 100644 --- a/rag-langchain/README.md +++ b/rag-langchain/README.md @@ -30,5 +30,5 @@ snapshot_download(repo_id="BAAI/bge-base-en-v1.5", Follow the instructions below to build you container image and run it locally. * `podman build -t ragapp rag-langchain -f rag-langchain/builds/Containerfile` -* `podman run --rm -it -p 8501:8501 -v Local/path/to/locallm/models/:/rag/models:Z -v Local/path/to/locallm/data:/rag/data:Z ragapp -- -H 10.88.0.1 -m http://10.88.0.1:8001/v1` +* `podman run --rm -it -p 8501:8501 -v Local/path/to/locallm/models/:/rag/models:Z -v Local/path/to/locallm/data:/rag/data:Z -e MODEL_SERVICE_ENDPOINT=http://10.88.0.1:8001/v1 ragapp -- -H 10.88.0.1 ` diff --git a/rag-langchain/ai-studio.yaml b/rag-langchain/ai-studio.yaml index e9e1db30..de273a5a 100644 --- a/rag-langchain/ai-studio.yaml +++ b/rag-langchain/ai-studio.yaml @@ -15,7 +15,7 @@ application: ports: - 8001 - name: chromadb-server - contextdir:: builds/chromadb + contextdir: builds/chromadb containerfile: Containerfile vectordb: true arch: diff --git a/rag-langchain/rag_app.py b/rag-langchain/rag_app.py index 71399391..128bbccc 100644 --- a/rag-langchain/rag_app.py +++ b/rag-langchain/rag_app.py @@ -18,15 +18,21 @@ import argparse import pathlib +model_service = os.getenv("MODEL_SERVICE_ENDPOINT", + "http://0.0.0.0:8001/v1") + parser = argparse.ArgumentParser() parser.add_argument("-c", "--chunk_size", default=150) parser.add_argument("-e", "--embedding_model", default="BAAI/bge-base-en-v1.5") parser.add_argument("-H", "--vdb_host", default="0.0.0.0") parser.add_argument("-p", "--vdb_port", default="8000") parser.add_argument("-n", "--name", default="test_collection") -parser.add_argument("-m", "--model_url", default="http://0.0.0.0:8001/v1") args = parser.parse_args() +vectorDB_client = HttpClient(host=args.vdb_host, + port=args.vdb_port, + settings=Settings(allow_reset=True,)) + def clear_vdb(): global client client.delete_collection(args.name) @@ -56,10 +62,7 @@ def get_files(): embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=args.embedding_model) e = SentenceTransformerEmbeddings(model_name=args.embedding_model) -client = HttpClient(host=args.vdb_host, - port=args.vdb_port, - settings=Settings(allow_reset=True,)) -collection = client.get_or_create_collection(args.name, +collection = vectorDB_client.get_or_create_collection(args.name, embedding_function=embedding_func) if collection.count() < 1 and data != None: print("populating db") @@ -87,13 +90,13 @@ def get_files(): for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) -db = Chroma(client=client, +db = Chroma(client=vectorDB_client, collection_name=args.name, embedding_function=e ) retriever = db.as_retriever(threshold=0.75) -llm = ChatOpenAI(base_url=args.model_url, +llm = ChatOpenAI(base_url=model_service, api_key="EMPTY", streaming=True, callbacks=[StreamlitCallbackHandler(st.container(),