diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index 73973871..358ad907 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -37,8 +37,21 @@ class Config: ) def __init__(self, **kwargs: Any): + """ + Create a new NVIDIARerank document compressor. + + Unless you plan to use the "nim" mode, you need to provide an API key. Your + options are - + 0. Pass the key as the nvidia_api_key parameter. + 1. Pass the key as the api_key parameter. + 2. Set the NVIDIA_API_KEY environment variable, recommended. + Precedence is in the order listed above. + """ super().__init__(**kwargs) - self._client = _NVIDIAClient(model=self.model) + self._client = _NVIDIAClient( + model=self.model, + api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), + ) @property def available_models(self) -> List[Model]: diff --git a/libs/ai-endpoints/tests/unit_tests/test_api_key.py b/libs/ai-endpoints/tests/unit_tests/test_api_key.py index 49cbeb4d..1018720e 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_api_key.py +++ b/libs/ai-endpoints/tests/unit_tests/test_api_key.py @@ -1,7 +1,7 @@ import inspect import os from contextlib import contextmanager -from typing import Generator +from typing import Any, Generator import pytest @@ -39,12 +39,17 @@ def test_create_with_api_key(cls: type, param: str) -> None: @pytest.mark.parametrize("cls", public_classes) def test_api_key_priority(cls: type) -> None: + # ChatNVIDIA and NVIDIAEmbeddings currently expose a client attribute + def get_api_key(instance: Any) -> str: + if isinstance(instance, langchain_nvidia_ai_endpoints.ChatNVIDIA) or isinstance( + instance, langchain_nvidia_ai_endpoints.NVIDIAEmbeddings + ): + return instance.client.api_key.get_secret_value() + return instance._client.client.api_key.get_secret_value() + with no_env_var("NVIDIA_API_KEY"): os.environ["NVIDIA_API_KEY"] = "ENV" - assert cls().client.api_key.get_secret_value() == "ENV" - assert cls(nvidia_api_key="PARAM").client.api_key.get_secret_value() == "PARAM" - assert cls(api_key="PARAM").client.api_key.get_secret_value() == "PARAM" - assert ( - cls(api_key="LOW", nvidia_api_key="HIGH").client.api_key.get_secret_value() - == "HIGH" - ) + assert get_api_key(cls()) == "ENV" + assert get_api_key(cls(nvidia_api_key="PARAM")) == "PARAM" + assert get_api_key(cls(api_key="PARAM")) == "PARAM" + assert get_api_key(cls(api_key="LOW", nvidia_api_key="HIGH")) == "HIGH"