diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py index 9c08f461e..6da92e1c8 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/providers.py @@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs): ), "text": PromptTemplate.from_template("{prompt}"), # No customization } - super().__init__(*args, **kwargs, **model_kwargs) async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]: @@ -583,7 +582,7 @@ def allows_concurrency(self): # References for using HuggingFaceEndpoint and InferenceClient: -# https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client +# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient # https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py class HfHubProvider(BaseProvider, HuggingFaceEndpoint): id = "huggingface_hub" @@ -630,10 +629,8 @@ def validate_environment(cls, values: Dict) -> Dict: ) return values - # Handle image outputs - def _call( - self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any - ) -> str: + # Handle text and image outputs + def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any) -> str: """Call out to Hugging Face Hub's inference endpoint. Args: @@ -657,21 +654,37 @@ def _call( stream=False, task=self.task, ) - try: + + try: # check if this is a text-generation task response_text = json.loads(response.decode())[0]["generated_text"] - except KeyError: - response_text = json.loads(response.decode())["generated_text"] - - # Maybe the generation has stopped at one of the stop sequences: - # then we remove this stop sequence from the end of the generated text - for stop_seq in invocation_params["stop_sequences"]: - if response_text[-len(stop_seq) :] == stop_seq: - response_text = response_text[: -len(stop_seq)] - - if type(response) is dict and "error" in response: - raise ValueError(f"Error raised by inference API: {response['error']}") + # Maybe the generation has stopped at one of the stop sequences: + # then we remove this stop sequence from the end of the generated text + for stop_seq in invocation_params["stop_sequences"]: + if response_text[-len(stop_seq) :] == stop_seq: + response_text = response_text[: -len(stop_seq)] + return response_text + except: # if fails, then try to process as a text-to-image task + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example + # Custom code for responding to image generation responses + if type(response)==bytes: # Is this an image + image = self.client.text_to_image(prompt) + imageFormat = image.format # Presume it's a PIL ImageFile + mimeType = "" + if imageFormat == "JPEG": + mimeType = "image/jpeg" + elif imageFormat == "PNG": + mimeType = "image/png" + elif imageFormat == "GIF": + mimeType = "image/gif" + else: + raise ValueError(f"Unrecognized image format {imageFormat}") + buffer = io.BytesIO() + image.save(buffer, format=imageFormat) + # # Encode image data to Base64 bytes, then decode bytes to str + return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode() + else: + raise ValueError("Task not supported, only text-generation and text-to-image tasks are valid.") - return response_text async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]: return await self._call_in_executor(*args, **kwargs)