Skip to content

Commit

Permalink
value error, test cases fix
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Jul 26, 2024
1 parent 7306a67 commit e46d14b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 21 deletions.
8 changes: 5 additions & 3 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ def _validate_hosted_model_compatibility(
Raises:
ValueError: If the model is incompatible with the client.
"""
if model.client != cls_name:
if not model.client:
warnings.warn(f"Unable to determine validity of {name}")
elif model.client != cls_name:
raise ValueError(
f"Model {name} is incompatible with client {cls_name}. "
"Please check `available_models`."
f"Please check `{cls_name}.get_available_models()`."
)


Expand Down Expand Up @@ -110,7 +112,7 @@ def _validate_locally_hosted_model_compatibility(
if model.client != cls_name:
raise ValueError(
f"Model {model_name} is incompatible with client {cls_name}. "
"Please check `available_models`."
f"Please check `{cls_name}.get_available_models()`."
)

if model_name not in [model.id for model in client.available_models]:
Expand Down
13 changes: 4 additions & 9 deletions libs/ai-endpoints/tests/integration_tests/test_register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,35 @@
# you will have to find the new ones from https://api.nvcf.nvidia.com/v2/nvcf/functions
#
@pytest.mark.parametrize(
"client, id, endpoint, model_type",
"client, id, endpoint",
[
(
ChatNVIDIA,
"meta/llama3-8b-instruct",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/a5a3ad64-ec2c-4bfc-8ef7-5636f26630fe",
"chat",
),
(
NVIDIAEmbeddings,
"NV-Embed-QA",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/09c64e32-2b65-4892-a285-2f585408d118",
"embedding",
),
(
NVIDIARerank,
"nv-rerank-qa-mistral-4b:1",
"https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/0bf77f50-5c35-4488-8e7a-f49bb1974af6",
"ranking",
),
],
)
def test_registered_model_functional(
client: type, id: str, endpoint: str, model_type: str, contact_service: Any
client: type, id: str, endpoint: str, contact_service: Any
) -> None:
model = Model(
id=id, endpoint=endpoint, client=client.__name__, model_type=model_type
)
model = Model(id=id, endpoint=endpoint)
with pytest.warns(
UserWarning
) as record: # warns because we're overriding known models
register_model(model)
contact_service(client(model=id))
assert len(record) == 1
assert len(record) == 2
assert isinstance(record[0].message, UserWarning)
assert "already registered" in str(record[0].message)
assert "Overriding" in str(record[0].message)
Expand Down
9 changes: 6 additions & 3 deletions libs/ai-endpoints/tests/unit_tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_known_unknown(public_class: type, known_unknown: str) -> None:
assert "unknown" in record[0].message.args[0]


def test_unknown_unknown(public_class: type) -> None:
def test_unknown_unknown(public_class: type, empty_v1_models: None) -> None:
"""
Test that a model not in /v1/models and not in known model table will be
rejected.
Expand Down Expand Up @@ -170,13 +170,16 @@ def test_hosted_all_incompatible(public_class: type, model: str, client: str) ->
"""
msg = (
"Model {model_name} is incompatible with client {cls_name}. "
"Please check `available_models`."
"Please check `{cls_name}.get_available_models()`."
)

if client != public_class.__name__:
with pytest.raises(ValueError) as err_msg:
public_class(model=model, nvidia_api_key="a-bogus-key")
assert err_msg == msg.format(model_name=model, cls_name=client)

assert msg.format(model_name=model, cls_name=public_class.__name__) in str(
err_msg.value
)


@pytest.mark.parametrize(
Expand Down
7 changes: 1 addition & 6 deletions libs/ai-endpoints/tests/unit_tests/test_register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,13 @@ def test_registered_model_usable(public_class: type) -> None:

def test_registered_model_without_client_usable(public_class: type) -> None:
id = f"test/no-client-{public_class.__name__}"
incompatible_err_msg = "Model {name} is incompatible with client {cls_name}. \
Please check `available_models`."
model = Model(id=id, endpoint="BOGUS")
register_model(model)
# todo: this should warn that the model is known but type is not
# and therefore inference may not work
# Marking this as failed
with pytest.raises(ValueError) as err_msg:
with pytest.warns(UserWarning):
public_class(model=id, nvidia_api_key="a-bogus-key")
assert err_msg == incompatible_err_msg.format(
name=id, cls_name=public_class.__name__
)


def test_missing_endpoint() -> None:
Expand Down

0 comments on commit e46d14b

Please sign in to comment.