-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2359 from Agenta-AI/docs/docs-rag-qa
Docs RAG QA
- Loading branch information
Showing
12 changed files
with
1,460 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
DOCS_PATH= | ||
DOCS_BASE_URL= | ||
OPENAI_API_KEY= | ||
COHERE_API_KEY= | ||
COLLECTION_NAME= | ||
QDRANT_URL= | ||
QDRANT_API_KEY= | ||
|
||
# optional | ||
MISTRAL_API_KEY= | ||
ANTHROPIC_API_KEY= | ||
GEMINI_API_KEY= | ||
GROQ_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# RAG Q&A Documentation System | ||
|
||
This project implements a RAG system for documentation Q&A. The documentation is expected to be in mdx format (we use for our tutorial our documentation using Docusaurus). | ||
|
||
The stack used: | ||
|
||
- Qdrant for vector database | ||
- Cohere for embedding | ||
- OpenAI for LLM and embedding | ||
|
||
## Requirements | ||
|
||
- Qdrant database set up | ||
- Cohere API key | ||
- OpenAI API key | ||
|
||
## Setup | ||
|
||
1. Set up virtual environment and install dependencies: | ||
|
||
```bash | ||
uv venv | ||
source .venv/bin/activate # On Unix/macOS | ||
# or | ||
.venv\scripts\activate # On Windows | ||
|
||
uv pip compile requirements.in --output-file requirements.txt | ||
|
||
uv pip sync requirements.txt | ||
``` | ||
|
||
2. Copy `.env.example` to `.env` and fill in your configuration: | ||
|
||
```bash | ||
cp .env.example .env | ||
|
||
DOCS_PATH= The path to your documentation folder containing the mdx files | ||
DOCS_BASE_URL= This is the base url of your documentation site. This will be used to generate the links in the citations. | ||
OPENAI_API_KEY= Your OpenAI API key | ||
COHERE_API_KEY= Your Cohere API key | ||
COLLECTION_NAME= The name of the collection in Qdrant to store the embeddings | ||
QDRANT_URL= The url of your Qdrant server | ||
QDRANT_API_KEY= The API key of your Qdrant server | ||
AGENTA_API_KEY= Your Agenta API key | ||
``` | ||
|
||
3. Run the ingestion script: | ||
|
||
```bash | ||
python ingest.py | ||
``` | ||
|
||
4. Serve the application to Agenta: | ||
|
||
```bash | ||
agenta init | ||
agenta variant serve query.py | ||
``` | ||
|
||
## Notes: | ||
|
||
- `generate_test_set.py` is used to generate a test set of questions based on the documentation for evaluation. |
99 changes: 99 additions & 0 deletions
99
examples/custom_workflows/rag-docs-qa/generate_test_set.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os | ||
import glob | ||
from pathlib import Path | ||
import pandas as pd | ||
from dotenv import load_dotenv | ||
from litellm import completion | ||
import frontmatter | ||
import tqdm | ||
import json | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
|
||
def get_files(docs_path): | ||
"""Get all markdown files recursively.""" | ||
return | ||
|
||
|
||
def extract_content(file_path): | ||
"""Extract content from markdown file.""" | ||
with open(file_path, "r", encoding="utf-8") as f: | ||
post = frontmatter.load(f) | ||
# Get title from frontmatter or filename | ||
title = post.get("title", Path(file_path).stem) | ||
# Get content without frontmatter | ||
content = post.content | ||
return title, content | ||
|
||
|
||
def generate_questions(title, content): | ||
"""Generate questions using OpenAI.""" | ||
system_prompt = """You are a helpful assistant that generates questions based on documentation content. | ||
Generate 5 questions that could be answered using the provided documentation. | ||
Your response must be a JSON object with a single key "questions" containing an array of strings.""" | ||
|
||
user_prompt = f""" | ||
Title: {title} | ||
Content: {content} # Limit content length to avoid token limits | ||
Generate 5 questions about this documentation. Put yourself in the shoes of a user attempting to 1) figure how to use the product for a use case 2) troubleshoot an issue 3) learn about the features of the product. | ||
The user in this case is a technical user (AI engineer) who is trying to build an llm application. | ||
The user would write the questions they would ask in a chat with a human. Therefore, not all questions will be clear and well written. | ||
""" | ||
|
||
try: | ||
response = completion( | ||
model="gpt-3.5-turbo-0125", # Using the latest model that supports JSON mode | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": user_prompt}, | ||
], | ||
response_format={"type": "json_object"}, | ||
) | ||
|
||
# Check if the response was complete | ||
if response.choices[0].finish_reason == "length": | ||
print(f"Warning: Response was truncated for {title}") | ||
return [] | ||
|
||
# Parse JSON response - no need for eval() | ||
result = json.loads(response.choices[0].message.content) | ||
return result["questions"] | ||
|
||
except Exception as e: | ||
print(f"Error generating questions for {title}: {str(e)}") | ||
return [] | ||
|
||
|
||
def main(): | ||
docs_path = os.getenv("DOCS_PATH") | ||
if not docs_path: | ||
raise ValueError("DOCS_PATH environment variable not set") | ||
|
||
# Get all files | ||
files = glob.glob(os.path.join(docs_path, "**/*.mdx"), recursive=True) | ||
all_questions = [] | ||
# Process each file | ||
for file_path in tqdm.tqdm(files, desc="Processing documentation files"): | ||
if "/reference/api" in file_path: | ||
# skip api docs | ||
continue | ||
try: | ||
title, content = extract_content(file_path) | ||
questions = generate_questions(title, content) | ||
all_questions.extend(questions) | ||
except Exception as e: | ||
print(f"Error processing {file_path}: {str(e)}") | ||
continue | ||
|
||
# Save to CSV | ||
df = pd.DataFrame({"query": all_questions}) | ||
df.to_csv("test_set.csv", index=False, lineterminator="\n") | ||
print(f"Generated {len(all_questions)} questions and saved to test_set.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
import os | ||
import glob | ||
from typing import List, Dict | ||
import hashlib | ||
from datetime import datetime | ||
import frontmatter | ||
from dotenv import load_dotenv | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http import models | ||
from litellm import embedding | ||
import tqdm | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
# Constants | ||
OPENAI_EMBEDDING_DIM = 1536 # For text-embedding-ada-002 | ||
COHERE_EMBEDDING_DIM = 1024 # For embed-english-v3.0 | ||
COLLECTION_NAME = "docs_collection" | ||
|
||
# Initialize Qdrant client | ||
qdrant_client = QdrantClient( | ||
url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY") | ||
) | ||
|
||
|
||
def get_all_docs(docs_path: str) -> List[str]: | ||
"""Get all MDX files in the docs directory.""" | ||
return glob.glob(os.path.join(docs_path, "**/*.mdx"), recursive=True) | ||
|
||
|
||
def calculate_doc_hash(content: str) -> str: | ||
"""Calculate a hash for the document content.""" | ||
return hashlib.md5(content.encode()).hexdigest() | ||
|
||
|
||
def get_doc_url(file_path: str, docs_path: str, docs_base_url: str) -> str: | ||
"""Convert file path to documentation URL.""" | ||
relative_path = os.path.relpath(file_path, docs_path) | ||
# Remove .mdx extension and convert to URL path | ||
url_path = os.path.splitext(relative_path)[0] | ||
return f"{docs_base_url}/{url_path}" | ||
|
||
|
||
def chunk_text(text: str, max_chunk_size: int = 1500) -> List[str]: | ||
""" | ||
Split text into chunks based on paragraphs and size. | ||
Tries to maintain context by keeping paragraphs together when possible. | ||
""" | ||
# Split by double newlines to preserve paragraph structure | ||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | ||
|
||
chunks = [] | ||
current_chunk = [] | ||
current_size = 0 | ||
|
||
for paragraph in paragraphs: | ||
paragraph_size = len(paragraph) | ||
|
||
# If a single paragraph is too large, split it by sentences | ||
if paragraph_size > max_chunk_size: | ||
sentences = [s.strip() + "." for s in paragraph.split(".") if s.strip()] | ||
for sentence in sentences: | ||
if len(sentence) > max_chunk_size: | ||
# If even a sentence is too long, split it by chunks | ||
for i in range(0, len(sentence), max_chunk_size): | ||
chunks.append(sentence[i : i + max_chunk_size]) | ||
elif current_size + len(sentence) > max_chunk_size: | ||
# Start new chunk | ||
chunks.append(" ".join(current_chunk)) | ||
current_chunk = [sentence] | ||
current_size = len(sentence) | ||
else: | ||
current_chunk.append(sentence) | ||
current_size += len(sentence) | ||
# If adding this paragraph would exceed the limit, start a new chunk | ||
elif current_size + paragraph_size > max_chunk_size: | ||
chunks.append(" ".join(current_chunk)) | ||
current_chunk = [paragraph] | ||
current_size = paragraph_size | ||
else: | ||
current_chunk.append(paragraph) | ||
current_size += paragraph_size | ||
|
||
# Add the last chunk if it exists | ||
if current_chunk: | ||
chunks.append(" ".join(current_chunk)) | ||
|
||
return chunks | ||
|
||
|
||
def process_doc(file_path: str, docs_path: str, docs_base_url: str) -> List[Dict]: | ||
"""Process a single document into chunks with metadata.""" | ||
with open(file_path, "r", encoding="utf-8") as f: | ||
# Parse frontmatter and content | ||
post = frontmatter.load(f) | ||
content = post.content | ||
|
||
# Calculate document hash | ||
doc_hash = calculate_doc_hash(content) | ||
|
||
# Get document URL | ||
doc_url = get_doc_url(file_path, docs_path, docs_base_url) | ||
|
||
# Create base metadata | ||
metadata = { | ||
"title": post.get("title", ""), | ||
"url": doc_url, | ||
"file_path": file_path, | ||
"last_updated": datetime.utcnow().isoformat(), | ||
"doc_hash": doc_hash, | ||
} | ||
|
||
# Chunk the content | ||
chunks = chunk_text(content) | ||
|
||
return [ | ||
{"content": chunk, "metadata": metadata, "doc_hash": doc_hash} | ||
for chunk in chunks | ||
] | ||
|
||
|
||
def get_embeddings(text: str) -> Dict[str, List[float]]: | ||
"""Get embeddings using both OpenAI and Cohere models via LiteLLM.""" | ||
# Get OpenAI embedding | ||
openai_response = embedding(model="text-embedding-ada-002", input=[text]) | ||
openai_embedding = openai_response["data"][0]["embedding"] | ||
|
||
# Get Cohere embedding | ||
cohere_response = embedding( | ||
model="cohere/embed-english-v3.0", | ||
input=[text], | ||
input_type="search_document", # Specific to Cohere v3 models | ||
) | ||
cohere_embedding = cohere_response["data"][0]["embedding"] | ||
|
||
return {"openai": openai_embedding, "cohere": cohere_embedding} | ||
|
||
|
||
def setup_qdrant_collection(): | ||
"""Create or recreate the vector collection.""" | ||
# Delete if exists | ||
try: | ||
qdrant_client.delete_collection(COLLECTION_NAME) | ||
except Exception: | ||
pass | ||
|
||
# Create collection with two vector types | ||
qdrant_client.create_collection( | ||
collection_name=COLLECTION_NAME, | ||
vectors_config={ | ||
"openai": models.VectorParams( | ||
size=OPENAI_EMBEDDING_DIM, distance=models.Distance.COSINE | ||
), | ||
"cohere": models.VectorParams( | ||
size=COHERE_EMBEDDING_DIM, distance=models.Distance.COSINE | ||
), | ||
}, | ||
) | ||
|
||
|
||
def upsert_chunks(chunks: List[Dict]): | ||
"""Upsert document chunks to the vector store.""" | ||
for i, chunk in enumerate(chunks): | ||
# Get both embeddings using LiteLLM | ||
embeddings = get_embeddings(chunk["content"]) | ||
|
||
# Create payload | ||
payload = {**chunk["metadata"], "content": chunk["content"], "chunk_index": i} | ||
|
||
# Upsert to Qdrant | ||
qdrant_client.upsert( | ||
collection_name=COLLECTION_NAME, | ||
points=[ | ||
models.PointStruct( | ||
id=f"{chunk['doc_hash']}", | ||
payload=payload, | ||
vector=embeddings, # Contains both 'openai' and 'cohere' embeddings | ||
) | ||
], | ||
) | ||
|
||
|
||
def main(): | ||
# Get environment variables | ||
docs_path = os.getenv("DOCS_PATH") | ||
docs_base_url = os.getenv("DOCS_BASE_URL") | ||
|
||
if not docs_path or not docs_base_url: | ||
raise ValueError("DOCS_PATH and DOCS_BASE_URL must be set in .env file") | ||
|
||
# Create fresh collection | ||
setup_qdrant_collection() | ||
|
||
# Process all documents | ||
all_docs = get_all_docs(docs_path) | ||
for doc_path in tqdm.tqdm(all_docs): | ||
print(f"Processing {doc_path}") | ||
chunks = process_doc(doc_path, docs_path, docs_base_url) | ||
upsert_chunks(chunks) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.