-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for RAG pipeline scripts
Signed-off-by: hmumtazz <[email protected]>
- Loading branch information
Showing
8 changed files
with
1,099 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Ignore data and ingestion directories | ||
ml_commons/rag_pipeline/data/ | ||
ml_commons/rag_pipeline/ingestion/ | ||
ml_commons/rag_pipeline/rag/config.ini | ||
# Ignore virtual environment | ||
.venv/ | ||
# Or, specify the full path | ||
/Users/hmumtazz/.cursor-tutor/opensearch-py-ml/.venv/ | ||
|
||
# Ignore Python cache files | ||
__pycache__/ | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# ingest_class.py | ||
|
||
import os | ||
import glob | ||
import json | ||
import tiktoken | ||
from tqdm import tqdm | ||
from colorama import Fore, Style, init | ||
from typing import List, Dict | ||
import csv | ||
import PyPDF2 | ||
import boto3 | ||
import botocore | ||
import time | ||
import random | ||
|
||
|
||
from opensearch_class import OpenSearchClass | ||
|
||
init(autoreset=True) # Initialize colorama | ||
|
||
class IngestClass: | ||
EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1' | ||
|
||
def __init__(self, config): | ||
self.config = config | ||
self.aws_region = config.get('region') | ||
self.index_name = config.get('index_name') | ||
self.bedrock_client = None | ||
self.opensearch = OpenSearchClass(config) | ||
|
||
def initialize_clients(self): | ||
try: | ||
self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region) | ||
if self.opensearch.initialize_opensearch_client(): | ||
print("Clients initialized successfully.") | ||
return True | ||
else: | ||
print("Failed to initialize OpenSearch client.") | ||
return False | ||
except Exception as e: | ||
print(f"Failed to initialize clients: {e}") | ||
return False | ||
|
||
def process_file(self, file_path: str) -> List[Dict[str, str]]: | ||
_, file_extension = os.path.splitext(file_path) | ||
|
||
if file_extension.lower() == '.csv': | ||
return self.process_csv(file_path) | ||
elif file_extension.lower() == '.txt': | ||
return self.process_txt(file_path) | ||
elif file_extension.lower() == '.pdf': | ||
return self.process_pdf(file_path) | ||
else: | ||
print(f"Unsupported file type: {file_extension}") | ||
return [] | ||
|
||
def process_csv(self, file_path: str) -> List[Dict[str, str]]: | ||
documents = [] | ||
with open(file_path, 'r') as csvfile: | ||
reader = csv.DictReader(csvfile) | ||
for row in reader: | ||
text = f"{row['name']} got nominated under the category, {row['category']}, for the film {row['film']}" | ||
if row.get('winner', '').lower() != 'true': | ||
text += " but did not win" | ||
documents.append({"text": text}) | ||
return documents | ||
|
||
def process_txt(self, file_path: str) -> List[Dict[str, str]]: | ||
with open(file_path, 'r') as txtfile: | ||
content = txtfile.read() | ||
return [{"text": content}] | ||
|
||
def process_pdf(self, file_path: str) -> List[Dict[str, str]]: | ||
documents = [] | ||
with open(file_path, 'rb') as pdffile: | ||
pdf_reader = PyPDF2.PdfReader(pdffile) | ||
for page in pdf_reader.pages: | ||
extracted_text = page.extract_text() | ||
if extracted_text: # Ensure that text was extracted | ||
documents.append({"text": extracted_text}) | ||
return documents | ||
|
||
def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2): | ||
if self.bedrock_client is None: | ||
print("Bedrock client is not initialized. Please run setup first.") | ||
return None | ||
|
||
delay = initial_delay | ||
for attempt in range(max_retries): | ||
try: | ||
payload = {"inputText": text} | ||
response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload)) | ||
response_body = json.loads(response['body'].read()) | ||
embedding = response_body.get('embedding') | ||
if embedding is None: | ||
print(f"No embedding returned for text: {text}") | ||
print(f"Response body: {response_body}") | ||
return None | ||
return embedding | ||
except botocore.exceptions.ClientError as e: | ||
error_code = e.response['Error']['Code'] | ||
error_message = e.response['Error']['Message'] | ||
print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}") | ||
if error_code == 'ThrottlingException': | ||
if attempt == max_retries - 1: | ||
raise | ||
time.sleep(delay + random.uniform(0, 1)) | ||
delay *= backoff_factor | ||
else: | ||
raise | ||
except Exception as ex: | ||
print(f"Unexpected error on attempt {attempt + 1}: {ex}") | ||
if attempt == max_retries - 1: | ||
raise | ||
return None | ||
|
||
def process_and_ingest_data(self, file_paths: List[str]): | ||
if not self.initialize_clients(): | ||
print("Failed to initialize clients. Aborting ingestion.") | ||
return | ||
|
||
all_documents = [] | ||
for file_path in file_paths: | ||
print(f"Processing file: {file_path}") | ||
documents = self.process_file(file_path) | ||
all_documents.extend(documents) | ||
|
||
total_documents = len(all_documents) | ||
print(f"Total documents to process: {total_documents}") | ||
|
||
print("Generating embeddings for the documents...") | ||
success_count = 0 | ||
error_count = 0 | ||
with tqdm(total=total_documents, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar: | ||
for doc in all_documents: | ||
try: | ||
embedding = self.text_embedding(doc['text']) | ||
if embedding is not None: | ||
doc['embedding'] = embedding | ||
success_count += 1 | ||
else: | ||
error_count += 1 | ||
print(f"{Fore.RED}Error generating embedding for document: {doc['text'][:50]}...{Style.RESET_ALL}") | ||
except Exception as e: | ||
error_count += 1 | ||
print(f"{Fore.RED}Error processing document: {str(e)}{Style.RESET_ALL}") | ||
pbar.update(1) | ||
pbar.set_postfix({'Success': success_count, 'Errors': error_count}) | ||
|
||
print(f"\n{Fore.GREEN}Documents with successful embeddings: {success_count}{Style.RESET_ALL}") | ||
print(f"{Fore.RED}Documents with failed embeddings: {error_count}{Style.RESET_ALL}") | ||
|
||
if success_count == 0: | ||
print(f"{Fore.RED}No documents to ingest. Aborting ingestion.{Style.RESET_ALL}") | ||
return | ||
|
||
print(f"{Fore.YELLOW}Ingesting data into OpenSearch...{Style.RESET_ALL}") | ||
actions = [] | ||
for doc in all_documents: | ||
if 'embedding' in doc and doc['embedding'] is not None: | ||
action = { | ||
"_index": self.index_name, | ||
"_source": { | ||
"nominee_text": doc['text'], | ||
"nominee_vector": doc['embedding'] | ||
} | ||
} | ||
actions.append(action) | ||
|
||
success, failed = self.opensearch.bulk_index(actions) | ||
print(f"{Fore.GREEN}Successfully ingested {success} documents.{Style.RESET_ALL}") | ||
print(f"{Fore.RED}Failed to ingest {failed} documents.{Style.RESET_ALL}") | ||
|
||
def ingest_command(self, paths: List[str]): | ||
all_files = [] | ||
for path in paths: | ||
if os.path.isfile(path): | ||
all_files.append(path) | ||
elif os.path.isdir(path): | ||
all_files.extend(glob.glob(os.path.join(path, '*'))) | ||
else: | ||
print(f"{Fore.YELLOW}Invalid path: {path}{Style.RESET_ALL}") | ||
|
||
supported_extensions = ['.csv', '.txt', '.pdf'] | ||
valid_files = [f for f in all_files if any(f.lower().endswith(ext) for ext in supported_extensions)] | ||
|
||
if not valid_files: | ||
print(f"{Fore.RED}No valid files found for ingestion.{Style.RESET_ALL}") | ||
return | ||
|
||
print(f"{Fore.GREEN}Found {len(valid_files)} valid files for ingestion.{Style.RESET_ALL}") | ||
|
||
self.process_and_ingest_data(valid_files) |
127 changes: 127 additions & 0 deletions
127
opensearch_py_ml/ml_commons/rag_pipeline/rag/opensearch_class.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# opensearch_class.py | ||
|
||
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, exceptions as opensearch_exceptions | ||
import boto3 | ||
from urllib.parse import urlparse | ||
from opensearchpy import helpers as opensearch_helpers | ||
|
||
class OpenSearchClass: | ||
def __init__(self, config): | ||
self.config = config | ||
self.opensearch_client = None | ||
self.aws_region = config.get('region') | ||
self.index_name = config.get('index_name') | ||
self.is_serverless = config.get('is_serverless', 'False') == 'True' | ||
self.opensearch_endpoint = config.get('opensearch_endpoint') | ||
self.opensearch_username = config.get('opensearch_username') | ||
self.opensearch_password = config.get('opensearch_password') | ||
|
||
def initialize_opensearch_client(self): | ||
if not self.opensearch_endpoint: | ||
print("OpenSearch endpoint not set. Please run setup first.") | ||
return False | ||
|
||
parsed_url = urlparse(self.opensearch_endpoint) | ||
host = parsed_url.hostname | ||
port = parsed_url.port or 443 | ||
|
||
if self.is_serverless: | ||
credentials = boto3.Session().get_credentials() | ||
auth = AWSV4SignerAuth(credentials, self.aws_region, 'aoss') | ||
else: | ||
if not self.opensearch_username or not self.opensearch_password: | ||
print("OpenSearch username or password not set. Please run setup first.") | ||
return False | ||
auth = (self.opensearch_username, self.opensearch_password) | ||
|
||
try: | ||
self.opensearch_client = OpenSearch( | ||
hosts=[{'host': host, 'port': port}], | ||
http_auth=auth, | ||
use_ssl=True, | ||
verify_certs=True, | ||
connection_class=RequestsHttpConnection, | ||
pool_maxsize=20 | ||
) | ||
print(f"Initialized OpenSearch client with host: {host} and port: {port}") | ||
return True | ||
except Exception as ex: | ||
print(f"Error initializing OpenSearch client: {ex}") | ||
return False | ||
|
||
def create_index(self, embedding_dimension, space_type): | ||
index_body = { | ||
"mappings": { | ||
"properties": { | ||
"nominee_text": {"type": "text"}, | ||
"nominee_vector": { | ||
"type": "knn_vector", | ||
"dimension": embedding_dimension, | ||
"method": { | ||
"name": "hnsw", | ||
"space_type": space_type, | ||
"engine": "nmslib", | ||
"parameters": {"ef_construction": 512, "m": 16}, | ||
}, | ||
}, | ||
} | ||
}, | ||
"settings": { | ||
"index": { | ||
"number_of_shards": 2, | ||
"knn.algo_param": {"ef_search": 512}, | ||
"knn": True, | ||
} | ||
}, | ||
} | ||
try: | ||
self.opensearch_client.indices.create(index=self.index_name, body=index_body) | ||
print(f"KNN index '{self.index_name}' created successfully with dimension {embedding_dimension} and space type {space_type}.") | ||
except opensearch_exceptions.RequestError as e: | ||
if 'resource_already_exists_exception' in str(e).lower(): | ||
print(f"Index '{self.index_name}' already exists.") | ||
else: | ||
print(f"Error creating index '{self.index_name}': {e}") | ||
|
||
def verify_and_create_index(self, embedding_dimension, space_type): | ||
try: | ||
index_exists = self.opensearch_client.indices.exists(index=self.index_name) | ||
if index_exists: | ||
print(f"KNN index '{self.index_name}' already exists.") | ||
else: | ||
self.create_index(embedding_dimension, space_type) | ||
return True | ||
except Exception as ex: | ||
print(f"Error verifying or creating index: {ex}") | ||
return False | ||
|
||
def bulk_index(self, actions): | ||
try: | ||
success, failed = opensearch_helpers.bulk(self.opensearch_client, actions) | ||
print(f"Indexed {success} documents successfully. Failed to index {failed} documents.") | ||
return success, failed | ||
except Exception as e: | ||
print(f"Error during bulk indexing: {e}") | ||
return 0, len(actions) | ||
|
||
def search(self, vector, k=5): | ||
try: | ||
response = self.opensearch_client.search( | ||
index=self.index_name, | ||
body={ | ||
"size": k, | ||
"_source": ["nominee_text"], | ||
"query": { | ||
"knn": { | ||
"nominee_vector": { | ||
"vector": vector, | ||
"k": k | ||
} | ||
} | ||
} | ||
} | ||
) | ||
return response['hits']['hits'] | ||
except Exception as e: | ||
print(f"Error during search: {e}") | ||
return [] |
Oops, something went wrong.