Skip to content

Commit

Permalink
Add stream param
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Aug 11, 2024
1 parent 7f5b12e commit 86f6219
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 3 deletions.
2 changes: 1 addition & 1 deletion integrations/google_ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"haystack-pydoc-tools",
]
[tool.hatch.envs.default.scripts]
test = "pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
test = "pytest --reruns 0 --reruns-delay 30 -x {args:tests}"
test-cov = "coverage run -m pytest --reruns 3 --reruns-delay 30 -x {args:tests}"
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
stream: Optional[bool] = False,
):
"""
Initializes a `GoogleAIGeminiGenerator` instance.
Expand All @@ -90,7 +91,9 @@ def __init__(
:param safety_settings: The safety settings to use.
A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values.
For more information, see [the API reference](https://ai.google.dev/api)
:param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling).
:param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/
function_calling).
:param stream: Whether to stream the response.
"""
genai.configure(api_key=api_key.resolve_value())

Expand All @@ -100,6 +103,7 @@ def __init__(
self._safety_settings = safety_settings
self._tools = tools
self._model = GenerativeModel(self._model_name, tools=self._tools)
self._stream = stream

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
Expand Down Expand Up @@ -127,13 +131,15 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=self._stream
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
if (safety_settings := data["init_parameters"].get("safety_settings")) is not None:
data["init_parameters"]["safety_settings"] = {k.value: v.value for k, v in safety_settings.items()}

return data

@classmethod
Expand Down Expand Up @@ -194,6 +200,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
stream=self._stream
)
self._model.start_chat()
replies = []
Expand Down
2 changes: 2 additions & 0 deletions integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_to_dict(monkeypatch):
"stop_sequences": ["stop"],
},
"safety_settings": {10: 3},
"stream": False,
"tools": [
b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai"
b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08"
Expand Down Expand Up @@ -140,6 +141,7 @@ def test_from_dict(monkeypatch):
b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08"
b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location"
],
"stream": False
},
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
stream: Optional[bool] = False
):
"""
Multi-modal generator using Gemini model via Google Vertex AI.
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool)
the list of supported arguments.
:param stream: Whether to stream the response.
"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
Expand All @@ -100,6 +102,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools
self._stream = stream

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
Expand Down Expand Up @@ -140,6 +143,7 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=self._stream
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
Expand Down Expand Up @@ -191,7 +195,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=self._stream,
)
self._model.start_chat()
replies = []
Expand Down

0 comments on commit 86f6219

Please sign in to comment.