diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/.gitignore b/opensearch_py_ml/ml_commons/rag_pipeline/rag/.gitignore new file mode 100644 index 00000000..801d43ba --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/.gitignore @@ -0,0 +1,12 @@ +# Ignore data and ingestion directories +ml_commons/rag_pipeline/data/ +ml_commons/rag_pipeline/ingestion/ +ml_commons/rag_pipeline/rag/config.ini +# Ignore virtual environment +.venv/ +# Or, specify the full path +/Users/hmumtazz/.cursor-tutor/opensearch-py-ml/.venv/ + +# Ignore Python cache files +__pycache__/ +*.pyc diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/ingest.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/ingest.py new file mode 100644 index 00000000..9bcfb316 --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/ingest.py @@ -0,0 +1,194 @@ +# ingest_class.py + +import os +import glob +import json +import tiktoken +from tqdm import tqdm +from colorama import Fore, Style, init +from typing import List, Dict +import csv +import PyPDF2 +import boto3 +import botocore +import time +import random + + +from opensearch_class import OpenSearchClass + +init(autoreset=True) # Initialize colorama + +class IngestClass: + EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1' + + def __init__(self, config): + self.config = config + self.aws_region = config.get('region') + self.index_name = config.get('index_name') + self.bedrock_client = None + self.opensearch = OpenSearchClass(config) + + def initialize_clients(self): + try: + self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region) + if self.opensearch.initialize_opensearch_client(): + print("Clients initialized successfully.") + return True + else: + print("Failed to initialize OpenSearch client.") + return False + except Exception as e: + print(f"Failed to initialize clients: {e}") + return False + + def process_file(self, file_path: str) -> List[Dict[str, str]]: + _, file_extension = os.path.splitext(file_path) + + if file_extension.lower() == '.csv': + return self.process_csv(file_path) + elif file_extension.lower() == '.txt': + return self.process_txt(file_path) + elif file_extension.lower() == '.pdf': + return self.process_pdf(file_path) + else: + print(f"Unsupported file type: {file_extension}") + return [] + + def process_csv(self, file_path: str) -> List[Dict[str, str]]: + documents = [] + with open(file_path, 'r') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + text = f"{row['name']} got nominated under the category, {row['category']}, for the film {row['film']}" + if row.get('winner', '').lower() != 'true': + text += " but did not win" + documents.append({"text": text}) + return documents + + def process_txt(self, file_path: str) -> List[Dict[str, str]]: + with open(file_path, 'r') as txtfile: + content = txtfile.read() + return [{"text": content}] + + def process_pdf(self, file_path: str) -> List[Dict[str, str]]: + documents = [] + with open(file_path, 'rb') as pdffile: + pdf_reader = PyPDF2.PdfReader(pdffile) + for page in pdf_reader.pages: + extracted_text = page.extract_text() + if extracted_text: # Ensure that text was extracted + documents.append({"text": extracted_text}) + return documents + + def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2): + if self.bedrock_client is None: + print("Bedrock client is not initialized. Please run setup first.") + return None + + delay = initial_delay + for attempt in range(max_retries): + try: + payload = {"inputText": text} + response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload)) + response_body = json.loads(response['body'].read()) + embedding = response_body.get('embedding') + if embedding is None: + print(f"No embedding returned for text: {text}") + print(f"Response body: {response_body}") + return None + return embedding + except botocore.exceptions.ClientError as e: + error_code = e.response['Error']['Code'] + error_message = e.response['Error']['Message'] + print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}") + if error_code == 'ThrottlingException': + if attempt == max_retries - 1: + raise + time.sleep(delay + random.uniform(0, 1)) + delay *= backoff_factor + else: + raise + except Exception as ex: + print(f"Unexpected error on attempt {attempt + 1}: {ex}") + if attempt == max_retries - 1: + raise + return None + + def process_and_ingest_data(self, file_paths: List[str]): + if not self.initialize_clients(): + print("Failed to initialize clients. Aborting ingestion.") + return + + all_documents = [] + for file_path in file_paths: + print(f"Processing file: {file_path}") + documents = self.process_file(file_path) + all_documents.extend(documents) + + total_documents = len(all_documents) + print(f"Total documents to process: {total_documents}") + + print("Generating embeddings for the documents...") + success_count = 0 + error_count = 0 + with tqdm(total=total_documents, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') as pbar: + for doc in all_documents: + try: + embedding = self.text_embedding(doc['text']) + if embedding is not None: + doc['embedding'] = embedding + success_count += 1 + else: + error_count += 1 + print(f"{Fore.RED}Error generating embedding for document: {doc['text'][:50]}...{Style.RESET_ALL}") + except Exception as e: + error_count += 1 + print(f"{Fore.RED}Error processing document: {str(e)}{Style.RESET_ALL}") + pbar.update(1) + pbar.set_postfix({'Success': success_count, 'Errors': error_count}) + + print(f"\n{Fore.GREEN}Documents with successful embeddings: {success_count}{Style.RESET_ALL}") + print(f"{Fore.RED}Documents with failed embeddings: {error_count}{Style.RESET_ALL}") + + if success_count == 0: + print(f"{Fore.RED}No documents to ingest. Aborting ingestion.{Style.RESET_ALL}") + return + + print(f"{Fore.YELLOW}Ingesting data into OpenSearch...{Style.RESET_ALL}") + actions = [] + for doc in all_documents: + if 'embedding' in doc and doc['embedding'] is not None: + action = { + "_index": self.index_name, + "_source": { + "nominee_text": doc['text'], + "nominee_vector": doc['embedding'] + } + } + actions.append(action) + + success, failed = self.opensearch.bulk_index(actions) + print(f"{Fore.GREEN}Successfully ingested {success} documents.{Style.RESET_ALL}") + print(f"{Fore.RED}Failed to ingest {failed} documents.{Style.RESET_ALL}") + + def ingest_command(self, paths: List[str]): + all_files = [] + for path in paths: + if os.path.isfile(path): + all_files.append(path) + elif os.path.isdir(path): + all_files.extend(glob.glob(os.path.join(path, '*'))) + else: + print(f"{Fore.YELLOW}Invalid path: {path}{Style.RESET_ALL}") + + supported_extensions = ['.csv', '.txt', '.pdf'] + valid_files = [f for f in all_files if any(f.lower().endswith(ext) for ext in supported_extensions)] + + if not valid_files: + print(f"{Fore.RED}No valid files found for ingestion.{Style.RESET_ALL}") + return + + print(f"{Fore.GREEN}Found {len(valid_files)} valid files for ingestion.{Style.RESET_ALL}") + + self.process_and_ingest_data(valid_files) diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/opensearch_class.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/opensearch_class.py new file mode 100644 index 00000000..eca4619c --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/opensearch_class.py @@ -0,0 +1,127 @@ +# opensearch_class.py + +from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, exceptions as opensearch_exceptions +import boto3 +from urllib.parse import urlparse +from opensearchpy import helpers as opensearch_helpers + +class OpenSearchClass: + def __init__(self, config): + self.config = config + self.opensearch_client = None + self.aws_region = config.get('region') + self.index_name = config.get('index_name') + self.is_serverless = config.get('is_serverless', 'False') == 'True' + self.opensearch_endpoint = config.get('opensearch_endpoint') + self.opensearch_username = config.get('opensearch_username') + self.opensearch_password = config.get('opensearch_password') + + def initialize_opensearch_client(self): + if not self.opensearch_endpoint: + print("OpenSearch endpoint not set. Please run setup first.") + return False + + parsed_url = urlparse(self.opensearch_endpoint) + host = parsed_url.hostname + port = parsed_url.port or 443 + + if self.is_serverless: + credentials = boto3.Session().get_credentials() + auth = AWSV4SignerAuth(credentials, self.aws_region, 'aoss') + else: + if not self.opensearch_username or not self.opensearch_password: + print("OpenSearch username or password not set. Please run setup first.") + return False + auth = (self.opensearch_username, self.opensearch_password) + + try: + self.opensearch_client = OpenSearch( + hosts=[{'host': host, 'port': port}], + http_auth=auth, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + pool_maxsize=20 + ) + print(f"Initialized OpenSearch client with host: {host} and port: {port}") + return True + except Exception as ex: + print(f"Error initializing OpenSearch client: {ex}") + return False + + def create_index(self, embedding_dimension, space_type): + index_body = { + "mappings": { + "properties": { + "nominee_text": {"type": "text"}, + "nominee_vector": { + "type": "knn_vector", + "dimension": embedding_dimension, + "method": { + "name": "hnsw", + "space_type": space_type, + "engine": "nmslib", + "parameters": {"ef_construction": 512, "m": 16}, + }, + }, + } + }, + "settings": { + "index": { + "number_of_shards": 2, + "knn.algo_param": {"ef_search": 512}, + "knn": True, + } + }, + } + try: + self.opensearch_client.indices.create(index=self.index_name, body=index_body) + print(f"KNN index '{self.index_name}' created successfully with dimension {embedding_dimension} and space type {space_type}.") + except opensearch_exceptions.RequestError as e: + if 'resource_already_exists_exception' in str(e).lower(): + print(f"Index '{self.index_name}' already exists.") + else: + print(f"Error creating index '{self.index_name}': {e}") + + def verify_and_create_index(self, embedding_dimension, space_type): + try: + index_exists = self.opensearch_client.indices.exists(index=self.index_name) + if index_exists: + print(f"KNN index '{self.index_name}' already exists.") + else: + self.create_index(embedding_dimension, space_type) + return True + except Exception as ex: + print(f"Error verifying or creating index: {ex}") + return False + + def bulk_index(self, actions): + try: + success, failed = opensearch_helpers.bulk(self.opensearch_client, actions) + print(f"Indexed {success} documents successfully. Failed to index {failed} documents.") + return success, failed + except Exception as e: + print(f"Error during bulk indexing: {e}") + return 0, len(actions) + + def search(self, vector, k=5): + try: + response = self.opensearch_client.search( + index=self.index_name, + body={ + "size": k, + "_source": ["nominee_text"], + "query": { + "knn": { + "nominee_vector": { + "vector": vector, + "k": k + } + } + } + } + ) + return response['hits']['hits'] + except Exception as e: + print(f"Error during search: {e}") + return [] diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/query.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/query.py new file mode 100644 index 00000000..d4305c90 --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/query.py @@ -0,0 +1,187 @@ +# query_class.py + +import json +import tiktoken +from colorama import Fore, Style, init +from typing import List +import boto3 +import botocore +import time +import random +from opensearch_class import OpenSearchClass + +init(autoreset=True) # Initialize colorama + +class QueryClass: + EMBEDDING_MODEL_ID = 'amazon.titan-embed-text-v1' + LLM_MODEL_ID = 'amazon.titan-text-express-v1' + + def __init__(self, config): + self.config = config + self.aws_region = config.get('region') + self.index_name = config.get('index_name') + self.bedrock_client = None + self.opensearch = OpenSearchClass(config) + + def initialize_clients(self): + try: + self.bedrock_client = boto3.client('bedrock-runtime', region_name=self.aws_region) + if self.opensearch.initialize_opensearch_client(): + print("Clients initialized successfully.") + return True + else: + print("Failed to initialize OpenSearch client.") + return False + except Exception as e: + print(f"Failed to initialize clients: {e}") + return False + + def text_embedding(self, text, max_retries=5, initial_delay=1, backoff_factor=2): + if self.bedrock_client is None: + print("Bedrock client is not initialized. Please run setup first.") + return None + + delay = initial_delay + for attempt in range(max_retries): + try: + payload = {"inputText": text} + response = self.bedrock_client.invoke_model(modelId=self.EMBEDDING_MODEL_ID, body=json.dumps(payload)) + response_body = json.loads(response['body'].read()) + embedding = response_body.get('embedding') + if embedding is None: + print(f"No embedding returned for text: {text}") + print(f"Response body: {response_body}") + return None + return embedding + except botocore.exceptions.ClientError as e: + error_code = e.response['Error']['Code'] + error_message = e.response['Error']['Message'] + print(f"ClientError on attempt {attempt + 1}: {error_code} - {error_message}") + if error_code == 'ThrottlingException': + if attempt == max_retries - 1: + raise + time.sleep(delay + random.uniform(0, 1)) + delay *= backoff_factor + else: + raise + except Exception as ex: + print(f"Unexpected error on attempt {attempt + 1}: {ex}") + if attempt == max_retries - 1: + raise + return None + + def bulk_query(self, queries, k=5): + print("Generating embeddings for queries...") + query_vectors = [] + for query in queries: + embedding = self.text_embedding(query) + if embedding: + query_vectors.append(embedding) + else: + print(f"{Fore.RED}Failed to generate embedding for query: {query}{Style.RESET_ALL}") + query_vectors.append(None) + + print("Performing bulk semantic search...") + results = [] + for i, vector in enumerate(query_vectors): + if vector is None: + results.append({ + 'query': queries[i], + 'context': "", + 'num_results': 0 + }) + continue + try: + hits = self.opensearch.search(vector, k) + context = '\n'.join([hit['_source']['nominee_text'] for hit in hits]) + results.append({ + 'query': queries[i], + 'context': context, + 'num_results': len(hits) + }) + except Exception as ex: + print(f"{Fore.RED}Error performing search for query '{queries[i]}': {ex}{Style.RESET_ALL}") + results.append({ + 'query': queries[i], + 'context': "", + 'num_results': 0 + }) + + return results + + def generate_answer(self, prompt, config): + try: + max_input_tokens = 8192 # Max tokens for the model + expected_output_tokens = config.get('maxTokenCount', 1000) + encoding = tiktoken.get_encoding("cl100k_base") # Use appropriate encoding + + prompt_tokens = encoding.encode(prompt) + allowable_input_tokens = max_input_tokens - expected_output_tokens + + if len(prompt_tokens) > allowable_input_tokens: + # Truncate the prompt to fit within the model's token limit + prompt_tokens = prompt_tokens[:allowable_input_tokens] + prompt = encoding.decode(prompt_tokens) + print(f"Prompt truncated to {allowable_input_tokens} tokens.") + + # Simplified LLM config with only supported parameters + llm_config = { + 'maxTokenCount': expected_output_tokens, + 'temperature': config.get('temperature', 0.7), + 'topP': config.get('topP', 1.0), + 'stopSequences': config.get('stopSequences', []) + } + + body = json.dumps({ + 'inputText': prompt, + 'textGenerationConfig': llm_config + }) + response = self.bedrock_client.invoke_model(modelId=self.LLM_MODEL_ID, body=body) + response_body = json.loads(response['body'].read()) + results = response_body.get('results', []) + if not results: + print("No results returned from LLM.") + return None + answer = results[0].get('outputText', '').strip() + return answer + except Exception as ex: + print(f"Error generating answer from LLM: {ex}") + return None + + def query_command(self, queries: List[str], num_results=5): + if not self.initialize_clients(): + print("Failed to initialize clients. Aborting query.") + return + + results = self.bulk_query(queries, k=num_results) + + llm_config = { + "maxTokenCount": 1000, + "temperature": 0.7, + "topP": 0.9, + "stopSequences": [] + } + + for result in results: + print(f"\nQuery: {result['query']}") + print(f"Found {result['num_results']} results.") + + if not result['context']: + print(f"{Fore.RED}No context available for this query.{Style.RESET_ALL}") + continue + + augmented_prompt = f"""Context: {result['context']} +Based on the above context, please provide a detailed and insightful answer to the following question. Feel free to make reasonable inferences or connections if the context doesn't provide all the information: + +Question: {result['query']} + +Answer:""" + + print("Generating answer using LLM...") + answer = self.generate_answer(augmented_prompt, llm_config) + + if answer: + print("Generated Answer:") + print(answer) + else: + print("Failed to generate an answer.") diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag.py new file mode 100755 index 00000000..80f57875 --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +""" +Main CLI script for OpenSearch with Bedrock Integration +""" + +import argparse +import configparser +from rag_setup import SetupClass +from ingest import IngestClass +from query import QueryClass + +CONFIG_FILE = 'config.ini' + +def load_config(): + config = configparser.ConfigParser() + config.read(CONFIG_FILE) + return config['DEFAULT'] + +def save_config(config): + parser = configparser.ConfigParser() + parser['DEFAULT'] = config + with open(CONFIG_FILE, 'w') as f: + parser.write(f) + +def main(): + parser = argparse.ArgumentParser(description="RAG Pipeline CLI") + parser.add_argument('command', choices=['setup', 'ingest', 'query'], help='Command to run') + parser.add_argument('--paths', nargs='+', help='Paths to files or directories for ingestion') + parser.add_argument('--queries', nargs='+', help='Query texts for search and answer generation') + parser.add_argument('--num_results', type=int, default=5, help='Number of top results to retrieve for each query') + + args = parser.parse_args() + + config = load_config() + + if args.command == 'setup': + setup = SetupClass() + setup.setup_command() + save_config(setup.config) + elif args.command == 'ingest': + if not args.paths: + paths = [] + while True: + path = input("Enter a file or directory path (or press Enter to finish): ") + if not path: + break + paths.append(path) + else: + paths = args.paths + ingest = IngestClass(config) + ingest.ingest_command(paths) + elif args.command == 'query': + if not args.queries: + queries = [] + while True: + query = input("Enter a query (or press Enter to finish): ") + if not query: + break + queries.append(query) + else: + queries = args.queries + query = QueryClass(config) + query.query_command(queries, num_results=args.num_results) + else: + parser.print_help() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag_setup.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag_setup.py new file mode 100644 index 00000000..47c03b9e --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/rag_setup.py @@ -0,0 +1,486 @@ +# setup_class.py +import boto3 +import botocore +from botocore.config import Config +import configparser +import subprocess +import os +import json +import time +import termios +import tty +import sys +from urllib.parse import urlparse +from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth + +class SetupClass: + CONFIG_FILE = 'config.ini' + SERVICE_AOSS = 'opensearchserverless' + SERVICE_BEDROCK = 'bedrock-runtime' + + def __init__(self): + self.aws_region = None + self.iam_principal = None + self.index_name = None + self.collection_name = None + self.opensearch_endpoint = None + self.is_serverless = None + self.opensearch_username = None + self.opensearch_password = None + self.aoss_client = None + self.bedrock_client = None + self.opensearch_client = None + + def check_and_configure_aws(self): + try: + session = boto3.Session() + credentials = session.get_credentials() + + if credentials is None: + print("AWS credentials are not configured.") + self.configure_aws() + else: + print("AWS credentials are already configured.") + reconfigure = input("Do you want to reconfigure? (yes/no): ").lower() + if reconfigure == 'yes': + self.configure_aws() + except Exception as e: + print(f"An error occurred while checking AWS credentials: {e}") + self.configure_aws() + + def configure_aws(self): + print("Let's configure your AWS credentials.") + + aws_access_key_id = input("Enter your AWS Access Key ID: ") + aws_secret_access_key = input("Enter your AWS Secret Access Key: ") + aws_region_input = input("Enter your preferred AWS region (e.g., us-west-2): ") + + try: + subprocess.run([ + 'aws', 'configure', 'set', + 'aws_access_key_id', aws_access_key_id + ], check=True) + + subprocess.run([ + 'aws', 'configure', 'set', + 'aws_secret_access_key', aws_secret_access_key + ], check=True) + + subprocess.run([ + 'aws', 'configure', 'set', + 'region', aws_region_input + ], check=True) + + print("AWS credentials have been successfully configured.") + except subprocess.CalledProcessError as e: + print(f"An error occurred while configuring AWS credentials: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + def load_config(self): + config = configparser.ConfigParser() + if os.path.exists(self.CONFIG_FILE): + config.read(self.CONFIG_FILE) + return dict(config['DEFAULT']) + return {} + + + def save_config(self, config): + parser = configparser.ConfigParser() + parser['DEFAULT'] = config + with open(self.CONFIG_FILE, 'w') as f: + parser.write(f) + + def get_password_with_asterisks(self, prompt="Enter password: "): # Accept 'prompt' + import sys + if sys.platform == 'win32': + import msvcrt + print(prompt, end='', flush=True) + password = "" + while True: + key = msvcrt.getch() + if key == b'\r': # Enter key + sys.stdout.write('\n') + return password + elif key == b'\x08': # Backspace key + if len(password) > 0: + password = password[:-1] + sys.stdout.write('\b \b') # Erase the last asterisk + sys.stdout.flush() + else: + password += key.decode('utf-8') + sys.stdout.write('*') # Mask input with '*' + sys.stdout.flush() + else: + import termios, tty + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + sys.stdout.write(prompt) + sys.stdout.flush() + password = "" + while True: + ch = sys.stdin.read(1) + if ch in ('\r', '\n'): # Enter key + sys.stdout.write('\n') + return password + elif ch == '\x7f': # Backspace key + if len(password) > 0: + password = password[:-1] + sys.stdout.write('\b \b') # Erase the last asterisk + sys.stdout.flush() + else: + password += ch + sys.stdout.write('*') # Mask input with '*' + sys.stdout.flush() + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + def setup_configuration(self): + config = self.load_config() + + self.aws_region = input(f"Enter your AWS Region [{config.get('region', 'us-west-2')}]: ") or config.get('region', 'us-west-2') + self.iam_principal = input(f"Enter your IAM Principal ARN [{config.get('iam_principal', '')}]: ") or config.get('iam_principal', '') + + service_type = input("Choose OpenSearch service type (1 for Serverless, 2 for Managed): ") + self.is_serverless = service_type == '1' + + if self.is_serverless: + self.index_name = input("Enter a name for your KNN index in OpenSearch: ") + self.collection_name = input("Enter the name for your OpenSearch collection: ") + self.opensearch_endpoint = None + self.opensearch_username = None + self.opensearch_password = None + else: + self.index_name = input("Enter a name for your KNN index in OpenSearch: ") + self.opensearch_endpoint = input("Enter your OpenSearch domain endpoint: ") + self.opensearch_username = input("Enter your OpenSearch username: ") + self.opensearch_password = self.get_password_with_asterisks("Enter your OpenSearch password: ") + self.collection_name = '' + + self.config = { + 'region': self.aws_region, + 'iam_principal': self.iam_principal, + 'index_name': self.index_name, + 'collection_name': self.collection_name if self.collection_name else '', + 'is_serverless': str(self.is_serverless), + 'opensearch_endpoint': self.opensearch_endpoint if self.opensearch_endpoint else '', + 'opensearch_username': self.opensearch_username if self.opensearch_username else '', + 'opensearch_password': self.opensearch_password if self.opensearch_password else '' + } + self.save_config(self.config) + print("Configuration saved successfully.") + + def initialize_clients(self): + try: + boto_config = Config( + region_name=self.aws_region, + signature_version='v4', + retries={'max_attempts': 10, 'mode': 'standard'} + ) + if self.is_serverless: + self.aoss_client = boto3.client(self.SERVICE_AOSS, config=boto_config) + self.bedrock_client = boto3.client(self.SERVICE_BEDROCK, region_name=self.aws_region) + + time.sleep(7) + print("AWS clients initialized successfully.") + return True + except Exception as e: + print(f"Failed to initialize AWS clients: {e}") + return False + + def create_security_policies(self): + if not self.is_serverless: + print("Security policies are not applicable for managed OpenSearch domains.") + return + + encryption_policy = json.dumps({ + "Rules": [{"Resource": [f"collection/{self.collection_name}"], "ResourceType": "collection"}], + "AWSOwnedKey": True + }) + + network_policy = json.dumps([{ + "Rules": [{"Resource": [f"collection/{self.collection_name}"], "ResourceType": "collection"}], + "AllowFromPublic": True + }]) + + data_access_policy = json.dumps([{ + "Rules": [ + {"Resource": ["collection/*"], "Permission": ["aoss:*"], "ResourceType": "collection"}, + {"Resource": ["index/*/*"], "Permission": ["aoss:*"], "ResourceType": "index"} + ], + "Principal": [self.iam_principal], + "Description": f"Data access policy for {self.collection_name}" + }]) + + encryption_policy_name = self.get_truncated_name(f"{self.collection_name}-enc-policy") + self.create_security_policy("encryption", encryption_policy_name, f"{self.collection_name} encryption security policy", encryption_policy) + self.create_security_policy("network", f"{self.collection_name}-net-policy", f"{self.collection_name} network security policy", network_policy) + self.create_access_policy(self.get_truncated_name(f"{self.collection_name}-access-policy"), f"{self.collection_name} data access policy", data_access_policy) + + def create_security_policy(self, policy_type, name, description, policy_body): + try: + if policy_type.lower() == "encryption": + self.aoss_client.create_security_policy(description=description, name=name, policy=policy_body, type="encryption") + elif policy_type.lower() == "network": + self.aoss_client.create_security_policy(description=description, name=name, policy=policy_body, type="network") + else: + raise ValueError("Invalid policy type specified.") + print(f"{policy_type.capitalize()} Policy '{name}' created successfully.") + except self.aoss_client.exceptions.ConflictException: + print(f"{policy_type.capitalize()} Policy '{name}' already exists.") + except Exception as ex: + print(f"Error creating {policy_type} policy '{name}': {ex}") + + def create_access_policy(self, name, description, policy_body): + try: + self.aoss_client.create_access_policy(description=description, name=name, policy=policy_body, type="data") + print(f"Data Access Policy '{name}' created successfully.") + except self.aoss_client.exceptions.ConflictException: + print(f"Data Access Policy '{name}' already exists.") + except Exception as ex: + print(f"Error creating data access policy '{name}': {ex}") + + def create_collection(self, collection_name, max_retries=3): + for attempt in range(max_retries): + try: + response = self.aoss_client.create_collection( + description=f"{collection_name} collection", + name=collection_name, + type="VECTORSEARCH" + ) + print(f"Collection '{collection_name}' creation initiated.") + return response['createCollectionDetail']['id'] + except self.aoss_client.exceptions.ConflictException: + print(f"Collection '{collection_name}' already exists.") + return self.get_collection_id(collection_name) + except Exception as ex: + print(f"Error creating collection '{collection_name}' (Attempt {attempt+1}/{max_retries}): {ex}") + if attempt == max_retries - 1: + return None + time.sleep(5) + return None + + def get_collection_id(self, collection_name): + try: + response = self.aoss_client.list_collections() + for collection in response['collectionSummaries']: + if collection['name'] == collection_name: + return collection['id'] + except Exception as ex: + print(f"Error getting collection ID: {ex}") + return None + + def wait_for_collection_active(self, collection_id, max_wait_minutes=30): + print(f"Waiting for collection '{self.collection_name}' to become active...") + start_time = time.time() + while time.time() - start_time < max_wait_minutes * 60: + try: + response = self.aoss_client.batch_get_collection(ids=[collection_id]) + status = response['collectionDetails'][0]['status'] + if status == 'ACTIVE': + print(f"Collection '{self.collection_name}' is now active.") + return True + elif status in ['FAILED', 'DELETED']: + print(f"Collection creation failed or was deleted. Status: {status}") + return False + else: + print(f"Collection status: {status}. Waiting...") + time.sleep(30) + except Exception as ex: + print(f"Error checking collection status: {ex}") + time.sleep(30) + print(f"Timed out waiting for collection to become active after {max_wait_minutes} minutes.") + return False + + def get_collection_endpoint(self): + if not self.is_serverless: + return self.opensearch_endpoint + + try: + collection_id = self.get_collection_id(self.collection_name) + if not collection_id: + print(f"Collection '{self.collection_name}' not found.") + return None + + batch_get_response = self.aoss_client.batch_get_collection(ids=[collection_id]) + collection_details = batch_get_response.get('collectionDetails', []) + + if not collection_details: + print(f"No details found for collection ID '{collection_id}'.") + return None + + self.opensearch_endpoint = collection_details[0].get('collectionEndpoint') + if self.opensearch_endpoint: + print(f"Collection '{self.collection_name}' has endpoint URL: {self.opensearch_endpoint}") + return self.opensearch_endpoint + else: + print(f"No endpoint URL found in collection '{self.collection_name}'.") + return None + except Exception as ex: + print(f"Error retrieving collection endpoint: {ex}") + return None + + def initialize_opensearch_client(self): + if not self.opensearch_endpoint: + print("OpenSearch endpoint not set. Please run setup first.") + return False + + parsed_url = urlparse(self.opensearch_endpoint) + host = parsed_url.hostname + port = parsed_url.port or 443 + + if self.is_serverless: + credentials = boto3.Session().get_credentials() + auth = AWSV4SignerAuth(credentials, self.aws_region, 'aoss') + else: + if not self.opensearch_username or not self.opensearch_password: + print("OpenSearch username or password not set. Please run setup first.") + return False + auth = (self.opensearch_username, self.opensearch_password) + + try: + self.opensearch_client = OpenSearch( + hosts=[{'host': host, 'port': port}], + http_auth=auth, + use_ssl=True, + verify_certs=True, + connection_class=RequestsHttpConnection, + pool_maxsize=20 + ) + print(f"Initialized OpenSearch client with host: {host} and port: {port}") + return True + except Exception as ex: + print(f"Error initializing OpenSearch client: {ex}") + return False + + def get_knn_index_details(self): + # Simplified dimension input + dimension_input = input("Press Enter to use the default embedding size (768), or type a custom size: ") + + if dimension_input.strip() == "": + embedding_dimension = 768 + else: + try: + embedding_dimension = int(dimension_input) + except ValueError: + print("Invalid input. Using default dimension of 768.") + embedding_dimension = 768 + + print(f"\nEmbedding dimension set to: {embedding_dimension}") + + # Space type selection + print("\nChoose the space type for KNN:") + print("1. L2 (Euclidean distance)") + print("2. Cosine similarity") + print("3. Inner product") + space_choice = input("Enter your choice (1-3), or press Enter for default (L2): ") + + if space_choice == "" or space_choice == "1": + space_type = "l2" + elif space_choice == "2": + space_type = "cosinesimil" + elif space_choice == "3": + space_type = "innerproduct" + else: + print("Invalid choice. Using default space type of L2 (Euclidean distance).") + space_type = "l2" + + print(f"Space type set to: {space_type}") + + return embedding_dimension, space_type + + + def create_index(self, embedding_dimension, space_type): + index_body = { + "mappings": { + "properties": { + "nominee_text": {"type": "text"}, + "nominee_vector": { + "type": "knn_vector", + "dimension": embedding_dimension, + "method": { + "name": "hnsw", + "space_type": space_type, + "engine": "nmslib", + "parameters": {"ef_construction": 512, "m": 16}, + }, + }, + } + }, + "settings": { + "index": { + "number_of_shards": 2, + "knn.algo_param": {"ef_search": 512}, + "knn": True, + } + }, + } + try: + self.opensearch_client.indices.create(index=self.index_name, body=index_body) + print(f"KNN index '{self.index_name}' created successfully with dimension {embedding_dimension} and space type {space_type}.") + except Exception as e: + if 'resource_already_exists_exception' in str(e).lower(): + print(f"Index '{self.index_name}' already exists.") + else: + print(f"Error creating index '{self.index_name}': {e}") + + def verify_and_create_index(self, embedding_dimension, space_type): + try: + index_exists = self.opensearch_client.indices.exists(index=self.index_name) + if index_exists: + print(f"KNN index '{self.index_name}' already exists.") + else: + self.create_index(embedding_dimension, space_type) + return True + except Exception as ex: + print(f"Error verifying or creating index: {ex}") + return False + + def get_truncated_name(self, base_name, max_length=32): + if len(base_name) <= max_length: + return base_name + return base_name[:max_length-3] + "..." + + def setup_command(self): + self.check_and_configure_aws() + self.setup_configuration() + + if not self.initialize_clients(): + print("Failed to initialize AWS clients. Setup incomplete.") + return + + if self.is_serverless: + self.create_security_policies() + collection_id = self.get_collection_id(self.collection_name) + if not collection_id: + print(f"Collection '{self.collection_name}' not found. Attempting to create it...") + collection_id = self.create_collection(self.collection_name) + + if collection_id: + if self.wait_for_collection_active(collection_id): + self.opensearch_endpoint = self.get_collection_endpoint() + if not self.opensearch_endpoint: + print("Failed to retrieve OpenSearch endpoint. Setup incomplete.") + return + else: + self.config['opensearch_endpoint'] = self.opensearch_endpoint + else: + print("Collection is not active. Setup incomplete.") + return + else: + if not self.opensearch_endpoint: + print("OpenSearch endpoint not set. Setup incomplete.") + return + + if self.initialize_opensearch_client(): + embedding_dimension, space_type = self.get_knn_index_details() + if self.verify_and_create_index(embedding_dimension, space_type): + print("Setup completed successfully.") + self.config['embedding_dimension'] = str(embedding_dimension) + self.config['space_type'] = space_type + else: + print("Index verification failed. Please check your index name and permissions.") + else: + print("Failed to initialize OpenSearch client. Setup incomplete.") \ No newline at end of file diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/requirements.txt b/opensearch_py_ml/ml_commons/rag_pipeline/rag/requirements.txt new file mode 100644 index 00000000..dc41b248 --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/requirements.txt @@ -0,0 +1,9 @@ +boto3 +opensearch-py +pandas +configparser +PyPDF2 +tiktoken +tqdm +colorama +requests_aws4auth diff --git a/opensearch_py_ml/ml_commons/rag_pipeline/rag/setup.py b/opensearch_py_ml/ml_commons/rag_pipeline/rag/setup.py new file mode 100644 index 00000000..b73dbcf5 --- /dev/null +++ b/opensearch_py_ml/ml_commons/rag_pipeline/rag/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages, find_namespace_packages + + + + +setup( + name="rag_pipeline", + version="0.1.0", + packages=find_namespace_packages(include=['opensearch_py_ml', 'opensearch_py_ml.*']), + entry_points={ + 'console_scripts': [ + 'rag=opensearch_py_ml.ml_commons.rag_pipeline.rag:main', + ], + }, +) \ No newline at end of file