Skip to content

Commit

Permalink
Merge pull request #75 from langchain-ai/mattf/add-truncate-to-reranking
Browse files Browse the repository at this point in the history
add support for reranking api change w/ truncate parameter
  • Loading branch information
mattf authored Jul 19, 2024
2 parents 6040bab + 43bdb6b commit 05ede29
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 10 deletions.
16 changes: 14 additions & 2 deletions libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,19 @@
"source": [
"#### Combine and rank documents\n",
"\n",
"Let's combine the BM25 as well as semantic search results. The resulting `docs` will be ordered by their relevance to the query by the reranking NIM."
"Let's combine the BM25 as well as semantic search results. The resulting `docs` will be ordered by their relevance to the query by the reranking NIM.\n",
"\n",
"#### Note on truncation\n",
"\n",
"Reranking models typically have a fixed context window that determines the maximum number of input tokens that can be processed. This limit could be a hard limit, equal to the model's maximum input token length, or an effective limit, beyond which the accuracy of the ranking decreases.\n",
"\n",
"Since models operate on tokens and applications usually work with text, it can be challenging for an application to ensure that its input stays within the model's token limits. By default, an exception is thrown if the input is too large.\n",
"\n",
"To assist with this, NVIDIA's NIMs (API Catalog or local) provide a `truncate` parameter that truncates the input on the server side if it's too large.\n",
"\n",
"The `truncate` parameter has three options:\n",
" - \"NONE\": The default option. An exception is thrown if the input is too large.\n",
" - \"END\": The server truncates the input from the end (right), discarding tokens as necessary."
]
},
{
Expand All @@ -598,7 +610,7 @@
}
],
"source": [
"ranker = NVIDIARerank()\n",
"ranker = NVIDIARerank(truncate=\"END\")\n",
"\n",
"all_docs = bm25_docs + sem_docs\n",
"\n",
Expand Down
30 changes: 22 additions & 8 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Generator, List, Optional, Sequence
from typing import Any, Generator, List, Literal, Optional, Sequence

from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
Expand Down Expand Up @@ -35,6 +35,13 @@ class Config:
)
top_n: int = Field(5, ge=0, description="The number of documents to return.")
model: Optional[str] = Field(description="The model to use for reranking.")
truncate: Optional[Literal["NONE", "END"]] = Field(
description=(
"Truncate input text if it exceeds the model's maximum token length. "
"Default is model dependent and is likely to raise error if an "
"input is too long."
),
)
max_batch_size: int = Field(
_default_batch_size, ge=1, description="The maximum batch size."
)
Expand All @@ -53,6 +60,9 @@ def __init__(self, **kwargs: Any):
nvidia_api_key (str): The API key to use for connecting to the hosted NIM.
api_key (str): Alternative to nvidia_api_key.
base_url (str): The base URL of the NIM to connect to.
truncate (str): "NONE", "END", truncate input text if it exceeds
the model's context length. Default is model dependent and
is likely to raise an error if an input is too long.
API Key:
- The recommended way to provide the API key is through the `NVIDIA_API_KEY`
Expand Down Expand Up @@ -89,13 +99,14 @@ def get_available_models(

# todo: batching when len(documents) > endpoint's max batch size
def _rank(self, documents: List[str], query: str) -> List[Ranking]:
response = self._client.client.get_req(
payload={
"model": self.model,
"query": {"text": query},
"passages": [{"text": passage} for passage in documents],
},
)
payload = {
"model": self.model,
"query": {"text": query},
"passages": [{"text": passage} for passage in documents],
}
if self.truncate:
payload["truncate"] = self.truncate
response = self._client.client.get_req(payload=payload)
if response.status_code != 200:
response.raise_for_status()
# todo: handle errors
Expand Down Expand Up @@ -134,6 +145,9 @@ def batch(ls: list, size: int) -> Generator[List[Document], None, None]:
query=query, documents=[d.page_content for d in doc_batch]
)
for ranking in rankings:
assert (
0 <= ranking.index < len(doc_batch)
), "invalid response from server: index out of range"
doc = doc_batch[ranking.index]
doc.metadata["relevance_score"] = ranking.logit
results.append(doc)
Expand Down
38 changes: 38 additions & 0 deletions libs/ai-endpoints/tests/integration_tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,41 @@ def test_rerank_batching(
# result_docs[i].page_content == reference_docs[i].page_content
# for i in range(top_n)
# ), "batched results do not match unbatched results"


@pytest.mark.parametrize("truncate", ["END"])
def test_truncate_positive(rerank_model: str, mode: dict, truncate: str) -> None:
query = "What is acceleration?"
documents = [
Document(page_content="NVIDIA " * length)
for length in [32, 1024, 64, 128, 2048, 256, 512]
]
client = NVIDIARerank(
model=rerank_model, top_n=len(documents), truncate=truncate, **mode
)
response = client.compress_documents(documents=documents, query=query)
assert len(response) == len(documents)


@pytest.mark.parametrize("truncate", [None, "NONE"])
@pytest.mark.xfail(
reason=(
"truncation is inconsistent across models, "
"nv-rerank-qa-mistral-4b:1 truncates by default "
"while others do not"
)
)
def test_truncate_negative(rerank_model: str, mode: dict, truncate: str) -> None:
query = "What is acceleration?"
documents = [
Document(page_content="NVIDIA " * length)
for length in [32, 1024, 64, 128, 2048, 256, 512]
]
truncate_param = {}
if truncate:
truncate_param = {"truncate": truncate}
client = NVIDIARerank(model=rerank_model, **truncate_param, **mode)
with pytest.raises(Exception) as e:
client.compress_documents(documents=documents, query=query)
assert "400" in str(e.value)
assert "exceeds maximum allowed" in str(e.value)
73 changes: 73 additions & 0 deletions libs/ai-endpoints/tests/unit_tests/test_ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Literal, Optional

import pytest
from langchain_core.documents import Document
from requests_mock import Mocker

from langchain_nvidia_ai_endpoints import NVIDIARerank


@pytest.fixture(autouse=True)
def mock_v1_models(requests_mock: Mocker) -> None:
requests_mock.get(
"https://integrate.api.nvidia.com/v1/models",
json={
"data": [
{
"id": "mock-model",
"object": "model",
"created": 1234567890,
"owned_by": "OWNER",
}
]
},
)


@pytest.fixture(autouse=True)
def mock_v1_ranking(requests_mock: Mocker) -> None:
requests_mock.post(
"https://integrate.api.nvidia.com/v1/ranking",
json={
"rankings": [
{"index": 0, "logit": 4.2},
]
},
)


@pytest.mark.parametrize(
"truncate",
[
None,
"END",
"NONE",
],
)
def test_truncate(
requests_mock: Mocker,
truncate: Optional[Literal["END", "NONE"]],
) -> None:
truncate_param = {}
if truncate:
truncate_param = {"truncate": truncate}
client = NVIDIARerank(model="mock-model", **truncate_param)
response = client.compress_documents(
documents=[Document(page_content="Nothing really.")], query="What is it?"
)

assert len(response) == 1

assert requests_mock.last_request is not None
request_payload = requests_mock.last_request.json()
if truncate is None:
assert "truncate" not in request_payload
else:
assert "truncate" in request_payload
assert request_payload["truncate"] == truncate


@pytest.mark.parametrize("truncate", [True, False, 1, 0, 1.0, "START", "BOGUS"])
def test_truncate_invalid(truncate: Any) -> None:
with pytest.raises(ValueError):
NVIDIARerank(truncate=truncate)

0 comments on commit 05ede29

Please sign in to comment.