Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type checking, coverage testing #61

Merged
merged 12 commits into from
Mar 30, 2024
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
name: Ruff Linting
name: Build - linting
on:
pull_request:
branches:
- main

jobs:
adopt-ruff:
linting:
runs-on: self-hosted
steps:
- name: Check out repository code
uses: actions/checkout@v4

- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
with:
python-version: 3.x
- name: Install ruff

- name: Install ruff
run: pip install ruff

- name: Run the adopt-ruff action
uses: chartboost/ruff-action@v1

- name: Ruff
run: |
ruff check src/
ruff format src/
32 changes: 32 additions & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: linting
on:
push:
branches:
- !main

jobs:
lint:
runs-on: self-hosted
steps:
- name: Check out repository code
uses: actions/checkout@v4

- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: 3.x

- name: Install ruff
run: pip install ruff

- name: Ruff check and fix
run: ruff check src/ --fix

- name: Ruff format
run: ruff format src/

- name: Commit
uses: stefanzweifel/git-auto-commit-action@v4
with:
commit_message: 'style fixes by ruff'
15 changes: 0 additions & 15 deletions .github/workflows/ruff_commit.yml

This file was deleted.

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
![Static Badge](https://img.shields.io/badge/codestyle-pyflake-purple?labelColor=white)
![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr/arjbingly/Capstone_5)


## Project Overview

- A ready to deploy RAG pipeline for document retrival.
Expand All @@ -29,6 +28,7 @@ Moreover, further customization can be made on the config file, `src/config.ini`
- _For Dev:_ `pip install -e .`

### Requirements

Required packages includes (_refer to [pyproject.toml](pyproject.toml)_):

- PyTorch
Expand All @@ -40,7 +40,7 @@ Required packages includes (_refer to [pyproject.toml](pyproject.toml)_):

### LLM Models

- **To run models locally** refer the [LLM Quantize Readme](./llm_quantize/readme.md) for details on downloading and
- **To run models locally** refer the [LLM Quantize Readme](./llm_quantize/README) for details on downloading and
quantizing LLM models.
- **To run models from Huggingface**, change the `model_name` under `llm` in `src/config.ini` to the huggingface
repo-id (If
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ max_new_tokens : 1024
temperature : 0.1
n_batch_gpu_cpp : 1024
n_ctx_cpp : 6000
n_gpu_layers_cpp : -1
n_gpu_layers_cpp : 16
# The number of layers to put on the GPU. Mixtral-18
std_out : True
base_dir : ${root:root_path}/models
Expand Down
6 changes: 3 additions & 3 deletions src/grag/components/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(self, embedding_type: str, embedding_model: str):
match self.embedding_type:
case "sentence-transformers":
self.embedding_function = SentenceTransformerEmbeddings(
model_name=self.embedding_model
model_name=self.embedding_model # type: ignore
)
case "instructor-embedding":
self.embedding_instruction = "Represent the document for retrival"
self.embedding_function = HuggingFaceInstructEmbeddings(
model_name=self.embedding_model
model_name=self.embedding_model # type: ignore
)
self.embedding_function.embed_instruction = self.embedding_instruction
self.embedding_function.embed_instruction = self.embedding_instruction # type: ignore
case _:
raise Exception("embedding_type is invalid")
10 changes: 5 additions & 5 deletions src/grag/components/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def model_name(self):
"""Returns the name of the model."""
return self._model_name

@model_name.setter
def model_name(self, value):
"""Returns the path to the model."""
self._model_name = value

@property
def model_path(self):
"""Sets the model name."""
return str(
self.base_dir / self.model_name / f"ggml-model-{self.quantization}.gguf"
)

@model_name.setter
def model_name(self, value):
"""Returns the path to the model."""
self._model_name = value

def hf_pipeline(self, is_local=False):
"""Loads the model using Hugging Face transformers.

Expand Down
8 changes: 4 additions & 4 deletions src/grag/components/multivec_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
self.store = LocalFileStore(self.store_path)
self.retriever = MultiVectorRetriever(
vectorstore=self.vectordb.langchain_client,
byte_store=self.store,
docstore=self.store, # type: ignore
id_key=self.id_key,
)
self.splitter = TextSplitter()
Expand Down Expand Up @@ -157,7 +157,7 @@ async def aadd_docs(self, docs: List[Document]):
chunks = self.split_docs(docs)
doc_ids = self.gen_doc_ids(docs)
await asyncio.run(self.vectordb.aadd_docs(chunks))
self.retriever.docstore.mset(list(zip(doc_ids)))
self.retriever.docstore.mset(list(zip(doc_ids, docs)))

def get_chunk(self, query: str, with_score=False, top_k=None):
"""Returns the most similar chunks from the vector database.
Expand Down Expand Up @@ -241,7 +241,7 @@ def ingest(
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: dict = None,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Ingests the files in directory.

Expand Down Expand Up @@ -283,7 +283,7 @@ async def aingest(
glob_pattern: str = "**/*.pdf",
dry_run: bool = False,
verbose: bool = True,
parser_kwargs: dict = None,
parser_kwargs: Optional[Dict[str, Any]] = None,
):
"""Asynchronously ingests the files in directory.

Expand Down
7 changes: 3 additions & 4 deletions src/grag/components/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def load(cls, filepath: Union[Path, str]):

def format(self, **kwargs) -> str:
"""Formats the prompt with provided keys and returns a string."""
return self.prompt.format(**kwargs)
if self.prompt is not None:
return self.prompt.format(**kwargs)
raise ValueError("Prompt is not defined.")


class FewShotPrompt(Prompt):
Expand All @@ -136,9 +138,6 @@ class FewShotPrompt(Prompt):
prefix: str
suffix: str
example_template: str
prompt: Optional[FewShotPromptTemplate] = Field(
exclude=True, repr=False, default=None
)

def __init__(self, **kwargs):
"""Initialize the prompt."""
Expand Down
6 changes: 4 additions & 2 deletions src/grag/components/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
- TextSplitter
"""

from typing import Union

from langchain.text_splitter import RecursiveCharacterTextSplitter

from .utils import get_config
Expand All @@ -22,8 +24,8 @@ class TextSplitter:

def __init__(
self,
chunk_size: int = text_splitter_conf["chunk_size"],
chunk_overlap: int = text_splitter_conf["chunk_overlap"],
chunk_size: Union[int, str] = text_splitter_conf["chunk_size"],
chunk_overlap: Union[int, str] = text_splitter_conf["chunk_overlap"],
):
"""Initialize TextSplitter."""
self.text_splitter = RecursiveCharacterTextSplitter(
Expand Down
69 changes: 3 additions & 66 deletions src/grag/components/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
- get_config: retrieves and parses the configuration settings from the 'config.ini' file.
"""

import json
import os
import textwrap
from configparser import ConfigParser, ExtendedInterpolation
from pathlib import Path
from typing import List

from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate


def stuff_docs(docs: List[Document]) -> str:
Expand All @@ -30,67 +27,6 @@ def stuff_docs(docs: List[Document]) -> str:
return "\n\n".join([doc.page_content for doc in docs])


def reformat_text_with_line_breaks(input_text, max_width=110):
"""Reformat the given text to ensure each line does not exceed a specific width, preserving existing line breaks.

Args:
input_text (str): The text to be reformatted.
max_width (int): The maximum width of each line.

Returns:
str: The reformatted text with preserved line breaks and adjusted line width.
"""
# Divide the text into separate lines
original_lines = input_text.split("\n")

# Apply wrapping to each individual line
reformatted_lines = [
textwrap.fill(line, width=max_width) for line in original_lines
]

# Combine the lines back into a single text block
reformatted_text = "\n".join(reformatted_lines)

return reformatted_text


def display_llm_output_and_sources(response_from_llm):
"""Displays the result from an LLM response and lists the sources.

Args:
response_from_llm (dict): The response object from an LLM which includes the result and source documents.
"""
# Display the main result from the LLM response
print(response_from_llm["result"])

# Separator for clarity
print("\nSources:")

# Loop through each source document and print its source
for source in response_from_llm["source_documents"]:
print(source.metadata["source"])


def load_prompt(json_file: str | os.PathLike, return_input_vars=False):
"""Loads a prompt template from json file and returns a langchain ChatPromptTemplate.

Args:
json_file: path to the prompt template json file.
return_input_vars: if true returns a list of expected input variables for the prompt.

Returns:
langchain_core.prompts.ChatPromptTemplate (and a list of input vars if return_input_vars is True)

"""
with open(f"{json_file}", "r") as f:
prompt_json = json.load(f)
prompt_template = ChatPromptTemplate.from_template(prompt_json["template"])

input_vars = prompt_json["input_variables"]

return (prompt_template, input_vars) if return_input_vars else prompt_template


def find_config_path(current_path: Path) -> Path:
"""Finds the path of the 'config.ini' file by traversing up the directory tree from the current path.

Expand Down Expand Up @@ -125,8 +61,9 @@ def get_config() -> ConfigParser:
"""
# Assuming this script is somewhere inside your project directory
script_location = Path(__file__).resolve()
if os.environ.get("CONFIG_PATH"):
config_path = os.environ.get("CONFIG_PATH")
config_path_ = os.environ.get("CONFIG_PATH")
if config_path_:
config_path = Path(config_path_)
else:
config_path = find_config_path(script_location)
os.environ["CONFIG_PATH"] = str(config_path)
Expand Down
10 changes: 7 additions & 3 deletions src/grag/components/vectordb/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""

from abc import ABC, abstractmethod
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_core.documents import Document
Expand All @@ -14,6 +14,10 @@
class VectorDB(ABC):
"""Abstract base class for vector database clients."""

def __init__(self):
"""Initialize the vector."""
self.allowed_metadata_types = ()

@abstractmethod
def __len__(self) -> int:
"""Number of chunks in the vector database."""
Expand Down Expand Up @@ -51,7 +55,7 @@ async def aadd_docs(self, docs: List[Document], verbose: bool = True) -> None:

@abstractmethod
def get_chunk(
self, query: str, with_score: bool = False, top_k: int = None
self, query: str, with_score: bool = False, top_k: Optional[int] = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database.

Expand All @@ -67,7 +71,7 @@ def get_chunk(

@abstractmethod
async def aget_chunk(
self, query: str, with_score: bool = False, top_k: int = None
self, query: str, with_score: bool = False, top_k: Optional[int] = None
) -> Union[List[Document], List[Tuple[Document, float]]]:
"""Returns the most similar chunks from the vector database (asynchronous).

Expand Down
Loading
Loading