Skip to content

Commit

Permalink
enforce api key standards for all public interfaces
Browse files Browse the repository at this point in the history
three ways to provide an API key (priority order, higher priority overrides lower)
 0. nvidia_api_key parameter to instance constructors
 1. api_key parameter to instance constructors
 2. NVIDIA_API_KEY environment variable, recommended approach
  • Loading branch information
mattf committed Apr 19, 2024
1 parent 4e0eb61 commit 9a340ee
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
33 changes: 26 additions & 7 deletions libs/ai-endpoints/tests/integration_tests/test_api_key.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
import inspect
import os

import pytest

from langchain_nvidia_ai_endpoints import ChatNVIDIA
import langchain_nvidia_ai_endpoints

from ..unit_tests.test_api_key import no_env_var

public_classes = [
member[1]
for member in inspect.getmembers(langchain_nvidia_ai_endpoints, inspect.isclass)
]


def test_missing_api_key_error() -> None:
@pytest.mark.parametrize("cls", public_classes)
def test_missing_api_key_error(cls: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
chat = ChatNVIDIA()
client = cls()
with pytest.raises(ValueError) as exc_info:
chat.invoke("Hello, world!")
client.available_models
message = str(exc_info.value)
assert "401" in message
assert "Unauthorized" in message
assert "API key" in message


def test_bogus_api_key_error() -> None:
@pytest.mark.parametrize("cls", public_classes)
def test_bogus_api_key_error(cls: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
chat = ChatNVIDIA(nvidia_api_key="BOGUS")
client = cls(nvidia_api_key="BOGUS")
with pytest.raises(ValueError) as exc_info:
chat.invoke("Hello, world!")
client.available_models
message = str(exc_info.value)
assert "401" in message
assert "Unauthorized" in message
assert "API key" in message


@pytest.mark.parametrize("cls", public_classes)
@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"])
def test_api_key(cls: type, param: str) -> None:
api_key = os.environ.get("NVIDIA_API_KEY")
with no_env_var("NVIDIA_API_KEY"):
client = cls(**{param: api_key})
assert client.available_models
40 changes: 32 additions & 8 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
import inspect
import os
from contextlib import contextmanager
from typing import Generator

from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
import pytest

import langchain_nvidia_ai_endpoints

public_classes = [
member[1]
for member in inspect.getmembers(langchain_nvidia_ai_endpoints, inspect.isclass)
]


@contextmanager
def no_env_var(var: str) -> Generator[None, None, None]:
try:
if key := os.environ.get(var, None):
if val := os.environ.get(var, None):
del os.environ[var]
yield
finally:
if key:
os.environ[var] = key
if val:
os.environ[var] = val


@pytest.mark.parametrize("cls", public_classes)
def test_create_without_api_key(cls: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
cls()


def test_create_chat_without_api_key() -> None:
@pytest.mark.parametrize("cls", public_classes)
@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"])
def test_create_with_api_key(cls: type, param: str) -> None:
with no_env_var("NVIDIA_API_KEY"):
ChatNVIDIA()
cls(**{param: "just testing no failure"})


def test_create_embeddings_without_api_key() -> None:
@pytest.mark.parametrize("cls", public_classes)
def test_api_key_priority(cls: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
NVIDIAEmbeddings()
os.environ["NVIDIA_API_KEY"] = "ENV"
assert cls().client.api_key.get_secret_value() == "ENV"
assert cls(nvidia_api_key="PARAM").client.api_key.get_secret_value() == "PARAM"
assert cls(api_key="PARAM").client.api_key.get_secret_value() == "PARAM"
assert (
cls(api_key="LOW", nvidia_api_key="HIGH").client.api_key.get_secret_value()
== "HIGH"
)

0 comments on commit 9a340ee

Please sign in to comment.