From 99c2427206fc76369072fe0c5925b8aaa74b82d0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 11 Jun 2024 12:37:10 +0000 Subject: [PATCH 1/6] add default model code --- .../langchain_nvidia_ai_endpoints/_common.py | 23 +++++++++++++++---- .../chat_models.py | 2 +- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 5b1e7ae0..b7017adf 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -107,6 +107,7 @@ class Config: description="Headers template must contain `call` and `stream` keys.", ) _available_models: Optional[List[Model]] = PrivateAttr(default=None) + _default_model: Optional[Model] = PrivateAttr(default=None) @classmethod def is_lc_serializable(cls) -> bool: @@ -156,6 +157,12 @@ def __add_authorization(self, payload: dict) -> dict: ) return payload + @property + def default_model(self) -> Optional[Model]: + if self._available_models is None: + self.available_models + return self._available_models[0] if self._available_models else None + @property def available_models(self) -> list[Model]: """List the available models that can be invoked.""" @@ -510,7 +517,7 @@ class _NVIDIAClient(BaseModel): client: NVEModel = Field(NVEModel) - model: str = Field(..., description="Name of the model to invoke") + model: Optional[str] = Field(..., description="Name of the model to invoke") is_hosted: bool = Field(True) #################################################################################### @@ -529,6 +536,17 @@ def _preprocess_args(cls, values: Any) -> Any: @root_validator def _postprocess_args(cls, values: Any) -> Any: + name = values.get("model") + if not name: + # set default model + name = values.get("client").default_model.id + values["model"] = name + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + ) if values["is_hosted"]: if not values["client"].api_key: warnings.warn( @@ -536,8 +554,6 @@ def _postprocess_args(cls, values: Any) -> Any: "This will become an error in the future.", UserWarning, ) - - name = values.get("model") if model := determine_model(name): values["model"] = model.id # not all models are on https://integrate.api.nvidia.com/v1, @@ -558,7 +574,6 @@ def _postprocess_args(cls, values: Any) -> Any: raise ValueError( f"Model {name} is unknown, check `available_models`" ) - return values @classmethod diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index eba47749..2481e351 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -136,7 +136,7 @@ class ChatNVIDIA(BaseChatModel): "https://integrate.api.nvidia.com/v1", description="Base url for model listing an invocation", ) - model: str = Field(_default_model, description="Name of the model to invoke") + model: Optional[str] = Field(..., description="Name of the model to invoke") temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]") max_tokens: Optional[int] = Field( 1024, description="Maximum # of tokens to generate" From 21fde67b12a56859de63de60bcf771a86441f9e1 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 11 Jun 2024 12:47:48 +0000 Subject: [PATCH 2/6] change default model to local mode only --- .../langchain_nvidia_ai_endpoints/_common.py | 22 ++++++++++--------- .../chat_models.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index b7017adf..17fbb1c8 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -537,16 +537,6 @@ def _preprocess_args(cls, values: Any) -> Any: @root_validator def _postprocess_args(cls, values: Any) -> Any: name = values.get("model") - if not name: - # set default model - name = values.get("client").default_model.id - values["model"] = name - warnings.warn( - f"Default model is set as: {name}. \n" - "Set model using model parameter. \n" - "To get available models use available_models property.", - UserWarning, - ) if values["is_hosted"]: if not values["client"].api_key: warnings.warn( @@ -574,6 +564,18 @@ def _postprocess_args(cls, values: Any) -> Any: raise ValueError( f"Model {name} is unknown, check `available_models`" ) + else: + if not name: + # set default model + name = values.get("client").default_model.id + values["model"] = name + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + ) + return values @classmethod diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 2481e351..445a2edf 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -136,7 +136,7 @@ class ChatNVIDIA(BaseChatModel): "https://integrate.api.nvidia.com/v1", description="Base url for model listing an invocation", ) - model: Optional[str] = Field(..., description="Name of the model to invoke") + model: Optional[str] = Field(description="Name of the model to invoke") temperature: Optional[float] = Field(description="Sampling temperature in [0, 1]") max_tokens: Optional[int] = Field( 1024, description="Maximum # of tokens to generate" From 7d7fceb5a0b35ff3c212877cb2d7aa8ce495d1d8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 24 Jun 2024 13:36:54 +0000 Subject: [PATCH 3/6] test cases, other models --- .../langchain_nvidia_ai_endpoints/_common.py | 29 ++++++++---- .../langchain_nvidia_ai_endpoints/_statics.py | 1 + .../chat_models.py | 1 + .../embeddings.py | 3 +- .../reranking.py | 5 +-- .../tests/unit_tests/test_api_key.py | 4 +- .../tests/unit_tests/test_base_url.py | 6 ++- .../tests/unit_tests/test_model.py | 44 +++++++++++++++++++ 8 files changed, 77 insertions(+), 16 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 17fbb1c8..551445a6 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -107,7 +107,6 @@ class Config: description="Headers template must contain `call` and `stream` keys.", ) _available_models: Optional[List[Model]] = PrivateAttr(default=None) - _default_model: Optional[Model] = PrivateAttr(default=None) @classmethod def is_lc_serializable(cls) -> bool: @@ -157,12 +156,6 @@ def __add_authorization(self, payload: dict) -> dict: ) return payload - @property - def default_model(self) -> Optional[Model]: - if self._available_models is None: - self.available_models - return self._available_models[0] if self._available_models else None - @property def available_models(self) -> list[Model]: """List the available models that can be invoked.""" @@ -192,6 +185,13 @@ def available_models(self) -> list[Model]: # so we'll let it through. use of this model will be # accompanied by a warning. model = Model(id=element["id"]) + + # add base model for local-nim mode + model.base_model = element.get("root") + + if model.base_model and model.id != model.base_model: + model.model_type = "lora" + self._available_models.append(model) return self._available_models @@ -532,6 +532,10 @@ def _preprocess_args(cls, values: Any) -> Any: "ai.api.nvidia.com", ] + # set default model for hosted endpoint + if values["is_hosted"] and not values["model"]: + values["model"] = values["default_model"] + return values @root_validator @@ -567,7 +571,7 @@ def _postprocess_args(cls, values: Any) -> Any: else: if not name: # set default model - name = values.get("client").default_model.id + name = values.get("client").available_models[0].id values["model"] = name warnings.warn( f"Default model is set as: {name}. \n" @@ -575,7 +579,7 @@ def _postprocess_args(cls, values: Any) -> Any: "To get available models use available_models property.", UserWarning, ) - + return values @classmethod @@ -603,6 +607,13 @@ def get_available_models( **kwargs: Any, ) -> List[Model]: """Retrieve a list of available models.""" + + # set client for lora models in local-nim mode + if not self.is_hosted: + for model in self.client.available_models: + if model.model_type == "lora": + model.client = filter + available = [ model for model in self.client.available_models if model.client == filter ] diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py index 5703d4f0..01e506bf 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_statics.py @@ -20,6 +20,7 @@ class Model(BaseModel): client: Optional[str] = None endpoint: Optional[str] = None aliases: Optional[list] = None + base_model: Optional[str] = None def __hash__(self) -> int: return hash(self.id) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 445a2edf..51a9be59 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -173,6 +173,7 @@ def __init__(self, **kwargs: Any): self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, + default_model=self._default_model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/chat/completions", ) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py index 767c5402..a151c5dd 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py @@ -33,7 +33,7 @@ class Config: "https://integrate.api.nvidia.com/v1", description="Base url for model listing an invocation", ) - model: str = Field(_default_model, description="Name of the model to invoke") + model: Optional[str] = Field(description="Name of the model to invoke") truncate: Literal["NONE", "START", "END"] = Field( default="NONE", description=( @@ -72,6 +72,7 @@ def __init__(self, **kwargs: Any): self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, + default_model=self._default_model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/embeddings", ) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index bde7c844..17e5dd01 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -35,9 +35,7 @@ class Config: description="Base url for model listing an invocation", ) top_n: int = Field(5, ge=0, description="The number of documents to return.") - model: str = Field( - _default_model_name, description="The model to use for reranking." - ) + model: Optional[str] = Field(description="The model to use for reranking.") max_batch_size: int = Field( _default_batch_size, ge=1, description="The maximum batch size." ) @@ -65,6 +63,7 @@ def __init__(self, **kwargs: Any): self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, + default_model=self._default_model_name, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), infer_path="{base_url}/ranking", ) diff --git a/libs/ai-endpoints/tests/unit_tests/test_api_key.py b/libs/ai-endpoints/tests/unit_tests/test_api_key.py index c1d23324..e9b9c148 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_api_key.py +++ b/libs/ai-endpoints/tests/unit_tests/test_api_key.py @@ -3,6 +3,7 @@ from typing import Any, Generator import pytest +import requests from langchain_core.pydantic_v1 import SecretStr @@ -25,7 +26,8 @@ def test_create_without_api_key(public_class: type) -> None: def test_create_unknown_url_no_api_key(public_class: type) -> None: with no_env_var("NVIDIA_API_KEY"): - public_class(base_url="https://test_url/v1") + with pytest.raises(requests.exceptions.ConnectionError): + public_class(base_url="https://test_url/v1") @pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"]) diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index f5ee84d6..508611ce 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -1,4 +1,5 @@ import pytest +import requests @pytest.mark.parametrize( @@ -33,5 +34,6 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: - client = public_class(base_url=base_url) - assert not client._client.is_hosted + with pytest.raises(requests.exceptions.ConnectionError): + client = public_class(base_url=base_url) + assert not client._client.is_hosted diff --git a/libs/ai-endpoints/tests/unit_tests/test_model.py b/libs/ai-endpoints/tests/unit_tests/test_model.py index d67977a9..4f06417b 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_model.py +++ b/libs/ai-endpoints/tests/unit_tests/test_model.py @@ -26,6 +26,31 @@ def mock_v1_models(requests_mock: Mocker, known_unknown: str) -> None: ) +@pytest.fixture(autouse=True) +def mock_v1_local_models(requests_mock: Mocker, known_unknown: str) -> None: + requests_mock.get( + "http://localhost:8000/v1/models", + json={ + "data": [ + { + "id": known_unknown, + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": known_unknown, + }, + { + "id": "lora1", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": known_unknown, + }, + ] + }, + ) + + @pytest.mark.parametrize( "alias", [ @@ -84,3 +109,22 @@ def test_unknown_unknown(public_class: type) -> None: with pytest.raises(ValueError) as e: public_class(model="test/unknown-unknown", nvidia_api_key="a-bogus-key") assert "unknown" in str(e.value) + + +def test_default_known(public_class: type, known_unknown: str) -> None: + """ + Test that a model in the model table will be accepted. + """ + # check if default model is getting set + with pytest.warns(UserWarning): + x = public_class(base_url="http://localhost:8000/v1") + assert x.model == known_unknown + + +def test_default_lora(public_class: type) -> None: + """ + Test that a model in the model table will be accepted. + """ + # find a model that matches the public_class under test + x = public_class(base_url="http://localhost:8000/v1", model="lora1") + assert x.model == "lora1" From d553c9d9b06c53fba0fbe4364ac3f9ecea9a21dc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 24 Jun 2024 13:46:44 +0000 Subject: [PATCH 4/6] notebook update --- .../docs/chat/nvidia_ai_endpoints.ipynb | 22 +++++++++++++++ .../text_embedding/nvidia_ai_endpoints.ipynb | 27 ++++++++++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb index d33970d2..e0eb5f55 100644 --- a/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb +++ b/libs/ai-endpoints/docs/chat/nvidia_ai_endpoints.ipynb @@ -137,6 +137,28 @@ "llm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\", model=\"meta/llama3-8b-instruct\")" ] }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7d4a4e2e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/raspawar/langchain-nvidia/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py:583: UserWarning: Default model is set as: meta/llama3-8b-instruct. \n", + "Set model using model parameter. \n", + "To get available models use available_models property.\n", + " UserWarning,\n" + ] + } + ], + "source": [ + "# OR connect to an embedding NIM running at localhost:8000, with default model(first available model)\n", + "llm = ChatNVIDIA(base_url=\"http://localhost:8000/v1\")" + ] + }, { "cell_type": "markdown", "id": "71d37987-d568-4a73-9d2a-8bd86323f8bf", diff --git a/libs/ai-endpoints/docs/text_embedding/nvidia_ai_endpoints.ipynb b/libs/ai-endpoints/docs/text_embedding/nvidia_ai_endpoints.ipynb index 233b21c0..8fdc43d6 100644 --- a/libs/ai-endpoints/docs/text_embedding/nvidia_ai_endpoints.ipynb +++ b/libs/ai-endpoints/docs/text_embedding/nvidia_ai_endpoints.ipynb @@ -143,14 +143,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings\n", "\n", "# connect to an embedding NIM running at localhost:8080\n", - "embedder = NVIDIAEmbeddings(base_url=\"http://localhost:8080/v1\")" + "embedder = NVIDIAEmbeddings(base_url=\"http://localhost:9080/v1\", model=\"NV-Embed-QA\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ubuntu/raspawar/langchain-nvidia/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py:579: UserWarning: Default model is set as: NV-Embed-QA. \n", + "Set model using model parameter. \n", + "To get available models use available_models property.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "# connect to an default embedding NIM running at localhost:8080\n", + "embedder = NVIDIAEmbeddings(base_url=\"http://localhost:9080/v1\")" ] }, { @@ -521,7 +542,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.12" } }, "nbformat": 4, From 566011956f030a63118c23a794cac26295e4fea2 Mon Sep 17 00:00:00 2001 From: raspawar Date: Tue, 25 Jun 2024 13:30:36 +0000 Subject: [PATCH 5/6] fix review changes --- .../langchain_nvidia_ai_endpoints/_common.py | 16 +++++------- .../tests/unit_tests/test_api_key.py | 23 ++++++++++++++--- .../tests/unit_tests/test_base_url.py | 25 ++++++++++++++++--- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 551445a6..177ae65f 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -189,9 +189,6 @@ def available_models(self) -> list[Model]: # add base model for local-nim mode model.base_model = element.get("root") - if model.base_model and model.id != model.base_model: - model.model_type = "lora" - self._available_models.append(model) return self._available_models @@ -608,14 +605,13 @@ def get_available_models( ) -> List[Model]: """Retrieve a list of available models.""" - # set client for lora models in local-nim mode - if not self.is_hosted: - for model in self.client.available_models: - if model.model_type == "lora": - model.client = filter - available = [ - model for model in self.client.available_models if model.client == filter + model + for model in self.client.available_models + if ( + model.client == filter + or (model.base_model and model.base_model != model.id) + ) ] # if we're talking to a hosted endpoint, we mix in the known models diff --git a/libs/ai-endpoints/tests/unit_tests/test_api_key.py b/libs/ai-endpoints/tests/unit_tests/test_api_key.py index e9b9c148..c79b0017 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_api_key.py +++ b/libs/ai-endpoints/tests/unit_tests/test_api_key.py @@ -3,8 +3,8 @@ from typing import Any, Generator import pytest -import requests from langchain_core.pydantic_v1 import SecretStr +from requests_mock import Mocker @contextmanager @@ -18,6 +18,24 @@ def no_env_var(var: str) -> Generator[None, None, None]: os.environ[var] = val +@pytest.fixture(autouse=True) +def mock_v1_local_models(requests_mock: Mocker) -> None: + requests_mock.get( + "https://test_url/v1/models", + json={ + "data": [ + { + "id": "model1", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "model1", + }, + ] + }, + ) + + def test_create_without_api_key(public_class: type) -> None: with no_env_var("NVIDIA_API_KEY"): with pytest.warns(UserWarning): @@ -26,8 +44,7 @@ def test_create_without_api_key(public_class: type) -> None: def test_create_unknown_url_no_api_key(public_class: type) -> None: with no_env_var("NVIDIA_API_KEY"): - with pytest.raises(requests.exceptions.ConnectionError): - public_class(base_url="https://test_url/v1") + public_class(base_url="https://test_url/v1") @pytest.mark.parametrize("param", ["nvidia_api_key", "api_key"]) diff --git a/libs/ai-endpoints/tests/unit_tests/test_base_url.py b/libs/ai-endpoints/tests/unit_tests/test_base_url.py index 508611ce..ee48ec50 100644 --- a/libs/ai-endpoints/tests/unit_tests/test_base_url.py +++ b/libs/ai-endpoints/tests/unit_tests/test_base_url.py @@ -1,5 +1,5 @@ import pytest -import requests +from requests_mock import Mocker @pytest.mark.parametrize( @@ -25,6 +25,24 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: assert client._client.is_hosted +@pytest.fixture(autouse=True) +def mock_v1_local_models(requests_mock: Mocker, base_url: str) -> None: + requests_mock.get( + f"{base_url}/models", + json={ + "data": [ + { + "id": "model1", + "object": "model", + "created": 1234567890, + "owned_by": "OWNER", + "root": "model1", + }, + ] + }, + ) + + @pytest.mark.parametrize( "base_url", [ @@ -34,6 +52,5 @@ def test_param_base_url_hosted(public_class: type, base_url: str) -> None: ], ) def test_param_base_url_not_hosted(public_class: type, base_url: str) -> None: - with pytest.raises(requests.exceptions.ConnectionError): - client = public_class(base_url=base_url) - assert not client._client.is_hosted + client = public_class(base_url=base_url) + assert not client._client.is_hosted From af78e07cb1948a0878b5e792fc9ed93c45e127d6 Mon Sep 17 00:00:00 2001 From: raspawar Date: Tue, 25 Jun 2024 14:43:07 +0000 Subject: [PATCH 6/6] fix for available_models filtering --- .../langchain_nvidia_ai_endpoints/_common.py | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index 177ae65f..f79b93f6 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -566,17 +566,29 @@ def _postprocess_args(cls, values: Any) -> Any: f"Model {name} is unknown, check `available_models`" ) else: + # set default model if not name: - # set default model - name = values.get("client").available_models[0].id - values["model"] = name - warnings.warn( - f"Default model is set as: {name}. \n" - "Set model using model parameter. \n" - "To get available models use available_models property.", - UserWarning, - ) - + if not (client := values.get("client")): + warnings.warn(f"Unable to determine validity of {name}") + else: + valid_models = [ + model.id + for model in client.available_models + if model.base_model and model.id == model.base_model + ] + name = next(iter(valid_models), None) + if name: + warnings.warn( + f"Default model is set as: {name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + UserWarning, + ) + values["model"] = name + else: + raise ValueError( + f"Model {name} is unknown, check `available_models`" + ) return values @classmethod @@ -605,23 +617,16 @@ def get_available_models( ) -> List[Model]: """Retrieve a list of available models.""" - available = [ - model - for model in self.client.available_models - if ( - model.client == filter - or (model.base_model and model.base_model != model.id) - ) - ] + available = self.client.available_models # if we're talking to a hosted endpoint, we mix in the known models # because they are not all discoverable by listing. for instance, # the NV-Embed-QA and VLM models are hosted on ai.api.nvidia.com # instead of integrate.api.nvidia.com. if self.is_hosted: - known = set( - model for model in MODEL_TABLE.values() if model.client == filter - ) - available = list(set(available) | known) + known = set(MODEL_TABLE.values()) + available = [ + model for model in set(available) | known if model.client == filter + ] return available