From 13821d6086b52aabbf07b1194658b3fd36ccdfa9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 11 Jul 2024 13:18:32 -0400 Subject: [PATCH] update ranking nim support for nims w/ /v1/models and multiple names --- .../langchain_nvidia_ai_endpoints/reranking.py | 3 +-- .../ai-endpoints/tests/integration_tests/conftest.py | 12 ++---------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index 17e5dd01..03e7862e 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -27,7 +27,6 @@ class Config: _client: _NVIDIAClient = PrivateAttr(_NVIDIAClient) _default_batch_size: int = 32 - _deprecated_model: str = "ai-rerank-qa-mistral-4b" _default_model_name: str = "nv-rerank-qa-mistral-4b:1" base_url: str = Field( @@ -92,7 +91,7 @@ def get_available_models( def _rank(self, documents: List[str], query: str) -> List[Ranking]: response = self._client.client.get_req( payload={ - "model": "nv-rerank-qa-mistral-4b:1", + "model": self.model, "query": {"text": query}, "passages": [{"text": passage} for passage in documents], }, diff --git a/libs/ai-endpoints/tests/integration_tests/conftest.py b/libs/ai-endpoints/tests/integration_tests/conftest.py index 30388a96..d05ae598 100644 --- a/libs/ai-endpoints/tests/integration_tests/conftest.py +++ b/libs/ai-endpoints/tests/integration_tests/conftest.py @@ -73,19 +73,11 @@ def get_all_known_models() -> List[Model]: metafunc.parametrize("chat_model", models, ids=models) if "rerank_model" in metafunc.fixturenames: - models = ["nv-rerank-qa-mistral-4b:1"] + models = [NVIDIARerank._default_model_name] if model := metafunc.config.getoption("rerank_model_id"): models = [model] - # nim-mode reranking does not support model listing via /v1/models endpoint if metafunc.config.getoption("all_models"): - if mode.get("mode", None) == "nim": - models = [model.id for model in NVIDIARerank(**mode).available_models] - else: - models = [ - model.id - for model in get_all_known_models() - if model.model_type == "ranking" - ] + models = [model.id for model in NVIDIARerank(**mode).available_models] metafunc.parametrize("rerank_model", models, ids=models) if "vlm_model" in metafunc.fixturenames: