From e215185b1cd9d20340a1fe1f56f3c1ba47fb6d74 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:14:49 +0100 Subject: [PATCH] Update Google Vertex integration (#102) * Pin google-cloud-aiplatform version * Make generation kwargs explicit and implement serialization * Implement GeminiChatGenerator * Fix linting and remove dead code * Bump version --- integrations/google-vertex/pyproject.toml | 2 +- .../src/google_vertex_haystack/__about__.py | 2 +- .../generators/chat/gemini.py | 179 ++++++++++++++++++ .../generators/gemini.py | 108 ++++++++--- 4 files changed, 261 insertions(+), 30 deletions(-) create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py diff --git a/integrations/google-vertex/pyproject.toml b/integrations/google-vertex/pyproject.toml index 2455b4fa9..a34853a60 100644 --- a/integrations/google-vertex/pyproject.toml +++ b/integrations/google-vertex/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "google-cloud-aiplatform", + "google-cloud-aiplatform>=1.38", ] [project.urls] diff --git a/integrations/google-vertex/src/google_vertex_haystack/__about__.py b/integrations/google-vertex/src/google_vertex_haystack/__about__.py index 0e4fa27cf..d4a92df1b 100644 --- a/integrations/google-vertex/src/google_vertex_haystack/__about__.py +++ b/integrations/google-vertex/src/google_vertex_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.1" +__version__ = "0.0.2" diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py new file mode 100644 index 000000000..765706301 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py @@ -0,0 +1,179 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from vertexai.preview.generative_models import ( + Content, + FunctionDeclaration, + GenerationConfig, + GenerativeModel, + HarmBlockThreshold, + HarmCategory, + Part, + Tool, +) + +logger = logging.getLogger(__name__) + + +@component +class GeminiChatGenerator: + def __init__( + self, + *, + model: str = "gemini-pro", + project_id: str, + location: Optional[str] = None, + generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Multi modal generator using Gemini model via Google Vertex AI. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "gemini-pro-vision". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._model = GenerativeModel(self._model_name) + + self._generation_config = generation_config + self._safety_settings = safety_settings + self._tools = tools + + def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: + return { + "name": function._raw_function_declaration.name, + "parameters": function._raw_function_declaration.parameters, + "description": function._raw_function_declaration.description, + } + + def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: + return { + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], + } + + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(config, dict): + return config + return { + "temperature": config._raw_generation_config.temperature, + "top_p": config._raw_generation_config.top_p, + "top_k": config._raw_generation_config.top_k, + "candidate_count": config._raw_generation_config.candidate_count, + "max_output_tokens": config._raw_generation_config.max_output_tokens, + "stop_sequences": config._raw_generation_config.stop_sequences, + } + + def to_dict(self) -> Dict[str, Any]: + data = default_to_dict( + self, + model=self._model_name, + project_id=self._project_id, + location=self._location, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeminiChatGenerator": + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + + return default_from_dict(cls, data) + + def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: + if isinstance(part, str): + return Part.from_text(part) + elif isinstance(part, ByteStream): + return Part.from_data(part.data, part.mime_type) + elif isinstance(part, Part): + return part + else: + msg = f"Unsupported type {type(part)} for part {part}" + raise ValueError(msg) + + def _message_to_part(self, message: ChatMessage) -> Part: + if message.role == ChatRole.SYSTEM and message.name: + p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) + for k, v in message.content.items(): + p.function_call.args[k] = v + return p + elif message.role == ChatRole.SYSTEM: + return Part.from_text(message.content) + elif message.role == ChatRole.FUNCTION: + return Part.from_function_response(name=message.name, response=message.content) + elif message.role == ChatRole.USER: + return self._convert_part(message.content) + + def _message_to_content(self, message: ChatMessage) -> Content: + if message.role == ChatRole.SYSTEM and message.name: + part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) + for k, v in message.content.items(): + part.function_call.args[k] = v + elif message.role == ChatRole.SYSTEM: + part = Part.from_text(message.content) + elif message.role == ChatRole.FUNCTION: + part = Part.from_function_response(name=message.name, response=message.content) + elif message.role == ChatRole.USER: + part = self._convert_part(message.content) + else: + msg = f"Unsupported message role {message.role}" + raise ValueError(msg) + role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + return Content(parts=[part], role=role) + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage]): + history = [self._message_to_content(m) for m in messages[:-1]] + session = self._model.start_chat(history=history) + + new_message = self._message_to_part(messages[-1]) + res = session.send_message( + content=new_message, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) + + replies = [] + for candidate in res.candidates: + for part in candidate.content.parts: + if part._raw_part.text != "": + replies.append(ChatMessage.from_system(part.text)) + elif part.function_call is not None: + replies.append( + ChatMessage( + content=dict(part.function_call.args.items()), + role=ChatRole.SYSTEM, + name=part.function_call.name, + ) + ) + + return {"replies": replies} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py index b01dc6795..aa8f9d9e8 100644 --- a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py @@ -8,8 +8,13 @@ from haystack.dataclasses.byte_stream import ByteStream from vertexai.preview.generative_models import ( Content, + FunctionDeclaration, + GenerationConfig, GenerativeModel, + HarmBlockThreshold, + HarmCategory, Part, + Tool, ) logger = logging.getLogger(__name__) @@ -17,7 +22,16 @@ @component class GeminiGenerator: - def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, + *, + model: str = "gemini-pro-vision", + project_id: str, + location: Optional[str] = None, + generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, + tools: Optional[List[Tool]] = None, + ): """ Multi modal generator using Gemini model via Google Vertex AI. @@ -29,8 +43,19 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio :param model: Name of the model to use, defaults to "gemini-pro-vision". :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. - :param kwargs: Additional keyword arguments to pass to the model. - For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + :param generation_config: The generation config to use, defaults to None. + Can either be a GenerationConfig object or a dictionary of parameters. + Accepted fields are: + - temperature + - top_p + - top_k + - candidate_count + - max_output_tokens + - stop_sequences + :param safety_settings: The safety settings to use, defaults to None. + A dictionary of HarmCategory to HarmBlockThreshold. + :param tools: The tools to use, defaults to None. + A list of Tool objects that can be used to modify the generation process. """ # Login to GCP. This will fail if user has not set up their gcloud SDK @@ -39,23 +64,59 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio self._model_name = model self._project_id = project_id self._location = location - self._kwargs = kwargs - - if kwargs.get("stream"): - msg = "The `stream` parameter is not supported by the Gemini generator." - raise ValueError(msg) - self._model = GenerativeModel(self._model_name) + self._generation_config = generation_config + self._safety_settings = safety_settings + self._tools = tools + + def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: + return { + "name": function._raw_function_declaration.name, + "parameters": function._raw_function_declaration.parameters, + "description": function._raw_function_declaration.description, + } + + def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: + return { + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], + } + + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(config, dict): + return config + return { + "temperature": config._raw_generation_config.temperature, + "top_p": config._raw_generation_config.top_p, + "top_k": config._raw_generation_config.top_k, + "candidate_count": config._raw_generation_config.candidate_count, + "max_output_tokens": config._raw_generation_config.max_output_tokens, + "stop_sequences": config._raw_generation_config.stop_sequences, + } + def to_dict(self) -> Dict[str, Any]: - # TODO: This is not fully implemented yet - return default_to_dict( - self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + data = default_to_dict( + self, + model=self._model_name, + project_id=self._project_id, + location=self._location, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, ) + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GeminiGenerator": - # TODO: This is not fully implemented yet + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -74,7 +135,12 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] - res = self._model.generate_content(contents=contents, **self._kwargs) + res = self._model.generate_content( + contents=contents, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) self._model.start_chat() answers = [] for candidate in res.candidates: @@ -89,17 +155,3 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): answers.append(function_call) return {"answers": answers} - - -# generator = GeminiGenerator(project_id="infinite-byte-223810") -# res = generator.run(["What can you do for me?"]) -# res -# another_res = generator.run(["Can you solve this math problems?", "2 + 2", "3 + 3", "1 / 1"]) -# another_res["answers"] -# from pathlib import Path - -# image = ByteStream.from_file_path( -# Path("/Users/silvanocerza/Downloads/photo_2023-11-07_11-45-42.jpg"), mime_type="image/jpeg" -# ) -# res = generator.run(["What is this about?", image]) -# res["answers"]