From 43bdb6b6a71b245055fbe53c5ef4fa93b2ba879f Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 18 Jul 2024 05:00:37 -0400 Subject: [PATCH] add support for reranking truncate parameter --- .../docs/retrievers/nvidia_rerank.ipynb | 16 +++- .../reranking.py | 30 ++++++-- .../tests/integration_tests/test_ranking.py | 38 ++++++++++ .../tests/unit_tests/test_ranking.py | 73 +++++++++++++++++++ 4 files changed, 147 insertions(+), 10 deletions(-) create mode 100644 libs/ai-endpoints/tests/unit_tests/test_ranking.py diff --git a/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb b/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb index 3164e8e0..207d167a 100644 --- a/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb +++ b/libs/ai-endpoints/docs/retrievers/nvidia_rerank.ipynb @@ -574,7 +574,19 @@ "source": [ "#### Combine and rank documents\n", "\n", - "Let's combine the BM25 as well as semantic search results. The resulting `docs` will be ordered by their relevance to the query by the reranking NIM." + "Let's combine the BM25 as well as semantic search results. The resulting `docs` will be ordered by their relevance to the query by the reranking NIM.\n", + "\n", + "#### Note on truncation\n", + "\n", + "Reranking models typically have a fixed context window that determines the maximum number of input tokens that can be processed. This limit could be a hard limit, equal to the model's maximum input token length, or an effective limit, beyond which the accuracy of the ranking decreases.\n", + "\n", + "Since models operate on tokens and applications usually work with text, it can be challenging for an application to ensure that its input stays within the model's token limits. By default, an exception is thrown if the input is too large.\n", + "\n", + "To assist with this, NVIDIA's NIMs (API Catalog or local) provide a `truncate` parameter that truncates the input on the server side if it's too large.\n", + "\n", + "The `truncate` parameter has three options:\n", + " - \"NONE\": The default option. An exception is thrown if the input is too large.\n", + " - \"END\": The server truncates the input from the end (right), discarding tokens as necessary." ] }, { @@ -598,7 +610,7 @@ } ], "source": [ - "ranker = NVIDIARerank()\n", + "ranker = NVIDIARerank(truncate=\"END\")\n", "\n", "all_docs = bm25_docs + sem_docs\n", "\n", diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index 03e7862e..85ce4f73 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Generator, List, Optional, Sequence +from typing import Any, Generator, List, Literal, Optional, Sequence from langchain_core.callbacks.manager import Callbacks from langchain_core.documents import Document @@ -35,6 +35,13 @@ class Config: ) top_n: int = Field(5, ge=0, description="The number of documents to return.") model: Optional[str] = Field(description="The model to use for reranking.") + truncate: Optional[Literal["NONE", "END"]] = Field( + description=( + "Truncate input text if it exceeds the model's maximum token length. " + "Default is model dependent and is likely to raise error if an " + "input is too long." + ), + ) max_batch_size: int = Field( _default_batch_size, ge=1, description="The maximum batch size." ) @@ -53,6 +60,9 @@ def __init__(self, **kwargs: Any): nvidia_api_key (str): The API key to use for connecting to the hosted NIM. api_key (str): Alternative to nvidia_api_key. base_url (str): The base URL of the NIM to connect to. + truncate (str): "NONE", "END", truncate input text if it exceeds + the model's context length. Default is model dependent and + is likely to raise an error if an input is too long. API Key: - The recommended way to provide the API key is through the `NVIDIA_API_KEY` @@ -89,13 +99,14 @@ def get_available_models( # todo: batching when len(documents) > endpoint's max batch size def _rank(self, documents: List[str], query: str) -> List[Ranking]: - response = self._client.client.get_req( - payload={ - "model": self.model, - "query": {"text": query}, - "passages": [{"text": passage} for passage in documents], - }, - ) + payload = { + "model": self.model, + "query": {"text": query}, + "passages": [{"text": passage} for passage in documents], + } + if self.truncate: + payload["truncate"] = self.truncate + response = self._client.client.get_req(payload=payload) if response.status_code != 200: response.raise_for_status() # todo: handle errors @@ -134,6 +145,9 @@ def batch(ls: list, size: int) -> Generator[List[Document], None, None]: query=query, documents=[d.page_content for d in doc_batch] ) for ranking in rankings: + assert ( + 0 <= ranking.index < len(doc_batch) + ), "invalid response from server: index out of range" doc = doc_batch[ranking.index] doc.metadata["relevance_score"] = ranking.logit results.append(doc) diff --git a/libs/ai-endpoints/tests/integration_tests/test_ranking.py b/libs/ai-endpoints/tests/integration_tests/test_ranking.py index 5937cb30..867aab64 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_ranking.py +++ b/libs/ai-endpoints/tests/integration_tests/test_ranking.py @@ -179,3 +179,41 @@ def test_rerank_batching( # result_docs[i].page_content == reference_docs[i].page_content # for i in range(top_n) # ), "batched results do not match unbatched results" + + +@pytest.mark.parametrize("truncate", ["END"]) +def test_truncate_positive(rerank_model: str, mode: dict, truncate: str) -> None: + query = "What is acceleration?" + documents = [ + Document(page_content="NVIDIA " * length) + for length in [32, 1024, 64, 128, 2048, 256, 512] + ] + client = NVIDIARerank( + model=rerank_model, top_n=len(documents), truncate=truncate, **mode + ) + response = client.compress_documents(documents=documents, query=query) + assert len(response) == len(documents) + + +@pytest.mark.parametrize("truncate", [None, "NONE"]) +@pytest.mark.xfail( + reason=( + "truncation is inconsistent across models, " + "nv-rerank-qa-mistral-4b:1 truncates by default " + "while others do not" + ) +) +def test_truncate_negative(rerank_model: str, mode: dict, truncate: str) -> None: + query = "What is acceleration?" + documents = [ + Document(page_content="NVIDIA " * length) + for length in [32, 1024, 64, 128, 2048, 256, 512] + ] + truncate_param = {} + if truncate: + truncate_param = {"truncate": truncate} + client = NVIDIARerank(model=rerank_model, **truncate_param, **mode) + with pytest.raises(Exception) as e: + client.compress_documents(documents=documents, query=query) + assert "400" in str(e.value) + assert "exceeds maximum allowed" in str(e.value) diff --git a/libs/ai-endpoints/tests/unit_tests/test_ranking.py b/libs/ai-endpoints/tests/unit_tests/test_ranking.py new file mode 100644 index 00000000..887ef3e9 --- /dev/null +++ b/libs/ai-endpoints/tests/unit_tests/test_ranking.py @@ -0,0 +1,73 @@ +from typing import Any, Literal, Optional + +import pytest +from langchain_core.documents import Document +from requests_mock import Mocker + +from langchain_nvidia_ai_endpoints import NVIDIARerank + + +@pytest.fixture(autouse=True) +def mock_v1_models(requests_mock: Mocker) -> None: + requests_mock.get( + "https://integrate.api.nvidia.com/v1/models", + json={ + "data": [ + { + "id": "mock-model", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + } + ] + }, + ) + + +@pytest.fixture(autouse=True) +def mock_v1_ranking(requests_mock: Mocker) -> None: + requests_mock.post( + "https://integrate.api.nvidia.com/v1/ranking", + json={ + "rankings": [ + {"index": 0, "logit": 4.2}, + ] + }, + ) + + +@pytest.mark.parametrize( + "truncate", + [ + None, + "END", + "NONE", + ], +) +def test_truncate( + requests_mock: Mocker, + truncate: Optional[Literal["END", "NONE"]], +) -> None: + truncate_param = {} + if truncate: + truncate_param = {"truncate": truncate} + client = NVIDIARerank(model="mock-model", **truncate_param) + response = client.compress_documents( + documents=[Document(page_content="Nothing really.")], query="What is it?" + ) + + assert len(response) == 1 + + assert requests_mock.last_request is not None + request_payload = requests_mock.last_request.json() + if truncate is None: + assert "truncate" not in request_payload + else: + assert "truncate" in request_payload + assert request_payload["truncate"] == truncate + + +@pytest.mark.parametrize("truncate", [True, False, 1, 0, 1.0, "START", "BOGUS"]) +def test_truncate_invalid(truncate: Any) -> None: + with pytest.raises(ValueError): + NVIDIARerank(truncate=truncate)