Skip to content

Commit

Permalink
Update prompt_caching.py example to use ChatPromptBuilder 2.5 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Sep 6, 2024
1 parent 8fba8f1 commit 3f9f0ae
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 48 deletions.
90 changes: 43 additions & 47 deletions integrations/anthropic/example/prompt_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@

from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator

enable_prompt_caching = True
# Advanced: We can also cache the HTTP GET requests for the HTML content to avoid repeating requests
# that fetched the same content.
# This type of caching requires requests_cache library to be installed
# Uncomment the following two lines to caching the HTTP requests

msg = ChatMessage.from_system(
"You are a prompt expert who answers questions based on the given documents.\n"
"Here are the documents:\n"
"{% for d in documents %} \n"
" {{d.content}} \n"
"{% endfor %}"
)
# import requests_cache
# requests_cache.install_cache("anthropic_demo")

ENABLE_PROMPT_CACHING = True # Toggle this to enable or disable prompt caching


def measure_and_print_streaming_chunk():
def create_streaming_callback():
first_token_time = None

def stream_callback(chunk: StreamingChunk) -> None:
Expand All @@ -34,66 +34,62 @@ def stream_callback(chunk: StreamingChunk) -> None:
return stream_callback, lambda: first_token_time


fetch_pipeline = Pipeline()
fetch_pipeline.add_component("fetcher", LinkContentFetcher())
fetch_pipeline.add_component("converter", HTMLToDocument())
fetch_pipeline.add_component("prompt_builder", ChatPromptBuilder(template=[msg], variables=["documents"]))
# Until prompt caching graduates from beta, we need to set the anthropic-beta header
generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if ENABLE_PROMPT_CACHING else {}

fetch_pipeline.connect("fetcher", "converter")
fetch_pipeline.connect("converter", "prompt_builder")

result = fetch_pipeline.run(
data={
"fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]},
}
claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs
)

# Now we have our document fetched as a ChatMessage
final_prompt_msg = result["prompt_builder"]["prompt"][0]

# We add a cache control header to the prompt message
if enable_prompt_caching:
final_prompt_msg.meta["cache_control"] = {"type": "ephemeral"}

generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if enable_prompt_caching else {}
claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"),
generation_kwargs=generation_kwargs,
pipe = Pipeline()
pipe.add_component("fetcher", LinkContentFetcher())
pipe.add_component("converter", HTMLToDocument())
pipe.add_component("prompt_builder", ChatPromptBuilder(variables=["documents"]))
pipe.add_component("llm", claude_llm)
pipe.connect("fetcher", "converter")
pipe.connect("converter", "prompt_builder")
pipe.connect("prompt_builder.prompt", "llm.messages")

system_message = ChatMessage.from_system(
"Claude is an AI assistant that answers questions based on the given documents.\n"
"Here are the documents:\n"
"{% for d in documents %} \n"
" {{d.content}} \n"
"{% endfor %}"
)

# Build QA pipeline
qa_pipeline = Pipeline()
qa_pipeline.add_component("llm", claude_llm)
if ENABLE_PROMPT_CACHING:
system_message.meta["cache_control"] = {"type": "ephemeral"}

questions = [
"What's this paper about?",
"What's the main contribution of this paper?",
"How can findings from this paper be applied to real-world problems?",
]

# Answer the questions using prompt caching (i.e. the entire document is cached, we run the question against it)
for question in questions:
print("Question: " + question)
print(f"Question: {question}")
start_time = time.time()
streaming_callback, get_first_token_time = measure_and_print_streaming_chunk()
streaming_callback, get_first_token_time = create_streaming_callback()
# reset LLM streaming callback to initialize new timers in streaming mode
claude_llm.streaming_callback = streaming_callback

result = qa_pipeline.run(
result = pipe.run(
data={
"llm": {
"messages": [
final_prompt_msg,
ChatMessage.from_user("Given these documents, answer the question:" + question),
]
},
"fetcher": {"urls": ["https://ar5iv.labs.arxiv.org/html/2310.04406"]},
"prompt_builder": {"template": [system_message, ChatMessage.from_user(f"Answer the question: {question}")]},
}
)

end_time = time.time()
total_time = end_time - start_time
time_to_first_token = get_first_token_time() - start_time

print(f"\nTotal time: {total_time:.2f} seconds")
print("\n" + "-" * 100)
print(f"Total generation time: {total_time:.2f} seconds")
print(f"Time to first token: {time_to_first_token:.2f} seconds")
# first time we create a prompt cache usage key 'cache_creation_input_tokens' will have a value of the number of
# tokens used to create the prompt cache
# on first subsequent cache hit we'll see a usage key 'cache_read_input_tokens' having a value of the number of
# tokens read from the cache
print(f"Cache usage: {result['llm']['replies'][0].meta.get('usage')}")
print("\n" + "=" * 50)
print("\n" + "=" * 100)
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _convert_to_anthropic_format(self, messages: List[ChatMessage]) -> List[Dict
message_dict = dataclasses.asdict(m)
formatted_message = {k: v for k, v in message_dict.items() if k in {"role", "content"} and v}
if m.is_from(ChatRole.SYSTEM):
# system messages need to be in the format expected by the Anthropic API
# system messages are treated differently and MUST be in the format expected by the Anthropic API
# remove role and content from the message dict, add type and text
formatted_message.pop("role")
formatted_message["type"] = "text"
Expand Down

0 comments on commit 3f9f0ae

Please sign in to comment.