diff --git a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py index f29a675388770..0f793f2828d6b 100644 --- a/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py +++ b/libs/partners/huggingface/langchain_huggingface/llms/huggingface_endpoint.py @@ -1,5 +1,6 @@ import json # type: ignore[import-not-found] import logging +import os from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import ( @@ -66,9 +67,10 @@ class HuggingFaceEndpoint(LLM): """ # noqa: E501 endpoint_url: Optional[str] = None - """Endpoint URL to use.""" + """Endpoint URL to use. If repo_id is not specified then this needs to given or + should be pass as env variable in `HF_INFERENCE_ENDPOINT`""" repo_id: Optional[str] = None - """Repo to use.""" + """Repo to use. If endpoint_url is not specified then this needs to given""" huggingfacehub_api_token: Optional[str] = None max_new_tokens: int = 512 """Maximum number of generated tokens""" @@ -146,19 +148,38 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra - values["endpoint_url"] = get_from_dict_or_env( - values, "endpoint_url", "HF_INFERENCE_ENDPOINT", None - ) - - if values["endpoint_url"] is None and "repo_id" not in values: + # to correctly create the InferenceClient and AsyncInferenceClient + # in validate_environment, we need to populate values["model"]. + # from InferenceClient docstring: + # model (`str`, `optional`): + # The model to run inference with. Can be a model id hosted on the Hugging + # Face Hub, e.g. `bigcode/starcoder` + # or a URL to a deployed Inference Endpoint. Defaults to None, in which + # case a recommended model is + # automatically selected for the task. + + # this string could be in 3 places of descending priority: + # 2. values["model"] or values["endpoint_url"] or values["repo_id"] + # (equal priority - don't allow both set) + # 3. values["HF_INFERENCE_ENDPOINT"] (if none above set) + + model = values.get("model") + endpoint_url = values.get("endpoint_url") + repo_id = values.get("repo_id") + + if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1: raise ValueError( - "Please specify an `endpoint_url` or `repo_id` for the model." + "Please specify either a `model` OR an `endpoint_url` OR a `repo_id`," + "not more than one." ) - if values["endpoint_url"] is not None and "repo_id" in values: + values["model"] = ( + model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT") + ) + if not values["model"]: raise ValueError( - "Please specify either an `endpoint_url` OR a `repo_id`, not both." + "Please specify a `model` or an `endpoint_url` or a `repo_id` for the " + "model." ) - values["model"] = values.get("endpoint_url") or values.get("repo_id") return values @root_validator(pre=False, skip_on_failure=True)