-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AnthropicVertexChatGenerator component (#1192)
* Created a model adapter * Create adapter class and add VertexAPI * Add chat generator for Anthropic Vertex * Add tests * Small fix * Improve doc_strings * Make project_id and region mandatory params * Small fix
- Loading branch information
Showing
3 changed files
with
334 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
135 changes: 135 additions & 0 deletions
135
...c/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import os | ||
from typing import Any, Callable, Dict, Optional | ||
|
||
from haystack import component, default_from_dict, default_to_dict, logging | ||
from haystack.dataclasses import StreamingChunk | ||
from haystack.utils import deserialize_callable, serialize_callable | ||
|
||
from anthropic import AnthropicVertex | ||
|
||
from .chat_generator import AnthropicChatGenerator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class AnthropicVertexChatGenerator(AnthropicChatGenerator): | ||
""" | ||
Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. | ||
It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, | ||
accessible through the Vertex AI API endpoint. | ||
To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. | ||
Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. | ||
Before making requests, you may need to authenticate with GCP using `gcloud auth login`. | ||
For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). | ||
Any valid text generation parameters for the Anthropic messaging API can be passed to | ||
the AnthropicVertex API. Users can provide these parameters directly to the component via | ||
the `generation_kwargs` parameter in `__init__` or the `run` method. | ||
For more details on the parameters supported by the Anthropic API, refer to the | ||
Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). | ||
```python | ||
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator | ||
from haystack.dataclasses import ChatMessage | ||
messages = [ChatMessage.from_user("What's Natural Language Processing?")] | ||
client = AnthropicVertexChatGenerator( | ||
model="claude-3-sonnet@20240229", | ||
project_id="your-project-id", region="your-region" | ||
) | ||
response = client.run(messages) | ||
print(response) | ||
>> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that | ||
>> focuses on enabling computers to understand, interpret, and generate human language. It involves developing | ||
>> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and | ||
>> communicate in natural languages like English, Spanish, or Chinese.', role=<ChatRole.ASSISTANT: 'assistant'>, | ||
>> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', | ||
>> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} | ||
``` | ||
For more details on supported models and their capabilities, refer to the Anthropic | ||
[documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
region: str, | ||
project_id: str, | ||
model: str = "claude-3-5-sonnet@20240620", | ||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, | ||
generation_kwargs: Optional[Dict[str, Any]] = None, | ||
ignore_tools_thinking_messages: bool = True, | ||
): | ||
""" | ||
Creates an instance of AnthropicVertexChatGenerator. | ||
:param region: The region where the Anthropic model is deployed. Defaults to "us-central1". | ||
:param project_id: The GCP project ID where the Anthropic model is deployed. | ||
:param model: The name of the model to use. | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function accepts StreamingChunk as an argument. | ||
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to | ||
the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) | ||
for more details. | ||
Supported generation_kwargs parameters are: | ||
- `system`: The system message to be passed to the model. | ||
- `max_tokens`: The maximum number of tokens to generate. | ||
- `metadata`: A dictionary of metadata to be passed to the model. | ||
- `stop_sequences`: A list of strings that the model should stop generating at. | ||
- `temperature`: The temperature to use for sampling. | ||
- `top_p`: The top_p value to use for nucleus sampling. | ||
- `top_k`: The top_k value to use for top-k sampling. | ||
- `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). | ||
:param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a | ||
"chain of thought" messages before returning the actual function names and parameters in a message. If | ||
`ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool | ||
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) | ||
for more details. | ||
""" | ||
self.region = region or os.environ.get("REGION") | ||
self.project_id = project_id or os.environ.get("PROJECT_ID") | ||
self.model = model | ||
self.generation_kwargs = generation_kwargs or {} | ||
self.streaming_callback = streaming_callback | ||
self.client = AnthropicVertex(region=self.region, project_id=self.project_id) | ||
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
:returns: | ||
The serialized component as a dictionary. | ||
""" | ||
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None | ||
return default_to_dict( | ||
self, | ||
region=self.region, | ||
project_id=self.project_id, | ||
model=self.model, | ||
streaming_callback=callback_name, | ||
generation_kwargs=self.generation_kwargs, | ||
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
:param data: The dictionary representation of this component. | ||
:returns: | ||
The deserialized component instance. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
serialized_callback_handler = init_params.get("streaming_callback") | ||
if serialized_callback_handler: | ||
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) | ||
return default_from_dict(cls, data) |
197 changes: 197 additions & 0 deletions
197
integrations/anthropic/tests/test_vertex_chat_generator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import os | ||
|
||
import anthropic | ||
import pytest | ||
from haystack.components.generators.utils import print_streaming_chunk | ||
from haystack.dataclasses import ChatMessage, ChatRole | ||
|
||
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator | ||
|
||
|
||
@pytest.fixture | ||
def chat_messages(): | ||
return [ | ||
ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), | ||
ChatMessage.from_user("What's the capital of France?"), | ||
] | ||
|
||
|
||
class TestAnthropicVertexChatGenerator: | ||
def test_init_default(self): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.streaming_callback is None | ||
assert not component.generation_kwargs | ||
assert component.ignore_tools_thinking_messages | ||
|
||
def test_init_with_parameters(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
model="claude-3-5-sonnet@20240620", | ||
streaming_callback=print_streaming_chunk, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
ignore_tools_thinking_messages=False, | ||
) | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.streaming_callback is print_streaming_chunk | ||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
assert component.ignore_tools_thinking_messages is False | ||
|
||
def test_to_dict_default(self): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": None, | ||
"generation_kwargs": {}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_to_dict_with_parameters(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
streaming_callback=print_streaming_chunk, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_to_dict_with_lambda_streaming_callback(self): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", | ||
project_id="test-project-id", | ||
model="claude-3-5-sonnet@20240620", | ||
streaming_callback=lambda x: x, | ||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, | ||
) | ||
data = component.to_dict() | ||
assert data == { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "tests.test_vertex_chat_generator.<lambda>", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
|
||
def test_from_dict(self): | ||
data = { | ||
"type": ( | ||
"haystack_integrations.components.generators." | ||
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" | ||
), | ||
"init_parameters": { | ||
"region": "us-central1", | ||
"project_id": "test-project-id", | ||
"model": "claude-3-5-sonnet@20240620", | ||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", | ||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, | ||
"ignore_tools_thinking_messages": True, | ||
}, | ||
} | ||
component = AnthropicVertexChatGenerator.from_dict(data) | ||
assert component.model == "claude-3-5-sonnet@20240620" | ||
assert component.region == "us-central1" | ||
assert component.project_id == "test-project-id" | ||
assert component.streaming_callback is print_streaming_chunk | ||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} | ||
|
||
def test_run(self, chat_messages, mock_chat_completion): | ||
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") | ||
response = component.run(chat_messages) | ||
|
||
# check that the component returns the correct ChatMessage response | ||
assert isinstance(response, dict) | ||
assert "replies" in response | ||
assert isinstance(response["replies"], list) | ||
assert len(response["replies"]) == 1 | ||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | ||
|
||
def test_run_with_params(self, chat_messages, mock_chat_completion): | ||
component = AnthropicVertexChatGenerator( | ||
region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} | ||
) | ||
response = component.run(chat_messages) | ||
|
||
# check that the component calls the Anthropic API with the correct parameters | ||
_, kwargs = mock_chat_completion.call_args | ||
assert kwargs["max_tokens"] == 10 | ||
assert kwargs["temperature"] == 0.5 | ||
|
||
# check that the component returns the correct response | ||
assert isinstance(response, dict) | ||
assert "replies" in response | ||
assert isinstance(response["replies"], list) | ||
assert len(response["replies"]) == 1 | ||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] | ||
|
||
@pytest.mark.skipif( | ||
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), | ||
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_live_run_wrong_model(self, chat_messages): | ||
component = AnthropicVertexChatGenerator( | ||
model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") | ||
) | ||
with pytest.raises(anthropic.NotFoundError): | ||
component.run(chat_messages) | ||
|
||
@pytest.mark.skipif( | ||
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), | ||
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", | ||
) | ||
@pytest.mark.integration | ||
def test_default_inference_params(self, chat_messages): | ||
client = AnthropicVertexChatGenerator( | ||
region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" | ||
) | ||
response = client.run(chat_messages) | ||
|
||
assert "replies" in response, "Response does not contain 'replies' key" | ||
replies = response["replies"] | ||
assert isinstance(replies, list), "Replies is not a list" | ||
assert len(replies) > 0, "No replies received" | ||
|
||
first_reply = replies[0] | ||
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" | ||
assert first_reply.content, "First reply has no content" | ||
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" | ||
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" | ||
assert first_reply.meta, "First reply has no metadata" | ||
|
||
# Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, | ||
# remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. |