Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Google Vertex integration #102

Merged
merged 5 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion integrations/google-vertex/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
]
dependencies = [
"haystack-ai",
"google-cloud-aiplatform",
"google-cloud-aiplatform>=1.38",
]

[project.urls]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
__version__ = "0.0.1"
__version__ = "0.0.2"
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
from typing import Any, Dict, List, Optional, Union

import vertexai
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from vertexai.preview.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Part,
Tool,
)

logger = logging.getLogger(__name__)


@component
class GeminiChatGenerator:
def __init__(
self,
*,
model: str = "gemini-pro",
project_id: str,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Multi modal generator using Gemini model via Google Vertex AI.

Authenticates using Google Cloud Application Default Credentials (ADCs).
For more information see the official Google documentation:
https://cloud.google.com/docs/authentication/provide-credentials-adc

:param project_id: ID of the GCP project to use.
:param model: Name of the model to use, defaults to "gemini-pro-vision".
:param location: The default location to use when making API calls, if not set uses us-central-1.
Defaults to None.
:param kwargs: Additional keyword arguments to pass to the model.
For a list of supported arguments see the `GenerativeModel.generate_content()` documentation.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
vertexai.init(project=project_id, location=location)

self._model_name = model
self._project_id = project_id
self._location = location
self._model = GenerativeModel(self._model_name)

self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
return {
"temperature": config._raw_generation_config.temperature,
"top_p": config._raw_generation_config.top_p,
"top_k": config._raw_generation_config.top_k,
"candidate_count": config._raw_generation_config.candidate_count,
"max_output_tokens": config._raw_generation_config.max_output_tokens,
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def to_dict(self) -> Dict[str, Any]:
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
location=self._location,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeminiChatGenerator":
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)

return default_from_dict(cls, data)

def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
elif isinstance(part, ByteStream):
return Part.from_data(part.data, part.mime_type)
elif isinstance(part, Part):
return part
else:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

def _message_to_part(self, message: ChatMessage) -> Part:
if message.role == ChatRole.SYSTEM and message.name:
p = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
p.function_call.args[k] = v
return p
elif message.role == ChatRole.SYSTEM:
return Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
return Part.from_function_response(name=message.name, response=message.content)
elif message.role == ChatRole.USER:
return self._convert_part(message.content)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.SYSTEM and message.name:
part = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
for k, v in message.content.items():
part.function_call.args[k] = v
elif message.role == ChatRole.SYSTEM:
part = Part.from_text(message.content)
elif message.role == ChatRole.FUNCTION:
part = Part.from_function_response(name=message.name, response=message.content)
elif message.role == ChatRole.USER:
part = self._convert_part(message.content)
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model"
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
history = [self._message_to_content(m) for m in messages[:-1]]
session = self._model.start_chat(history=history)

new_message = self._message_to_part(messages[-1])
res = session.send_message(
content=new_message,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
)

replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_system(part.text))
elif part.function_call is not None:
replies.append(
ChatMessage(
content=dict(part.function_call.args.items()),
role=ChatRole.SYSTEM,
name=part.function_call.name,
)
)

return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,30 @@
from haystack.dataclasses.byte_stream import ByteStream
from vertexai.preview.generative_models import (
Content,
FunctionDeclaration,
GenerationConfig,
GenerativeModel,
HarmBlockThreshold,
HarmCategory,
Part,
Tool,
)

logger = logging.getLogger(__name__)


@component
class GeminiGenerator:
def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, location: Optional[str] = None, **kwargs):
def __init__(
self,
*,
model: str = "gemini-pro-vision",
project_id: str,
location: Optional[str] = None,
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
):
"""
Multi modal generator using Gemini model via Google Vertex AI.

Expand All @@ -29,8 +43,19 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio
:param model: Name of the model to use, defaults to "gemini-pro-vision".
:param location: The default location to use when making API calls, if not set uses us-central-1.
Defaults to None.
:param kwargs: Additional keyword arguments to pass to the model.
For a list of supported arguments see the `GenerativeModel.generate_content()` documentation.
:param generation_config: The generation config to use, defaults to None.
Can either be a GenerationConfig object or a dictionary of parameters.
Accepted fields are:
- temperature
- top_p
- top_k
- candidate_count
- max_output_tokens
- stop_sequences
:param safety_settings: The safety settings to use, defaults to None.
A dictionary of HarmCategory to HarmBlockThreshold.
:param tools: The tools to use, defaults to None.
A list of Tool objects that can be used to modify the generation process.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
Expand All @@ -39,23 +64,59 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio
self._model_name = model
self._project_id = project_id
self._location = location
self._kwargs = kwargs

if kwargs.get("stream"):
msg = "The `stream` parameter is not supported by the Gemini generator."
raise ValueError(msg)

self._model = GenerativeModel(self._model_name)

self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
return config
return {
"temperature": config._raw_generation_config.temperature,
"top_p": config._raw_generation_config.top_p,
"top_k": config._raw_generation_config.top_k,
"candidate_count": config._raw_generation_config.candidate_count,
"max_output_tokens": config._raw_generation_config.max_output_tokens,
"stop_sequences": config._raw_generation_config.stop_sequences,
}

def to_dict(self) -> Dict[str, Any]:
# TODO: This is not fully implemented yet
return default_to_dict(
self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs
data = default_to_dict(
self,
model=self._model_name,
project_id=self._project_id,
location=self._location,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeminiGenerator":
# TODO: This is not fully implemented yet
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)

return default_from_dict(cls, data)

def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
Expand All @@ -74,7 +135,12 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]):
converted_parts = [self._convert_part(p) for p in parts]

contents = [Content(parts=converted_parts, role="user")]
res = self._model.generate_content(contents=contents, **self._kwargs)
res = self._model.generate_content(
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
)
self._model.start_chat()
answers = []
for candidate in res.candidates:
Expand All @@ -89,17 +155,3 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]):
answers.append(function_call)

return {"answers": answers}


# generator = GeminiGenerator(project_id="infinite-byte-223810")
# res = generator.run(["What can you do for me?"])
# res
# another_res = generator.run(["Can you solve this math problems?", "2 + 2", "3 + 3", "1 / 1"])
# another_res["answers"]
# from pathlib import Path

# image = ByteStream.from_file_path(
# Path("/Users/silvanocerza/Downloads/photo_2023-11-07_11-45-42.jpg"), mime_type="image/jpeg"
# )
# res = generator.run(["What is this about?", image])
# res["answers"]