Skip to content

Commit

Permalink
add nvidia_api_key and api_key params to NVIDIARerank
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Apr 20, 2024
1 parent e2de402 commit d95fd5a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
15 changes: 14 additions & 1 deletion libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,21 @@ class Config:
)

def __init__(self, **kwargs: Any):
"""
Create a new NVIDIARerank document compressor.
Unless you plan to use the "nim" mode, you need to provide an API key. Your
options are -
0. Pass the key as the nvidia_api_key parameter.
1. Pass the key as the api_key parameter.
2. Set the NVIDIA_API_KEY environment variable, recommended.
Precedence is in the order listed above.
"""
super().__init__(**kwargs)
self._client = _NVIDIAClient(model=self.model)
self._client = _NVIDIAClient(
model=self.model,
api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)),
)

@property
def available_models(self) -> List[Model]:
Expand Down
21 changes: 13 additions & 8 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import os
from contextlib import contextmanager
from typing import Generator
from typing import Any, Generator

import pytest

Expand Down Expand Up @@ -39,12 +39,17 @@ def test_create_with_api_key(cls: type, param: str) -> None:

@pytest.mark.parametrize("cls", public_classes)
def test_api_key_priority(cls: type) -> None:
# ChatNVIDIA and NVIDIAEmbeddings currently expose a client attribute
def get_api_key(instance: Any) -> str:
if isinstance(instance, langchain_nvidia_ai_endpoints.ChatNVIDIA) or isinstance(
instance, langchain_nvidia_ai_endpoints.NVIDIAEmbeddings
):
return instance.client.api_key.get_secret_value()
return instance._client.client.api_key.get_secret_value()

with no_env_var("NVIDIA_API_KEY"):
os.environ["NVIDIA_API_KEY"] = "ENV"
assert cls().client.api_key.get_secret_value() == "ENV"
assert cls(nvidia_api_key="PARAM").client.api_key.get_secret_value() == "PARAM"
assert cls(api_key="PARAM").client.api_key.get_secret_value() == "PARAM"
assert (
cls(api_key="LOW", nvidia_api_key="HIGH").client.api_key.get_secret_value()
== "HIGH"
)
assert get_api_key(cls()) == "ENV"
assert get_api_key(cls(nvidia_api_key="PARAM")) == "PARAM"
assert get_api_key(cls(api_key="PARAM")) == "PARAM"
assert get_api_key(cls(api_key="LOW", nvidia_api_key="HIGH")) == "HIGH"

0 comments on commit d95fd5a

Please sign in to comment.