Skip to content

Commit

Permalink
refactor: eliminate blind pytest.warns
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Jul 31, 2024
1 parent 923de96 commit f5e1622
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 7 deletions.
4 changes: 3 additions & 1 deletion libs/ai-endpoints/tests/integration_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

def test_missing_api_key_error(public_class: type, contact_service: Any) -> None:
with no_env_var("NVIDIA_API_KEY"):
with pytest.warns(UserWarning):
with pytest.warns(UserWarning) as record:
client = public_class()
assert len(record) == 1
assert "API key is required for the hosted" in str(record[0].message)
with pytest.raises(Exception) as exc_info:
contact_service(client)
message = str(exc_info.value)
Expand Down
11 changes: 8 additions & 3 deletions libs/ai-endpoints/tests/unit_tests/test_api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,18 @@ def mock_v1_local_models(requests_mock: Mocker) -> None:

def test_create_without_api_key(public_class: type) -> None:
with no_env_var("NVIDIA_API_KEY"):
with pytest.warns(UserWarning):
with pytest.warns(UserWarning) as record:
public_class()
assert len(record) == 1
assert "API key is required for the hosted" in str(record[0].message)


def test_create_unknown_url_no_api_key(public_class: type) -> None:
with no_env_var("NVIDIA_API_KEY") and pytest.warns(UserWarning):
public_class(base_url="https://test_url/v1")
with no_env_var("NVIDIA_API_KEY"):
with pytest.warns(UserWarning) as record:
public_class(base_url="https://test_url/v1")
assert len(record) == 1
assert "Default model is set as" in str(record[0].message)


@pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"])
Expand Down
4 changes: 3 additions & 1 deletion libs/ai-endpoints/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def embedding(requests_mock: Mocker) -> Generator[NVIDIAEmbeddings, None, None]:
"usage": {"prompt_tokens": 8, "total_tokens": 8},
},
)
with pytest.warns(UserWarning):
with pytest.warns(UserWarning) as record:
yield NVIDIAEmbeddings(model=model, nvidia_api_key="a-bogus-key")
assert len(record) == 1
assert "type is unknown and inference may fail" in str(record[0].message)


def test_embed_documents_negative_input_int(embedding: NVIDIAEmbeddings) -> None:
Expand Down
4 changes: 3 additions & 1 deletion libs/ai-endpoints/tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ def test_default_known(public_class: type, known_unknown: str) -> None:
Test that a model in the model table will be accepted.
"""
# check if default model is getting set
with pytest.warns(UserWarning):
with pytest.warns(UserWarning) as record:
x = public_class(base_url="http://localhost:8000/v1")
assert x.model == known_unknown
assert len(record) == 1
assert "Default model is set as: mock-model" in str(record[0].message)


def test_default_lora(public_class: type) -> None:
Expand Down
4 changes: 3 additions & 1 deletion libs/ai-endpoints/tests/unit_tests/test_statics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ def test_model_table_integrity_name_id(entry: str) -> None:


def test_determine_model_deprecated_alternative_warns(alias: str) -> None:
with pytest.warns(UserWarning):
with pytest.warns(UserWarning) as record:
determine_model(alias)
assert len(record) == 1
assert f"Model {alias} is deprecated" in str(record[0].message)

0 comments on commit f5e1622

Please sign in to comment.