Skip to content

Commit

Permalink
huggingface[patch]: make HuggingFaceEndpoint serializable (#27027)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 1, 2024
1 parent 9d10151 commit b5e28d3
Show file tree
Hide file tree
Showing 3 changed files with 829 additions and 791 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional

from huggingface_hub import ( # type: ignore[import-untyped]
AsyncInferenceClient,
InferenceClient,
login,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import from_env, get_pydantic_field_names
from pydantic import ConfigDict, Field, model_validator
from langchain_core.utils import get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from typing_extensions import Self

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,10 +78,12 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given"""
huggingfacehub_api_token: Optional[str] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
huggingfacehub_api_token: Optional[SecretStr] = Field(
default_factory=secret_from_env(
["HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN"], default=None
)
)
max_new_tokens: int = 512
max_new_tokens: int = Field(default=512, alias="max_tokens")
"""Maximum number of generated tokens"""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for
Expand Down Expand Up @@ -116,14 +123,15 @@ class HuggingFaceEndpoint(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

@model_validator(mode="before")
Expand Down Expand Up @@ -189,36 +197,23 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid."""
try:
from huggingface_hub import login # type: ignore[import]

except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)

huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)

if huggingfacehub_api_token is not None:
if self.huggingfacehub_api_token is not None:
try:
login(token=huggingfacehub_api_token)
login(token=self.huggingfacehub_api_token.get_secret_value())
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e

from huggingface_hub import AsyncInferenceClient, InferenceClient

# Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{
key: value
for key, value in self.server_kwargs.items()
Expand All @@ -230,7 +225,9 @@ def validate_environment(self) -> Self:
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
token=huggingfacehub_api_token,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
**{
key: value
for key, value in self.server_kwargs.items()
Expand Down Expand Up @@ -426,3 +423,15 @@ async def _astream(
# break if stop sequence found
if stop_seq_found:
break

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@classmethod
def get_lc_namespace(cls) -> list[str]:
return ["langchain_huggingface", "llms"]

@property
def lc_secrets(self) -> dict[str, str]:
return {"huggingfacehub_api_token": "HUGGINGFACEHUB_API_TOKEN"}
Loading

0 comments on commit b5e28d3

Please sign in to comment.