-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #75 from langchain-ai/mattf/add-truncate-to-reranking
add support for reranking api change w/ truncate parameter
- Loading branch information
Showing
4 changed files
with
147 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Any, Literal, Optional | ||
|
||
import pytest | ||
from langchain_core.documents import Document | ||
from requests_mock import Mocker | ||
|
||
from langchain_nvidia_ai_endpoints import NVIDIARerank | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_v1_models(requests_mock: Mocker) -> None: | ||
requests_mock.get( | ||
"https://integrate.api.nvidia.com/v1/models", | ||
json={ | ||
"data": [ | ||
{ | ||
"id": "mock-model", | ||
"object": "model", | ||
"created": 1234567890, | ||
"owned_by": "OWNER", | ||
} | ||
] | ||
}, | ||
) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_v1_ranking(requests_mock: Mocker) -> None: | ||
requests_mock.post( | ||
"https://integrate.api.nvidia.com/v1/ranking", | ||
json={ | ||
"rankings": [ | ||
{"index": 0, "logit": 4.2}, | ||
] | ||
}, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"truncate", | ||
[ | ||
None, | ||
"END", | ||
"NONE", | ||
], | ||
) | ||
def test_truncate( | ||
requests_mock: Mocker, | ||
truncate: Optional[Literal["END", "NONE"]], | ||
) -> None: | ||
truncate_param = {} | ||
if truncate: | ||
truncate_param = {"truncate": truncate} | ||
client = NVIDIARerank(model="mock-model", **truncate_param) | ||
response = client.compress_documents( | ||
documents=[Document(page_content="Nothing really.")], query="What is it?" | ||
) | ||
|
||
assert len(response) == 1 | ||
|
||
assert requests_mock.last_request is not None | ||
request_payload = requests_mock.last_request.json() | ||
if truncate is None: | ||
assert "truncate" not in request_payload | ||
else: | ||
assert "truncate" in request_payload | ||
assert request_payload["truncate"] == truncate | ||
|
||
|
||
@pytest.mark.parametrize("truncate", [True, False, 1, 0, 1.0, "START", "BOGUS"]) | ||
def test_truncate_invalid(truncate: Any) -> None: | ||
with pytest.raises(ValueError): | ||
NVIDIARerank(truncate=truncate) |