Skip to content

Commit

Permalink
Ruff format patch
Browse files Browse the repository at this point in the history
  • Loading branch information
arjbingly committed May 1, 2024
1 parent 1f912ba commit 515a672
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 93 deletions.
3 changes: 1 addition & 2 deletions cookbook/RAG-GUI/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
RAG-GUI
"""RAG-GUI
=======
A cookbook demonstrating how to run a RAG app on streamlit.
Expand Down
14 changes: 7 additions & 7 deletions src/grag/components/create_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from importlib_resources import files


def create_config(path: Union[str, Path] = '.') -> None:
def create_config(path: Union[str, Path] = ".") -> None:
"""Create a configuration file if it doesn't exist.
This function checks for the existence of a 'config.ini' file at the given path.
Expand All @@ -17,8 +17,8 @@ def create_config(path: Union[str, Path] = '.') -> None:
and does not overwrite the existing file.
Args:
path (Union[str, Path]): The directory path where the 'config.ini' should be
located. If not specified, defaults to the current
path (Union[str, Path]): The directory path where the 'config.ini' should be
located. If not specified, defaults to the current
directory ('.').
Returns:
Expand All @@ -29,15 +29,15 @@ def create_config(path: Union[str, Path] = '.') -> None:
PermissionError: If the process does not have permission to write to the specified
directory.
"""
default_config_path = files(grag.resources).joinpath('default_config.ini')
path = Path(path) / 'config.ini'
default_config_path = files(grag.resources).joinpath("default_config.ini")
path = Path(path) / "config.ini"
path = path.resolve()
if path.exists():
print('Config file already exists')
print("Config file already exists")
else:
shutil.copyfile(default_config_path, path, follow_symlinks=True)
print(f"Created config file at {path}")


if __name__ == '__main__':
if __name__ == "__main__":
create_config()
15 changes: 9 additions & 6 deletions src/grag/components/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ def __init__(
model_name: str,
quantization: str,
pipeline: str,
device_map: str = 'auto',
task: str = 'text-generation',
max_new_tokens: str = '1024',
device_map: str = "auto",
task: str = "text-generation",
max_new_tokens: str = "1024",
temperature: Union[str, int] = 0.1,
n_batch: Union[str, int] = 1024,
n_ctx: Union[str, int] = 6000,
n_gpu_layers: Union[str, int] = -1,
std_out: Union[bool, str] = True,
base_dir: Union[str, Path] = Path('models'),
base_dir: Union[str, Path] = Path("models"),
callbacks=None,
):
"""Initialize the LLM class using the given parameters.
Expand Down Expand Up @@ -184,8 +184,11 @@ def llama_cpp(self):
return llm

def load_model(
self, model_name: Optional[str] = None, pipeline: Optional[str] = None, quantization: Optional[str] = None,
is_local: Optional[bool] = None
self,
model_name: Optional[str] = None,
pipeline: Optional[str] = None,
quantization: Optional[str] = None,
is_local: Optional[bool] = None,
):
"""Loads the model based on the specified pipeline and model name.
Expand Down
50 changes: 26 additions & 24 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Retriever:
linked document, chunk, etc.
Attributes:
vectordb: ChromaClient class instance from components.client
vectordb: ChromaClient class instance from components.client
(Optional, if the user provides it, store_path, id_key and namespace is not considered)
store_path: Path to the local file store
id_key: A key prefix for identifying documents
Expand All @@ -44,13 +44,13 @@ class Retriever:
"""

def __init__(
self,
vectordb: Optional[VectorDB] = None,
store_path: Union[str, Path] = Path('data/doc_store'),
top_k: Union[str, int] = 3,
id_key: str = 'doc_id',
namespace: str = '71e4b558187b270922923569301f1039',
client_kwargs: Optional[Dict[str, Any]] = None,
self,
vectordb: Optional[VectorDB] = None,
store_path: Union[str, Path] = Path("data/doc_store"),
top_k: Union[str, int] = 3,
id_key: str = "doc_id",
namespace: str = "71e4b558187b270922923569301f1039",
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.
Expand All @@ -66,10 +66,12 @@ def __init__(
self.id_key = id_key
self.namespace = uuid.UUID(namespace)
if vectordb is None:
if any([self.store_path is None,
self.id_key is None,
self.namespace is None]):
raise TypeError("Arguments [store_path, id_key, namespace] or vectordb must be provided.")
if any(
[self.store_path is None, self.id_key is None, self.namespace is None]
):
raise TypeError(
"Arguments [store_path, id_key, namespace] or vectordb must be provided."
)
if client_kwargs is not None:
self.vectordb = DeepLakeClient(**client_kwargs)
else:
Expand Down Expand Up @@ -241,12 +243,12 @@ def get_docs_from_chunks(self, chunks: List[Document], one_to_one=False):
return [d for d in docs if d is not None]

def ingest(
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Ingests the files in directory.
Expand Down Expand Up @@ -283,12 +285,12 @@ def ingest(
print(f"DRY RUN: found - {filepath.relative_to(dir_path)}")

async def aingest(
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
self,
dir_path: Union[str, Path],
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Asynchronously ingests the files in directory.
Expand Down
28 changes: 14 additions & 14 deletions src/grag/components/parse_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@ class ParsePDF:
image_output_dir (str): Directory to save extracted images, if any.
add_captions_to_text (bool): Whether to include figure captions in text output. Default is True.
add_captions_to_blocks (bool): Whether to add captions to table and image blocks. Default is True.
add_caption_first (bool): Whether to place captions before their corresponding image or table in the output.
add_caption_first (bool): Whether to place captions before their corresponding image or table in the output.
Default is True.
table_as_html (bool): Whether to add table elements as HTML. Default is False.
"""

def __init__(
self,
single_text_out: bool = True,
strategy: str = "hi_res",
infer_table_structure: bool = True,
extract_images: bool = True,
image_output_dir: Optional[str] = None,
add_captions_to_text: bool = True,
add_captions_to_blocks: bool = True,
add_caption_first: bool = True,
table_as_html: bool = False,
self,
single_text_out: bool = True,
strategy: str = "hi_res",
infer_table_structure: bool = True,
extract_images: bool = True,
image_output_dir: Optional[str] = None,
add_captions_to_text: bool = True,
add_captions_to_blocks: bool = True,
add_caption_first: bool = True,
table_as_html: bool = False,
):
"""Initialize instance variables with parameters."""
self.strategy = strategy
Expand Down Expand Up @@ -98,7 +98,7 @@ def classify(self, partitions):
if element.category == "Table":
if self.add_captions_to_blocks and i + 1 < len(partitions):
if (
partitions[i + 1].category == "FigureCaption"
partitions[i + 1].category == "FigureCaption"
): # check for caption
caption_element = partitions[i + 1]
else:
Expand All @@ -109,7 +109,7 @@ def classify(self, partitions):
elif element.category == "Image":
if self.add_captions_to_blocks and i + 1 < len(partitions):
if (
partitions[i + 1].category == "FigureCaption"
partitions[i + 1].category == "FigureCaption"
): # check for caption
caption_element = partitions[i + 1]
else:
Expand Down Expand Up @@ -197,7 +197,7 @@ def process_tables(self, elements):

if caption_element:
if (
self.add_caption_first
self.add_caption_first
): # if there is a caption, add that before the element
content = "\n\n".join([str(caption_element), table_data])
else:
Expand Down
2 changes: 1 addition & 1 deletion src/grag/components/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, **kwargs):
)

def save(
self, filepath: Union[Path, str, None], overwrite=False
self, filepath: Union[Path, str, None], overwrite=False
) -> Union[None, ValueError]:
"""Saves the prompt class into a json file."""
dump = self.model_dump_json(indent=2, exclude_defaults=True, exclude_none=True)
Expand Down
4 changes: 1 addition & 3 deletions src/grag/components/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ class TextSplitter:
"""

def __init__(
self,
chunk_size: Union[int, str] = 2000,
chunk_overlap: Union[int, str] = 400
self, chunk_size: Union[int, str] = 2000, chunk_overlap: Union[int, str] = 400
):
"""Initialize TextSplitter using chunk_size and chunk_overlap."""
self.text_splitter = RecursiveCharacterTextSplitter(
Expand Down
8 changes: 4 additions & 4 deletions src/grag/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def get_config(load_env=False):
if config_path_:
config_path = Path(config_path_)
else:
script_location = Path('.').resolve()
script_location = Path(".").resolve()
config_path = find_config_path(script_location)
if config_path is not None:
os.environ["CONFIG_PATH"] = str(config_path)
Expand All @@ -86,9 +86,9 @@ def get_config(load_env=False):
config = ConfigParser(interpolation=ExtendedInterpolation())
config.read(config_path)
print(f"Loaded config from {config_path}.")
# Load .env
# Load .env
if load_env:
env_path = Path(config['env']['env_path'])
env_path = Path(config["env"]["env_path"])
if env_path.exists():
load_dotenv(env_path)
print(f"Loaded environment variables from {env_path}")
Expand All @@ -112,7 +112,7 @@ def configure_args(cls):
Raises:
TypeError: If there is a mismatch in provided arguments and class constructor requirements.
"""
module_namespace = cls.__module__.split('.')[-1]
module_namespace = cls.__module__.split(".")[-1]

config = get_config()[module_namespace]

Expand Down
4 changes: 2 additions & 2 deletions src/grag/components/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def aadd_docs(self, docs: List[Document], verbose: bool = True) -> None:

@abstractmethod
def get_chunk(
self, query: str, with_score: bool = False, top_k: Optional[int] = None
self, query: str, with_score: bool = False, top_k: Optional[int] = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database.
Expand All @@ -72,7 +72,7 @@ def get_chunk(

@abstractmethod
async def aget_chunk(
self, query: str, with_score: bool = False, top_k: Optional[int] = None
self, query: str, with_score: bool = False, top_k: Optional[int] = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database (asynchronous).
Expand Down
28 changes: 14 additions & 14 deletions src/grag/components/vectordb/chroma_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ class ChromaClient(VectorDB):
"""

def __init__(
self,
host: str = 'localhost',
port: Union[str, int] = 8000,
collection_name: str = 'grag',
embedding_type: str = 'instructor-embedding',
embedding_model: str = 'hkunlp/instructor-xl',
self,
host: str = "localhost",
port: Union[str, int] = 8000,
collection_name: str = "grag",
embedding_type: str = "instructor-embedding",
embedding_model: str = "hkunlp/instructor-xl",
):
"""Initialize a ChromaClient object.
Args:
host: IP Address of hosted Chroma Vectorstore, defaults to localhost
port: port address of hosted Chroma Vectorstore, defaults to 8000
collection_name: name of the collection in the Chroma Vectorstore, defaults to 'grag'
embedding_type: type of embedding used, supported 'sentence-transformers' and 'instructor-embedding',
embedding_type: type of embedding used, supported 'sentence-transformers' and 'instructor-embedding',
defaults to instructor-embedding
embedding_model: model name of embedding used, should correspond to the embedding_type,
embedding_model: model name of embedding used, should correspond to the embedding_type,
defaults to hkunlp/instructor-xl.
"""
self.host = host
Expand Down Expand Up @@ -127,7 +127,7 @@ def add_docs(self, docs: List[Document], verbose=True) -> None:
"""
docs = self._filter_metadata(docs)
for doc in (
tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs
tqdm(docs, desc=f"Adding to {self.collection_name}:") if verbose else docs
):
_id = self.langchain_client.add_documents([doc])

Expand All @@ -144,17 +144,17 @@ async def aadd_docs(self, docs: List[Document], verbose=True) -> None:
docs = self._filter_metadata(docs)
if verbose:
for doc in atqdm(
docs,
desc=f"Adding documents to {self.collection_name}",
total=len(docs),
docs,
desc=f"Adding documents to {self.collection_name}",
total=len(docs),
):
await self.langchain_client.aadd_documents([doc])
else:
for doc in docs:
await self.langchain_client.aadd_documents([doc])

def get_chunk(
self, query: str, with_score: bool = False, top_k: Optional[int] = None
self, query: str, with_score: bool = False, top_k: Optional[int] = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the chroma database.
Expand All @@ -177,7 +177,7 @@ def get_chunk(
)

async def aget_chunk(
self, query: str, with_score=False, top_k=None
self, query: str, with_score=False, top_k=None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most (cosine) similar chunks from the vector database, asynchronously.
Expand Down
Loading

0 comments on commit 515a672

Please sign in to comment.