diff --git a/README.md b/README.md index a6079de..6f55e8c 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,13 @@ -# GRAG +

GRAG

+ +[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) +![Static Badge](https://img.shields.io/badge/docstring%20style-google-pink?labelColor=white) +![Static Badge](https://img.shields.io/badge/linter-ruff-yellow?labelColor=white) +![Docs](https://img.shields.io/github/actions/workflow/status/arjbingly/Capstone_5/ruff_linting.yml) +![Static Badge](https://img.shields.io/badge/buildstyle-hatchling-purple?labelColor=white) +![Static Badge](https://img.shields.io/badge/codestyle-pyflake-purple?labelColor=white) +![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr/arjbingly/Capstone_5) + ## Project Overview diff --git a/cookbook/Basic-RAG/BasicRAG_stuff.py b/cookbook/Basic-RAG/BasicRAG_stuff.py index da95ec6..554c305 100644 --- a/cookbook/Basic-RAG/BasicRAG_stuff.py +++ b/cookbook/Basic-RAG/BasicRAG_stuff.py @@ -6,6 +6,7 @@ client = DeepLakeClient(collection_name="test") retriever = Retriever(vectordb=client) + rag = BasicRAG(doc_chain="stuff", retriever=retriever) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index b97ba19..e27f822 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,3 +117,6 @@ docstring-code-format = true [tool.ruff.lint.pydocstyle] convention = "google" + +[tool.mypy] +ignore_missing_imports = true diff --git a/src/config.ini b/src/config.ini index e23c1b5..d55e002 100644 --- a/src/config.ini +++ b/src/config.ini @@ -1,5 +1,5 @@ [llm] -model_name : Llama-2-7b-chat +model_name : Llama-2-13b-chat # meta-llama/Llama-2-70b-chat-hf Mixtral-8x7B-Instruct-v0.1 quantization : Q5_K_M pipeline : llama_cpp diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py index 05478df..5a396fa 100644 --- a/src/grag/components/multivec_retriever.py +++ b/src/grag/components/multivec_retriever.py @@ -49,7 +49,7 @@ def __init__( store_path: str = multivec_retriever_conf["store_path"], id_key: str = multivec_retriever_conf["id_key"], namespace: str = multivec_retriever_conf["namespace"], - top_k=1, + top_k=int(multivec_retriever_conf["top_k"]), client_kwargs: Optional[Dict[str, Any]] = None, ): """Initialize the Retriever. diff --git a/src/tests/rag/basic_rag_test.py b/src/tests/rag/basic_rag_test.py index 2249028..b8c2ceb 100644 --- a/src/tests/rag/basic_rag_test.py +++ b/src/tests/rag/basic_rag_test.py @@ -1,11 +1,16 @@ -from typing import Text, List +from typing import List, Text +from grag.components.multivec_retriever import Retriever +from grag.components.vectordb.deeplake_client import DeepLakeClient from grag.rag.basic_rag import BasicRAG +client = DeepLakeClient(collection_name="test") +retriever = Retriever(vectordb=client) + def test_rag_stuff(): - rag = BasicRAG(doc_chain="stuff") - response, sources = rag("What is simulated annealing?") + rag = BasicRAG(doc_chain="stuff", retriever=retriever) + response, sources = rag("What is Flash Attention?") assert isinstance(response, Text) assert isinstance(sources, List) assert all(isinstance(s, str) for s in sources) @@ -13,9 +18,8 @@ def test_rag_stuff(): def test_rag_refine(): - rag = BasicRAG(doc_chain="refine") - response, sources = rag("What is simulated annealing?") - # assert isinstance(response, Text) + rag = BasicRAG(doc_chain="refine", retriever=retriever) + response, sources = rag("What is Flash Attention?") assert isinstance(response, List) assert all(isinstance(s, str) for s in response) assert isinstance(sources, List)