Skip to content

Commit

Permalink
Merge pull request #69 from langchain-ai/mattf/update-ranking-support
Browse files Browse the repository at this point in the history
update ranking nim support for nims w/ /v1/models and multiple names
  • Loading branch information
mattf authored Jul 16, 2024
2 parents c6c0c33 + 13821d6 commit 8125dd2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 12 deletions.
3 changes: 1 addition & 2 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class Config:
_client: _NVIDIAClient = PrivateAttr(_NVIDIAClient)

_default_batch_size: int = 32
_deprecated_model: str = "ai-rerank-qa-mistral-4b"
_default_model_name: str = "nv-rerank-qa-mistral-4b:1"

base_url: str = Field(
Expand Down Expand Up @@ -92,7 +91,7 @@ def get_available_models(
def _rank(self, documents: List[str], query: str) -> List[Ranking]:
response = self._client.client.get_req(
payload={
"model": "nv-rerank-qa-mistral-4b:1",
"model": self.model,
"query": {"text": query},
"passages": [{"text": passage} for passage in documents],
},
Expand Down
12 changes: 2 additions & 10 deletions libs/ai-endpoints/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,11 @@ def get_all_known_models() -> List[Model]:
metafunc.parametrize("chat_model", models, ids=models)

if "rerank_model" in metafunc.fixturenames:
models = ["nv-rerank-qa-mistral-4b:1"]
models = [NVIDIARerank._default_model_name]
if model := metafunc.config.getoption("rerank_model_id"):
models = [model]
# nim-mode reranking does not support model listing via /v1/models endpoint
if metafunc.config.getoption("all_models"):
if mode.get("mode", None) == "nim":
models = [model.id for model in NVIDIARerank(**mode).available_models]
else:
models = [
model.id
for model in get_all_known_models()
if model.model_type == "ranking"
]
models = [model.id for model in NVIDIARerank(**mode).available_models]
metafunc.parametrize("rerank_model", models, ids=models)

if "vlm_model" in metafunc.fixturenames:
Expand Down

0 comments on commit 8125dd2

Please sign in to comment.