Skip to content

Commit

Permalink
Tests for pipeline se/deserialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 6, 2024
1 parent 1d83d68 commit c33d7dd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
16 changes: 16 additions & 0 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.dataclasses import ChatMessage, StreamingChunk
from vertexai.generative_models import (
Content,
Expand Down Expand Up @@ -278,3 +280,17 @@ def streaming_callback(chunk: StreamingChunk) -> None:

mock_model.send_message.assert_called_once()
assert streaming_callback_called == ["First part", "Second part"]


def test_serialization_deserialization_pipeline():

pipeline = Pipeline()
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
pipeline.add_component("prompt_builder", ChatPromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiChatGenerator(project_id="TestID123"))
pipeline.connect("prompt_builder.prompt", "gemini.messages")

pipeline_dict = pipeline.to_dict()

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline
19 changes: 19 additions & 0 deletions integrations/google_vertex/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest.mock import MagicMock, Mock, patch

from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.dataclasses import StreamingChunk
from vertexai.generative_models import (
FunctionDeclaration,
Expand Down Expand Up @@ -254,3 +256,20 @@ def streaming_callback(_chunk: StreamingChunk) -> None:
gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback)
gemini.run(["Come on, stream!"])
assert streaming_callback_called


def test_serialization_deserialization_pipeline():
template = """
Answer the following questions:
1. What is the weather like today?
"""
pipeline = Pipeline()

pipeline.add_component("prompt_builder", PromptBuilder(template=template))
pipeline.add_component("gemini", VertexAIGeminiGenerator(project_id="TestID123"))
pipeline.connect("prompt_builder", "gemini")

pipeline_dict = pipeline.to_dict()

new_pipeline = Pipeline.from_dict(pipeline_dict)
assert new_pipeline == pipeline

0 comments on commit c33d7dd

Please sign in to comment.