diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py index a520411c..3ba8d3c1 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py @@ -524,6 +524,11 @@ def _postprocess_args(cls, values: Any) -> Any: name = values.get("model") if model := determine_model(name): values["model"] = model.id + # not all models are on https://integrate.api.nvidia.com/v1, + # those that are not are served from their own endpoints + if model.endpoint: + # we override the infer_path to use the custom endpoint + values["client"].infer_path = model.endpoint else: if not (client := values.get("client")): warnings.warn(f"Unable to determine validity of {name}") diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py index 79d3a7ce..02233311 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/chat_models.py @@ -44,7 +44,7 @@ from langchain_core.tools import BaseTool from langchain_nvidia_ai_endpoints._common import _NVIDIAClient -from langchain_nvidia_ai_endpoints._statics import Model, determine_model +from langchain_nvidia_ai_endpoints._statics import Model _CallbackManager = Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] _DictOrPydanticClass = Union[Dict[str, Any], Type[BaseModel]] @@ -168,17 +168,11 @@ def __init__(self, **kwargs: Any): environment variable. """ super().__init__(**kwargs) - infer_path = "{base_url}/chat/completions" - # not all chat models are on https://integrate.api.nvidia.com/v1, - # those that are not are served from their own endpoints - if model := determine_model(self.model): - if model.endpoint: # some models have custom endpoints - infer_path = model.endpoint self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), - infer_path=infer_path, + infer_path="{base_url}/chat/completions", ) # todo: only store the model in one place # the model may be updated to a newer name during initialization diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py index 59d77447..767c5402 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py @@ -8,7 +8,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr, validator from langchain_nvidia_ai_endpoints._common import _NVIDIAClient -from langchain_nvidia_ai_endpoints._statics import Model, determine_model +from langchain_nvidia_ai_endpoints._statics import Model from langchain_nvidia_ai_endpoints.callbacks import usage_callback_var @@ -69,17 +69,11 @@ def __init__(self, **kwargs: Any): environment variable. """ super().__init__(**kwargs) - infer_path = "{base_url}/embeddings" - # not all embedding models are on https://integrate.api.nvidia.com/v1, - # those that are not are served from their own endpoints - if model := determine_model(self.model): - if model.endpoint: # some models have custom endpoints - infer_path = model.endpoint self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), - infer_path=infer_path, + infer_path="{base_url}/embeddings", ) # todo: only store the model in one place # the model may be updated to a newer name during initialization diff --git a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py index 7267a85d..bde7c844 100644 --- a/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py +++ b/libs/ai-endpoints/langchain_nvidia_ai_endpoints/reranking.py @@ -8,7 +8,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr from langchain_nvidia_ai_endpoints._common import _NVIDIAClient -from langchain_nvidia_ai_endpoints._statics import Model, determine_model +from langchain_nvidia_ai_endpoints._statics import Model class Ranking(BaseModel): @@ -62,17 +62,11 @@ def __init__(self, **kwargs: Any): environment variable. """ super().__init__(**kwargs) - infer_path = "{base_url}/ranking" - # not all models are on https://integrate.api.nvidia.com/v1, - # those that are not are served from their own endpoints - if model := determine_model(self.model): - if model.endpoint: # some models have custom endpoints - infer_path = model.endpoint self._client = _NVIDIAClient( base_url=self.base_url, model=self.model, api_key=kwargs.get("nvidia_api_key", kwargs.get("api_key", None)), - infer_path=infer_path, + infer_path="{base_url}/ranking", ) # todo: only store the model in one place # the model may be updated to a newer name during initialization