From e46d14bead6cbd97e066c5e910153ba83286b3c6 Mon Sep 17 00:00:00 2001 From: raspawar Date: Fri, 26 Jul 2024 16:46:15 +0530 Subject: [PATCH] value error, test cases fix --- .../langchain_nvidia_ai_endpoints/utils.py | 8 +++++--- .../tests/integration_tests/test_register_model.py | 13 ++++--------- libs/ai-endpoints/tests/unit_tests/test_model.py | 9 ++++++--- .../tests/unit_tests/test_register_model.py | 7 +------ 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py index b20c9cc2..ad83bbb7 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/utils.py @@ -46,10 +46,12 @@ def _validate_hosted_model_compatibility( Raises: ValueError: If the model is incompatible with the client. """ - if model.client != cls_name: + if not model.client: + warnings.warn(f"Unable to determine validity of {name}") + elif model.client != cls_name: raise ValueError( f"Model {name} is incompatible with client {cls_name}. " - "Please check `available_models`." + f"Please check `{cls_name}.get_available_models()`." ) @@ -110,7 +112,7 @@ def _validate_locally_hosted_model_compatibility( if model.client != cls_name: raise ValueError( f"Model {model_name} is incompatible with client {cls_name}. " - "Please check `available_models`." + f"Please check `{cls_name}.get_available_models()`." ) if model_name not in [model.id for model in client.available_models]: diff --git a/libs/ai-endpoints/tests/integration_tests/test_register_model.py b/libs/ai-endpoints/tests/integration_tests/test_register_model.py index fa302fa1..77e5aa50 100644 --- a/libs/ai-endpoints/tests/integration_tests/test_register_model.py +++ b/libs/ai-endpoints/tests/integration_tests/test_register_model.py @@ -16,40 +16,35 @@ # you will have to find the new ones from https://api.nvcf.nvidia.com/v2/nvcf/functions # @pytest.mark.parametrize( - "client, id, endpoint, model_type", + "client, id, endpoint", [ ( ChatNVIDIA, "meta/llama3-8b-instruct", "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/a5a3ad64-ec2c-4bfc-8ef7-5636f26630fe", - "chat", ), ( NVIDIAEmbeddings, "NV-Embed-QA", "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/09c64e32-2b65-4892-a285-2f585408d118", - "embedding", ), ( NVIDIARerank, "nv-rerank-qa-mistral-4b:1", "https://api.nvcf.nvidia.com/v2/nvcf/pexec/functions/0bf77f50-5c35-4488-8e7a-f49bb1974af6", - "ranking", ), ], ) def test_registered_model_functional( - client: type, id: str, endpoint: str, model_type: str, contact_service: Any + client: type, id: str, endpoint: str, contact_service: Any ) -> None: - model = Model( - id=id, endpoint=endpoint, client=client.__name__, model_type=model_type - ) + model = Model(id=id, endpoint=endpoint) with pytest.warns( UserWarning ) as record: # warns because we're overriding known models register_model(model) contact_service(client(model=id)) - assert len(record) == 1 + assert len(record) == 2 assert isinstance(record[0].message, UserWarning) assert "already registered" in str(record[0].message) assert "Overriding" in str(record[0].message) diff --git a/libs/ai-endpoints/tests/unit_tests/test_model.py b/libs/ai-endpoints/tests/unit_tests/test_model.py index 7633a128..5e2bcae4 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_model.py @@ -129,7 +129,7 @@ def test_known_unknown(public_class: type, known_unknown: str) -> None: assert "unknown" in record[0].message.args[0] -def test_unknown_unknown(public_class: type) -> None: +def test_unknown_unknown(public_class: type, empty_v1_models: None) -> None: """ Test that a model not in /v1/models and not in known model table will be rejected. @@ -170,13 +170,16 @@ def test_hosted_all_incompatible(public_class: type, model: str, client: str) -> """ msg = ( "Model {model_name} is incompatible with client {cls_name}. " - "Please check `available_models`." + "Please check `{cls_name}.get_available_models()`." ) if client != public_class.__name__: with pytest.raises(ValueError) as err_msg: public_class(model=model, nvidia_api_key="a-bogus-key") - assert err_msg == msg.format(model_name=model, cls_name=client) + + assert msg.format(model_name=model, cls_name=public_class.__name__) in str( + err_msg.value + ) @pytest.mark.parametrize( diff --git a/libs/ai-endpoints/tests/unit_tests/test_register_model.py b/libs/ai-endpoints/tests/unit_tests/test_register_model.py index 9c4d4514..5289b176 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_register_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_register_model.py @@ -64,18 +64,13 @@ def test_registered_model_usable(public_class: type) -> None: def test_registered_model_without_client_usable(public_class: type) -> None: id = f"test/no-client-{public_class.__name__}" - incompatible_err_msg = "Model {name} is incompatible with client {cls_name}. \ - Please check `available_models`." model = Model(id=id, endpoint="BOGUS") register_model(model) # todo: this should warn that the model is known but type is not # and therefore inference may not work # Marking this as failed - with pytest.raises(ValueError) as err_msg: + with pytest.warns(UserWarning): public_class(model=id, nvidia_api_key="a-bogus-key") - assert err_msg == incompatible_err_msg.format( - name=id, cls_name=public_class.__name__ - ) def test_missing_endpoint() -> None: