Skip to content

Commit

Permalink
Anthropic: caching of system messages and tools (langchain4j#1826)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1915

## Change
Implemented caching of system messages and tools for Anthropic.

## General checklist
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
- [X] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
Claudio-code authored Nov 8, 2024
1 parent a48bf2c commit 33ab87d
Show file tree
Hide file tree
Showing 22 changed files with 715 additions and 279 deletions.
17 changes: 17 additions & 0 deletions docs/docs/integrations/language-models/anthropic.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ AnthropicChatModel model = AnthropicChatModel.builder()
.topK(...)
.maxTokens(...)
.stopSequences(...)
.cacheSystemMessages(...)
.cacheTools(...)
.timeout(...)
.maxRetries(...)
.logRequests(...)
Expand Down Expand Up @@ -85,6 +87,21 @@ Anthropic supports [tools](/tutorials/tools) in both streaming and non-streaming

Anthropic documentation on tools can be found [here](https://docs.anthropic.com/claude/docs/tool-use).

## Caching

`AnthropicChatModel` and `AnthropicStreamingChatModel` support caching of system messages and tools.
Caching is disabled by default.
It can be enabled by setting the `cacheSystemMessages` and `cacheTools` parameters, respectively.

When enabled,`cache_control` blocks will be added to all system messages and tools respectively.

To use caching, please set `beta("prompt-caching-2024-07-31")`.

`AnthropicChatModel` and `AnthropicStreamingChatModel` return `AnthropicTokenUsage` in response which
contains `cacheCreationInputTokens` and `cacheReadInputTokens`.

More info on caching can be found [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching).

## Quarkus

See more details [here](https://docs.quarkiverse.io/quarkus-langchain4j/dev/anthropic.html).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
Expand All @@ -29,6 +30,8 @@
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createErrorContext;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerResponse;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.EPHEMERAL;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.NO_CACHE;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*;
import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages;
import static java.util.Collections.emptyList;
Expand All @@ -47,12 +50,14 @@
* <br>
* <br>
* The content of {@link SystemMessage}s is sent using the "system" parameter.
* If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n).
* <br>
* <br>
* Sanitization is performed on the {@link ChatMessage}s provided to conform to Anthropic API requirements. This process
* includes verifying that the first message is a {@link UserMessage} and removing any consecutive {@link UserMessage}s.
* Any messages removed during sanitization are logged as warnings and not submitted to the API.
* <br>
* <br>
* Supports caching {@link SystemMessage}s and {@link ToolSpecification}s.
*/
@Slf4j
public class AnthropicChatModel implements ChatLanguageModel {
Expand All @@ -64,26 +69,31 @@ public class AnthropicChatModel implements ChatLanguageModel {
private final Integer topK;
private final int maxTokens;
private final List<String> stopSequences;
private final boolean cacheSystemMessages;
private final boolean cacheTools;
private final int maxRetries;
private final List<ChatModelListener> listeners;

/**
* Constructs an instance of an {@code AnthropicChatModel} with the specified parameters.
*
* @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/"
* @param apiKey The API key for authentication with the Anthropic API.
* @param version The version of the Anthropic API. Default: "2023-06-01"
* @param beta The value of the "anthropic-beta" HTTP header. It is used when tools are present in the request. Default: "tools-2024-04-04"
* @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307"
* @param temperature The temperature
* @param topP The top-P
* @param topK The top-K
* @param maxTokens The maximum number of tokens to generate. Default: 1024
* @param stopSequences The custom text sequences that will cause the model to stop generating
* @param timeout The timeout for API requests. Default: 60 seconds
* @param maxRetries The maximum number of retries for API requests. Default: 3
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false
* @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/"
* @param apiKey The API key for authentication with the Anthropic API.
* @param version The value of the "anthropic-version" HTTP header. Default: "2023-06-01"
* @param beta The value of the "anthropic-beta" HTTP header.
* @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307"
* @param temperature The temperature
* @param topP The top-P
* @param topK The top-K
* @param maxTokens The maximum number of tokens to generate. Default: 1024
* @param stopSequences The custom text sequences that will cause the model to stop generating
* @param cacheSystemMessages If true, it will add cache_control block to all system messages. Default: false
* @param cacheTools If true, it will add cache_control block to all tools. Default: false
* @param timeout The timeout for API requests. Default: 60 seconds
* @param maxRetries The maximum number of retries for API requests. Default: 3
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false
* @param listeners A list of {@link ChatModelListener} instances to be notified.
*/
@Builder
private AnthropicChatModel(String baseUrl,
Expand All @@ -96,6 +106,8 @@ private AnthropicChatModel(String baseUrl,
Integer topK,
Integer maxTokens,
List<String> stopSequences,
Boolean cacheSystemMessages,
Boolean cacheTools,
Duration timeout,
Integer maxRetries,
Boolean logRequests,
Expand All @@ -105,7 +117,7 @@ private AnthropicChatModel(String baseUrl,
.baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/"))
.apiKey(apiKey)
.version(getOrDefault(version, "2023-06-01"))
.beta(getOrDefault(beta, "tools-2024-04-04"))
.beta(beta)
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
Expand All @@ -116,6 +128,8 @@ private AnthropicChatModel(String baseUrl,
this.topK = topK;
this.maxTokens = getOrDefault(maxTokens, 1024);
this.stopSequences = stopSequences;
this.cacheSystemMessages = getOrDefault(cacheSystemMessages, false);
this.cacheTools = getOrDefault(cacheTools, false);
this.maxRetries = getOrDefault(maxRetries, 3);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}
Expand Down Expand Up @@ -150,8 +164,9 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {

@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {

List<ChatMessage> sanitizedMessages = sanitizeMessages(messages);
String systemPrompt = toAnthropicSystemPrompt(messages);
List<AnthropicTextContent> systemPrompt = toAnthropicSystemPrompt(messages, cacheSystemMessages ? EPHEMERAL : NO_CACHE);

AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder()
.model(modelName)
Expand All @@ -163,7 +178,7 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi
.temperature(temperature)
.topP(topP)
.topK(topK)
.tools(toAnthropicTools(toolSpecifications))
.tools(toAnthropicTools(toolSpecifications, cacheTools ? EPHEMERAL : NO_CACHE))
.build();

ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType;
import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest;
import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent;
import dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
Expand All @@ -33,6 +35,8 @@
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307;
import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.EPHEMERAL;
import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.NO_CACHE;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicMessages;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicSystemPrompt;
import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.toAnthropicTools;
Expand All @@ -53,15 +57,14 @@
* <br>
* <br>
* The content of {@link SystemMessage}s is sent using the "system" parameter.
* If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n).
* <br>
* <br>
* Sanitization is performed on the {@link ChatMessage}s provided to ensure conformity with Anthropic API requirements.
* This includes ensuring the first message is a {@link UserMessage} and that there are no consecutive {@link UserMessage}s.
* Any messages removed during sanitization are logged as warnings and not submitted to the API.
* <br>
* <br>
* Does not support tools.
* Supports caching {@link SystemMessage}s and {@link ToolSpecification}s.
*/
@Slf4j
public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
Expand All @@ -73,34 +76,43 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel {
private final Integer topK;
private final int maxTokens;
private final List<String> stopSequences;
private final boolean cacheSystemMessages;
private final boolean cacheTools;
private final List<ChatModelListener> listeners;

/**
* Constructs an instance of an {@code AnthropicStreamingChatModel} with the specified parameters.
*
* @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/"
* @param apiKey The API key for authentication with the Anthropic API.
* @param version The version of the Anthropic API. Default: "2023-06-01"
* @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307"
* @param temperature The temperature
* @param topP The top-P
* @param topK The top-K
* @param maxTokens The maximum number of tokens to generate. Default: 1024
* @param stopSequences The custom text sequences that will cause the model to stop generating
* @param timeout The timeout for API requests. Default: 60 seconds
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false
* @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/"
* @param apiKey The API key for authentication with the Anthropic API.
* @param version The value of the "anthropic-version" HTTP header. Default: "2023-06-01"
* @param beta The value of the "anthropic-beta" HTTP header.
* @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307"
* @param temperature The temperature
* @param topP The top-P
* @param topK The top-K
* @param maxTokens The maximum number of tokens to generate. Default: 1024
* @param stopSequences The custom text sequences that will cause the model to stop generating
* @param cacheSystemMessages If true, it will add cache_control block to all system messages. Default: false
* @param cacheTools If true, it will add cache_control block to all tools. Default: false
* @param timeout The timeout for API requests. Default: 60 seconds
* @param logRequests Whether to log the content of API requests using SLF4J. Default: false
* @param logResponses Whether to log the content of API responses using SLF4J. Default: false
* @param listeners A list of {@link ChatModelListener} instances to be notified.
*/
@Builder
private AnthropicStreamingChatModel(String baseUrl,
String apiKey,
String version,
String beta,
String modelName,
Double temperature,
Double topP,
Integer topK,
Integer maxTokens,
List<String> stopSequences,
Boolean cacheSystemMessages,
Boolean cacheTools,
Duration timeout,
Boolean logRequests,
Boolean logResponses,
Expand All @@ -109,6 +121,7 @@ private AnthropicStreamingChatModel(String baseUrl,
.baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/"))
.apiKey(apiKey)
.version(getOrDefault(version, "2023-06-01"))
.beta(beta)
.timeout(getOrDefault(timeout, Duration.ofSeconds(60)))
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
Expand All @@ -119,6 +132,8 @@ private AnthropicStreamingChatModel(String baseUrl,
this.topK = topK;
this.maxTokens = getOrDefault(maxTokens, 1024);
this.stopSequences = stopSequences;
this.cacheSystemMessages = getOrDefault(cacheSystemMessages, false);
this.cacheTools = getOrDefault(cacheTools, false);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
}

Expand Down Expand Up @@ -164,8 +179,9 @@ private void generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
StreamingResponseHandler<AiMessage> handler) {

List<ChatMessage> sanitizedMessages = sanitizeMessages(messages);
String systemPrompt = toAnthropicSystemPrompt(messages);
List<AnthropicTextContent> systemPrompt = toAnthropicSystemPrompt(messages, cacheSystemMessages ? EPHEMERAL : NO_CACHE);
ensureNotNull(handler, "handler");

AnthropicCreateMessageRequest.AnthropicCreateMessageRequestBuilder requestBuilder = AnthropicCreateMessageRequest.builder()
Expand All @@ -179,11 +195,12 @@ private void generate(List<ChatMessage> messages,
.topP(topP)
.topK(topK);

AnthropicCacheType toolsCacheType = cacheTools ? EPHEMERAL : NO_CACHE;
if (toolThatMustBeExecuted != null) {
requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted)));
requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted), toolsCacheType));
requestBuilder.toolChoice(AnthropicToolChoice.from(toolThatMustBeExecuted.name()));
} else if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toAnthropicTools(toolSpecifications));
requestBuilder.tools(toAnthropicTools(toolSpecifications, toolsCacheType));
}

AnthropicCreateMessageRequest request = requestBuilder.build();
Expand All @@ -199,7 +216,7 @@ private void generate(List<ChatMessage> messages,
}
});

StreamingResponseHandler<AiMessage> listenerHandler = new StreamingResponseHandler<AiMessage>() {
StreamingResponseHandler<AiMessage> listenerHandler = new StreamingResponseHandler<>() {
@Override
public void onNext(String token) {
handler.onNext(token);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package dev.langchain4j.model.anthropic;

import dev.langchain4j.model.output.TokenUsage;

public class AnthropicTokenUsage extends TokenUsage {

private final Integer cacheCreationInputTokens;
private final Integer cacheReadInputTokens;

/**
* Creates a new {@link AnthropicTokenUsage} instance with the given input, output token counts
* and cache creation/read input tokens.
*
* @param inputTokenCount The input token count, or null if unknown.
* @param outputTokenCount The output token count, or null if unknown.
* @param cacheCreationInputTokens The total cached token created count, or null if unknown.
* @param cacheReadInputTokens The total cached token read count, or null if unknown.
*/
public AnthropicTokenUsage(Integer inputTokenCount,
Integer outputTokenCount,
Integer cacheCreationInputTokens,
Integer cacheReadInputTokens) {
super(inputTokenCount, outputTokenCount);
this.cacheCreationInputTokens = cacheCreationInputTokens;
this.cacheReadInputTokens = cacheReadInputTokens;
}

/**
* Returns The total cached token created count, or null if unknown.
*
* @return The total cached token created count, or null if unknown.
*/
public Integer cacheCreationInputTokens() {
return cacheCreationInputTokens;
}

/**
* Returns The total cached token read count, or null if unknown.
*
* @return The total cached token read count, or null if unknown.
*/
public Integer cacheReadInputTokens() {
return cacheReadInputTokens;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ Call<AnthropicCreateMessageResponse> createMessage(@Header(X_API_KEY) String api
@Headers({"content-type: application/json"})
Call<ResponseBody> streamMessage(@Header(X_API_KEY) String apiKey,
@Header("anthropic-version") String version,
@Header("anthropic-beta") String beta,
@Body AnthropicCreateMessageRequest request);
}
Loading

0 comments on commit 33ab87d

Please sign in to comment.