Skip to content

Commit

Permalink
Merge branch 'main' into vectordb
Browse files Browse the repository at this point in the history
  • Loading branch information
arjbingly authored Mar 24, 2024
2 parents f94114e + 58be4d8 commit 7a7d5a7
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 21 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ exclude_lines = [
[tool.ruff]
line-length = 88
indent-width = 4
extend-exclude = ["tests", "others"]

[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "I", "D"]
ignore = ["D104"]
exclude = ["__about__.py"]


[tool.ruff.format]
quote-style = "double"
Expand Down
10 changes: 9 additions & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb
allow_reset : True

[deeplake]
collection_name : arxiv
# embedding_type : sentence-transformers
# embedding_model : "all-mpnet-base-v2"
embedding_type : instructor-embedding
embedding_model : hkunlp/instructor-xl
store_path : ${data:data_path}/vectordb

[text_splitter]
chunk_size : 5000
chunk_overlap : 400
Expand All @@ -57,4 +65,4 @@ table_as_html : True
data_path : ${root:root_path}/data

[root]
root_path : /home/ubuntu/volume_2k/Capstone_5
root_path : /home/ubuntu/CapStone/Capstone_5
8 changes: 8 additions & 0 deletions src/grag/components/embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Class for embedding.
This module provides:
- Embedding
"""

from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
Expand All @@ -6,6 +12,7 @@

class Embedding:
"""A class for vector embeddings.
Supports:
huggingface sentence transformers -> model_type = 'sentence-transformers'
huggingface instructor embeddings -> model_type = 'instructor-embedding'
Expand All @@ -17,6 +24,7 @@ class Embedding:
"""

def __init__(self, embedding_type: str, embedding_model: str):
"""Initialize the embedding with embedding_type and embedding_model."""
self.embedding_type = embedding_type
self.embedding_model = embedding_model
match self.embedding_type:
Expand Down
3 changes: 3 additions & 0 deletions src/grag/components/llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Class for LLM."""

import os
from pathlib import Path

Expand Down Expand Up @@ -50,6 +52,7 @@ def __init__(
quantization=llm_conf["quantization"],
pipeline=llm_conf["pipeline"],
):
"""Initialize the LLM class using the given parameters."""
self.base_dir = Path(base_dir)
self._model_name = model_name
self.quantization = quantization
Expand Down
29 changes: 20 additions & 9 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Class for retriever.
This module provides:
- Retriever
"""

import asyncio
import uuid
from typing import Any, Dict, List, Optional
Expand All @@ -14,9 +20,11 @@


class Retriever:
"""A class for multi vector retriever, it connects to a vector database and a local file store.
It is used to return most similar chunks from a vector store but has the additional funcationality
to return a linked document, chunk, etc.
"""A class for multi vector retriever.
It connects to a vector database and a local file store.
It is used to return most similar chunks from a vector store but has the additional functionality to return a
linked document, chunk, etc.
Attributes:
store_path: Path to the local file store
Expand All @@ -39,7 +47,9 @@ def __init__(
top_k=1,
client_kwargs: Optional[Dict[str, Any]] = None
):
"""Args:
"""Initialize the Retriever.
Args:
store_path: Path to the local file store, defaults to argument from config file
id_key: A key prefix for identifying documents, defaults to argument from config file
namespace: A namespace for producing unique id, defaults to argument from congig file
Expand Down Expand Up @@ -67,6 +77,7 @@ def __init__(

def id_gen(self, doc: Document) -> str:
"""Takes a document and returns a unique id (uuid5) using the namespace and document source.
This ensures that a single document always gets the same unique id.
Args:
Expand All @@ -90,7 +101,9 @@ def gen_doc_ids(self, docs: List[Document]) -> List[str]:
return [self.id_gen(doc) for doc in docs]

def split_docs(self, docs: List[Document]) -> List[Document]:
"""Takes a list of documents and splits them into smaller chunks using TextSplitter from compoenents.text_splitter
"""Takes a list of documents and splits them into smaller chunks.
Using TextSplitter from components.text_splitter
Also adds the unique parent document id into metadata
Args:
Expand All @@ -110,8 +123,7 @@ def split_docs(self, docs: List[Document]) -> List[Document]:
return chunks

def add_docs(self, docs: List[Document]):
"""Takes a list of documents, splits them using the split_docs method and then adds them into the vector database
and adds the parent document into the file store.
"""Adds given documents into the vector database also adds the parent document into the file store.
Args:
docs: List of langchain_core.documents.Document
Expand All @@ -126,8 +138,7 @@ def add_docs(self, docs: List[Document]):
self.retriever.docstore.mset(list(zip(doc_ids, docs)))

async def aadd_docs(self, docs: List[Document]):
"""Takes a list of documents, splits them using the split_docs method and then adds them into the vector database
and adds the parent document into the file store.
"""Adds given documents into the vector database also adds the parent document into the file store.
Args:
docs: List of langchain_core.documents.Document
Expand Down
13 changes: 11 additions & 2 deletions src/grag/components/parse_pdf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Classes for parsing files.
This module provides:
- ParsePDF
"""

from langchain_core.documents import Document
from unstructured.partition.pdf import partition_pdf

Expand Down Expand Up @@ -32,7 +38,7 @@ def __init__(
add_captions_to_blocks=parser_conf["add_captions_to_blocks"],
table_as_html=parser_conf["table_as_html"],
):
# Instantialize instance variables with parameters
"""Initialize instance variables with parameters."""
self.strategy = strategy
if extract_images: # by default always extract Table
self.extract_image_block_types = [
Expand Down Expand Up @@ -72,7 +78,8 @@ def partition(self, path: str):

def classify(self, partitions):
"""Classifies the partitioned elements into Text, Tables, and Images list in a dictionary.
Add captions for each element (if available).
Also adds captions for each element (if available).
Parameters:
partitions (list): The list of partitioned elements from the PDF document.
Expand Down Expand Up @@ -117,6 +124,8 @@ def classify(self, partitions):
return classified_elements

def text_concat(self, elements) -> str:
"""Context aware concatenates all elements into a single string."""
full_text = ""
for current_element, next_element in zip(elements, elements[1:]):
curr_type = current_element.category
next_type = next_element.category
Expand Down
51 changes: 50 additions & 1 deletion src/grag/components/prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""Classes for prompts.
This module provides:
- Prompt - for generic prompts
- FewShotPrompt - for few-shot prompts
"""

import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand All @@ -13,6 +20,20 @@


class Prompt(BaseModel):
"""A class for generic prompts.
Attributes:
name (str): The prompt name (Optional, defaults to "custom_prompt")
llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None")
task (str): The task (Optional, defaults to QA)
source (str): The source of the prompt (Optional, defaults to "NoSource")
doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff")
language (str): The language of the prompt (Optional, defaults to "en")
filepath (str): The filepath of the prompt (Optional)
input_keys (List[str]): The input keys for the prompt
template (str): The template for the prompt
"""

name: str = Field(default="custom_prompt")
llm_type: str = Field(default="None")
task: str = Field(default="QA")
Expand All @@ -27,13 +48,15 @@ class Prompt(BaseModel):
@field_validator("input_keys")
@classmethod
def validate_input_keys(cls, v) -> List[str]:
"""Validate the input_keys field."""
if v is None or v == []:
raise ValueError("input_keys cannot be empty")
return v

@field_validator("doc_chain")
@classmethod
def validate_doc_chain(cls, v: str) -> str:
"""Validate the doc_chain field."""
if v not in SUPPORTED_DOC_CHAINS:
raise ValueError(
f"The provided doc_chain, {v} is not supported, supported doc_chains are {SUPPORTED_DOC_CHAINS}"
Expand All @@ -43,6 +66,7 @@ def validate_doc_chain(cls, v: str) -> str:
@field_validator("task")
@classmethod
def validate_task(cls, v: str) -> str:
"""Validate the task field."""
if v not in SUPPORTED_TASKS:
raise ValueError(
f"The provided task, {v} is not supported, supported tasks are {SUPPORTED_TASKS}"
Expand All @@ -53,6 +77,7 @@ def validate_task(cls, v: str) -> str:
# def load_template(self):
# self.prompt = ChatPromptTemplate.from_template(self.template)
def __init__(self, **kwargs):
"""Initialize the prompt."""
super().__init__(**kwargs)
self.prompt = PromptTemplate(
input_variables=self.input_keys, template=self.template
Expand All @@ -61,6 +86,7 @@ def __init__(self, **kwargs):
def save(
self, filepath: Union[Path, str, None], overwrite=False
) -> Union[None, ValueError]:
"""Saves the prompt class into a json file."""
dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True)
if filepath is None:
filepath = f"{self.name}.json"
Expand All @@ -74,17 +100,37 @@ def save(

@classmethod
def load(cls, filepath: Union[Path, str]):
"""Loads a json file and returns a Prompt class."""
with open(f"{filepath}", "r") as f:
prompt_json = json.load(f)
_prompt = cls(**prompt_json)
_prompt.filepath = str(filepath)
return _prompt

def format(self, **kwargs):
def format(self, **kwargs) -> str:
"""Formats the prompt with provided keys and returns a string."""
return self.prompt.format(**kwargs)


class FewShotPrompt(Prompt):
"""A class for generic prompts.
Attributes:
name (str): The prompt name (Optional, defaults to "custom_prompt") (Parent Class)
llm_type (str): The type of llm, llama2, etc (Optional, defaults to "None") (Parent Class)
task (str): The task (Optional, defaults to QA) (Parent Class)
source (str): The source of the prompt (Optional, defaults to "NoSource") (Parent Class)
doc_chain (str): The doc chain for the prompt ("stuff", "refine") (Optional, defaults to "stuff") (Parent Class)
language (str): The language of the prompt (Optional, defaults to "en") (Parent Class)
filepath (str): The filepath of the prompt (Optional) (Parent Class)
input_keys (List[str]): The input keys for the prompt (Parent Class)
input_keys (List[str]): The output keys for the prompt
prefix (str): The template prefix for the prompt
suffix (str): The template suffix for the prompt
example_template (str): The template for formatting the examples
examples (List[Dict[str, Any]]): The list of examples, each example is a dictionary with respective keys
"""

output_keys: List[str]
examples: List[Dict[str, Any]]
prefix: str
Expand All @@ -95,6 +141,7 @@ class FewShotPrompt(Prompt):
)

def __init__(self, **kwargs):
"""Initialize the prompt."""
super().__init__(**kwargs)
eg_formatter = PromptTemplate(
input_vars=self.input_keys + self.output_keys,
Expand All @@ -111,13 +158,15 @@ def __init__(self, **kwargs):
@field_validator("output_keys")
@classmethod
def validate_output_keys(cls, v) -> List[str]:
"""Validate the output_keys field."""
if v is None or v == []:
raise ValueError("output_keys cannot be empty")
return v

@field_validator("examples")
@classmethod
def validate_examples(cls, v) -> List[Dict[str, Any]]:
"""Validate the examples field."""
if v is None or v == []:
raise ValueError("examples cannot be empty")
for eg in v:
Expand Down
25 changes: 22 additions & 3 deletions src/grag/components/text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Class for splitting/chunking text.
This module provides:
- TextSplitter
"""

from langchain.text_splitter import RecursiveCharacterTextSplitter

from .utils import get_config
Expand All @@ -7,10 +13,23 @@

# %%
class TextSplitter:
def __init__(self):
"""Class for recursively chunking text, it prioritizes '/n/n then '/n' and so on.
Attributes:
chunk_size: maximum size of chunk
chunk_overlap: chunk overlap size
"""

def __init__(
self,
chunk_size: int = text_splitter_conf["chunk_size"],
chunk_overlap: int = text_splitter_conf["chunk_overlap"],
):
"""Initialize TextSplitter."""
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=int(text_splitter_conf["chunk_size"]),
chunk_overlap=int(text_splitter_conf["chunk_overlap"]),
chunk_size=int(chunk_size),
chunk_overlap=int(chunk_overlap),
length_function=len,
is_separator_regex=False,
)
"""Initialize TextSplitter using chunk_size and chunk_overlap"""
Loading

0 comments on commit 7a7d5a7

Please sign in to comment.