diff --git a/.github/workflows/rag.yaml b/.github/workflows/rag.yaml index 6af3c500..531101e1 100644 --- a/.github/workflows/rag.yaml +++ b/.github/workflows/rag.yaml @@ -5,16 +5,16 @@ on: branches: - main paths: - - ./recipes/common/Makefile.common - - ./recipes/natural_language_processing/rag/** - - .github/workflows/rag.yaml + - 'recipes/common/Makefile.common' + - 'recipes/natural_language_processing/rag/**' + - '.github/workflows/rag.yaml' push: branches: - main paths: - - ./recipes/common/Makefile.common - - ./recipes/natural_language_processing/rag/** - - .github/workflows/rag.yaml + - 'recipes/common/Makefile.common' + - '/recipes/natural_language_processing/rag/**' + - '.github/workflows/rag.yaml' workflow_dispatch: diff --git a/.gitignore b/.gitignore index 347e2d36..ca818ba0 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ recipes/common/bin/* */.venv/ training/cloud/examples training/instructlab/instructlab +vector_dbs/milvus/volumes/milvus/* diff --git a/recipes/natural_language_processing/rag/app/Containerfile b/recipes/natural_language_processing/rag/app/Containerfile index cd51ff3b..faa8ab00 100644 --- a/recipes/natural_language_processing/rag/app/Containerfile +++ b/recipes/natural_language_processing/rag/app/Containerfile @@ -16,6 +16,7 @@ COPY requirements.txt . RUN pip install --upgrade pip RUN pip install --no-cache-dir --upgrade -r /rag/requirements.txt COPY rag_app.py . +COPY manage_vectordb.py . EXPOSE 8501 ENV HF_HUB_CACHE=/rag/models/ ENTRYPOINT [ "streamlit", "run" ,"rag_app.py" ] diff --git a/recipes/natural_language_processing/rag/app/manage_vectordb.py b/recipes/natural_language_processing/rag/app/manage_vectordb.py new file mode 100644 index 00000000..82566abd --- /dev/null +++ b/recipes/natural_language_processing/rag/app/manage_vectordb.py @@ -0,0 +1,81 @@ +from langchain_community.vectorstores import Chroma +from chromadb import HttpClient +from chromadb.config import Settings +import chromadb.utils.embedding_functions as embedding_functions +from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings +from langchain_community.vectorstores import Milvus +from pymilvus import MilvusClient +from pymilvus import connections, utility + +class VectorDB: + def __init__(self, vector_vendor, host, port, collection_name, embedding_model): + self.vector_vendor = vector_vendor + self.host = host + self.port = port + self.collection_name = collection_name + self.embedding_model = embedding_model + + def connect(self): + # Connection logic + print(f"Connecting to {self.host}:{self.port}...") + if self.vector_vendor == "chromadb": + self.client = HttpClient(host=self.host, + port=self.port, + settings=Settings(allow_reset=True,)) + elif self.vector_vendor == "milvus": + self.client = MilvusClient(uri=f"http://{self.host}:{self.port}") + return self.client + + def populate_db(self, documents): + # Logic to populate the VectorDB with vectors + e = SentenceTransformerEmbeddings(model_name=self.embedding_model) + print(f"Populating VectorDB with vectors...") + if self.vector_vendor == "chromadb": + embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=self.embedding_model) + collection = self.client.get_or_create_collection(self.collection_name, + embedding_function=embedding_func) + if collection.count() < 1: + db = Chroma.from_documents( + documents=documents, + embedding=e, + collection_name=self.collection_name, + client=self.client + ) + print("DB populated") + else: + db = Chroma(client=self.client, + collection_name=self.collection_name, + embedding_function=e, + ) + print("DB already populated") + + elif self.vector_vendor == "milvus": + connections.connect(host=self.host, port=self.port) + if not utility.has_collection(self.collection_name): + print("Populating VectorDB with vectors...") + db = Milvus.from_documents( + documents, + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + print("DB populated") + else: + print("DB already populated") + db = Milvus( + e, + collection_name=self.collection_name, + connection_args={"host": self.host, "port": self.port}, + ) + return db + + def clear_db(self): + print(f"Clearing VectorDB...") + try: + if self.vector_vendor == "chromadb": + self.client.delete_collection(self.collection_name) + elif self.vector_vendor == "milvus": + self.client.drop_collection(self.collection_name) + print("Cleared DB") + except: + print("Couldn't clear the collection possibly because it doesn't exist") diff --git a/recipes/natural_language_processing/rag/app/populate_vectordb.py b/recipes/natural_language_processing/rag/app/populate_vectordb.py deleted file mode 100644 index 2bbb6efc..00000000 --- a/recipes/natural_language_processing/rag/app/populate_vectordb.py +++ /dev/null @@ -1,36 +0,0 @@ -from langchain_community.document_loaders import TextLoader -from langchain.text_splitter import CharacterTextSplitter -import chromadb.utils.embedding_functions as embedding_functions -import chromadb -from chromadb.config import Settings -import uuid -import os -import argparse -import time - -parser = argparse.ArgumentParser() -parser.add_argument("-d", "--docs", default="data/fake_meeting.txt") -parser.add_argument("-c", "--chunk_size", default=150) -parser.add_argument("-e", "--embedding_model", default="BAAI/bge-base-en-v1.5") -parser.add_argument("-H", "--vdb_host", default="0.0.0.0") -parser.add_argument("-p", "--vdb_port", default="8000") -parser.add_argument("-n", "--name", default="test_collection") -args = parser.parse_args() - -raw_documents = TextLoader(args.docs).load() -text_splitter = CharacterTextSplitter(separator = ".", chunk_size=int(args.chunk_size), chunk_overlap=0) -docs = text_splitter.split_documents(raw_documents) -os.environ["TORCH_HOME"] = "./models/" - -embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=args.embedding_model) -client = chromadb.HttpClient(host=args.vdb_host, - port=args.vdb_port, - settings=Settings(allow_reset=True,)) -collection = client.get_or_create_collection(args.name, - embedding_function=embedding_func) -for doc in docs: - collection.add( - ids=[str(uuid.uuid1())], - metadatas=doc.metadata, - documents=doc.page_content - ) \ No newline at end of file diff --git a/recipes/natural_language_processing/rag/app/rag_app.py b/recipes/natural_language_processing/rag/app/rag_app.py index 3fa664c2..bc87455a 100644 --- a/recipes/natural_language_processing/rag/app/rag_app.py +++ b/recipes/natural_language_processing/rag/app/rag_app.py @@ -1,91 +1,68 @@ from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnablePassthrough -from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain.text_splitter import CharacterTextSplitter from langchain_community.callbacks import StreamlitCallbackHandler -from langchain_community.vectorstores import Chroma +from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import PyPDFLoader -from langchain.schema.document import Document -from chromadb import HttpClient -from chromadb.config import Settings -import chromadb.utils.embedding_functions as embedding_functions -import streamlit as st +from manage_vectordb import VectorDB import tempfile -import uuid +import streamlit as st import os model_service = os.getenv("MODEL_ENDPOINT","http://0.0.0.0:8001") model_service = f"{model_service}/v1" chunk_size = os.getenv("CHUNK_SIZE", 150) embedding_model = os.getenv("EMBEDDING_MODEL","BAAI/bge-base-en-v1.5") +vdb_vendor = os.getenv("VECTORDB_VENDOR", "chromadb") vdb_host = os.getenv("VECTORDB_HOST", "0.0.0.0") vdb_port = os.getenv("VECTORDB_PORT", "8000") vdb_name = os.getenv("VECTORDB_NAME", "test_collection") +vdb = VectorDB(vdb_vendor, vdb_host, vdb_port, vdb_name, embedding_model) +vectorDB_client = vdb.connect() +def split_docs(raw_documents): + text_splitter = CharacterTextSplitter(separator = ".", + chunk_size=int(chunk_size), + chunk_overlap=0) + docs = text_splitter.split_documents(raw_documents) + return docs -vectorDB_client = HttpClient(host=vdb_host, - port=vdb_port, - settings=Settings(allow_reset=True,)) - -def clear_vdb(): - global vectorDB_client - try: - vectorDB_client.delete_collection(vdb_name) - print("Cleared DB") - except: - pass def read_file(file): file_type = file.type - if file_type == "application/pdf": temp = tempfile.NamedTemporaryFile() with open(temp.name, "wb") as f: f.write(file.getvalue()) loader = PyPDFLoader(temp.name) - pages = loader.load() - text = "".join([p.page_content for p in pages]) if file_type == "text/plain": - text = file.read().decode() - - return text + temp = tempfile.NamedTemporaryFile() + with open(temp.name, "wb") as f: + f.write(file.getvalue()) + loader = TextLoader(temp.name) + raw_documents = loader.load() + return raw_documents st.title("📚 RAG DEMO") with st.sidebar: file = st.file_uploader(label="📄 Upload Document", - type=[".txt",".pdf"], - on_change=clear_vdb - ) + type=[".txt",".pdf"], + on_change=vdb.clear_db + ) ### populate the DB #### -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=embedding_model) -e = SentenceTransformerEmbeddings(model_name=embedding_model) - -collection = vectorDB_client.get_or_create_collection(vdb_name, - embedding_function=embedding_func) -if collection.count() < 1 and file != None: - print("populating db") +if file != None: text = read_file(file) - raw_documents = [Document(page_content=text, - metadata={"":""})] - text_splitter = CharacterTextSplitter(separator = ".", - chunk_size=int(chunk_size), - chunk_overlap=0) - docs = text_splitter.split_documents(raw_documents) - for doc in docs: - collection.add( - ids=[str(uuid.uuid1())], - metadatas=doc.metadata, - documents=doc.page_content - ) -if file == None: - print("Empty VectorDB") + documents = split_docs(text) + db = vdb.populate_db(documents) + retriever = db.as_retriever(threshold=0.75) else: - print("DB already populated") + retriever = {} + print("Empty VectorDB") + + ######################## if "messages" not in st.session_state: @@ -95,11 +72,6 @@ def read_file(file): for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) -db = Chroma(client=vectorDB_client, - collection_name=vdb_name, - embedding_function=e - ) -retriever = db.as_retriever(threshold=0.75) llm = ChatOpenAI(base_url=model_service, api_key="EMPTY", @@ -126,4 +98,4 @@ def read_file(file): response = chain.invoke(prompt) st.chat_message("assistant").markdown(response.content) st.session_state.messages.append({"role": "assistant", "content": response.content}) - st.rerun() + st.rerun() \ No newline at end of file diff --git a/vector_dbs/Makefile b/vector_dbs/chromadb/Makefile similarity index 58% rename from vector_dbs/Makefile rename to vector_dbs/chromadb/Makefile index e3725928..7f7e125b 100644 --- a/vector_dbs/Makefile +++ b/vector_dbs/chromadb/Makefile @@ -3,4 +3,4 @@ APPIMAGE ?= quay.io/ai-lab/${APP}:latest .PHONY: build build: - podman build -f chromadb/Containerfile -t ${APPIMAGE} . + podman build -f Containerfile -t ${APPIMAGE} . diff --git a/vector_dbs/milvus/Containerfile b/vector_dbs/milvus/Containerfile new file mode 100644 index 00000000..779a32bc --- /dev/null +++ b/vector_dbs/milvus/Containerfile @@ -0,0 +1,2 @@ +FROM docker.io/milvusdb/milvus:master-20240426-bed6363f +ADD embedEtcd.yaml /milvus/configs/embedEtcd.yaml diff --git a/vector_dbs/milvus/Makefile b/vector_dbs/milvus/Makefile new file mode 100644 index 00000000..9c7a4c96 --- /dev/null +++ b/vector_dbs/milvus/Makefile @@ -0,0 +1,55 @@ +REGISTRY ?= quay.io +REGISTRY_ORG ?= ai-lab +COMPONENT = vector_dbs + +IMAGE ?= $(REGISTRY)/$(REGISTRY_ORG)/$(COMPONENT)/milvus:latest + +ARCH ?= $(shell uname -m) +PLATFORM ?= linux/$(ARCH) + +gRCP_PORT := 19530 +REST_PORT := 9091 +CLIENT_PORT := 2379 + +LIB_MILVUS_DIR_MOUNTPATH := $(shell pwd)/volumes/milvus + +.PHONY: build +build: + podman build --platform $(PLATFORM) -f Containerfile -t ${IMAGE} . + +.PHONY: run +run: + podman run -d \ + --name milvus-standalone \ + --security-opt seccomp:unconfined \ + -e ETCD_USE_EMBED=true \ + -e ETCD_CONFIG_PATH=/milvus/configs/embedEtcd.yaml \ + -e COMMON_STORAGETYPE=local \ + -v $(LIB_MILVUS_DIR_MOUNTPATH):/var/lib/milvus \ + -p $(gRCP_PORT):$(gRCP_PORT) \ + -p $(REST_PORT):$(REST_PORT) \ + -p $(CLIENT_PORT):$(CLIENT_PORT) \ + --health-cmd="curl -f http://localhost:$(REST_PORT)/healthz" \ + --health-interval=30s \ + --health-start-period=90s \ + --health-timeout=20s \ + --health-retries=3 \ + $(IMAGE) \ + milvus run standalone 1> /dev/null + +.PHONY: stop +stop: + -podman stop milvus-standalone + +.PHONY: delete +delete: + -podman rm milvus-standalone -f + +.PHONY: podman-clean +podman-clean: + @container_ids=$$(podman ps --format "{{.ID}} {{.Image}}" | awk '$$2 == "$(IMAGE)" {print $$1}'); \ + echo "removing all containers with IMAGE=$(IMAGE)"; \ + for id in $$container_ids; do \ + echo "Removing container: $$id,"; \ + podman rm -f $$id; \ + done diff --git a/vector_dbs/milvus/embedEtcd.yaml b/vector_dbs/milvus/embedEtcd.yaml new file mode 100644 index 00000000..32954faa --- /dev/null +++ b/vector_dbs/milvus/embedEtcd.yaml @@ -0,0 +1,5 @@ +listen-client-urls: http://0.0.0.0:2379 +advertise-client-urls: http://0.0.0.0:2379 +quota-backend-bytes: 4294967296 +auto-compaction-mode: revision +auto-compaction-retention: '1000'