Skip to content

Commit

Permalink
huggingface: fix model param population (#24743)
Browse files Browse the repository at this point in the history
- **Description:** Fix the validation error for `endpoint_url` for
HuggingFaceEndpoint. I have given a descriptive detail of the isse in
the issue that I have created.
- **Issue:** #24742

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
keenborder786 and efriis authored Aug 24, 2024
1 parent c7a8af2 commit 9a29398
Showing 1 changed file with 32 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json # type: ignore[import-not-found]
import logging
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional

from langchain_core.callbacks import (
Expand Down Expand Up @@ -66,9 +67,10 @@ class HuggingFaceEndpoint(LLM):
""" # noqa: E501

endpoint_url: Optional[str] = None
"""Endpoint URL to use."""
"""Endpoint URL to use. If repo_id is not specified then this needs to given or
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None
"""Repo to use."""
"""Repo to use. If endpoint_url is not specified then this needs to given"""
huggingfacehub_api_token: Optional[str] = None
max_new_tokens: int = 512
"""Maximum number of generated tokens"""
Expand Down Expand Up @@ -146,19 +148,38 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:

values["model_kwargs"] = extra

values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "HF_INFERENCE_ENDPOINT", None
)

if values["endpoint_url"] is None and "repo_id" not in values:
# to correctly create the InferenceClient and AsyncInferenceClient
# in validate_environment, we need to populate values["model"].
# from InferenceClient docstring:
# model (`str`, `optional`):
# The model to run inference with. Can be a model id hosted on the Hugging
# Face Hub, e.g. `bigcode/starcoder`
# or a URL to a deployed Inference Endpoint. Defaults to None, in which
# case a recommended model is
# automatically selected for the task.

# this string could be in 3 places of descending priority:
# 2. values["model"] or values["endpoint_url"] or values["repo_id"]
# (equal priority - don't allow both set)
# 3. values["HF_INFERENCE_ENDPOINT"] (if none above set)

model = values.get("model")
endpoint_url = values.get("endpoint_url")
repo_id = values.get("repo_id")

if sum([bool(model), bool(endpoint_url), bool(repo_id)]) > 1:
raise ValueError(
"Please specify an `endpoint_url` or `repo_id` for the model."
"Please specify either a `model` OR an `endpoint_url` OR a `repo_id`,"
"not more than one."
)
if values["endpoint_url"] is not None and "repo_id" in values:
values["model"] = (
model or endpoint_url or repo_id or os.environ.get("HF_INFERENCE_ENDPOINT")
)
if not values["model"]:
raise ValueError(
"Please specify either an `endpoint_url` OR a `repo_id`, not both."
"Please specify a `model` or an `endpoint_url` or a `repo_id` for the "
"model."
)
values["model"] = values.get("endpoint_url") or values.get("repo_id")
return values

@root_validator(pre=False, skip_on_failure=True)
Expand Down

0 comments on commit 9a29398

Please sign in to comment.