-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enforce api key standards for all public interfaces
three ways to provide an API key (priority order, higher priority overrides lower) 0. nvidia_api_key parameter to instance constructors 1. api_key parameter to instance constructors 2. NVIDIA_API_KEY environment variable, recommended approach
- Loading branch information
Showing
2 changed files
with
58 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,46 @@ | ||
import inspect | ||
import os | ||
|
||
import pytest | ||
|
||
from langchain_nvidia_ai_endpoints import ChatNVIDIA | ||
import langchain_nvidia_ai_endpoints | ||
|
||
from ..unit_tests.test_api_key import no_env_var | ||
|
||
public_classes = [ | ||
member[1] | ||
for member in inspect.getmembers(langchain_nvidia_ai_endpoints, inspect.isclass) | ||
] | ||
|
||
|
||
def test_missing_api_key_error() -> None: | ||
@pytest.mark.parametrize("cls", public_classes) | ||
def test_missing_api_key_error(cls: type) -> None: | ||
with no_env_var("NVIDIA_API_KEY"): | ||
chat = ChatNVIDIA() | ||
client = cls() | ||
with pytest.raises(ValueError) as exc_info: | ||
chat.invoke("Hello, world!") | ||
client.available_models | ||
message = str(exc_info.value) | ||
assert "401" in message | ||
assert "Unauthorized" in message | ||
assert "API key" in message | ||
|
||
|
||
def test_bogus_api_key_error() -> None: | ||
@pytest.mark.parametrize("cls", public_classes) | ||
def test_bogus_api_key_error(cls: type) -> None: | ||
with no_env_var("NVIDIA_API_KEY"): | ||
chat = ChatNVIDIA(nvidia_api_key="BOGUS") | ||
client = cls(nvidia_api_key="BOGUS") | ||
with pytest.raises(ValueError) as exc_info: | ||
chat.invoke("Hello, world!") | ||
client.available_models | ||
message = str(exc_info.value) | ||
assert "401" in message | ||
assert "Unauthorized" in message | ||
assert "API key" in message | ||
|
||
|
||
@pytest.mark.parametrize("cls", public_classes) | ||
@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"]) | ||
def test_api_key(cls: type, param: str) -> None: | ||
api_key = os.environ.get("NVIDIA_API_KEY") | ||
with no_env_var("NVIDIA_API_KEY"): | ||
client = cls(**{param: api_key}) | ||
assert client.available_models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,50 @@ | ||
import inspect | ||
import os | ||
from contextlib import contextmanager | ||
from typing import Generator | ||
|
||
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | ||
import pytest | ||
|
||
import langchain_nvidia_ai_endpoints | ||
|
||
public_classes = [ | ||
member[1] | ||
for member in inspect.getmembers(langchain_nvidia_ai_endpoints, inspect.isclass) | ||
] | ||
|
||
|
||
@contextmanager | ||
def no_env_var(var: str) -> Generator[None, None, None]: | ||
try: | ||
if key := os.environ.get(var, None): | ||
if val := os.environ.get(var, None): | ||
del os.environ[var] | ||
yield | ||
finally: | ||
if key: | ||
os.environ[var] = key | ||
if val: | ||
os.environ[var] = val | ||
|
||
|
||
@pytest.mark.parametrize("cls", public_classes) | ||
def test_create_without_api_key(cls: type) -> None: | ||
with no_env_var("NVIDIA_API_KEY"): | ||
cls() | ||
|
||
|
||
def test_create_chat_without_api_key() -> None: | ||
@pytest.mark.parametrize("cls", public_classes) | ||
@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"]) | ||
def test_create_with_api_key(cls: type, param: str) -> None: | ||
with no_env_var("NVIDIA_API_KEY"): | ||
ChatNVIDIA() | ||
cls(**{param: "just testing no failure"}) | ||
|
||
|
||
def test_create_embeddings_without_api_key() -> None: | ||
@pytest.mark.parametrize("cls", public_classes) | ||
def test_api_key_priority(cls: type) -> None: | ||
with no_env_var("NVIDIA_API_KEY"): | ||
NVIDIAEmbeddings() | ||
os.environ["NVIDIA_API_KEY"] = "ENV" | ||
assert cls().client.api_key.get_secret_value() == "ENV" | ||
assert cls(nvidia_api_key="PARAM").client.api_key.get_secret_value() == "PARAM" | ||
assert cls(api_key="PARAM").client.api_key.get_secret_value() == "PARAM" | ||
assert ( | ||
cls(api_key="LOW", nvidia_api_key="HIGH").client.api_key.get_secret_value() | ||
== "HIGH" | ||
) |