diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 3e911e4f4..1f54defa9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -5,7 +5,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm -from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation +from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model from .truncate import EmbeddingTruncateMode @@ -98,7 +98,7 @@ def __init__( def default_model(self): """Set default model in local NIM mode.""" valid_models = [ - model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id + model.id for model in self.available_models if not model.base_model or model.base_model == model.id ] name = next(iter(valid_models), None) if name: @@ -129,12 +129,15 @@ def warm_up(self): api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + client=self.__class__.__name__, + model_type="embedding", ) self._initialized = True if not self.model: self.default_model() + validate_hosted_model(self.__class__.__name__, self.model, self) def to_dict(self) -> Dict[str, Any]: """ @@ -157,6 +160,13 @@ def to_dict(self) -> Dict[str, Any]: truncate=str(self.truncate) if self.truncate is not None else None, ) + @property + def available_models(self) -> List[Model]: + """ + Get a list of available models that work with ChatNVIDIA. + """ + return self.backend.models() if self.backend else [] + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": """ diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 0387c32b7..7316f42a9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -4,7 +4,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation +from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model from .truncate import EmbeddingTruncateMode @@ -82,7 +82,7 @@ def __init__( def default_model(self): """Set default model in local NIM mode.""" valid_models = [ - model.id for model in self.backend.models() if not model.base_model or model.base_model == model.id + model.id for model in self.available_models if not model.base_model or model.base_model == model.id ] name = next(iter(valid_models), None) if name: @@ -113,12 +113,15 @@ def warm_up(self): api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + client=self.__class__.__name__, + model_type="embedding", ) self._initialized = True if not self.model: self.default_model() + validate_hosted_model(self.__class__.__name__, self.model, self) def to_dict(self) -> Dict[str, Any]: """ @@ -137,6 +140,13 @@ def to_dict(self) -> Dict[str, Any]: truncate=str(self.truncate) if self.truncate is not None else None, ) + @property + def available_models(self) -> List[Model]: + """ + Get a list of available models that work with ChatNVIDIA. + """ + return self.backend.models() if self.backend else [] + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "NvidiaTextEmbedder": """ diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 3eadcc5df..d67b485d9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -7,7 +7,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace -from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation +from haystack_integrations.utils.nvidia import Model, NimBackend, is_hosted, url_validation, validate_hosted_model _DEFAULT_API_URL = "https://integrate.api.nvidia.com/v1" @@ -82,7 +82,7 @@ def __init__( def default_model(self): """Set default model in local NIM mode.""" valid_models = [ - model.id for model in self._backend.models() if not model.base_model or model.base_model == model.id + model.id for model in self.available_models if not model.base_model or model.base_model == model.id ] name = next(iter(valid_models), None) if name: @@ -113,10 +113,13 @@ def warm_up(self): api_url=self._api_url, api_key=self._api_key, model_kwargs=self._model_arguments, + client=self.__class__.__name__, + model_type="chat", ) if not self.is_hosted and not self._model: self.default_model() + validate_hosted_model(self.__class__.__name__, self._model, self) def to_dict(self) -> Dict[str, Any]: """ @@ -133,6 +136,13 @@ def to_dict(self) -> Dict[str, Any]: model_arguments=self._model_arguments, ) + @property + def available_models(self) -> List[Model]: + """ + Get a list of available models that work with ChatNVIDIA. + """ + return self._backend.models() if self._backend else [] + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator": """ diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index da301d29d..11f70dba5 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,4 +1,5 @@ -from .nim_backend import Model, NimBackend -from .utils import is_hosted, url_validation +from .nim_backend import NimBackend +from .statics import Model +from .utils import determine_model, is_hosted, url_validation, validate_hosted_model -__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"] +__all__ = ["NimBackend", "Model", "is_hosted", "url_validation", "validate_hosted_model", "determine_model"] diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index 0d1f57e5c..0387b47f0 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -1,27 +1,12 @@ -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple import requests from haystack import Document from haystack.utils import Secret -REQUEST_TIMEOUT = 60 - - -@dataclass -class Model: - """ - Model information. +from .statics import Model - id: unique identifier for the model, passed as model parameter for requests - aliases: list of aliases for the model - base_model: root model for the model - All aliases are deprecated and will trigger a warning when used. - """ - - id: str - aliases: Optional[List[str]] = field(default_factory=list) - base_model: Optional[str] = None +REQUEST_TIMEOUT = 60 class NimBackend: @@ -31,6 +16,8 @@ def __init__( api_url: str, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_kwargs: Optional[Dict[str, Any]] = None, + client: Optional[Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder"]] = None, + model_type: Optional[Literal["chat", "embedding"]] = None, ): headers = { "Content-Type": "application/json", @@ -46,6 +33,8 @@ def __init__( self.model = model self.api_url = api_url self.model_kwargs = model_kwargs or {} + self.client = client + self.model_type = model_type def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: url = f"{self.api_url}/embeddings" @@ -125,7 +114,11 @@ def models(self) -> List[Model]: res.raise_for_status() data = res.json()["data"] - models = [Model(element["id"]) for element in data if "id" in element] + models = [ + Model(id=element["id"], client=self.client, model_type=self.model_type, base_model=element.get("root")) + for element in data + if "id" in element + ] if not models: msg = f"No hosted model were found at URL '{url}'." raise ValueError(msg) diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/statics.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/statics.py new file mode 100644 index 000000000..78b468e9a --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/statics.py @@ -0,0 +1,443 @@ +from dataclasses import dataclass +from typing import Literal, Optional + +# from pydantic import field_validator + + +@dataclass +class Model: + """ + Model information. + + id: unique identifier for the model, passed as model parameter for requests + model_type: API type (chat, vlm, embedding, ranking, completions) + client: client name, e.g. NvidiaGenerator, NvidiaTextEmbedder, NVIDIARerank, NVIDIA + endpoint: custom endpoint for the model + aliases: list of aliases for the model + + All aliases are deprecated and will trigger a warning when used. + """ + + id: str + model_type: Optional[Literal["chat", "embedding"]] = None + client: Optional[Literal["NvidiaGenerator", "NvidiaTextEmbedder", "NvidiaDocumentEmbedder"]] = None + endpoint: Optional[str] = None + aliases: Optional[list] = None + base_model: Optional[str] = None + supports_tools: Optional[bool] = False + supports_structured_output: Optional[bool] = False + + def __hash__(self) -> int: + return hash(self.id) + + # @field_validator("client") + # def validate_client(self, client: str, values: dict) -> str: + # if client: + # supported = { + # "NvidiaGenerator": ("chat"), + # "NvidiaTextEmbedder": ("embedding",), + # "NvidiaDocumentEmbedder": ("embedding",), + # } + # model_type = values.get("model_type") + # if model_type not in supported[client]: + # err_msg = f"Model type '{model_type}' not supported by client '{client}'" + # raise ValueError(err_msg) + # return client + + +CHAT_MODEL_TABLE = { + "meta/codellama-70b": Model( + id="meta/codellama-70b", + model_type="chat", + client="NvidiaGenerator", + aliases=[ + "ai-codellama-70b", + "playground_llama2_code_70b", + "llama2_code_70b", + "playground_llama2_code_34b", + "llama2_code_34b", + "playground_llama2_code_13b", + "llama2_code_13b", + ], + ), + "google/gemma-7b": Model( + id="google/gemma-7b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-gemma-7b", "playground_gemma_7b", "gemma_7b"], + ), + "meta/llama2-70b": Model( + id="meta/llama2-70b", + model_type="chat", + client="NvidiaGenerator", + aliases=[ + "ai-llama2-70b", + "playground_llama2_70b", + "llama2_70b", + "playground_llama2_13b", + "llama2_13b", + ], + ), + "mistralai/mistral-7b-instruct-v0.2": Model( + id="mistralai/mistral-7b-instruct-v0.2", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-mistral-7b-instruct-v2", "playground_mistral_7b", "mistral_7b"], + ), + "mistralai/mixtral-8x7b-instruct-v0.1": Model( + id="mistralai/mixtral-8x7b-instruct-v0.1", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-mixtral-8x7b-instruct", "playground_mixtral_8x7b", "mixtral_8x7b"], + ), + "google/codegemma-7b": Model( + id="google/codegemma-7b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-codegemma-7b"], + ), + "google/gemma-2b": Model( + id="google/gemma-2b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-gemma-2b", "playground_gemma_2b", "gemma_2b"], + ), + "google/recurrentgemma-2b": Model( + id="google/recurrentgemma-2b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-recurrentgemma-2b"], + ), + "mistralai/mistral-large": Model( + id="mistralai/mistral-large", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-mistral-large"], + ), + "mistralai/mixtral-8x22b-instruct-v0.1": Model( + id="mistralai/mixtral-8x22b-instruct-v0.1", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-mixtral-8x22b-instruct"], + ), + "meta/llama3-8b-instruct": Model( + id="meta/llama3-8b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-llama3-8b"], + ), + "meta/llama3-70b-instruct": Model( + id="meta/llama3-70b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-llama3-70b"], + ), + "microsoft/phi-3-mini-128k-instruct": Model( + id="microsoft/phi-3-mini-128k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-mini"], + ), + "snowflake/arctic": Model( + id="snowflake/arctic", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-arctic"], + ), + "databricks/dbrx-instruct": Model( + id="databricks/dbrx-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-dbrx-instruct"], + ), + "microsoft/phi-3-mini-4k-instruct": Model( + id="microsoft/phi-3-mini-4k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-mini-4k", "playground_phi2", "phi2"], + ), + "seallms/seallm-7b-v2.5": Model( + id="seallms/seallm-7b-v2.5", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-seallm-7b"], + ), + "aisingapore/sea-lion-7b-instruct": Model( + id="aisingapore/sea-lion-7b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-sea-lion-7b-instruct"], + ), + "microsoft/phi-3-small-8k-instruct": Model( + id="microsoft/phi-3-small-8k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-small-8k-instruct"], + ), + "microsoft/phi-3-small-128k-instruct": Model( + id="microsoft/phi-3-small-128k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-small-128k-instruct"], + ), + "microsoft/phi-3-medium-4k-instruct": Model( + id="microsoft/phi-3-medium-4k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-medium-4k-instruct"], + ), + "ibm/granite-8b-code-instruct": Model( + id="ibm/granite-8b-code-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-granite-8b-code-instruct"], + ), + "ibm/granite-34b-code-instruct": Model( + id="ibm/granite-34b-code-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-granite-34b-code-instruct"], + ), + "google/codegemma-1.1-7b": Model( + id="google/codegemma-1.1-7b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-codegemma-1.1-7b"], + ), + "mediatek/breeze-7b-instruct": Model( + id="mediatek/breeze-7b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-breeze-7b-instruct"], + ), + "upstage/solar-10.7b-instruct": Model( + id="upstage/solar-10.7b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-solar-10_7b-instruct"], + ), + "writer/palmyra-med-70b-32k": Model( + id="writer/palmyra-med-70b-32k", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-palmyra-med-70b-32k"], + ), + "writer/palmyra-med-70b": Model( + id="writer/palmyra-med-70b", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-palmyra-med-70b"], + ), + "mistralai/mistral-7b-instruct-v0.3": Model( + id="mistralai/mistral-7b-instruct-v0.3", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-mistral-7b-instruct-v03"], + ), + "01-ai/yi-large": Model( + id="01-ai/yi-large", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-yi-large"], + ), + "nvidia/nemotron-4-340b-instruct": Model( + id="nvidia/nemotron-4-340b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["qa-nemotron-4-340b-instruct"], + ), + "mistralai/codestral-22b-instruct-v0.1": Model( + id="mistralai/codestral-22b-instruct-v0.1", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-codestral-22b-instruct-v01"], + supports_structured_output=True, + ), + "google/gemma-2-9b-it": Model( + id="google/gemma-2-9b-it", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-gemma-2-9b-it"], + ), + "google/gemma-2-27b-it": Model( + id="google/gemma-2-27b-it", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-gemma-2-27b-it"], + ), + "microsoft/phi-3-medium-128k-instruct": Model( + id="microsoft/phi-3-medium-128k-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-phi-3-medium-128k-instruct"], + ), + "deepseek-ai/deepseek-coder-6.7b-instruct": Model( + id="deepseek-ai/deepseek-coder-6.7b-instruct", + model_type="chat", + client="NvidiaGenerator", + aliases=["ai-deepseek-coder-6_7b-instruct"], + ), + "nv-mistralai/mistral-nemo-12b-instruct": Model( + id="nv-mistralai/mistral-nemo-12b-instruct", + model_type="chat", + client="NvidiaGenerator", + supports_tools=True, + supports_structured_output=True, + ), + "meta/llama-3.1-8b-instruct": Model( + id="meta/llama-3.1-8b-instruct", + model_type="chat", + client="NvidiaGenerator", + supports_tools=True, + supports_structured_output=True, + ), + "meta/llama-3.1-70b-instruct": Model( + id="meta/llama-3.1-70b-instruct", + model_type="chat", + client="NvidiaGenerator", + supports_tools=True, + supports_structured_output=True, + ), + "meta/llama-3.1-405b-instruct": Model( + id="meta/llama-3.1-405b-instruct", + model_type="chat", + client="NvidiaGenerator", + supports_tools=True, + supports_structured_output=True, + ), + "nvidia/usdcode-llama3-70b-instruct": Model( + id="nvidia/usdcode-llama3-70b-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "mistralai/mamba-codestral-7b-v0.1": Model( + id="mistralai/mamba-codestral-7b-v0.1", + model_type="chat", + client="NvidiaGenerator", + ), + "writer/palmyra-fin-70b-32k": Model( + id="writer/palmyra-fin-70b-32k", + model_type="chat", + client="NvidiaGenerator", + supports_structured_output=True, + ), + "google/gemma-2-2b-it": Model( + id="google/gemma-2-2b-it", + model_type="chat", + client="NvidiaGenerator", + ), + "mistralai/mistral-large-2-instruct": Model( + id="mistralai/mistral-large-2-instruct", + model_type="chat", + client="NvidiaGenerator", + supports_tools=True, + supports_structured_output=True, + ), + "mistralai/mathstral-7b-v0.1": Model( + id="mistralai/mathstral-7b-v0.1", + model_type="chat", + client="NvidiaGenerator", + ), + "rakuten/rakutenai-7b-instruct": Model( + id="rakuten/rakutenai-7b-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "rakuten/rakutenai-7b-chat": Model( + id="rakuten/rakutenai-7b-chat", + model_type="chat", + client="NvidiaGenerator", + ), + "baichuan-inc/baichuan2-13b-chat": Model( + id="baichuan-inc/baichuan2-13b-chat", + model_type="chat", + client="NvidiaGenerator", + ), + "thudm/chatglm3-6b": Model( + id="thudm/chatglm3-6b", + model_type="chat", + client="NvidiaGenerator", + ), + "microsoft/phi-3.5-mini-instruct": Model( + id="microsoft/phi-3.5-mini-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "microsoft/phi-3.5-moe-instruct": Model( + id="microsoft/phi-3.5-moe-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "nvidia/nemotron-mini-4b-instruct": Model( + id="nvidia/nemotron-mini-4b-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "ai21labs/jamba-1.5-large-instruct": Model( + id="ai21labs/jamba-1.5-large-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "ai21labs/jamba-1.5-mini-instruct": Model( + id="ai21labs/jamba-1.5-mini-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "yentinglin/llama-3-taiwan-70b-instruct": Model( + id="yentinglin/llama-3-taiwan-70b-instruct", + model_type="chat", + client="NvidiaGenerator", + ), + "tokyotech-llm/llama-3-swallow-70b-instruct-v0.1": Model( + id="tokyotech-llm/llama-3-swallow-70b-instruct-v0.1", + model_type="chat", + client="NvidiaGenerator", + ), +} + +EMBEDDING_MODEL_TABLE = { + "snowflake/arctic-embed-l": Model( + id="snowflake/arctic-embed-l", + model_type="embedding", + client="NvidiaTextEmbedder", + aliases=["ai-arctic-embed-l"], + ), + "NV-Embed-QA": Model( + id="NV-Embed-QA", + model_type="embedding", + client="NvidiaTextEmbedder", + endpoint="https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings", + aliases=[ + "ai-embed-qa-4", + "playground_nvolveqa_40k", + "nvolveqa_40k", + ], + ), + "nvidia/nv-embed-v1": Model( + id="nvidia/nv-embed-v1", + model_type="embedding", + client="NvidiaTextEmbedder", + aliases=["ai-nv-embed-v1"], + ), + "nvidia/nv-embedqa-mistral-7b-v2": Model( + id="nvidia/nv-embedqa-mistral-7b-v2", + model_type="embedding", + client="NvidiaTextEmbedder", + ), + "nvidia/nv-embedqa-e5-v5": Model( + id="nvidia/nv-embedqa-e5-v5", + model_type="embedding", + client="NvidiaTextEmbedder", + ), + "baai/bge-m3": Model( + id="baai/bge-m3", + model_type="embedding", + client="NvidiaTextEmbedder", + ), +} + + +MODEL_TABLE = { + **CHAT_MODEL_TABLE, + **EMBEDDING_MODEL_TABLE, +} diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 7d4dfc3b4..ed63824e4 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -1,7 +1,9 @@ import warnings -from typing import List +from typing import List, Optional from urllib.parse import urlparse, urlunparse +from .statics import MODEL_TABLE, Model + def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: """ @@ -45,3 +47,75 @@ def is_hosted(api_url: str): "integrate.api.nvidia.com", "ai.api.nvidia.com", ] + + +def lookup_model(name: str) -> Optional[Model]: + """ + Lookup a model by name, using only the table of known models. + The name is either: + - directly in the table + - an alias in the table + - not found (None) + Callers can check to see if the name was an alias by + comparing the result's id field to the name they provided. + """ + model = None + if not (model := MODEL_TABLE.get(name)): + for mdl in MODEL_TABLE.values(): + if mdl.aliases and name in mdl.aliases: + model = mdl + break + return model + + +def determine_model(name: str) -> Optional[Model]: + """ + Determine the model to use based on a name, using + only the table of known models. + + Raise a warning if the model is found to be + an alias of a known model. + + If the model is not found, return None. + """ + if model := lookup_model(name): + # all aliases are deprecated + if model.id != name: + warn_msg = f"Model {name} is deprecated. Using {model.id} instead." + warnings.warn(warn_msg, UserWarning, stacklevel=1) + return model + + +def validate_hosted_model(class_name: str, model_name: str, client) -> None: + """ + Validates compatibility of the hosted model with the client. + + Args: + model_name (str): The name of the model. + + Raises: + ValueError: If the model is incompatible with the client. + """ + if model := determine_model(model_name): + if not model.client: + warn_msg = f"Unable to determine validity of {model.id}" + warnings.warn(warn_msg, stacklevel=1) + elif model.model_type == "embedding" and class_name in ["NvidiaTextEmbedder", "NvidiaDocumentEmbedder"]: + pass + elif model.client != class_name: + err_msg = f"Model {model.id} is incompatible with client {class_name}. \ + Please check `{class_name}.available_models`." + raise ValueError(err_msg) + else: + candidates = [model for model in client.available_models if model.id == model_name] + assert len(candidates) <= 1, f"Multiple candidates for {model_name} in `available_models`: {candidates}" + if candidates: + model = candidates[0] + warn_msg = f"Found {model_name} in available_models, but type is unknown and inference may fail." + warnings.warn(warn_msg, stacklevel=1) + else: + err_msg = f"Model {model_name} is unknown, check `available_models`" + raise ValueError(err_msg) + # else: + # if model_name not in [model.id for model in client.available_models]: + # raise ValueError(f"No locally hosted {model_name} was found.") diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index bef0f996e..d8e2a4f64 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -221,7 +221,7 @@ def test_run_default_model(self): with pytest.warns(UserWarning) as record: embedder.warm_up() - assert len(record) == 1 + assert len(record) == 2 assert "Default model is set as:" in str(record[0].message) assert embedder.model == "model1" diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 0bd8b1fc6..2b8ff7f06 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -162,7 +162,7 @@ def test_run_integration_with_default_model_nim_backend(self): ) with pytest.warns(UserWarning) as record: generator.warm_up() - assert len(record) == 1 + assert len(record) == 2 assert "Default model is set as:" in str(record[0].message) assert generator._model == "model1" assert not generator.is_hosted @@ -192,10 +192,11 @@ def test_run_integration_with_api_catalog(self): assert result["replies"] assert result["meta"] + @pytest.mark.usefixtures("mock_local_models") def test_local_nim_without_key(self) -> None: generator = NvidiaGenerator( - model="BOGUS", - api_url="http://localhost:8000", + model="model1", + api_url="http://localhost:8080", api_key=None, ) generator.warm_up() diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 7c8428cc2..08cdfebe4 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -107,7 +107,7 @@ def test_run_default_model(self): with pytest.warns(UserWarning) as record: embedder.warm_up() - assert len(record) == 1 + assert len(record) == 2 assert "Default model is set as:" in str(record[0].message) assert embedder.model == "model1"