Skip to content

Commit

Permalink
reset MODEL_TABLE between all unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattf committed Sep 24, 2024
1 parent f64f520 commit e3af181
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
14 changes: 13 additions & 1 deletion libs/ai-endpoints/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Callable, List
from typing import Callable, Generator, List

import pytest
import requests_mock
Expand All @@ -10,6 +10,7 @@
NVIDIAEmbeddings,
NVIDIARerank,
)
from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE


@pytest.fixture(
Expand Down Expand Up @@ -46,6 +47,17 @@ def mock_v1_models(requests_mock: requests_mock.Mocker, mock_model: str) -> None
)


@pytest.fixture(autouse=True)
def reset_model_table() -> Generator[None, None, None]:
"""
Reset MODEL_TABLE between tests.
"""
original = MODEL_TABLE.copy()
yield
MODEL_TABLE.clear()
MODEL_TABLE.update(original)


@pytest.fixture
def mock_streaming_response(
requests_mock: requests_mock.Mocker, mock_model: str
Expand Down
19 changes: 3 additions & 16 deletions libs/ai-endpoints/tests/unit_tests/test_available_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import warnings
from typing import Any, Generator

import pytest
from typing import Any

from langchain_nvidia_ai_endpoints import Model, register_model
from langchain_nvidia_ai_endpoints._statics import MODEL_TABLE


def test_model_listing(public_class: Any, mock_model: str) -> None:
Expand All @@ -14,19 +11,9 @@ def test_model_listing(public_class: Any, mock_model: str) -> None:
assert any(model.id == mock_model for model in models)


@pytest.fixture
def model_table() -> Generator[None, None, None]:
"""
Reset MODEL_TABLE between tests.
"""
original = MODEL_TABLE.copy()
yield
MODEL_TABLE.clear()
MODEL_TABLE.update(original)


def test_model_listing_hosted(
public_class: Any, mock_model: str, model_table: None
public_class: Any,
mock_model: str,
) -> None:
model = Model(
id=mock_model,
Expand Down
15 changes: 7 additions & 8 deletions libs/ai-endpoints/tests/unit_tests/test_register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_mismatched_type_client(model_type: str, client: str) -> None:
with pytest.raises(ValueError) as e:
register_model(
Model(
id=f"{model_type}-{client}",
id="model",
model_type=model_type,
client=client,
endpoint="BOGUS",
Expand All @@ -56,7 +56,7 @@ def test_duplicate_model_warns() -> None:
assert "Overriding" in str(record[0].message)


def test_registered_model_usable(public_class: type) -> None:
def test_registered_model_usable(public_class: type, mock_model: str) -> None:
model_type = {
"ChatNVIDIA": "chat",
"NVIDIAEmbeddings": "embedding",
Expand All @@ -65,20 +65,19 @@ def test_registered_model_usable(public_class: type) -> None:
}[public_class.__name__]
with warnings.catch_warnings():
warnings.simplefilter("error")
id = f"registered-model-{model_type}"
model = Model(
id=id,
id=mock_model,
model_type=model_type,
client=public_class.__name__,
endpoint="BOGUS",
)
register_model(model)
x = public_class(model=id, nvidia_api_key="a-bogus-key")
assert x.model == id
x = public_class(model=mock_model, nvidia_api_key="a-bogus-key")
assert x.model == mock_model


def test_registered_model_without_client_usable(public_class: type) -> None:
id = f"test/no-client-{public_class.__name__}"
id = "test/no-client"
model = Model(id=id, endpoint="BOGUS")
register_model(model)
with pytest.warns(UserWarning) as record:
Expand Down Expand Up @@ -156,7 +155,7 @@ def test_registered_model_is_available() -> None:


def test_registered_model_without_client_is_not_listed(public_class: type) -> None:
model_name = f"test/{public_class.__name__}"
model_name = "test/model"
register_model(Model(id=model_name, endpoint="BOGUS"))
models = public_class.get_available_models(api_key="BOGUS") # type: ignore
assert model_name not in [model.id for model in models]

0 comments on commit e3af181

Please sign in to comment.