diff --git a/README.md b/README.md
index a6079de..6f55e8c 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,13 @@
-# GRAG
+
GRAG
+
+[![License: AGPL v3](https://img.shields.io/badge/License-AGPL_v3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
+![Static Badge](https://img.shields.io/badge/docstring%20style-google-pink?labelColor=white)
+![Static Badge](https://img.shields.io/badge/linter-ruff-yellow?labelColor=white)
+![Docs](https://img.shields.io/github/actions/workflow/status/arjbingly/Capstone_5/ruff_linting.yml)
+![Static Badge](https://img.shields.io/badge/buildstyle-hatchling-purple?labelColor=white)
+![Static Badge](https://img.shields.io/badge/codestyle-pyflake-purple?labelColor=white)
+![GitHub Issues or Pull Requests](https://img.shields.io/github/issues-pr/arjbingly/Capstone_5)
+
## Project Overview
diff --git a/cookbook/Basic-RAG/BasicRAG_stuff.py b/cookbook/Basic-RAG/BasicRAG_stuff.py
index da95ec6..554c305 100644
--- a/cookbook/Basic-RAG/BasicRAG_stuff.py
+++ b/cookbook/Basic-RAG/BasicRAG_stuff.py
@@ -6,6 +6,7 @@
client = DeepLakeClient(collection_name="test")
retriever = Retriever(vectordb=client)
+
rag = BasicRAG(doc_chain="stuff", retriever=retriever)
if __name__ == "__main__":
diff --git a/pyproject.toml b/pyproject.toml
index b97ba19..e27f822 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,3 +117,6 @@ docstring-code-format = true
[tool.ruff.lint.pydocstyle]
convention = "google"
+
+[tool.mypy]
+ignore_missing_imports = true
diff --git a/src/config.ini b/src/config.ini
index e23c1b5..d55e002 100644
--- a/src/config.ini
+++ b/src/config.ini
@@ -1,5 +1,5 @@
[llm]
-model_name : Llama-2-7b-chat
+model_name : Llama-2-13b-chat
# meta-llama/Llama-2-70b-chat-hf Mixtral-8x7B-Instruct-v0.1
quantization : Q5_K_M
pipeline : llama_cpp
diff --git a/src/grag/components/multivec_retriever.py b/src/grag/components/multivec_retriever.py
index 05478df..5a396fa 100644
--- a/src/grag/components/multivec_retriever.py
+++ b/src/grag/components/multivec_retriever.py
@@ -49,7 +49,7 @@ def __init__(
store_path: str = multivec_retriever_conf["store_path"],
id_key: str = multivec_retriever_conf["id_key"],
namespace: str = multivec_retriever_conf["namespace"],
- top_k=1,
+ top_k=int(multivec_retriever_conf["top_k"]),
client_kwargs: Optional[Dict[str, Any]] = None,
):
"""Initialize the Retriever.
diff --git a/src/tests/rag/basic_rag_test.py b/src/tests/rag/basic_rag_test.py
index 2249028..b8c2ceb 100644
--- a/src/tests/rag/basic_rag_test.py
+++ b/src/tests/rag/basic_rag_test.py
@@ -1,11 +1,16 @@
-from typing import Text, List
+from typing import List, Text
+from grag.components.multivec_retriever import Retriever
+from grag.components.vectordb.deeplake_client import DeepLakeClient
from grag.rag.basic_rag import BasicRAG
+client = DeepLakeClient(collection_name="test")
+retriever = Retriever(vectordb=client)
+
def test_rag_stuff():
- rag = BasicRAG(doc_chain="stuff")
- response, sources = rag("What is simulated annealing?")
+ rag = BasicRAG(doc_chain="stuff", retriever=retriever)
+ response, sources = rag("What is Flash Attention?")
assert isinstance(response, Text)
assert isinstance(sources, List)
assert all(isinstance(s, str) for s in sources)
@@ -13,9 +18,8 @@ def test_rag_stuff():
def test_rag_refine():
- rag = BasicRAG(doc_chain="refine")
- response, sources = rag("What is simulated annealing?")
- # assert isinstance(response, Text)
+ rag = BasicRAG(doc_chain="refine", retriever=retriever)
+ response, sources = rag("What is Flash Attention?")
assert isinstance(response, List)
assert all(isinstance(s, str) for s in response)
assert isinstance(sources, List)