diff --git a/libs/ai-endpoints/tests/unit_tests/conftest.py b/libs/ai-endpoints/tests/unit_tests/conftest.py index 41a73905..a64c535c 100644 --- a/libs/ai-endpoints/tests/unit_tests/conftest.py +++ b/libs/ai-endpoints/tests/unit_tests/conftest.py @@ -1,5 +1,5 @@ import re -from typing import Callable, List +from typing import Callable, Generator, List import pytest import requests_mock @@ -10,6 +10,7 @@ NVIDIAEmbeddings, NVIDIARerank, ) +from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE @pytest.fixture( @@ -46,6 +47,17 @@ def mock_v1_models(requests_mock: requests_mock.Mocker, mock_model: str) -> None ) +@pytest.fixture(autouse=True) +def reset_model_table() -> Generator[None, None, None]: + """ + Reset MODEL_TABLE between tests. + """ + original = MODEL_TABLE.copy() + yield + MODEL_TABLE.clear() + MODEL_TABLE.update(original) + + @pytest.fixture def mock_streaming_response( requests_mock: requests_mock.Mocker, mock_model: str diff --git a/libs/ai-endpoints/tests/unit_tests/test_available_models.py b/libs/ai-endpoints/tests/unit_tests/test_available_models.py index 369c01b0..9ee80c0a 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_available_models.py +++ b/libs/ai-endpoints/tests/unit_tests/test_available_models.py @@ -1,10 +1,7 @@ import warnings -from typing import Any, Generator - -import pytest +from typing import Any from langchain_nvidia_ai_endpoints import Model, register_model -from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE def test_model_listing(public_class: Any, mock_model: str) -> None: @@ -14,19 +11,9 @@ def test_model_listing(public_class: Any, mock_model: str) -> None: assert any(model.id == mock_model for model in models) -@pytest.fixture -def model_table() -> Generator[None, None, None]: - """ - Reset MODEL_TABLE between tests. - """ - original = MODEL_TABLE.copy() - yield - MODEL_TABLE.clear() - MODEL_TABLE.update(original) - - def test_model_listing_hosted( - public_class: Any, mock_model: str, model_table: None + public_class: Any, + mock_model: str, ) -> None: model = Model( id=mock_model, diff --git a/libs/ai-endpoints/tests/unit_tests/test_register_model.py b/libs/ai-endpoints/tests/unit_tests/test_register_model.py index f87efa11..ee936e29 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_register_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_register_model.py @@ -36,7 +36,7 @@ def test_mismatched_type_client(model_type: str, client: str) -> None: with pytest.raises(ValueError) as e: register_model( Model( - id=f"{model_type}-{client}", + id="model", model_type=model_type, client=client, endpoint="BOGUS", @@ -56,7 +56,7 @@ def test_duplicate_model_warns() -> None: assert "Overriding" in str(record[0].message) -def test_registered_model_usable(public_class: type) -> None: +def test_registered_model_usable(public_class: type, mock_model: str) -> None: model_type = { "ChatNVIDIA": "chat", "NVIDIAEmbeddings": "embedding", @@ -65,20 +65,19 @@ def test_registered_model_usable(public_class: type) -> None: }[public_class.__name__] with warnings.catch_warnings(): warnings.simplefilter("error") - id = f"registered-model-{model_type}" model = Model( - id=id, + id=mock_model, model_type=model_type, client=public_class.__name__, endpoint="BOGUS", ) register_model(model) - x = public_class(model=id, nvidia_api_key="a-bogus-key") - assert x.model == id + x = public_class(model=mock_model, nvidia_api_key="a-bogus-key") + assert x.model == mock_model def test_registered_model_without_client_usable(public_class: type) -> None: - id = f"test/no-client-{public_class.__name__}" + id = "test/no-client" model = Model(id=id, endpoint="BOGUS") register_model(model) with pytest.warns(UserWarning) as record: @@ -156,7 +155,7 @@ def test_registered_model_is_available() -> None: def test_registered_model_without_client_is_not_listed(public_class: type) -> None: - model_name = f"test/{public_class.__name__}" + model_name = "test/model" register_model(Model(id=model_name, endpoint="BOGUS")) models = public_class.get_available_models(api_key="BOGUS") # type: ignore assert model_name not in [model.id for model in models]