diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 43b50495c..56a740146 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -1,4 +1,3 @@ -import dataclasses import json from typing import Any, Callable, ClassVar, Dict, List, Optional, Union @@ -275,8 +274,16 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict """ anthropic_formatted_messages = [] for m in messages: - message_dict = dataclasses.asdict(m) - formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + message_dict = m.to_dict() + formatted_message = {} + + # legacy format + if "role" in message_dict and "content" in message_dict: + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + # new format + elif "_role" in message_dict and "_content" in message_dict: + formatted_message = {"role": m.role.value, "content": m.text} + if m.is_from(ChatRole.SYSTEM): # system messages are treated differently and MUST be in the format expected by the Anthropic API # remove role and content from the message dict, add type and text diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 9a111fc9d..69b3265aa 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -421,6 +421,8 @@ def test_prompt_caching(self, cache_enabled): assert len(result["replies"]) == 1 token_usage = result["replies"][0].meta.get("usage") + print(token_usage) + if cache_enabled: # either we created cache or we read it (depends on how you execute this integration test) assert ( @@ -428,5 +430,5 @@ def test_prompt_caching(self, cache_enabled): or token_usage.get("cache_read_input_tokens") > 1024 ) else: - assert "cache_creation_input_tokens" not in token_usage - assert "cache_read_input_tokens" not in token_usage + assert token_usage["cache_creation_input_tokens"] == 0 + assert token_usage["cache_read_input_tokens"] == 0