diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..0bd29898e 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator.