Skip to content

Commit

Permalink
Updated Hugging Face chat and magics processing with new APIs, clients (
Browse files Browse the repository at this point in the history
#784)

* Updated HF chat processing

(1) The API has changed and uses the HuggingFaceClient class instead of HuggingFaceHub, which is deprecated.
(2) InferenceClient replaces InferenceAPI
(3) Removed legacy code that does not work with the new APIs/

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Handle text gen and text_to_image tasks

Added logic to branch to one of text-gen or text-to-image tasks based on the type of response received.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reworking conditional branching for text vs image

Used a different approach to check for task type

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
srdas and pre-commit-ci[bot] authored May 16, 2024
1 parent e0eaeaa commit 20875ad
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 64 deletions.
127 changes: 65 additions & 62 deletions packages/jupyter-ai-magics/jupyter_ai_magics/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
Bedrock,
Cohere,
GPT4All,
HuggingFaceHub,
HuggingFaceEndpoint,
OpenAI,
SagemakerEndpoint,
Together,
Expand Down Expand Up @@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs):
),
"text": PromptTemplate.from_template("{prompt}"), # No customization
}

super().__init__(*args, **kwargs, **model_kwargs)

async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
Expand Down Expand Up @@ -582,14 +581,10 @@ def allows_concurrency(self):
return False


HUGGINGFACE_HUB_VALID_TASKS = (
"text2text-generation",
"text-generation",
"text-to-image",
)


class HfHubProvider(BaseProvider, HuggingFaceHub):
# References for using HuggingFaceEndpoint and InferenceClient:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
class HfHubProvider(BaseProvider, HuggingFaceEndpoint):
id = "huggingface_hub"
name = "Hugging Face Hub"
models = ["*"]
Expand All @@ -609,33 +604,35 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
try:
from huggingface_hub.inference_api import InferenceApi
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e
try:
from huggingface_hub import InferenceClient

repo_id = values["repo_id"]
client = InferenceApi(
repo_id=repo_id,
values["client"] = InferenceClient(
model=values["model"],
timeout=values["timeout"],
token=huggingfacehub_api_token,
task=values.get("task"),
**values["server_kwargs"],
)
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
raise ValueError(
f"Got invalid task {client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
)
values["client"] = client
except ImportError:
raise ValueError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)
return values

# Handle image outputs
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
# Handle text and image outputs
def _call(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
"""Call out to Hugging Face Hub's inference endpoint.
Args:
Expand All @@ -650,45 +647,51 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)

if type(response) is dict and "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")

# Custom code for responding to image generation responses
if self.client.task == "text-to-image":
imageFormat = response.format # Presume it's a PIL ImageFile
mimeType = ""
if imageFormat == "JPEG":
mimeType = "image/jpeg"
elif imageFormat == "PNG":
mimeType = "image/png"
elif imageFormat == "GIF":
mimeType = "image/gif"
invocation_params = self._invocation_params(stop, **kwargs)
invocation_params["stop"] = invocation_params[
"stop_sequences"
] # porting 'stop_sequences' into the 'stop' argument
response = self.client.post(
json={"inputs": prompt, "parameters": invocation_params},
stream=False,
task=self.task,
)

try:
if "generated_text" in str(response):
# text2 text or text-generation task
response_text = json.loads(response.decode())[0]["generated_text"]
# Maybe the generation has stopped at one of the stop sequences:
# then we remove this stop sequence from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if response_text[-len(stop_seq) :] == stop_seq:
response_text = response_text[: -len(stop_seq)]
return response_text
else:
raise ValueError(f"Unrecognized image format {imageFormat}")

buffer = io.BytesIO()
response.save(buffer, format=imageFormat)
# Encode image data to Base64 bytes, then decode bytes to str
return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()

if self.client.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.client.task == "text2text-generation":
text = response[0]["generated_text"]
else:
# text-to-image task
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
# Custom code for responding to image generation responses
image = self.client.text_to_image(prompt)
imageFormat = image.format # Presume it's a PIL ImageFile
mimeType = ""
if imageFormat == "JPEG":
mimeType = "image/jpeg"
elif imageFormat == "PNG":
mimeType = "image/png"
elif imageFormat == "GIF":
mimeType = "image/gif"
else:
raise ValueError(f"Unrecognized image format {imageFormat}")
buffer = io.BytesIO()
image.save(buffer, format=imageFormat)
# # Encode image data to Base64 bytes, then decode bytes to str
return (
mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
)
except:
raise ValueError(
f"Got invalid task {self.client.task}, "
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
"Task not supported, only text-generation and text-to-image tasks are valid."
)
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text

async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
return await self._call_in_executor(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import PromptTemplate
from langchain_core.prompts import PromptTemplate

from .base import BaseChatHandler, SlashCommandRoutingType

Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from langchain.chains import LLMChain
from langchain.llms import BaseLLM
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import BaseOutputParser
from langchain_core.prompts import PromptTemplate


class OutlineSection(BaseModel):
Expand Down

0 comments on commit 20875ad

Please sign in to comment.