Skip to content

Commit

Permalink
Initial commit for RAG pipeline scripts
Browse files Browse the repository at this point in the history
Signed-off-by: hmumtazz <[email protected]>
  • Loading branch information
hmumtazz committed Nov 15, 2024
1 parent 39284d6 commit db78131
Show file tree
Hide file tree
Showing 8 changed files with 1,099 additions and 0 deletions.
12 changes: 12 additions & 0 deletions opensearch_py_ml/ml_commons/rag_pipeline/rag/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Ignore data and ingestion directories
ml_commons/rag_pipeline/data/
ml_commons/rag_pipeline/ingestion/
ml_commons/rag_pipeline/rag/config.ini
# Ignore virtual environment
.venv/
# Or, specify the full path
/Users/hmumtazz/.cursor-tutor/opensearch-py-ml/.venv/

# Ignore Python cache files
__pycache__/
*.pyc
194 changes: 194 additions & 0 deletions opensearch_py_ml/ml_commons/rag_pipeline/rag/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# ingest_class.py

import os
import glob
import json
import tiktoken
from tqdm import tqdm
from colorama import Fore, Style, init
from typing import List, Dict
import csv
import PyPDF2
import boto3
import botocore
import time
import random


from opensearch_class import OpenSearchClass

init(autoreset=True) # Initialize colorama

class IngestClass:
EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1'

def __init__(self, config):
self.config = config
self.aws_region = config.get('region')
self.index_name = config.get('index_name')
self.bedrock_client = None
self.opensearch = OpenSearchClass(config)

def initialize_clients(self):
try:
self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region)
if self.opensearch.initialize_opensearch_client():
print("Clients initialized successfully.")
return True
else:
print("Failed to initialize OpenSearch client.")
return False
except Exception as e:
print(f"Failed to initialize clients: {e}")
return False

def process_file(self, file_path: str) -> List[Dict[str, str]]:
_, file_extension = os.path.splitext(file_path)

if file_extension.lower() == '.csv':
return self.process_csv(file_path)
elif file_extension.lower() == '.txt':
return self.process_txt(file_path)
elif file_extension.lower() == '.pdf':
return self.process_pdf(file_path)
else:
print(f"Unsupported file type: {file_extension}")
return []

def process_csv(self, file_path: str) -> List[Dict[str, str]]:
documents = []
with open(file_path, 'r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
text = f"{row['name']} got nominated under the category, {row['category']}, for the film {row['film']}"
if row.get('winner', '').lower() != 'true':
text += " but did not win"
documents.append({"text": text})
return documents

def process_txt(self, file_path: str) -> List[Dict[str, str]]:
with open(file_path, 'r') as txtfile:
content = txtfile.read()
return [{"text": content}]

def process_pdf(self, file_path: str) -> List[Dict[str, str]]:
documents = []
with open(file_path, 'rb') as pdffile:
pdf_reader = PyPDF2.PdfReader(pdffile)
for page in pdf_reader.pages:
extracted_text = page.extract_text()
if extracted_text: # Ensure that text was extracted
documents.append({"text": extracted_text})
return documents

def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2):
if self.bedrock_client is None:
print("Bedrock client is not initialized. Please run setup first.")
return None

delay = initial_delay
for attempt in range(max_retries):
try:
payload = {"inputText": text}
response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload))
response_body = json.loads(response['body'].read())
embedding = response_body.get('embedding')
if embedding is None:
print(f"No embedding returned for text: {text}")
print(f"Response body: {response_body}")
return None
return embedding
except botocore.exceptions.ClientError as e:
error_code = e.response['Error']['Code']
error_message = e.response['Error']['Message']
print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}")
if error_code == 'ThrottlingException':
if attempt == max_retries - 1:
raise
time.sleep(delay + random.uniform(0, 1))
delay *= backoff_factor
else:
raise
except Exception as ex:
print(f"Unexpected error on attempt {attempt + 1}: {ex}")
if attempt == max_retries - 1:
raise
return None

def process_and_ingest_data(self, file_paths: List[str]):
if not self.initialize_clients():
print("Failed to initialize clients. Aborting ingestion.")
return

all_documents = []
for file_path in file_paths:
print(f"Processing file: {file_path}")
documents = self.process_file(file_path)
all_documents.extend(documents)

total_documents = len(all_documents)
print(f"Total documents to process: {total_documents}")

print("Generating embeddings for the documents...")
success_count = 0
error_count = 0
with tqdm(total=total_documents, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar:
for doc in all_documents:
try:
embedding = self.text_embedding(doc['text'])
if embedding is not None:
doc['embedding'] = embedding
success_count += 1
else:
error_count += 1
print(f"{Fore.RED}Error generating embedding for document: {doc['text'][:50]}...{Style.RESET_ALL}")
except Exception as e:
error_count += 1
print(f"{Fore.RED}Error processing document: {str(e)}{Style.RESET_ALL}")
pbar.update(1)
pbar.set_postfix({'Success': success_count, 'Errors': error_count})

print(f"\n{Fore.GREEN}Documents with successful embeddings: {success_count}{Style.RESET_ALL}")
print(f"{Fore.RED}Documents with failed embeddings: {error_count}{Style.RESET_ALL}")

if success_count == 0:
print(f"{Fore.RED}No documents to ingest. Aborting ingestion.{Style.RESET_ALL}")
return

print(f"{Fore.YELLOW}Ingesting data into OpenSearch...{Style.RESET_ALL}")
actions = []
for doc in all_documents:
if 'embedding' in doc and doc['embedding'] is not None:
action = {
"_index": self.index_name,
"_source": {
"nominee_text": doc['text'],
"nominee_vector": doc['embedding']
}
}
actions.append(action)

success, failed = self.opensearch.bulk_index(actions)
print(f"{Fore.GREEN}Successfully ingested {success} documents.{Style.RESET_ALL}")
print(f"{Fore.RED}Failed to ingest {failed} documents.{Style.RESET_ALL}")

def ingest_command(self, paths: List[str]):
all_files = []
for path in paths:
if os.path.isfile(path):
all_files.append(path)
elif os.path.isdir(path):
all_files.extend(glob.glob(os.path.join(path, '*')))
else:
print(f"{Fore.YELLOW}Invalid path: {path}{Style.RESET_ALL}")

supported_extensions = ['.csv', '.txt', '.pdf']
valid_files = [f for f in all_files if any(f.lower().endswith(ext) for ext in supported_extensions)]

if not valid_files:
print(f"{Fore.RED}No valid files found for ingestion.{Style.RESET_ALL}")
return

print(f"{Fore.GREEN}Found {len(valid_files)} valid files for ingestion.{Style.RESET_ALL}")

self.process_and_ingest_data(valid_files)
127 changes: 127 additions & 0 deletions opensearch_py_ml/ml_commons/rag_pipeline/rag/opensearch_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# opensearch_class.py

from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, exceptions as opensearch_exceptions
import boto3
from urllib.parse import urlparse
from opensearchpy import helpers as opensearch_helpers

class OpenSearchClass:
def __init__(self, config):
self.config = config
self.opensearch_client = None
self.aws_region = config.get('region')
self.index_name = config.get('index_name')
self.is_serverless = config.get('is_serverless', 'False') == 'True'
self.opensearch_endpoint = config.get('opensearch_endpoint')
self.opensearch_username = config.get('opensearch_username')
self.opensearch_password = config.get('opensearch_password')

def initialize_opensearch_client(self):
if not self.opensearch_endpoint:
print("OpenSearch endpoint not set. Please run setup first.")
return False

parsed_url = urlparse(self.opensearch_endpoint)
host = parsed_url.hostname
port = parsed_url.port or 443

if self.is_serverless:
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, self.aws_region, 'aoss')
else:
if not self.opensearch_username or not self.opensearch_password:
print("OpenSearch username or password not set. Please run setup first.")
return False
auth = (self.opensearch_username, self.opensearch_password)

try:
self.opensearch_client = OpenSearch(
hosts=[{'host': host, 'port': port}],
http_auth=auth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
pool_maxsize=20
)
print(f"Initialized OpenSearch client with host: {host} and port: {port}")
return True
except Exception as ex:
print(f"Error initializing OpenSearch client: {ex}")
return False

def create_index(self, embedding_dimension, space_type):
index_body = {
"mappings": {
"properties": {
"nominee_text": {"type": "text"},
"nominee_vector": {
"type": "knn_vector",
"dimension": embedding_dimension,
"method": {
"name": "hnsw",
"space_type": space_type,
"engine": "nmslib",
"parameters": {"ef_construction": 512, "m": 16},
},
},
}
},
"settings": {
"index": {
"number_of_shards": 2,
"knn.algo_param": {"ef_search": 512},
"knn": True,
}
},
}
try:
self.opensearch_client.indices.create(index=self.index_name, body=index_body)
print(f"KNN index '{self.index_name}' created successfully with dimension {embedding_dimension} and space type {space_type}.")
except opensearch_exceptions.RequestError as e:
if 'resource_already_exists_exception' in str(e).lower():
print(f"Index '{self.index_name}' already exists.")
else:
print(f"Error creating index '{self.index_name}': {e}")

def verify_and_create_index(self, embedding_dimension, space_type):
try:
index_exists = self.opensearch_client.indices.exists(index=self.index_name)
if index_exists:
print(f"KNN index '{self.index_name}' already exists.")
else:
self.create_index(embedding_dimension, space_type)
return True
except Exception as ex:
print(f"Error verifying or creating index: {ex}")
return False

def bulk_index(self, actions):
try:
success, failed = opensearch_helpers.bulk(self.opensearch_client, actions)
print(f"Indexed {success} documents successfully. Failed to index {failed} documents.")
return success, failed
except Exception as e:
print(f"Error during bulk indexing: {e}")
return 0, len(actions)

def search(self, vector, k=5):
try:
response = self.opensearch_client.search(
index=self.index_name,
body={
"size": k,
"_source": ["nominee_text"],
"query": {
"knn": {
"nominee_vector": {
"vector": vector,
"k": k
}
}
}
}
)
return response['hits']['hits']
except Exception as e:
print(f"Error during search: {e}")
return []
Loading

0 comments on commit db78131

Please sign in to comment.