Skip to content

Commit

Permalink
Update test to turn on/off prompt cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Sep 19, 2024
1 parent c39a9f2 commit 9ca99f5
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions integrations/anthropic/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,16 +401,18 @@ def test_convert_messages_to_anthropic_format(self, monkeypatch):

@pytest.mark.integration
@pytest.mark.skipif(not os.environ.get("ANTHROPIC_API_KEY", None), reason="ANTHROPIC_API_KEY not set")
def test_prompt_caching(self):
generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}}
@pytest.mark.parametrize("cache_enabled", [True, False])
def test_prompt_caching(self, cache_enabled):
generation_kwargs = {"extra_headers": {"anthropic-beta": "prompt-caching-2024-07-31"}} if cache_enabled else {}

claude_llm = AnthropicChatGenerator(
api_key=Secret.from_env_var("ANTHROPIC_API_KEY"), generation_kwargs=generation_kwargs
)

# see https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations
system_message = ChatMessage.from_system("This is the cached, here we make it at least 1024 tokens long." * 70)
system_message.meta["cache_control"] = {"type": "ephemeral"}
if cache_enabled:
system_message.meta["cache_control"] = {"type": "ephemeral"}

messages = [system_message, ChatMessage.from_user("What's in cached content?")]
result = claude_llm.run(messages)
Expand All @@ -419,7 +421,12 @@ def test_prompt_caching(self):
assert len(result["replies"]) == 1
token_usage = result["replies"][0].meta.get("usage")

# either we created cache or we read it (depends on how you execute this integration test)
assert (
token_usage.get("cache_creation_input_tokens") > 1024 or token_usage.get("cache_read_input_tokens") > 1024
)
if cache_enabled:
# either we created cache or we read it (depends on how you execute this integration test)
assert (
token_usage.get("cache_creation_input_tokens") > 1024
or token_usage.get("cache_read_input_tokens") > 1024
)
else:
assert "cache_creation_input_tokens" not in token_usage
assert "cache_read_input_tokens" not in token_usage

0 comments on commit 9ca99f5

Please sign in to comment.