From 2db16164eb8a50b6ff8d3a4e03776549699e0255 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Mon, 19 Aug 2024 13:05:21 +0200 Subject: [PATCH 01/12] Add prompt caching, add example --- .../anthropic/example/prompt_caching.py | 63 +++++++++++++++++++ .../anthropic/chat/chat_generator.py | 3 + 2 files changed, 66 insertions(+) create mode 100644 integrations/anthropic/example/prompt_caching.py diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py new file mode 100644 index 000000000..618090f8e --- /dev/null +++ b/integrations/anthropic/example/prompt_caching.py @@ -0,0 +1,63 @@ +# To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable. + +from haystack import Pipeline +from haystack.components.builders import ChatPromptBuilder +from haystack.components.converters import HTMLToDocument +from haystack.components.fetchers import LinkContentFetcher +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage +from haystack.utils import Secret + +from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + +msg = ChatMessage.from_system( + "You are a prompt expert who answers questions based on the given documents.\n" + "Here are the documents:\n" + "{% for d in documents %} \n" + " {{d.content}} \n" + "{% endfor %}" +) + +fetch_pipeline = Pipeline() +fetch_pipeline.add_component("fetcher", LinkContentFetcher()) +fetch_pipeline.add_component("converter", HTMLToDocument()) +fetch_pipeline.add_component("prompt_builder", ChatPromptBuilder(template=[msg], variables=["documents"])) + +fetch_pipeline.connect("fetcher", "converter") +fetch_pipeline.connect("converter", "prompt_builder") + +result = fetch_pipeline.run( + data={ + "fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]}, + } +) + +# Now we have our document fetched as a ChatMessage +final_prompt_msg = result["prompt_builder"]["prompt"][0] + +# We add a cache control header to the prompt message +final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"} + + +# Build QA pipeline +qa_pipeline = Pipeline() +qa_pipeline.add_component("llm", AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), + streaming_callback=print_streaming_chunk, + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, +)) + +questions = ["Why is Monte-Carlo Tree Search used in LATS", + "Summarize LATS selection, expansion, evaluation, simulation, backpropagation, and reflection"] + +# Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it) +for question in questions: + print("Question: " + question) + qa_pipeline.run( + data={ + "llm": {"messages": [final_prompt_msg, + ChatMessage.from_user("Given these documents, answer the question:" + question)]}, + } + ) + print("\n") + diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 9954f08c5..9fc5066fd 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -72,6 +72,7 @@ class AnthropicChatGenerator: "temperature", "top_p", "top_k", + "extra_headers", ] def __init__( @@ -101,6 +102,7 @@ def __init__( - `temperature`: The temperature to use for sampling. - `top_p`: The top_p value to use for nucleus sampling. - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a "chain of thought" messages before returning the actual function names and parameters in a message. If `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool @@ -260,6 +262,7 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict for m in messages: message_dict = dataclasses.asdict(m) filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + filtered_message.update(m.meta or {}) anthropic_formatted_messages.append(filtered_message) return anthropic_formatted_messages From 37908ce2f1e7c82f73869c1792eabf3d7722013e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 20 Aug 2024 20:23:32 +0200 Subject: [PATCH 02/12] Print prompt caching data in example --- integrations/anthropic/example/prompt_caching.py | 4 +++- .../components/generators/anthropic/chat/chat_generator.py | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index 618090f8e..495f292e6 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -53,11 +53,13 @@ # Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it) for question in questions: print("Question: " + question) - qa_pipeline.run( + result = qa_pipeline.run( data={ "llm": {"messages": [final_prompt_msg, ChatMessage.from_user("Given these documents, answer the question:" + question)]}, } ) + + print("\n\nChecking cache usage:", result["llm"]["replies"][0].meta.get("usage")) print("\n") diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 9fc5066fd..5492abf01 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -184,7 +184,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, anthropic_formatted_messages = self._convert_to_anthropic_format(messages) # system message provided by the user overrides the system message from the self.generation_kwargs - system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + system = [anthropic_formatted_messages[0]] if messages and messages[0].is_from(ChatRole.SYSTEM) else None if system: anthropic_formatted_messages = anthropic_formatted_messages[1:] @@ -262,6 +262,11 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict for m in messages: message_dict = dataclasses.asdict(m) filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + if m.is_from(ChatRole.SYSTEM): + # system messages need to be in the format expected by the Anthropic API + filtered_message.pop("role") + filtered_message["type"] = "text" + filtered_message["text"] = filtered_message.pop("content") filtered_message.update(m.meta or {}) anthropic_formatted_messages.append(filtered_message) return anthropic_formatted_messages From f50b491f56412d14aace7625d0e9a4c54cebd9b1 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 20 Aug 2024 20:27:15 +0200 Subject: [PATCH 03/12] Lint --- .../anthropic/example/prompt_caching.py | 28 ++++++++++++------- integrations/anthropic/pyproject.toml | 1 + 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index 495f292e6..6ec71c520 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -41,25 +41,33 @@ # Build QA pipeline qa_pipeline = Pipeline() -qa_pipeline.add_component("llm", AnthropicChatGenerator( - api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), - streaming_callback=print_streaming_chunk, - generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, -)) +qa_pipeline.add_component( + "llm", + AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), + streaming_callback=print_streaming_chunk, + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, + ), +) -questions = ["Why is Monte-Carlo Tree Search used in LATS", - "Summarize LATS selection, expansion, evaluation, simulation, backpropagation, and reflection"] +questions = [ + "Why is Monte-Carlo Tree Search used in LATS", + "Summarize LATS selection, expansion, evaluation, simulation, backpropagation, and reflection", +] # Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it) for question in questions: print("Question: " + question) result = qa_pipeline.run( data={ - "llm": {"messages": [final_prompt_msg, - ChatMessage.from_user("Given these documents, answer the question:" + question)]}, + "llm": { + "messages": [ + final_prompt_msg, + ChatMessage.from_user("Given these documents, answer the question:" + question), + ] + }, } ) print("\n\nChecking cache usage:", result["llm"]["replies"][0].meta.get("usage")) print("\n") - diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index 3f8c9812b..e1d3fa867 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -106,6 +106,7 @@ select = [ "YTT", ] ignore = [ + "T201", # print statements # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords From c8d93f11f83fb3907511b0446bd9e8008090ed6b Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 20 Aug 2024 21:02:54 +0200 Subject: [PATCH 04/12] Anthropic allows multiple system messages, simplify --- .../anthropic/chat/chat_generator.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 5492abf01..8390405c1 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -179,20 +179,20 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, f"Model parameters {disallowed_params} are not allowed and will be ignored. " f"Allowed parameters are {self.ALLOWED_PARAMS}." ) - - # adapt ChatMessage(s) to the format expected by the Anthropic API - anthropic_formatted_messages = self._convert_to_anthropic_format(messages) - - # system message provided by the user overrides the system message from the self.generation_kwargs - system = [anthropic_formatted_messages[0]] if messages and messages[0].is_from(ChatRole.SYSTEM) else None - if system: - anthropic_formatted_messages = anthropic_formatted_messages[1:] + system_messages: List[ChatMessage] = [msg for msg in messages if msg.is_from(ChatRole.SYSTEM)] + non_system_messages: List[ChatMessage] = [msg for msg in messages if not msg.is_from(ChatRole.SYSTEM)] + system_messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(system_messages) if system_messages else [] + ) + messages_formatted: List[Dict[str, Any]] = ( + self._convert_to_anthropic_format(non_system_messages) if non_system_messages else [] + ) response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), - system=system if system else filtered_generation_kwargs.pop("system", ""), + system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""), model=self.model, - messages=anthropic_formatted_messages, + messages=messages_formatted, stream=self.streaming_callback is not None, **filtered_generation_kwargs, ) @@ -261,14 +261,15 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict anthropic_formatted_messages = [] for m in messages: message_dict = dataclasses.asdict(m) - filtered_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} + formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} if m.is_from(ChatRole.SYSTEM): # system messages need to be in the format expected by the Anthropic API - filtered_message.pop("role") - filtered_message["type"] = "text" - filtered_message["text"] = filtered_message.pop("content") - filtered_message.update(m.meta or {}) - anthropic_formatted_messages.append(filtered_message) + # remove role and content from the message dict, add type and text + formatted_message.pop("role") + formatted_message["type"] = "text" + formatted_message["text"] = formatted_message.pop("content") + formatted_message.update(m.meta or {}) + anthropic_formatted_messages.append(formatted_message) return anthropic_formatted_messages def _connect_chunks( From 8fba8f1301f33658825137a7ab5d569f300d9fb5 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 30 Aug 2024 17:29:27 +0200 Subject: [PATCH 05/12] PR feedback --- .../anthropic/example/prompt_caching.py | 56 ++++++++++++++----- .../anthropic/chat/chat_generator.py | 17 ++++++ 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index 6ec71c520..60534bd5a 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -1,15 +1,18 @@ # To run this example, you will need to set a `ANTHROPIC_API_KEY` environment variable. +import time + from haystack import Pipeline from haystack.components.builders import ChatPromptBuilder from haystack.components.converters import HTMLToDocument from haystack.components.fetchers import LinkContentFetcher -from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils import Secret from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator +enable_prompt_caching = True + msg = ChatMessage.from_system( "You are a prompt expert who answers questions based on the given documents.\n" "Here are the documents:\n" @@ -18,6 +21,19 @@ "{% endfor %}" ) + +def measure_and_print_streaming_chunk(): + first_token_time = None + + def stream_callback(chunk: StreamingChunk) -> None: + nonlocal first_token_time + if first_token_time is None: + first_token_time = time.time() + print(chunk.content, flush=True, end="") + + return stream_callback, lambda: first_token_time + + fetch_pipeline = Pipeline() fetch_pipeline.add_component("fetcher", LinkContentFetcher()) fetch_pipeline.add_component("converter", HTMLToDocument()) @@ -36,28 +52,32 @@ final_prompt_msg = result["prompt_builder"]["prompt"][0] # We add a cache control header to the prompt message -final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"} +if enable_prompt_caching: + final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"} +generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_prompt_caching else {} +claude_llm = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), + generation_kwargs=generation_kwargs, +) # Build QA pipeline qa_pipeline = Pipeline() -qa_pipeline.add_component( - "llm", - AnthropicChatGenerator( - api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), - streaming_callback=print_streaming_chunk, - generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, - ), -) +qa_pipeline.add_component("llm", claude_llm) questions = [ - "Why is Monte-Carlo Tree Search used in LATS", - "Summarize LATS selection, expansion, evaluation, simulation, backpropagation, and reflection", + "What's this paper about?", + "What's the main contribution of this paper?", + "How can findings from this paper be applied to real-world problems?", ] # Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it) for question in questions: print("Question: " + question) + start_time = time.time() + streaming_callback, get_first_token_time = measure_and_print_streaming_chunk() + claude_llm.streaming_callback = streaming_callback + result = qa_pipeline.run( data={ "llm": { @@ -69,5 +89,11 @@ } ) - print("\n\nChecking cache usage:", result["llm"]["replies"][0].meta.get("usage")) - print("\n") + end_time = time.time() + total_time = end_time - start_time + time_to_first_token = get_first_token_time() - start_time + + print(f"\nTotal time: {total_time:.2f} seconds") + print(f"Time to first token: {time_to_first_token:.2f} seconds") + print(f"Cache usage: {result['llm']['replies'][0].meta.get('usage')}") + print("\n" + "=" * 50) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 8390405c1..82c47a385 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -188,6 +188,23 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, self._convert_to_anthropic_format(non_system_messages) if non_system_messages else [] ) + extra_headers = filtered_generation_kwargs.get("extra_headers", {}) + prompt_caching_on = "anthropic-beta" in extra_headers and "prompt-caching" in extra_headers["anthropic-beta"] + has_cached_messages = any("cache_control" in m for m in system_messages_formatted) or any( + "cache_control" in m for m in messages_formatted + ) + if has_cached_messages and not prompt_caching_on: + # this avoids Anthropic errors when prompt caching is not enabled + # but user requested individual messages to be cached + logger.warn( + "Prompt caching is not enabled but you requested individual messages to be cached. " + "Messages will be sent to the API without prompt caching." + ) + for m in system_messages_formatted: + m.pop("cache_control", None) + for m in messages_formatted: + m.pop("cache_control", None) + response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), system=system_messages_formatted or filtered_generation_kwargs.pop("system", ""), From 3f9f0ae93960ab11791b20daf5ef4be5b886e646 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 6 Sep 2024 10:49:53 +0200 Subject: [PATCH 06/12] Update prompt_caching.py example to use ChatPromptBuilder 2.5 fixes --- .../anthropic/example/prompt_caching.py | 90 +++++++++---------- .../anthropic/chat/chat_generator.py | 2 +- 2 files changed, 44 insertions(+), 48 deletions(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index 60534bd5a..c7b073938 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -11,18 +11,18 @@ from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator -enable_prompt_caching = True +# Advanced: We can also cache the HTTP GET requests for the HTML content to avoid repeating requests +# that fetched the same content. +# This type of caching requires requests_cache library to be installed +# Uncomment the following two lines to caching the HTTP requests -msg = ChatMessage.from_system( - "You are a prompt expert who answers questions based on the given documents.\n" - "Here are the documents:\n" - "{% for d in documents %} \n" - " {{d.content}} \n" - "{% endfor %}" -) +# import requests_cache +# requests_cache.install_cache("anthropic_demo") + +ENABLE_PROMPT_CACHING = True # Toggle this to enable or disable prompt caching -def measure_and_print_streaming_chunk(): +def create_streaming_callback(): first_token_time = None def stream_callback(chunk: StreamingChunk) -> None: @@ -34,36 +34,32 @@ def stream_callback(chunk: StreamingChunk) -> None: return stream_callback, lambda: first_token_time -fetch_pipeline = Pipeline() -fetch_pipeline.add_component("fetcher", LinkContentFetcher()) -fetch_pipeline.add_component("converter", HTMLToDocument()) -fetch_pipeline.add_component("prompt_builder", ChatPromptBuilder(template=[msg], variables=["documents"])) +# Until prompt caching graduates from beta, we need to set the anthropic-beta header +generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if ENABLE_PROMPT_CACHING else {} -fetch_pipeline.connect("fetcher", "converter") -fetch_pipeline.connect("converter", "prompt_builder") - -result = fetch_pipeline.run( - data={ - "fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]}, - } +claude_llm = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs ) -# Now we have our document fetched as a ChatMessage -final_prompt_msg = result["prompt_builder"]["prompt"][0] - -# We add a cache control header to the prompt message -if enable_prompt_caching: - final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"} - -generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_prompt_caching else {} -claude_llm = AnthropicChatGenerator( - api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), - generation_kwargs=generation_kwargs, +pipe = Pipeline() +pipe.add_component("fetcher", LinkContentFetcher()) +pipe.add_component("converter", HTMLToDocument()) +pipe.add_component("prompt_builder", ChatPromptBuilder(variables=["documents"])) +pipe.add_component("llm", claude_llm) +pipe.connect("fetcher", "converter") +pipe.connect("converter", "prompt_builder") +pipe.connect("prompt_builder.prompt", "llm.messages") + +system_message = ChatMessage.from_system( + "Claude is an AI assistant that answers questions based on the given documents.\n" + "Here are the documents:\n" + "{% for d in documents %} \n" + " {{d.content}} \n" + "{% endfor %}" ) -# Build QA pipeline -qa_pipeline = Pipeline() -qa_pipeline.add_component("llm", claude_llm) +if ENABLE_PROMPT_CACHING: + system_message.meta["cache_control"] = {"type": "ephemeral"} questions = [ "What's this paper about?", @@ -71,29 +67,29 @@ def stream_callback(chunk: StreamingChunk) -> None: "How can findings from this paper be applied to real-world problems?", ] -# Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it) for question in questions: - print("Question: " + question) + print(f"Question: {question}") start_time = time.time() - streaming_callback, get_first_token_time = measure_and_print_streaming_chunk() + streaming_callback, get_first_token_time = create_streaming_callback() + # reset LLM streaming callback to initialize new timers in streaming mode claude_llm.streaming_callback = streaming_callback - result = qa_pipeline.run( + result = pipe.run( data={ - "llm": { - "messages": [ - final_prompt_msg, - ChatMessage.from_user("Given these documents, answer the question:" + question), - ] - }, + "fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]}, + "prompt_builder": {"template": [system_message, ChatMessage.from_user(f"Answer the question: {question}")]}, } ) end_time = time.time() total_time = end_time - start_time time_to_first_token = get_first_token_time() - start_time - - print(f"\nTotal time: {total_time:.2f} seconds") + print("\n" + "-" * 100) + print(f"Total generation time: {total_time:.2f} seconds") print(f"Time to first token: {time_to_first_token:.2f} seconds") + # first time we create a prompt cache usage key 'cache_creation_input_tokens' will have a value of the number of + # tokens used to create the prompt cache + # on first subsequent cache hit we'll see a usage key 'cache_read_input_tokens' having a value of the number of + # tokens read from the cache print(f"Cache usage: {result['llm']['replies'][0].meta.get('usage')}") - print("\n" + "=" * 50) + print("\n" + "=" * 100) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index 82c47a385..ae6806709 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -280,7 +280,7 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict message_dict = dataclasses.asdict(m) formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v} if m.is_from(ChatRole.SYSTEM): - # system messages need to be in the format expected by the Anthropic API + # system messages are treated differently and MUST be in the format expected by the Anthropic API # remove role and content from the message dict, add type and text formatted_message.pop("role") formatted_message["type"] = "text" From 57f23233c1a5d484e611f6f45fe23142d698d9ef Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 6 Sep 2024 11:00:14 +0200 Subject: [PATCH 07/12] Small fixes --- .../generators/anthropic/chat/chat_generator.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py index ae6806709..43b50495c 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/chat_generator.py @@ -200,10 +200,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, "Prompt caching is not enabled but you requested individual messages to be cached. " "Messages will be sent to the API without prompt caching." ) - for m in system_messages_formatted: - m.pop("cache_control", None) - for m in messages_formatted: - m.pop("cache_control", None) + system_messages_formatted = list(map(self._remove_cache_control, system_messages_formatted)) + messages_formatted = list(map(self._remove_cache_control, messages_formatted)) response: Union[Message, Stream[MessageStreamEvent]] = self.client.messages.create( max_tokens=filtered_generation_kwargs.pop("max_tokens", 512), @@ -317,3 +315,11 @@ def _build_chunk(self, delta: TextDelta) -> StreamingChunk: :returns: The StreamingChunk. """ return StreamingChunk(content=delta.text) + + def _remove_cache_control(self, message: Dict[str, Any]) -> Dict[str, Any]: + """ + Removes the cache_control key from the message. + :param message: The message to remove the cache_control key from. + :returns: The message with the cache_control key removed. + """ + return {k: v for k, v in message.items() if k != "cache_control"} From b798a7e653577b55a69bf5c45785ecf58a237f71 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Fri, 6 Sep 2024 11:36:30 +0200 Subject: [PATCH 08/12] Add unit tests --- .../anthropic/tests/test_chat_generator.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 3ffa24c94..b9ea475f9 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -262,3 +262,92 @@ def test_tools_use(self): fc_response = json.loads(first_reply.content) assert "name" in fc_response, "First reply does not contain name of the tool" assert "input" in fc_response, "First reply does not contain input of the tool" + + def test_prompt_caching_enabled(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API when prompt + caching is enabled + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator( + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + ) + assert component.generation_kwargs.get("extra_headers", {}).get("anthropic-beta") == "prompt-caching-2024-07-31" + + def test_prompt_caching_cache_control_without_extra_headers(self, monkeypatch, mock_chat_completion, caplog): + """ + Test that the cache_control is removed from the messages when prompt caching is not enabled via extra_headers + This is to avoid Anthropic errors when prompt caching is not enabled + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator() + + messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")] + + # Add cache_control to messages + for msg in messages: + msg.meta["cache_control"] = {"type": "ephemeral"} + + # Invoke run with messages + component.run(messages) + + # Check caplog for the warning message that should have been logged + assert any("Prompt caching" in record.message for record in caplog.records) + + # Check that the Anthropic API was called without cache_control in messages so that it does not raise an error + _, kwargs = mock_chat_completion.call_args + for msg in kwargs["messages"]: + assert "cache_control" not in msg + + @pytest.mark.parametrize("enable_caching", [True, False]) + def test_run_with_prompt_caching(self, monkeypatch, mock_chat_completion, enable_caching): + """ + Test that the generation_kwargs extra_headers are correctly passed to the Anthropic API in both cases of + prompt caching being enabled or not + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + + generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_caching else {} + component = AnthropicChatGenerator(generation_kwargs=generation_kwargs) + + messages = [ChatMessage.from_system("System message"), ChatMessage.from_user("User message")] + + component.run(messages) + + # Check that the Anthropic API was called with the correct headers + _, kwargs = mock_chat_completion.call_args + headers = kwargs.get("extra_headers", {}) + if enable_caching: + assert "anthropic-beta" in headers + else: + assert "anthropic-beta" not in headers + + def test_to_dict_with_prompt_caching(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly serialized to a dictionary + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + component = AnthropicChatGenerator( + generation_kwargs={"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + ) + data = component.to_dict() + assert ( + data["init_parameters"]["generation_kwargs"]["extra_headers"]["anthropic-beta"] + == "prompt-caching-2024-07-31" + ) + + def test_from_dict_with_prompt_caching(self, monkeypatch): + """ + Test that the generation_kwargs extra_headers are correctly deserialized from a dictionary + """ + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + data = { + "type": "haystack_integrations.components.generators.anthropic.chat.chat_generator.AnthropicChatGenerator", + "init_parameters": { + "api_key": {"env_vars": ["ANTHROPIC_API_KEY"], "strict": True, "type": "env_var"}, + "model": "claude-3-5-sonnet-20240620", + "generation_kwargs": {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}, + }, + } + component = AnthropicChatGenerator.from_dict(data) + assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31" From 7c5e16b153d8672182fe91e02e71102c3e4b9198 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 17 Sep 2024 09:17:28 +0200 Subject: [PATCH 09/12] Improve UX for prompt caching example --- integrations/anthropic/example/prompt_caching.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/integrations/anthropic/example/prompt_caching.py b/integrations/anthropic/example/prompt_caching.py index c7b073938..d8cc0f0e8 100644 --- a/integrations/anthropic/example/prompt_caching.py +++ b/integrations/anthropic/example/prompt_caching.py @@ -91,5 +91,12 @@ def stream_callback(chunk: StreamingChunk) -> None: # tokens used to create the prompt cache # on first subsequent cache hit we'll see a usage key 'cache_read_input_tokens' having a value of the number of # tokens read from the cache - print(f"Cache usage: {result['llm']['replies'][0].meta.get('usage')}") + token_stats = result["llm"]["replies"][0].meta.get("usage") + if token_stats.get("cache_creation_input_tokens", 0) > 0: + print("Cache created! ", end="") + elif token_stats.get("cache_read_input_tokens", 0) > 0: + print("Cache hit! ", end="") + else: + print("Cache not used, something is wrong with the prompt caching setup. ", end="") + print(f"Cache usage details: {token_stats}") print("\n" + "=" * 100) From befdbeba79d556845154ead057cbfaa17090c946 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 18 Sep 2024 13:35:04 +0200 Subject: [PATCH 10/12] Add unit test for _convert_to_anthropic_format --- .../anthropic/tests/test_chat_generator.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index b9ea475f9..c09402ac0 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -351,3 +351,51 @@ def test_from_dict_with_prompt_caching(self, monkeypatch): } component = AnthropicChatGenerator.from_dict(data) assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31" + + @pytest.mark.unit + def test_convert_messages_to_anthropic_format(self, monkeypatch): + monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") + generator = AnthropicChatGenerator() + + # Test scenario 1: Regular user and assistant messages + messages = [ + ChatMessage.from_user("Hello"), + ChatMessage.from_assistant("Hi there!"), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + # Test scenario 2: System message + messages = [ChatMessage.from_system("You are a helpful assistant.")] + result = generator._convert_to_anthropic_format(messages) + assert result == [{"type": "text", "text": "You are a helpful assistant."}] + + # Test scenario 3: Mixed message types + messages = [ + ChatMessage.from_system("Be concise."), + ChatMessage.from_user("What's AI?"), + ChatMessage.from_assistant("Artificial Intelligence."), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"type": "text", "text": "Be concise."}, + {"role": "user", "content": "What's AI?"}, + {"role": "assistant", "content": "Artificial Intelligence."}, + ] + + # Test scenario 4: metadata + messages = [ + ChatMessage.from_user("What's AI?"), + ChatMessage.from_assistant("Artificial Intelligence.", meta={"confidence": 0.9}), + ] + result = generator._convert_to_anthropic_format(messages) + assert result == [ + {"role": "user", "content": "What's AI?"}, + {"role": "assistant", "content": "Artificial Intelligence.", "confidence": 0.9}, + ] + + # Test scenario 5: Empty message list + assert generator._convert_to_anthropic_format([]) == [] From c39a9f269c2b92f6c610f47013b3ce81da5eda1d Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Sep 2024 09:24:14 +0200 Subject: [PATCH 11/12] More integration tests --- .../anthropic/tests/test_chat_generator.py | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index c09402ac0..6973e4f73 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -352,7 +352,6 @@ def test_from_dict_with_prompt_caching(self, monkeypatch): component = AnthropicChatGenerator.from_dict(data) assert component.generation_kwargs["extra_headers"]["anthropic-beta"] == "prompt-caching-2024-07-31" - @pytest.mark.unit def test_convert_messages_to_anthropic_format(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-api-key") generator = AnthropicChatGenerator() @@ -399,3 +398,28 @@ def test_convert_messages_to_anthropic_format(self, monkeypatch): # Test scenario 5: Empty message list assert generator._convert_to_anthropic_format([]) == [] + + @pytest.mark.integration + @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set") + def test_prompt_caching(self): + generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + + claude_llm = AnthropicChatGenerator( + api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs + ) + + # see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations + system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70) + system_message.meta["cache_control"] = {"type": "ephemeral"} + + messages = [system_message, ChatMessage.from_user("What's in cached content?")] + result = claude_llm.run(messages) + + assert "replies" in result + assert len(result["replies"]) == 1 + token_usage = result["replies"][0].meta.get("usage") + + # either we created cache or we read it (depends on how you execute this integration test) + assert ( + token_usage.get("cache_creation_input_tokens") > 1024 or token_usage.get("cache_read_input_tokens") > 1024 + ) From 9ca99f5d90d018fb8b97e4bec82ce1573af4c601 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 19 Sep 2024 11:26:44 +0200 Subject: [PATCH 12/12] Update test to turn on/off prompt cache --- .../anthropic/tests/test_chat_generator.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/integrations/anthropic/tests/test_chat_generator.py b/integrations/anthropic/tests/test_chat_generator.py index 6973e4f73..155cf7950 100644 --- a/integrations/anthropic/tests/test_chat_generator.py +++ b/integrations/anthropic/tests/test_chat_generator.py @@ -401,8 +401,9 @@ def test_convert_messages_to_anthropic_format(self, monkeypatch): @pytest.mark.integration @pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set") - def test_prompt_caching(self): - generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} + @pytest.mark.parametrize("cache_enabled", [True, False]) + def test_prompt_caching(self, cache_enabled): + generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if cache_enabled else {} claude_llm = AnthropicChatGenerator( api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs @@ -410,7 +411,8 @@ def test_prompt_caching(self): # see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70) - system_message.meta["cache_control"] = {"type": "ephemeral"} + if cache_enabled: + system_message.meta["cache_control"] = {"type": "ephemeral"} messages = [system_message, ChatMessage.from_user("What's in cached content?")] result = claude_llm.run(messages) @@ -419,7 +421,12 @@ def test_prompt_caching(self): assert len(result["replies"]) == 1 token_usage = result["replies"][0].meta.get("usage") - # either we created cache or we read it (depends on how you execute this integration test) - assert ( - token_usage.get("cache_creation_input_tokens") > 1024 or token_usage.get("cache_read_input_tokens") > 1024 - ) + if cache_enabled: + # either we created cache or we read it (depends on how you execute this integration test) + assert ( + token_usage.get("cache_creation_input_tokens") > 1024 + or token_usage.get("cache_read_input_tokens") > 1024 + ) + else: + assert "cache_creation_input_tokens" not in token_usage + assert "cache_read_input_tokens" not in token_usage