From 698efbd1e0c77a2c7e78be4f6080d96b8a0cd912 Mon Sep 17 00:00:00 2001 From: Arjun Bingly Date: Fri, 22 Mar 2024 18:03:14 -0400 Subject: [PATCH 1/4] Update to remove ruff errors --- pyproject.toml | 4 ++ src/config.ini | 10 ++++- src/grag/components/embedding.py | 8 ++++ src/grag/components/llm.py | 30 +++++++------- src/grag/components/multivec_retriever.py | 38 ++++++++++------- src/grag/components/parse_pdf.py | 36 +++++++++------- src/grag/components/prompt.py | 50 ++++++++++++++++++++++- src/grag/components/text_splitter.py | 22 ++++++++-- src/grag/components/utils.py | 17 ++++++-- 9 files changed, 163 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 897ab02..f7c2d4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,9 +101,13 @@ exclude_lines = [ [tool.ruff] line-length = 88 indent-width = 4 +extend-exclude = ["tests", "others"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "D"] +ignore = ["D104"] +exclude = ["__about__.py"] + [tool.ruff.format] quote-style = "double" diff --git a/src/config.ini b/src/config.ini index 452ac04..54990bf 100644 --- a/src/config.ini +++ b/src/config.ini @@ -25,6 +25,14 @@ embedding_model : hkunlp/instructor-xl store_path : ${data:data_path}/vectordb allow_reset : True +[deeplake] +collection_name : arxiv +# embedding_type : sentence-transformers +# embedding_model : "all-mpnet-base-v2" +embedding_type : instructor-embedding +embedding_model : hkunlp/instructor-xl +store_path : ${data:data_path}/vectordb + [text_splitter] chunk_size : 5000 chunk_overlap : 400 @@ -51,4 +59,4 @@ table_as_html : True data_path : ${root:root_path}/data [root] -root_path : /home/ubuntu/volume_2k/Capstone_5 \ No newline at end of file +root_path : /home/ubuntu/CapStone/Capstone_5 diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index 7a9d249..73202e0 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -1,3 +1,9 @@ +"""Class for embedding. + +This module provies: +- Embedding +""" + from langchain_community.embeddings import HuggingFaceInstructEmbeddings from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, @@ -6,6 +12,7 @@ class Embedding: """A class for vector embeddings. + Supports: huggingface sentence transformers -> model_type = 'sentence-transformers' huggingface instructor embeddings -> model_type = 'instructor-embedding' @@ -17,6 +24,7 @@ class Embedding: """ def __init__(self, embedding_type: str, embedding_model: str): + """Initialize the embedding with embedding_type and embedding_model.""" self.embedding_type = embedding_type self.embedding_model = embedding_model match self.embedding_type: diff --git a/src/grag/components/llm.py b/src/grag/components/llm.py index 20db968..23d2a26 100644 --- a/src/grag/components/llm.py +++ b/src/grag/components/llm.py @@ -1,3 +1,4 @@ +"""Class for LLM.""" import os from pathlib import Path @@ -36,20 +37,21 @@ class LLM: """ def __init__( - self, - model_name=llm_conf["model_name"], - device_map=llm_conf["device_map"], - task=llm_conf["task"], - max_new_tokens=llm_conf["max_new_tokens"], - temperature=llm_conf["temperature"], - n_batch=llm_conf["n_batch_gpu_cpp"], - n_ctx=llm_conf["n_ctx_cpp"], - n_gpu_layers=llm_conf["n_gpu_layers_cpp"], - std_out=llm_conf["std_out"], - base_dir=llm_conf["base_dir"], - quantization=llm_conf["quantization"], - pipeline=llm_conf["pipeline"], + self, + model_name=llm_conf["model_name"], + device_map=llm_conf["device_map"], + task=llm_conf["task"], + max_new_tokens=llm_conf["max_new_tokens"], + temperature=llm_conf["temperature"], + n_batch=llm_conf["n_batch_gpu_cpp"], + n_ctx=llm_conf["n_ctx_cpp"], + n_gpu_layers=llm_conf["n_gpu_layers_cpp"], + std_out=llm_conf["std_out"], + base_dir=llm_conf["base_dir"], + quantization=llm_conf["quantization"], + pipeline=llm_conf["pipeline"], ): + """Initialize the LLM class using the given parameters.""" self.base_dir = Path(base_dir) self._model_name = model_name self.quantization = quantization @@ -160,7 +162,7 @@ def llama_cpp(self): return llm def load_model( - self, model_name=None, pipeline=None, quantization=None, is_local=None + self, model_name=None, pipeline=None, quantization=None, is_local=None ): """Loads the model based on the specified pipeline and model name. diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index 18ed752..253cb31 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -1,3 +1,8 @@ +"""Class for retriever. + +This module provides: +- Retriever +""" import asyncio import uuid from typing import List @@ -13,9 +18,11 @@ class Retriever: - """A class for multi vector retriever, it connects to a vector database and a local file store. - It is used to return most similar chunks from a vector store but has the additional funcationality - to return a linked document, chunk, etc. + """A class for multi vector retriever. + + It connects to a vector database and a local file store. + It is used to return most similar chunks from a vector store but has the additional functionality to return a + linked document, chunk, etc. Attributes: store_path: Path to the local file store @@ -30,13 +37,15 @@ class Retriever: """ def __init__( - self, - store_path: str = multivec_retriever_conf["store_path"], - id_key: str = multivec_retriever_conf["id_key"], - namespace: str = multivec_retriever_conf["namespace"], - top_k=1, + self, + store_path: str = multivec_retriever_conf["store_path"], + id_key: str = multivec_retriever_conf["id_key"], + namespace: str = multivec_retriever_conf["namespace"], + top_k=1, ): - """Args: + """Initialize the Retriever. + + Args: store_path: Path to the local file store, defaults to argument from config file id_key: A key prefix for identifying documents, defaults to argument from config file namespace: A namespace for producing unique id, defaults to argument from congig file @@ -58,6 +67,7 @@ def __init__( def id_gen(self, doc: Document) -> str: """Takes a document and returns a unique id (uuid5) using the namespace and document source. + This ensures that a single document always gets the same unique id. Args: @@ -81,7 +91,9 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]: return [self.id_gen(doc) for doc in docs] def split_docs(self, docs: List[Document]) -> List[Document]: - """Takes a list of documents and splits them into smaller chunks using TextSplitter from compoenents.text_splitter + """Takes a list of documents and splits them into smaller chunks. + + Using TextSplitter from components.text_splitter Also adds the unique parent document id into metadata Args: @@ -101,8 +113,7 @@ def split_docs(self, docs: List[Document]) -> List[Document]: return chunks def add_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document @@ -117,8 +128,7 @@ def add_docs(self, docs: List[Document]): self.retriever.docstore.mset(list(zip(doc_ids, docs))) async def aadd_docs(self, docs: List[Document]): - """Takes a list of documents, splits them using the split_docs method and then adds them into the vector database - and adds the parent document into the file store. + """Adds given documents into the vector database also adds the parent document into the file store. Args: docs: List of langchain_core.documents.Document diff --git a/src/grag/components/parse_pdf.py b/src/grag/components/parse_pdf.py index d918c93..9ab2205 100644 --- a/src/grag/components/parse_pdf.py +++ b/src/grag/components/parse_pdf.py @@ -1,3 +1,8 @@ +"""Classes for parsing files. + +This module provides: +- ParsePDF +""" from langchain_core.documents import Document from unstructured.partition.pdf import partition_pdf @@ -22,17 +27,17 @@ class ParsePDF: """ def __init__( - self, - single_text_out=parser_conf["single_text_out"], - strategy=parser_conf["strategy"], - infer_table_structure=parser_conf["infer_table_structure"], - extract_images=parser_conf["extract_images"], - image_output_dir=parser_conf["image_output_dir"], - add_captions_to_text=parser_conf["add_captions_to_text"], - add_captions_to_blocks=parser_conf["add_captions_to_blocks"], - table_as_html=parser_conf["table_as_html"], + self, + single_text_out=parser_conf["single_text_out"], + strategy=parser_conf["strategy"], + infer_table_structure=parser_conf["infer_table_structure"], + extract_images=parser_conf["extract_images"], + image_output_dir=parser_conf["image_output_dir"], + add_captions_to_text=parser_conf["add_captions_to_text"], + add_captions_to_blocks=parser_conf["add_captions_to_blocks"], + table_as_html=parser_conf["table_as_html"], ): - # Instantialize instance variables with parameters + """Initialize instance variables with parameters.""" self.strategy = strategy if extract_images: # by default always extract Table self.extract_image_block_types = [ @@ -72,7 +77,8 @@ def partition(self, path: str): def classify(self, partitions): """Classifies the partitioned elements into Text, Tables, and Images list in a dictionary. - Add captions for each element (if available). + + Also adds captions for each element (if available). Parameters: partitions (list): The list of partitioned elements from the PDF document. @@ -88,7 +94,7 @@ def classify(self, partitions): if element.category == "Table": if self.add_captions_to_blocks and i + 1 < len(partitions): if ( - partitions[i + 1].category == "FigureCaption" + partitions[i + 1].category == "FigureCaption" ): # check for caption caption_element = partitions[i + 1] else: @@ -99,7 +105,7 @@ def classify(self, partitions): elif element.category == "Image": if self.add_captions_to_blocks and i + 1 < len(partitions): if ( - partitions[i + 1].category == "FigureCaption" + partitions[i + 1].category == "FigureCaption" ): # check for caption caption_element = partitions[i + 1] else: @@ -117,6 +123,8 @@ def classify(self, partitions): return classified_elements def text_concat(self, elements) -> str: + """Context aware concatenates all elements into a single string.""" + full_text = "" for current_element, next_element in zip(elements, elements[1:]): curr_type = current_element.category next_type = next_element.category @@ -185,7 +193,7 @@ def process_tables(self, elements): if caption_element: if ( - self.add_caption_first + self.add_caption_first ): # if there is a caption, add that before the element content = "\n\n".join([str(caption_element), table_data]) else: diff --git a/src/grag/components/prompt.py b/src/grag/components/prompt.py index ecefa71..86bf8df 100644 --- a/src/grag/components/prompt.py +++ b/src/grag/components/prompt.py @@ -1,3 +1,9 @@ +"""Classes for prompts. + +This module provides: +- Prompt - for generic prompts +- FewShotPrompt - for few-shot prompts +""" import json from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -13,6 +19,19 @@ class Prompt(BaseModel): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") + task (str): The task (Optional, defaults to QA) + source (str): The source of the prompt (Optional, defaults to "NoSource") + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") + language (str): The language of the prompt (Optional, defaults to "en") + filepath (str): The filepath of the prompt (Optional) + input_keys (List[str]): The input keys for the prompt + template (str): The template for the prompt + """ name: str = Field(default="custom_prompt") llm_type: str = Field(default="None") task: str = Field(default="QA") @@ -27,6 +46,7 @@ class Prompt(BaseModel): @field_validator("input_keys") @classmethod def validate_input_keys(cls, v) -> List[str]: + """Validate the input_keys field.""" if v is None or v == []: raise ValueError("input_keys cannot be empty") return v @@ -34,6 +54,7 @@ def validate_input_keys(cls, v) -> List[str]: @field_validator("doc_chain") @classmethod def validate_doc_chain(cls, v: str) -> str: + """Validate the doc_chain field.""" if v not in SUPPORTED_DOC_CHAINS: raise ValueError( f"The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}" @@ -43,6 +64,7 @@ def validate_doc_chain(cls, v: str) -> str: @field_validator("task") @classmethod def validate_task(cls, v: str) -> str: + """Validate the task field.""" if v not in SUPPORTED_TASKS: raise ValueError( f"The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}" @@ -53,14 +75,16 @@ def validate_task(cls, v: str) -> str: # def load_template(self): # self.prompt = ChatPromptTemplate.from_template(self.template) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) self.prompt = PromptTemplate( input_variables=self.input_keys, template=self.template ) def save( - self, filepath: Union[Path, str, None], overwrite=False + self, filepath: Union[Path, str, None], overwrite=False ) -> Union[None, ValueError]: + """Saves the prompt class into a json file.""" dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True) if filepath is None: filepath = f"{self.name}.json" @@ -74,17 +98,36 @@ def save( @classmethod def load(cls, filepath: Union[Path, str]): + """Loads a json file and returns a Prompt class.""" with open(f"{filepath}", "r") as f: prompt_json = json.load(f) _prompt = cls(**prompt_json) _prompt.filepath = str(filepath) return _prompt - def format(self, **kwargs): + def format(self, **kwargs) -> str: + """Formats the prompt with provided keys and returns a string.""" return self.prompt.format(**kwargs) class FewShotPrompt(Prompt): + """A class for generic prompts. + + Attributes: + name (str): The prompt name (Optional, defaults to "custom_prompt") (Parent Class) + llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") (Parent Class) + task (str): The task (Optional, defaults to QA) (Parent Class) + source (str): The source of the prompt (Optional, defaults to "NoSource") (Parent Class) + doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") (Parent Class) + language (str): The language of the prompt (Optional, defaults to "en") (Parent Class) + filepath (str): The filepath of the prompt (Optional) (Parent Class) + input_keys (List[str]): The input keys for the prompt (Parent Class) + input_keys (List[str]): The output keys for the prompt + prefix (str): The template prefix for the prompt + suffix (str): The template suffix for the prompt + example_template (str): The template for formatting the examples + examples (List[Dict[str, Any]]): The list of examples, each example is a dictionary with respective keys + """ output_keys: List[str] examples: List[Dict[str, Any]] prefix: str @@ -95,6 +138,7 @@ class FewShotPrompt(Prompt): ) def __init__(self, **kwargs): + """Initialize the prompt.""" super().__init__(**kwargs) eg_formatter = PromptTemplate( input_vars=self.input_keys + self.output_keys, @@ -111,6 +155,7 @@ def __init__(self, **kwargs): @field_validator("output_keys") @classmethod def validate_output_keys(cls, v) -> List[str]: + """Validate the output_keys field.""" if v is None or v == []: raise ValueError("output_keys cannot be empty") return v @@ -118,6 +163,7 @@ def validate_output_keys(cls, v) -> List[str]: @field_validator("examples") @classmethod def validate_examples(cls, v) -> List[Dict[str, Any]]: + """Validate the examples field.""" if v is None or v == []: raise ValueError("examples cannot be empty") for eg in v: diff --git a/src/grag/components/text_splitter.py b/src/grag/components/text_splitter.py index cff3c7c..a0ecfd9 100644 --- a/src/grag/components/text_splitter.py +++ b/src/grag/components/text_splitter.py @@ -1,3 +1,8 @@ +"""Class for splitting/chunking text. + +This module provides: +- TextSplitter +""" from langchain.text_splitter import RecursiveCharacterTextSplitter from .utils import get_config @@ -7,10 +12,21 @@ # %% class TextSplitter: - def __init__(self): + """Class for recursively chunking text, it prioritizes '/n/n then '/n' and so on. + + Attributes: + chunk_size: maximum size of chunk + chunk_overlap: chunk overlap size + """ + + def __init__(self, + chunk_size: int = text_splitter_conf["chunk_size"], + chunk_overlap: int = text_splitter_conf["chunk_overlap"]): + """Initialize TextSplitter.""" self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=int(text_splitter_conf["chunk_size"]), - chunk_overlap=int(text_splitter_conf["chunk_overlap"]), + chunk_size=int(chunk_size), + chunk_overlap=int(chunk_overlap), length_function=len, is_separator_regex=False, ) + """Initialize TextSplitter using chunk_size and chunk_overlap""" diff --git a/src/grag/components/utils.py b/src/grag/components/utils.py index cb64258..0233d32 100644 --- a/src/grag/components/utils.py +++ b/src/grag/components/utils.py @@ -1,3 +1,11 @@ +"""Utils functions. + +This module provides: +- stuff_docs: concats langchain documents into string +- load_prompt: loads json prompt to langchain prompt +- find_config_path: finds the path of the 'config.ini' file by traversing up the directory tree from the current path. +- get_config: retrieves and parses the configuration settings from the 'config.ini' file. +""" import json import os import textwrap @@ -10,7 +18,9 @@ def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates langchain documents into a string using '\n\n' seperator. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -20,8 +30,7 @@ def stuff_docs(docs: List[Document]) -> str: def reformat_text_with_line_breaks(input_text, max_width=110): - """Reformat the given text to ensure each line does not exceed a specific width, - preserving existing line breaks. + """Reformat the given text to ensure each line does not exceed a specific width, preserving existing line breaks. Args: input_text (str): The text to be reformatted. @@ -62,7 +71,7 @@ def display_llm_output_and_sources(response_from_llm): def load_prompt(json_file: str | os.PathLike, return_input_vars=False): - """Loads a prompt template from json file and returns a langchain ChatPromptTemplate + """Loads a prompt template from json file and returns a langchain ChatPromptTemplate. Args: json_file: path to the prompt template json file. From acd4ba282f6c0dc821b80a4bc23a1a70849349e6 Mon Sep 17 00:00:00 2001 From: Arjun Bingly Date: Fri, 22 Mar 2024 18:20:14 -0400 Subject: [PATCH 2/4] Ruff bugs --- src/grag/rag/basic_rag.py | 44 ++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/src/grag/rag/basic_rag.py b/src/grag/rag/basic_rag.py index a99ecdd..05a7520 100644 --- a/src/grag/rag/basic_rag.py +++ b/src/grag/rag/basic_rag.py @@ -1,3 +1,8 @@ +"""Class for Basic RAG. + +This module provides: +- BasicRAG +""" import json from typing import List, Union @@ -13,15 +18,27 @@ class BasicRAG: + """Class for Basis RAG. + + Attributes: + model_name (str): Name of the llm model + doc_chain (str): Name of the document chain, ("stuff", "refine"), defaults to "stuff" + task (str): Name of task, defaults to "QA" + llm_kwargs (dict): Keyword arguments for LLM class + retriever_kwargs (dict): Keyword arguments for Retriever class + custom_prompt (Prompt): Prompt, defaults to None + """ + def __init__( - self, - model_name=None, - doc_chain="stuff", - task="QA", - llm_kwargs=None, - retriever_kwargs=None, - custom_prompt: Union[Prompt, FewShotPrompt, None] = None, + self, + model_name=None, + doc_chain="stuff", + task="QA", + llm_kwargs=None, + retriever_kwargs=None, + custom_prompt: Union[Prompt, FewShotPrompt, List[Prompt, FewShotPrompt], None] = None, ): + """Initialize BasicRAG.""" if retriever_kwargs is None: self.retriever = Retriever() else: @@ -54,6 +71,7 @@ def __init__( @property def model_name(self): + """Return the name of the model.""" return self._model_name @model_name.setter @@ -67,6 +85,7 @@ def model_name(self, value): @property def doc_chain(self): + """Returns the doc_chain.""" return self._doc_chain @doc_chain.setter @@ -86,6 +105,7 @@ def doc_chain(self, value): @property def task(self): + """Returns the task.""" return self._task @task.setter @@ -99,6 +119,7 @@ def task(self, value): self.prompt_matcher() def prompt_matcher(self): + """Matches relvant prompt using model, task and doc_chain.""" matcher_path = self.prompt_path.joinpath("matcher.json") with open(f"{matcher_path}", "r") as f: matcher_dict = json.load(f) @@ -122,7 +143,9 @@ def prompt_matcher(self): @staticmethod def stuff_docs(docs: List[Document]) -> str: - """Args: + r"""Concatenates docs into a string seperated by '\n\n'. + + Args: docs: List of langchain_core.documents.Document Returns: @@ -132,6 +155,8 @@ def stuff_docs(docs: List[Document]) -> str: @staticmethod def output_parser(call_func): + """Decorator to format llm output.""" + def output_parser_wrapper(*args, **kwargs): response, sources = call_func(*args, **kwargs) if conf["llm"]["std_out"] == "False": @@ -146,6 +171,7 @@ def output_parser_wrapper(*args, **kwargs): @output_parser def stuff_call(self, query: str): + """Call function for stuff chain.""" retrieved_docs = self.retriever.get_chunk(query) context = self.stuff_docs(retrieved_docs) prompt = self.main_prompt.format(context=context, question=query) @@ -155,6 +181,7 @@ def stuff_call(self, query: str): @output_parser def refine_call(self, query: str): + """Call function for refine chain.""" retrieved_docs = self.retriever.get_chunk(query) sources = [doc.metadata["source"] for doc in retrieved_docs] responses = [] @@ -176,6 +203,7 @@ def refine_call(self, query: str): return responses, sources def __call__(self, query: str): + """Call function for the class.""" if self.doc_chain == "stuff": return self.stuff_call(query) elif self.doc_chain == "refine": From 2df28d1c7c2cf0b640faddcd796f8888607fc182 Mon Sep 17 00:00:00 2001 From: arjbingly Date: Fri, 22 Mar 2024 22:23:56 +0000 Subject: [PATCH 3/4] style fixes by ruff --- src/grag/components/embedding.py | 2 +- src/grag/components/llm.py | 29 ++++++++++++----------- src/grag/components/multivec_retriever.py | 21 ++++++++-------- src/grag/components/parse_pdf.py | 27 +++++++++++---------- src/grag/components/prompt.py | 7 ++++-- src/grag/components/text_splitter.py | 11 +++++---- src/grag/components/utils.py | 3 ++- src/grag/rag/basic_rag.py | 21 +++++++++------- 8 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index 73202e0..eab107f 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -12,7 +12,7 @@ class Embedding: """A class for vector embeddings. - + Supports: huggingface sentence transformers -> model_type = 'sentence-transformers' huggingface instructor embeddings -> model_type = 'instructor-embedding' diff --git a/src/grag/components/llm.py b/src/grag/components/llm.py index 23d2a26..6e7296c 100644 --- a/src/grag/components/llm.py +++ b/src/grag/components/llm.py @@ -1,4 +1,5 @@ """Class for LLM.""" + import os from pathlib import Path @@ -37,19 +38,19 @@ class LLM: """ def __init__( - self, - model_name=llm_conf["model_name"], - device_map=llm_conf["device_map"], - task=llm_conf["task"], - max_new_tokens=llm_conf["max_new_tokens"], - temperature=llm_conf["temperature"], - n_batch=llm_conf["n_batch_gpu_cpp"], - n_ctx=llm_conf["n_ctx_cpp"], - n_gpu_layers=llm_conf["n_gpu_layers_cpp"], - std_out=llm_conf["std_out"], - base_dir=llm_conf["base_dir"], - quantization=llm_conf["quantization"], - pipeline=llm_conf["pipeline"], + self, + model_name=llm_conf["model_name"], + device_map=llm_conf["device_map"], + task=llm_conf["task"], + max_new_tokens=llm_conf["max_new_tokens"], + temperature=llm_conf["temperature"], + n_batch=llm_conf["n_batch_gpu_cpp"], + n_ctx=llm_conf["n_ctx_cpp"], + n_gpu_layers=llm_conf["n_gpu_layers_cpp"], + std_out=llm_conf["std_out"], + base_dir=llm_conf["base_dir"], + quantization=llm_conf["quantization"], + pipeline=llm_conf["pipeline"], ): """Initialize the LLM class using the given parameters.""" self.base_dir = Path(base_dir) @@ -162,7 +163,7 @@ def llama_cpp(self): return llm def load_model( - self, model_name=None, pipeline=None, quantization=None, is_local=None + self, model_name=None, pipeline=None, quantization=None, is_local=None ): """Loads the model based on the specified pipeline and model name. diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index 253cb31..97684dd 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -3,6 +3,7 @@ This module provides: - Retriever """ + import asyncio import uuid from typing import List @@ -19,9 +20,9 @@ class Retriever: """A class for multi vector retriever. - + It connects to a vector database and a local file store. - It is used to return most similar chunks from a vector store but has the additional functionality to return a + It is used to return most similar chunks from a vector store but has the additional functionality to return a linked document, chunk, etc. Attributes: @@ -37,14 +38,14 @@ class Retriever: """ def __init__( - self, - store_path: str = multivec_retriever_conf["store_path"], - id_key: str = multivec_retriever_conf["id_key"], - namespace: str = multivec_retriever_conf["namespace"], - top_k=1, + self, + store_path: str = multivec_retriever_conf["store_path"], + id_key: str = multivec_retriever_conf["id_key"], + namespace: str = multivec_retriever_conf["namespace"], + top_k=1, ): """Initialize the Retriever. - + Args: store_path: Path to the local file store, defaults to argument from config file id_key: A key prefix for identifying documents, defaults to argument from config file @@ -67,7 +68,7 @@ def __init__( def id_gen(self, doc: Document) -> str: """Takes a document and returns a unique id (uuid5) using the namespace and document source. - + This ensures that a single document always gets the same unique id. Args: @@ -92,7 +93,7 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]: def split_docs(self, docs: List[Document]) -> List[Document]: """Takes a list of documents and splits them into smaller chunks. - + Using TextSplitter from components.text_splitter Also adds the unique parent document id into metadata diff --git a/src/grag/components/parse_pdf.py b/src/grag/components/parse_pdf.py index 9ab2205..dc30f8a 100644 --- a/src/grag/components/parse_pdf.py +++ b/src/grag/components/parse_pdf.py @@ -3,6 +3,7 @@ This module provides: - ParsePDF """ + from langchain_core.documents import Document from unstructured.partition.pdf import partition_pdf @@ -27,15 +28,15 @@ class ParsePDF: """ def __init__( - self, - single_text_out=parser_conf["single_text_out"], - strategy=parser_conf["strategy"], - infer_table_structure=parser_conf["infer_table_structure"], - extract_images=parser_conf["extract_images"], - image_output_dir=parser_conf["image_output_dir"], - add_captions_to_text=parser_conf["add_captions_to_text"], - add_captions_to_blocks=parser_conf["add_captions_to_blocks"], - table_as_html=parser_conf["table_as_html"], + self, + single_text_out=parser_conf["single_text_out"], + strategy=parser_conf["strategy"], + infer_table_structure=parser_conf["infer_table_structure"], + extract_images=parser_conf["extract_images"], + image_output_dir=parser_conf["image_output_dir"], + add_captions_to_text=parser_conf["add_captions_to_text"], + add_captions_to_blocks=parser_conf["add_captions_to_blocks"], + table_as_html=parser_conf["table_as_html"], ): """Initialize instance variables with parameters.""" self.strategy = strategy @@ -77,7 +78,7 @@ def partition(self, path: str): def classify(self, partitions): """Classifies the partitioned elements into Text, Tables, and Images list in a dictionary. - + Also adds captions for each element (if available). Parameters: @@ -94,7 +95,7 @@ def classify(self, partitions): if element.category == "Table": if self.add_captions_to_blocks and i + 1 < len(partitions): if ( - partitions[i + 1].category == "FigureCaption" + partitions[i + 1].category == "FigureCaption" ): # check for caption caption_element = partitions[i + 1] else: @@ -105,7 +106,7 @@ def classify(self, partitions): elif element.category == "Image": if self.add_captions_to_blocks and i + 1 < len(partitions): if ( - partitions[i + 1].category == "FigureCaption" + partitions[i + 1].category == "FigureCaption" ): # check for caption caption_element = partitions[i + 1] else: @@ -193,7 +194,7 @@ def process_tables(self, elements): if caption_element: if ( - self.add_caption_first + self.add_caption_first ): # if there is a caption, add that before the element content = "\n\n".join([str(caption_element), table_data]) else: diff --git a/src/grag/components/prompt.py b/src/grag/components/prompt.py index 86bf8df..4364c06 100644 --- a/src/grag/components/prompt.py +++ b/src/grag/components/prompt.py @@ -4,6 +4,7 @@ - Prompt - for generic prompts - FewShotPrompt - for few-shot prompts """ + import json from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -20,7 +21,7 @@ class Prompt(BaseModel): """A class for generic prompts. - + Attributes: name (str): The prompt name (Optional, defaults to "custom_prompt") llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") @@ -32,6 +33,7 @@ class Prompt(BaseModel): input_keys (List[str]): The input keys for the prompt template (str): The template for the prompt """ + name: str = Field(default="custom_prompt") llm_type: str = Field(default="None") task: str = Field(default="QA") @@ -82,7 +84,7 @@ def __init__(self, **kwargs): ) def save( - self, filepath: Union[Path, str, None], overwrite=False + self, filepath: Union[Path, str, None], overwrite=False ) -> Union[None, ValueError]: """Saves the prompt class into a json file.""" dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True) @@ -128,6 +130,7 @@ class FewShotPrompt(Prompt): example_template (str): The template for formatting the examples examples (List[Dict[str, Any]]): The list of examples, each example is a dictionary with respective keys """ + output_keys: List[str] examples: List[Dict[str, Any]] prefix: str diff --git a/src/grag/components/text_splitter.py b/src/grag/components/text_splitter.py index a0ecfd9..d04c9a5 100644 --- a/src/grag/components/text_splitter.py +++ b/src/grag/components/text_splitter.py @@ -3,6 +3,7 @@ This module provides: - TextSplitter """ + from langchain.text_splitter import RecursiveCharacterTextSplitter from .utils import get_config @@ -13,15 +14,17 @@ # %% class TextSplitter: """Class for recursively chunking text, it prioritizes '/n/n then '/n' and so on. - + Attributes: chunk_size: maximum size of chunk chunk_overlap: chunk overlap size """ - def __init__(self, - chunk_size: int = text_splitter_conf["chunk_size"], - chunk_overlap: int = text_splitter_conf["chunk_overlap"]): + def __init__( + self, + chunk_size: int = text_splitter_conf["chunk_size"], + chunk_overlap: int = text_splitter_conf["chunk_overlap"], + ): """Initialize TextSplitter.""" self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=int(chunk_size), diff --git a/src/grag/components/utils.py b/src/grag/components/utils.py index 0233d32..2e34dc9 100644 --- a/src/grag/components/utils.py +++ b/src/grag/components/utils.py @@ -6,6 +6,7 @@ - find_config_path: finds the path of the 'config.ini' file by traversing up the directory tree from the current path. - get_config: retrieves and parses the configuration settings from the 'config.ini' file. """ + import json import os import textwrap @@ -19,7 +20,7 @@ def stuff_docs(docs: List[Document]) -> str: r"""Concatenates langchain documents into a string using '\n\n' seperator. - + Args: docs: List of langchain_core.documents.Document diff --git a/src/grag/rag/basic_rag.py b/src/grag/rag/basic_rag.py index 05a7520..be055ba 100644 --- a/src/grag/rag/basic_rag.py +++ b/src/grag/rag/basic_rag.py @@ -3,6 +3,7 @@ This module provides: - BasicRAG """ + import json from typing import List, Union @@ -19,7 +20,7 @@ class BasicRAG: """Class for Basis RAG. - + Attributes: model_name (str): Name of the llm model doc_chain (str): Name of the document chain, ("stuff", "refine"), defaults to "stuff" @@ -30,13 +31,15 @@ class BasicRAG: """ def __init__( - self, - model_name=None, - doc_chain="stuff", - task="QA", - llm_kwargs=None, - retriever_kwargs=None, - custom_prompt: Union[Prompt, FewShotPrompt, List[Prompt, FewShotPrompt], None] = None, + self, + model_name=None, + doc_chain="stuff", + task="QA", + llm_kwargs=None, + retriever_kwargs=None, + custom_prompt: Union[ + Prompt, FewShotPrompt, List[Prompt, FewShotPrompt], None + ] = None, ): """Initialize BasicRAG.""" if retriever_kwargs is None: @@ -144,7 +147,7 @@ def prompt_matcher(self): @staticmethod def stuff_docs(docs: List[Document]) -> str: r"""Concatenates docs into a string seperated by '\n\n'. - + Args: docs: List of langchain_core.documents.Document From 00e2d6bcd915df648917212e2c6dd01703f20a0e Mon Sep 17 00:00:00 2001 From: Arjun Bingly Date: Sat, 23 Mar 2024 15:31:37 -0400 Subject: [PATCH 4/4] Update embedding docstring --- src/grag/components/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/grag/components/embedding.py b/src/grag/components/embedding.py index eab107f..eeb0f82 100644 --- a/src/grag/components/embedding.py +++ b/src/grag/components/embedding.py @@ -1,6 +1,6 @@ """Class for embedding. -This module provies: +This module provides: - Embedding """