Skip to content

Commit

Permalink
Merge pull request #85 from thehapyone/packages-update
Browse files Browse the repository at this point in the history
Packages update
  • Loading branch information
thehapyone authored Nov 11, 2024
2 parents 7ae1973 + 7839682 commit bac0db7
Show file tree
Hide file tree
Showing 9 changed files with 2,229 additions and 2,487 deletions.
8 changes: 8 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,20 @@ paths = ["/path/to/file1", "/path/to/file2"]
The [embedding] and [llm] sections configure the embedding engines and Large Language Models (LLMs) used by Sage

```toml
## LLM and ReRanker Configuration are based on LiteLLM providers - https://docs.litellm.ai/docs/providers/
[llm]
model = "gpt-4-turbo"
[embedding]
type = "huggingface"
model = "jinaai/jina-embeddings-v2-base-en"
[reranker]
top_n = 5
model = "cohere/rerank-english-v2.0"
#model = "BAAI/bge-reranker-large"
revision = "55611d7bca2a7133960a6d3b71e083071bbfc312"
```

## Environment Variables for Sensitive Credentials
Expand Down
4,376 changes: 2,048 additions & 2,328 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions public/stylesheet.css
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,10 @@ img[alt="logo"] {
background-color: #0f0f0f !important;
color: #d3e2df !important;
}


/* Targeting code elements within specific div class */
.MuiBox-root.css-13s85fp code {
background-color: #0f0f0f !important;
color: white !important;
}
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sage"
version = "0.0.13"
version = "0.0.14"
description = "Sage: A conversational AI assistant simplifying data interactions with intuitive ease"
authors = ["Ayo Ayibiowu <[email protected]>"]
license = "Apache-2.0 license"
Expand Down Expand Up @@ -37,13 +37,11 @@ atlassian-python-api = "^3.41.11"
aiosqlite = "^0.20.0"
litellm = "^1.46.0"
dataclasses-json = "^0.6.7"
crewai = "^0.67.1"
cohere = "^5.6.1"
langchain = "^0.2.16"
langchain-core = "^0.2.41"
langchain-community = "^0.2.17"
langchain-cohere = "^0.1.9"
chainlit = "1.2.0"
chainlit = "1.3.2"
crewai = "^0.79.4"
langchain-core = "^0.3.15"
langchain = "^0.3.7"
langchain-community = "^0.3.5"

[tool.poetry.group.tests]
optional = true
Expand Down
12 changes: 3 additions & 9 deletions sage/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,10 @@ password = "your_web_password"
paths = ["/path/to/file1", "/path/to/file2"]

[reranker]
type = "huggingface"
top_n = 5

[reranker.cohere]
name = "your_cohere_reranker_name"
password = "your_cohere_password"

[reranker.huggingface]
name = "your_huggingface_reranker_name"
revision = "your_huggingface_revision"
model = "cohere/rerank-english-v2.0"
#model = "BAAI/bge-reranker-large"
revision = "55611d7bca2a7133960a6d3b71e083071bbfc312"

[embedding]
type = "huggingface"
Expand Down
45 changes: 17 additions & 28 deletions sage/sources/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sage.utils.exceptions import SourceException
from sage.utils.labels import generate_source_label
from sage.utils.supports import CustomFAISS as FAISS
from sage.utils.supports import aexecute_concurrently, asyncify
from sage.utils.supports import ReRanker, aexecute_concurrently, asyncify
from sage.validators.config_toml import ConfluenceModel, Files, GitlabModel, Web


Expand Down Expand Up @@ -240,34 +240,23 @@ def _compression_retriever(
raise SourceException("There is no valid reranker configuration found")

try:
if ranker_config.type == "cohere":
from langchain_cohere import CohereRerank

_compressor = CohereRerank(
top_n=ranker_config.top_n,
model=ranker_config.cohere.name,
cohere_api_key=ranker_config.cohere.password.get_secret_value(),
user_agent=core_config.user_agent,
)
elif ranker_config.type == "huggingface":
from sage.utils.supports import BgeRerank

_compressor = BgeRerank(
name=ranker_config.huggingface.name,
top_n=ranker_config.top_n,
cache_dir=str(core_config.data_dir / "models"),
revision=ranker_config.huggingface.revision,
)
else:
raise SourceException(
f"Reranker type {ranker_config.type} not supported has a valid compression retriever"
)
except Exception as error:
raise SourceException(str(error))
_compressor = ReRanker(
top_n=ranker_config.top_n,
model=ranker_config.model,
revision=ranker_config.revision,
cache_dir=str(core_config.data_dir / "models"),
)

_compression_retriever = ContextualCompressionRetriever(
base_compressor=_compressor, base_retriever=retriever
)
except Exception as e:
logger.error(
"An error has occurred while loading the compression retriever",
exc_info=True,
)
raise e

_compression_retriever = ContextualCompressionRetriever(
base_compressor=_compressor, base_retriever=retriever
)
return _compression_retriever

## Helper to create a retriever while the data input is a list of files path
Expand Down
125 changes: 124 additions & 1 deletion sage/utils/supports.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from langchain_community.docstore.base import AddableMixin
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from litellm import aembedding, embedding
from litellm import aembedding, arerank, embedding, rerank
from markdown import markdown
from pydantic import PrivateAttr, model_validator
from sentence_transformers import CrossEncoder

from sage.validators.config_toml import Config
Expand Down Expand Up @@ -190,6 +191,128 @@ def compress_documents(
return self._rerank(query=query, documents=documents)


class ReRanker(BaseDocumentCompressor):
"""Document compressor using litellm Rerank"""

model: str = "cohere/rerank-english-v3.0"
"""Model name to use for reranking."""
top_n: int = 10
"""Number of documents to return."""
cache_dir: str = None
revision: Optional[str] = None
provider: str = "litellm"
_hugging_reranker: BgeRerank = PrivateAttr(default=None)

@model_validator(mode="after")
def set_provider(self) -> "ReRanker":
"""Initialize huggingface reranker if applicable"""
if any(x in self.model for x in ("huggingface", "BAAI")):
self.provider = "huggingface"
self._hugging_reranker = BgeRerank(
name=self.model,
top_n=self.top_n,
cache_dir=self.cache_dir,
revision=self.revision,
)
return self

@staticmethod
def _parse_response(
response: list[dict], documents: Sequence[Document]
) -> Sequence[Document]:
"""Parse rerank response and attach scores to documents"""
final_results = []
for r in response:
doc = documents[r["index"]]
doc.metadata["relevance_score"] = r["relevance_score"]
final_results.append(doc)
return final_results

def _get_document_contents(self, documents: Sequence[Document]) -> list[str]:
"""Extract page contents from documents"""
return [doc.page_content for doc in documents]

def _rerank(self, query: str, documents: Sequence[Document]) -> Sequence[Document]:
"""Rerank the documents"""
if not documents:
return []

if self._hugging_reranker:
result = self._hugging_reranker.compress_documents(
query=query, documents=documents
)
return result

response = rerank(
model=self.model,
query=query,
documents=self._get_document_contents(documents),
top_n=self.top_n,
return_documents=False,
)
return self._parse_response(response.results, documents)

async def _arerank(
self, query: str, documents: Sequence[Document]
) -> Sequence[Document]:
"""Rerank the documents"""
if not documents:
return []

if self._hugging_reranker:
result = self._hugging_reranker.compress_documents(
query=query, documents=documents
)
return result

response = await arerank(
model=self.model,
query=query,
documents=self._get_document_contents(documents),
top_n=self.top_n,
return_documents=False,
)
return self._parse_response(response.results, documents)

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
return self._rerank(query=query, documents=documents)

async def acompress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
return await self._arerank(query=query, documents=documents)


def markdown_to_text_using_html2text(markdown_text: str) -> str:
"""Convert the markdown docs into plaintext using the html2text plugin
Expand Down
34 changes: 3 additions & 31 deletions sage/validators/config_toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,40 +267,12 @@ class EmbeddingsConfig(BaseModel):
dimension: Optional[int] = None


class CohereReRanker(Password):
"""The Cohere rerank schema"""

name: str

@model_validator(mode="after")
def set_password(self) -> "Password":
if self.password is None:
if password_env := os.getenv("COHERE_API_KEY"):
self.password = SecretStr(password_env)
else:
raise ConfigException(
(
"The COHERE_API_KEY | config password is missing. "
"Please add it via an env variable or to the config password field."
)
)
return self


class HuggingFaceReRanker(BaseModel):
"""The HuggingFace schema"""

name: str
revision: str


class ReRankerConfig(ModelValidateType):
class ReRankerConfig(BaseModel):
"""Reranker config schema"""

cohere: Optional[CohereReRanker] = None
huggingface: Optional[HuggingFaceReRanker] = None
type: Literal["cohere", "huggingface"]
top_n: int = 5
model: str
revision: Optional[str] = None


class LLMConfig(BaseModel):
Expand Down
Loading

0 comments on commit bac0db7

Please sign in to comment.