diff --git a/docs/docs/integrations/language-models/anthropic.md b/docs/docs/integrations/language-models/anthropic.md index 1648e957a49..8f9386f61de 100644 --- a/docs/docs/integrations/language-models/anthropic.md +++ b/docs/docs/integrations/language-models/anthropic.md @@ -41,6 +41,8 @@ AnthropicChatModel model = AnthropicChatModel.builder() .topK(...) .maxTokens(...) .stopSequences(...) + .cacheSystemMessages(...) + .cacheTools(...) .timeout(...) .maxRetries(...) .logRequests(...) @@ -85,6 +87,21 @@ Anthropic supports [tools](/tutorials/tools) in both streaming and non-streaming Anthropic documentation on tools can be found [here](https://docs.anthropic.com/claude/docs/tool-use). +## Caching + +`AnthropicChatModel` and `AnthropicStreamingChatModel` support caching of system messages and tools. +Caching is disabled by default. +It can be enabled by setting the `cacheSystemMessages` and `cacheTools` parameters, respectively. + +When enabled,`cache_control` blocks will be added to all system messages and tools respectively. + +To use caching, please set `beta("prompt-caching-2024-07-31")`. + +`AnthropicChatModel` and `AnthropicStreamingChatModel` return `AnthropicTokenUsage` in response which +contains `cacheCreationInputTokens` and `cacheReadInputTokens`. + +More info on caching can be found [here](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching). + ## Quarkus See more details [here](https://docs.quarkiverse.io/quarkus-langchain4j/dev/anthropic.html). diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java index 547b2da07ac..02b16386186 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicChatModel.java @@ -5,6 +5,7 @@ import dev.langchain4j.data.message.*; import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest; import dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse; +import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent; import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.listener.ChatModelErrorContext; @@ -29,6 +30,8 @@ import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createErrorContext; import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerRequest; import static dev.langchain4j.model.anthropic.InternalAnthropicHelper.createModelListenerResponse; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.EPHEMERAL; +import static dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType.NO_CACHE; import static dev.langchain4j.model.anthropic.internal.mapper.AnthropicMapper.*; import static dev.langchain4j.model.anthropic.internal.sanitizer.MessageSanitizer.sanitizeMessages; import static java.util.Collections.emptyList; @@ -47,12 +50,14 @@ *
*
* The content of {@link SystemMessage}s is sent using the "system" parameter. - * If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n). *
*
* Sanitization is performed on the {@link ChatMessage}s provided to conform to Anthropic API requirements. This process * includes verifying that the first message is a {@link UserMessage} and removing any consecutive {@link UserMessage}s. * Any messages removed during sanitization are logged as warnings and not submitted to the API. + *
+ *
+ * Supports caching {@link SystemMessage}s and {@link ToolSpecification}s. */ @Slf4j public class AnthropicChatModel implements ChatLanguageModel { @@ -64,26 +69,31 @@ public class AnthropicChatModel implements ChatLanguageModel { private final Integer topK; private final int maxTokens; private final List stopSequences; + private final boolean cacheSystemMessages; + private final boolean cacheTools; private final int maxRetries; private final List listeners; /** * Constructs an instance of an {@code AnthropicChatModel} with the specified parameters. * - * @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/" - * @param apiKey The API key for authentication with the Anthropic API. - * @param version The version of the Anthropic API. Default: "2023-06-01" - * @param beta The value of the "anthropic-beta" HTTP header. It is used when tools are present in the request. Default: "tools-2024-04-04" - * @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307" - * @param temperature The temperature - * @param topP The top-P - * @param topK The top-K - * @param maxTokens The maximum number of tokens to generate. Default: 1024 - * @param stopSequences The custom text sequences that will cause the model to stop generating - * @param timeout The timeout for API requests. Default: 60 seconds - * @param maxRetries The maximum number of retries for API requests. Default: 3 - * @param logRequests Whether to log the content of API requests using SLF4J. Default: false - * @param logResponses Whether to log the content of API responses using SLF4J. Default: false + * @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/" + * @param apiKey The API key for authentication with the Anthropic API. + * @param version The value of the "anthropic-version" HTTP header. Default: "2023-06-01" + * @param beta The value of the "anthropic-beta" HTTP header. + * @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307" + * @param temperature The temperature + * @param topP The top-P + * @param topK The top-K + * @param maxTokens The maximum number of tokens to generate. Default: 1024 + * @param stopSequences The custom text sequences that will cause the model to stop generating + * @param cacheSystemMessages If true, it will add cache_control block to all system messages. Default: false + * @param cacheTools If true, it will add cache_control block to all tools. Default: false + * @param timeout The timeout for API requests. Default: 60 seconds + * @param maxRetries The maximum number of retries for API requests. Default: 3 + * @param logRequests Whether to log the content of API requests using SLF4J. Default: false + * @param logResponses Whether to log the content of API responses using SLF4J. Default: false + * @param listeners A list of {@link ChatModelListener} instances to be notified. */ @Builder private AnthropicChatModel(String baseUrl, @@ -96,6 +106,8 @@ private AnthropicChatModel(String baseUrl, Integer topK, Integer maxTokens, List stopSequences, + Boolean cacheSystemMessages, + Boolean cacheTools, Duration timeout, Integer maxRetries, Boolean logRequests, @@ -105,7 +117,7 @@ private AnthropicChatModel(String baseUrl, .baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/")) .apiKey(apiKey) .version(getOrDefault(version, "2023-06-01")) - .beta(getOrDefault(beta, "tools-2024-04-04")) + .beta(beta) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses, false)) @@ -116,6 +128,8 @@ private AnthropicChatModel(String baseUrl, this.topK = topK; this.maxTokens = getOrDefault(maxTokens, 1024); this.stopSequences = stopSequences; + this.cacheSystemMessages = getOrDefault(cacheSystemMessages, false); + this.cacheTools = getOrDefault(cacheTools, false); this.maxRetries = getOrDefault(maxRetries, 3); this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); } @@ -150,8 +164,9 @@ public Response generate(List messages) { @Override public Response generate(List messages, List toolSpecifications) { + List sanitizedMessages = sanitizeMessages(messages); - String systemPrompt = toAnthropicSystemPrompt(messages); + List systemPrompt = toAnthropicSystemPrompt(messages, cacheSystemMessages ? EPHEMERAL : NO_CACHE); AnthropicCreateMessageRequest request = AnthropicCreateMessageRequest.builder() .model(modelName) @@ -163,7 +178,7 @@ public Response generate(List messages, List *
* The content of {@link SystemMessage}s is sent using the "system" parameter. - * If there are multiple {@link SystemMessage}s, they are concatenated with a double newline (\n\n). *
*
* Sanitization is performed on the {@link ChatMessage}s provided to ensure conformity with Anthropic API requirements. @@ -61,7 +64,7 @@ * Any messages removed during sanitization are logged as warnings and not submitted to the API. *
*
- * Does not support tools. + * Supports caching {@link SystemMessage}s and {@link ToolSpecification}s. */ @Slf4j public class AnthropicStreamingChatModel implements StreamingChatLanguageModel { @@ -73,34 +76,43 @@ public class AnthropicStreamingChatModel implements StreamingChatLanguageModel { private final Integer topK; private final int maxTokens; private final List stopSequences; + private final boolean cacheSystemMessages; + private final boolean cacheTools; private final List listeners; /** * Constructs an instance of an {@code AnthropicStreamingChatModel} with the specified parameters. * - * @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/" - * @param apiKey The API key for authentication with the Anthropic API. - * @param version The version of the Anthropic API. Default: "2023-06-01" - * @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307" - * @param temperature The temperature - * @param topP The top-P - * @param topK The top-K - * @param maxTokens The maximum number of tokens to generate. Default: 1024 - * @param stopSequences The custom text sequences that will cause the model to stop generating - * @param timeout The timeout for API requests. Default: 60 seconds - * @param logRequests Whether to log the content of API requests using SLF4J. Default: false - * @param logResponses Whether to log the content of API responses using SLF4J. Default: false + * @param baseUrl The base URL of the Anthropic API. Default: "https://api.anthropic.com/v1/" + * @param apiKey The API key for authentication with the Anthropic API. + * @param version The value of the "anthropic-version" HTTP header. Default: "2023-06-01" + * @param beta The value of the "anthropic-beta" HTTP header. + * @param modelName The name of the Anthropic model to use. Default: "claude-3-haiku-20240307" + * @param temperature The temperature + * @param topP The top-P + * @param topK The top-K + * @param maxTokens The maximum number of tokens to generate. Default: 1024 + * @param stopSequences The custom text sequences that will cause the model to stop generating + * @param cacheSystemMessages If true, it will add cache_control block to all system messages. Default: false + * @param cacheTools If true, it will add cache_control block to all tools. Default: false + * @param timeout The timeout for API requests. Default: 60 seconds + * @param logRequests Whether to log the content of API requests using SLF4J. Default: false + * @param logResponses Whether to log the content of API responses using SLF4J. Default: false + * @param listeners A list of {@link ChatModelListener} instances to be notified. */ @Builder private AnthropicStreamingChatModel(String baseUrl, String apiKey, String version, + String beta, String modelName, Double temperature, Double topP, Integer topK, Integer maxTokens, List stopSequences, + Boolean cacheSystemMessages, + Boolean cacheTools, Duration timeout, Boolean logRequests, Boolean logResponses, @@ -109,6 +121,7 @@ private AnthropicStreamingChatModel(String baseUrl, .baseUrl(getOrDefault(baseUrl, "https://api.anthropic.com/v1/")) .apiKey(apiKey) .version(getOrDefault(version, "2023-06-01")) + .beta(beta) .timeout(getOrDefault(timeout, Duration.ofSeconds(60))) .logRequests(getOrDefault(logRequests, false)) .logResponses(getOrDefault(logResponses, false)) @@ -119,6 +132,8 @@ private AnthropicStreamingChatModel(String baseUrl, this.topK = topK; this.maxTokens = getOrDefault(maxTokens, 1024); this.stopSequences = stopSequences; + this.cacheSystemMessages = getOrDefault(cacheSystemMessages, false); + this.cacheTools = getOrDefault(cacheTools, false); this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners); } @@ -164,8 +179,9 @@ private void generate(List messages, List toolSpecifications, ToolSpecification toolThatMustBeExecuted, StreamingResponseHandler handler) { + List sanitizedMessages = sanitizeMessages(messages); - String systemPrompt = toAnthropicSystemPrompt(messages); + List systemPrompt = toAnthropicSystemPrompt(messages, cacheSystemMessages ? EPHEMERAL : NO_CACHE); ensureNotNull(handler, "handler"); AnthropicCreateMessageRequest.AnthropicCreateMessageRequestBuilder requestBuilder = AnthropicCreateMessageRequest.builder() @@ -179,11 +195,12 @@ private void generate(List messages, .topP(topP) .topK(topK); + AnthropicCacheType toolsCacheType = cacheTools ? EPHEMERAL : NO_CACHE; if (toolThatMustBeExecuted != null) { - requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted))); + requestBuilder.tools(toAnthropicTools(singletonList(toolThatMustBeExecuted), toolsCacheType)); requestBuilder.toolChoice(AnthropicToolChoice.from(toolThatMustBeExecuted.name())); } else if (!isNullOrEmpty(toolSpecifications)) { - requestBuilder.tools(toAnthropicTools(toolSpecifications)); + requestBuilder.tools(toAnthropicTools(toolSpecifications, toolsCacheType)); } AnthropicCreateMessageRequest request = requestBuilder.build(); @@ -199,7 +216,7 @@ private void generate(List messages, } }); - StreamingResponseHandler listenerHandler = new StreamingResponseHandler() { + StreamingResponseHandler listenerHandler = new StreamingResponseHandler<>() { @Override public void onNext(String token) { handler.onNext(token); diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTokenUsage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTokenUsage.java new file mode 100644 index 00000000000..e915db216b2 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/AnthropicTokenUsage.java @@ -0,0 +1,45 @@ +package dev.langchain4j.model.anthropic; + +import dev.langchain4j.model.output.TokenUsage; + +public class AnthropicTokenUsage extends TokenUsage { + + private final Integer cacheCreationInputTokens; + private final Integer cacheReadInputTokens; + + /** + * Creates a new {@link AnthropicTokenUsage} instance with the given input, output token counts + * and cache creation/read input tokens. + * + * @param inputTokenCount The input token count, or null if unknown. + * @param outputTokenCount The output token count, or null if unknown. + * @param cacheCreationInputTokens The total cached token created count, or null if unknown. + * @param cacheReadInputTokens The total cached token read count, or null if unknown. + */ + public AnthropicTokenUsage(Integer inputTokenCount, + Integer outputTokenCount, + Integer cacheCreationInputTokens, + Integer cacheReadInputTokens) { + super(inputTokenCount, outputTokenCount); + this.cacheCreationInputTokens = cacheCreationInputTokens; + this.cacheReadInputTokens = cacheReadInputTokens; + } + + /** + * Returns The total cached token created count, or null if unknown. + * + * @return The total cached token created count, or null if unknown. + */ + public Integer cacheCreationInputTokens() { + return cacheCreationInputTokens; + } + + /** + * Returns The total cached token read count, or null if unknown. + * + * @return The total cached token read count, or null if unknown. + */ + public Integer cacheReadInputTokens() { + return cacheReadInputTokens; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java index 407fc50ed84..aed6a45eece 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicApi.java @@ -20,5 +20,6 @@ Call createMessage(@Header(X_API_KEY) String api @Headers({"content-type: application/json"}) Call streamMessage(@Header(X_API_KEY) String apiKey, @Header("anthropic-version") String version, + @Header("anthropic-beta") String beta, @Body AnthropicCreateMessageRequest request); } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheControl.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheControl.java new file mode 100644 index 00000000000..17f064c1277 --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheControl.java @@ -0,0 +1,24 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import lombok.EqualsAndHashCode; +import lombok.Getter; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; + +@EqualsAndHashCode +@JsonInclude(NON_NULL) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Getter +public class AnthropicCacheControl { + + private final String type; + + public AnthropicCacheControl(String type) { + this.type = type; + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheType.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheType.java new file mode 100644 index 00000000000..3711b3d0cbf --- /dev/null +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCacheType.java @@ -0,0 +1,19 @@ +package dev.langchain4j.model.anthropic.internal.api; + +import java.util.function.Supplier; + +public enum AnthropicCacheType { + + NO_CACHE(() -> new AnthropicCacheControl("no_cache")), + EPHEMERAL(() -> new AnthropicCacheControl("ephemeral")); + + private final Supplier value; + + AnthropicCacheType(Supplier value) { + this.value = value; + } + + public AnthropicCacheControl cacheControl() { + return this.value.get(); + } +} diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java index cc76d6ae250..31d5e3c9eb1 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicCreateMessageRequest.java @@ -24,7 +24,7 @@ public class AnthropicCreateMessageRequest { public String model; public List messages; - public String system; + public List system; public int maxTokens; public List stopSequences; public boolean stream; diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java index 6f1071d5e05..a9bd9bda237 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicMessageContent.java @@ -15,8 +15,14 @@ public abstract class AnthropicMessageContent { public String type; + public AnthropicCacheControl cacheControl; public AnthropicMessageContent(String type) { this.type = type; } + + public AnthropicMessageContent(String type, AnthropicCacheControl cacheControl) { + this.type = type; + this.cacheControl = cacheControl; + } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java index d863ad14fd3..f5e833eeb42 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTextContent.java @@ -22,4 +22,9 @@ public AnthropicTextContent(String text) { super("text"); this.text = text; } + + public AnthropicTextContent(String text, AnthropicCacheControl cacheControl) { + super("text", cacheControl); + this.text = text; + } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java index f56ba7bc135..31ae6a7d74d 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicTool.java @@ -21,4 +21,5 @@ public class AnthropicTool { public String name; public String description; public AnthropicToolSchema inputSchema; + public AnthropicCacheControl cacheControl; } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java index a0d18893680..1bbec761c58 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/api/AnthropicUsage.java @@ -14,4 +14,6 @@ public class AnthropicUsage { public Integer inputTokens; public Integer outputTokens; + public Integer cacheCreationInputTokens; + public Integer cacheReadInputTokens; } \ No newline at end of file diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java index 8d3a2433340..412b3460889 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/AnthropicClient.java @@ -61,9 +61,6 @@ public B version(String version) { } public B beta(String beta) { - if (beta == null) { - throw new IllegalArgumentException("beta cannot be null or empty"); - } this.beta = beta; return (B) this; } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java index ad5e10f98ca..19e279e1a90 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/client/DefaultAnthropicClient.java @@ -12,8 +12,7 @@ import dev.langchain4j.model.anthropic.internal.api.AnthropicDelta; import dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage; import dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData; -import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent; -import dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent; +import dev.langchain4j.model.anthropic.AnthropicTokenUsage; import dev.langchain4j.model.anthropic.internal.api.AnthropicUsage; import dev.langchain4j.model.output.FinishReason; import dev.langchain4j.model.output.Response; @@ -42,7 +41,6 @@ import static com.fasterxml.jackson.databind.SerializationFeature.INDENT_OUTPUT; import static dev.langchain4j.internal.Utils.isNotNullOrEmpty; import static dev.langchain4j.internal.Utils.isNullOrBlank; -import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT; import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TOOL_USE; @@ -115,7 +113,7 @@ public DefaultAnthropicClient build() { public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageRequest request) { try { retrofit2.Response retrofitResponse - = anthropicApi.createMessage(apiKey, version, toBeta(request), request).execute(); + = anthropicApi.createMessage(apiKey, version, beta, request).execute(); if (retrofitResponse.isSuccessful()) { return retrofitResponse.body(); } else { @@ -131,17 +129,6 @@ public AnthropicCreateMessageResponse createMessage(AnthropicCreateMessageReques } } - private String toBeta(AnthropicCreateMessageRequest request) { - return hasTools(request) ? beta : null; - } - - private static boolean hasTools(AnthropicCreateMessageRequest request) { - return !isNullOrEmpty(request.tools) || request.messages.stream() - .flatMap(message -> message.content.stream()) - .anyMatch(content -> - (content instanceof AnthropicToolUseContent) || (content instanceof AnthropicToolResultContent)); - } - @Override public void createMessage(AnthropicCreateMessageRequest request, StreamingResponseHandler handler) { @@ -157,6 +144,9 @@ public void createMessage(AnthropicCreateMessageRequest request, StreamingRespon final AtomicInteger inputTokenCount = new AtomicInteger(); final AtomicInteger outputTokenCount = new AtomicInteger(); + final AtomicInteger cacheCreationInputTokens = new AtomicInteger(); + final AtomicInteger cacheReadInputTokens = new AtomicInteger(); + final AtomicReference responseId = new AtomicReference<>(); final AtomicReference responseModel = new AtomicReference<>(); @@ -238,6 +228,12 @@ private void handleUsage(AnthropicUsage usage) { if (usage.outputTokens != null) { this.outputTokenCount.addAndGet(usage.outputTokens); } + if (usage.cacheCreationInputTokens != null) { + this.cacheCreationInputTokens.addAndGet(usage.cacheCreationInputTokens); + } + if (usage.cacheReadInputTokens != null) { + this.cacheReadInputTokens.addAndGet(usage.cacheReadInputTokens); + } } private void handleContentBlockStart(AnthropicStreamingData data) { @@ -309,7 +305,7 @@ private void handleMessageStop() { private Response build() { String text = String.join("\n", contents); - TokenUsage tokenUsage = new TokenUsage(inputTokenCount.get(), outputTokenCount.get()); + TokenUsage tokenUsage = new AnthropicTokenUsage(inputTokenCount.get(), outputTokenCount.get(), cacheCreationInputTokens.get(), cacheReadInputTokens.get()); FinishReason finishReason = toFinishReason(stopReason); Map metadata = createMetadata(); @@ -385,7 +381,7 @@ public void onClosed(EventSource eventSource) { } }; - Call call = anthropicApi.streamMessage(apiKey, version, request); + Call call = anthropicApi.streamMessage(apiKey, version, beta, request); EventSources.createFactory(okHttpClient).newEventSource(call.request(), eventSourceListener); } } diff --git a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java index 435526759a0..72fe3cc181b 100644 --- a/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java +++ b/langchain4j-anthropic/src/main/java/dev/langchain4j/model/anthropic/internal/mapper/AnthropicMapper.java @@ -13,11 +13,13 @@ import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType; import dev.langchain4j.model.anthropic.internal.api.AnthropicContent; import dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent; import dev.langchain4j.model.anthropic.internal.api.AnthropicMessage; import dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent; import dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent; +import dev.langchain4j.model.anthropic.AnthropicTokenUsage; import dev.langchain4j.model.anthropic.internal.api.AnthropicTool; import dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent; import dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema; @@ -33,7 +35,6 @@ import static dev.langchain4j.internal.Exceptions.illegalArgument; import static dev.langchain4j.internal.Utils.isNotNullOrBlank; -import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.internal.Utils.isNullOrEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; import static dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType.TEXT; @@ -94,7 +95,8 @@ private static List toAnthropicMessageContents(UserMess return message.contents().stream() .map(content -> { if (content instanceof TextContent) { - return new AnthropicTextContent(((TextContent) content).text()); + TextContent textContent = (TextContent) content; + return new AnthropicTextContent(textContent.text()); } else if (content instanceof ImageContent) { Image image = ((ImageContent) content).image(); if (image.url() != null) { @@ -138,17 +140,18 @@ private static List toAnthropicMessageContents(AiMessag return contents; } - public static String toAnthropicSystemPrompt(List messages) { - String systemPrompt = messages.stream() - .filter(message -> message instanceof SystemMessage) - .map(message -> ((SystemMessage) message).text()) - .collect(joining("\n\n")); - if (isNullOrBlank(systemPrompt)) { - return null; - } else { - return systemPrompt; - } + public static List toAnthropicSystemPrompt(List messages, AnthropicCacheType cacheType) { + return messages.stream() + .filter(message -> message instanceof SystemMessage) + .map(message -> { + SystemMessage systemMessage = (SystemMessage) message; + if (cacheType != AnthropicCacheType.NO_CACHE) { + return new AnthropicTextContent(systemMessage.text(), cacheType.cacheControl()); + } + return new AnthropicTextContent(systemMessage.text()); + }) + .collect(toList()); } public static AiMessage toAiMessage(List contents) { @@ -186,7 +189,7 @@ public static TokenUsage toTokenUsage(AnthropicUsage anthropicUsage) { if (anthropicUsage == null) { return null; } - return new TokenUsage(anthropicUsage.inputTokens, anthropicUsage.outputTokens); + return new AnthropicTokenUsage(anthropicUsage.inputTokens, anthropicUsage.outputTokens, anthropicUsage.cacheCreationInputTokens, anthropicUsage.cacheReadInputTokens); } public static FinishReason toFinishReason(String anthropicStopReason) { @@ -207,36 +210,42 @@ public static FinishReason toFinishReason(String anthropicStopReason) { } } - public static List toAnthropicTools(List toolSpecifications) { + public static List toAnthropicTools(List toolSpecifications, AnthropicCacheType cacheToolsPrompt) { if (toolSpecifications == null) { return null; } return toolSpecifications.stream() - .map(AnthropicMapper::toAnthropicTool) - .collect(toList()); + .map(toolSpecification -> toAnthropicTool(toolSpecification, cacheToolsPrompt)) + .collect(toList()); } - public static AnthropicTool toAnthropicTool(ToolSpecification toolSpecification) { + public static AnthropicTool toAnthropicTool(ToolSpecification toolSpecification, AnthropicCacheType cacheToolsPrompt) { + AnthropicTool. AnthropicToolBuilder toolBuilder; if (toolSpecification.parameters() != null) { JsonObjectSchema parameters = toolSpecification.parameters(); - return AnthropicTool.builder() - .name(toolSpecification.name()) - .description(toolSpecification.description()) - .inputSchema(AnthropicToolSchema.builder() - .properties(parameters != null ? toMap(parameters.properties()) : emptyMap()) - .required(parameters != null ? parameters.required() : emptyList()) - .build()) - .build(); + toolBuilder = AnthropicTool.builder() + .name(toolSpecification.name()) + .description(toolSpecification.description()) + .inputSchema(AnthropicToolSchema.builder() + .properties(parameters != null ? toMap(parameters.properties()) : emptyMap()) + .required(parameters != null ? parameters.required() : emptyList()) + .build()); + } else { ToolParameters parameters = toolSpecification.toolParameters(); - return AnthropicTool.builder() - .name(toolSpecification.name()) - .description(toolSpecification.description()) - .inputSchema(AnthropicToolSchema.builder() - .properties(parameters != null ? parameters.properties() : emptyMap()) - .required(parameters != null ? parameters.required() : emptyList()) - .build()) - .build(); + toolBuilder = AnthropicTool.builder() + .name(toolSpecification.name()) + .description(toolSpecification.description()) + .inputSchema(AnthropicToolSchema.builder() + .properties(parameters != null ? parameters.properties() : emptyMap()) + .required(parameters != null ? parameters.required() : emptyList()) + .build()); + + } + + if (cacheToolsPrompt != AnthropicCacheType.NO_CACHE) { + return toolBuilder.cacheControl(cacheToolsPrompt.cacheControl()).build(); } + return toolBuilder.build(); } } diff --git a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json index ff3f2cdb180..ca82db57221 100644 --- a/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json +++ b/langchain4j-anthropic/src/main/resources/META-INF/native-image/dev.langchain4j/langchain4j-anthropic/reflect-config.json @@ -1,182 +1,200 @@ [ - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicDelta", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContentSource", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessage", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicRole", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTool", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - }, - { - "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicUsage", - "allDeclaredConstructors": true, - "allPublicConstructors": true, - "allDeclaredMethods": true, - "allPublicMethods": true, - "allDeclaredFields": true, - "allPublicFields": true - } -] \ No newline at end of file + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCacheControl", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCacheType", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicContentBlockType", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageRequest", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicCreateMessageResponse", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicDelta", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicImageContentSource", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessage", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicMessageContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicResponseMessage", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicRole", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicStreamingData", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTextContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicTool", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoice", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolChoiceType", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolResultContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolSchema", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicToolUseContent", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + }, + { + "name": "dev.langchain4j.model.anthropic.internal.api.AnthropicUsage", + "allDeclaredConstructors": true, + "allPublicConstructors": true, + "allDeclaredMethods": true, + "allPublicMethods": true, + "allDeclaredFields": true, + "allPublicFields": true + } +] diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java index ae97e6d77df..b70e63a9446 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicChatModelIT.java @@ -10,6 +10,7 @@ import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.junit.jupiter.api.Test; @@ -21,6 +22,7 @@ import java.time.Duration; import java.util.Base64; import java.util.List; +import java.util.Random; import java.util.stream.Stream; import static dev.langchain4j.agent.tool.JsonSchemaProperty.INTEGER; @@ -29,6 +31,8 @@ import static dev.langchain4j.data.message.ToolExecutionResultMessage.from; import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.readBytes; +import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_HAIKU_20241022; +import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.output.FinishReason.LENGTH; @@ -206,6 +210,106 @@ void should_respect_stop_sequences() { assertThat(response.finishReason()).isEqualTo(OTHER); } + @Test + void should_cache_system_message() { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_HAIKU_20240307) + .cacheSystemMessages(true) + .logRequests(true) + .logResponses(true) + .build(); + + SystemMessage systemMessage = SystemMessage.from("What types of messages are supported in LangChain?".repeat(172) + randomString(2)); + UserMessage userMessage = new UserMessage(TextContent.from("What types of messages are supported in LangChain?")); + + // when + Response response = model.generate(systemMessage, userMessage); + + // then + AnthropicTokenUsage createCacheTokenUsage = (AnthropicTokenUsage) response.tokenUsage(); + assertThat(createCacheTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createCacheTokenUsage.cacheReadInputTokens()).isEqualTo(0); + + // when + Response response2 = model.generate(systemMessage, userMessage); + + // then + AnthropicTokenUsage readCacheTokenUsage = (AnthropicTokenUsage) response2.tokenUsage(); + assertThat(readCacheTokenUsage.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheTokenUsage.cacheReadInputTokens()).isGreaterThan(0); + } + + @Test + void should_cache_multiple_system_messages() { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_HAIKU_20240307) + .cacheSystemMessages(true) + .logRequests(true) + .logResponses(true) + .build(); + + SystemMessage systemMessage = SystemMessage.from("What types of messages are supported in LangChain?".repeat(87) + randomString(2)); + SystemMessage systemMessage2 = SystemMessage.from("What types of messages are supported in LangChain?".repeat(87) + randomString(2)); + UserMessage userMessage = new UserMessage(TextContent.from("What types of messages are supported in LangChain?")); + + // when + Response response = model.generate(systemMessage, systemMessage2, userMessage); + + // then + AnthropicTokenUsage createCacheTokenUsage = (AnthropicTokenUsage) response.tokenUsage(); + assertThat(createCacheTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createCacheTokenUsage.cacheReadInputTokens()).isEqualTo(0); + + // when + Response response2 = model.generate(systemMessage, systemMessage2, userMessage); + + // then + AnthropicTokenUsage readCacheTokenUsage = (AnthropicTokenUsage) response2.tokenUsage(); + assertThat(readCacheTokenUsage.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheTokenUsage.cacheReadInputTokens()).isGreaterThan(0); + } + + @Test + void should_fail_if_more_than_four_system_message_with_cache() { + + // given + ChatLanguageModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_HAIKU_20240307) + .cacheSystemMessages(true) + .logRequests(true) + .logResponses(true) + .build(); + + SystemMessage systemMessageOne = SystemMessage.from("banana"); + SystemMessage systemMessageTwo = SystemMessage.from("banana"); + SystemMessage systemMessageThree = SystemMessage.from("banana"); + SystemMessage systemMessageFour = SystemMessage.from("banana"); + SystemMessage systemMessageFive = SystemMessage.from("banana"); + + // then + assertThatThrownBy(() -> model.generate( + systemMessageOne, + systemMessageTwo, + systemMessageThree, + systemMessageFour, + systemMessageFive + )) + .isExactlyInstanceOf(RuntimeException.class) + .hasMessage("dev.langchain4j.model.anthropic.internal.client.AnthropicHttpException: " + + "{\"type\":\"error\",\"error\":{\"type\":\"invalid_request_error\",\"message\":\"messages: at least one message is required\"}}"); + + } + @Test void test_all_parameters() { @@ -348,6 +452,91 @@ void should_execute_a_tool_then_answer(AnthropicChatModelName modelName) { assertThat(secondResponse.finishReason()).isEqualTo(STOP); } + @Test + void should_cache_system_message_and_tools() { + + // given + AnthropicChatModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_5_HAIKU_20241022) + .cacheSystemMessages(true) + .cacheTools(true) + .logRequests(true) + .logResponses(true) + .build(); + + SystemMessage systemMessage = SystemMessage.from("returns a sum of two numbers".repeat(210) + randomString(2)); + + UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!"); + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("calculator") + .description(randomString(2)) + .parameters(JsonObjectSchema.builder() + .addIntegerProperty("first") + .addIntegerProperty("second") + .build()) + .build(); + + // when + Response response = model.generate(List.of(systemMessage, userMessage), List.of(toolSpecification)); + + // then + AnthropicTokenUsage createCacheTokenUsage = (AnthropicTokenUsage) response.tokenUsage(); + assertThat(createCacheTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createCacheTokenUsage.cacheReadInputTokens()).isEqualTo(0); + + // when + Response response2 = model.generate(List.of(systemMessage, userMessage), List.of(toolSpecification)); + + // then + AnthropicTokenUsage readCacheTokenUsage = (AnthropicTokenUsage) response2.tokenUsage(); + assertThat(readCacheTokenUsage.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheTokenUsage.cacheReadInputTokens()).isGreaterThan(0); + } + + @Test + void should_cache_tools() { + + // given + AnthropicChatModel model = AnthropicChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_5_HAIKU_20241022) + .cacheTools(true) + .logRequests(true) + .logResponses(true) + .build(); + + UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!"); + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("calculator") + .description("returns a sum of two numbers".repeat(214) + randomString(2)) + .parameters(JsonObjectSchema.builder() + .addIntegerProperty("first") + .addIntegerProperty("second") + .build()) + .build(); + + // when + Response response = model.generate(singletonList(userMessage), List.of(toolSpecification)); + + // then + AnthropicTokenUsage createCacheTokenUsage = (AnthropicTokenUsage) response.tokenUsage(); + assertThat(createCacheTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(createCacheTokenUsage.cacheReadInputTokens()).isEqualTo(0); + + // when + Response response2 = model.generate(singletonList(userMessage), List.of(toolSpecification)); + + // then + AnthropicTokenUsage readCacheTokenUsage = (AnthropicTokenUsage) response2.tokenUsage(); + assertThat(readCacheTokenUsage.cacheCreationInputTokens()).isEqualTo(0); + assertThat(readCacheTokenUsage.cacheReadInputTokens()).isGreaterThan(0); + } + @Test void should_execute_multiple_tools_in_parallel_then_answer() { @@ -469,8 +658,19 @@ void should_execute_a_tool_with_nested_properties_then_answer() { } static Stream models_supporting_tools() { + // claude 2 does not support tools return stream(AnthropicChatModelName.values()) - .filter(modelName -> modelName.toString().startsWith("claude-3")) + .filter(modelName -> !modelName.toString().startsWith("claude-2")) .map(Arguments::of); } -} \ No newline at end of file + + static String randomString(int length) { + String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + Random random = new Random(); + StringBuilder result = new StringBuilder(length); + for (int i = 0; i < length; i++) { + result.append(characters.charAt(random.nextInt(characters.length()))); + } + return result.toString(); + } +} diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java index 4ee9f134390..f6fc502e367 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicMapperTest.java @@ -28,7 +28,6 @@ class AnthropicMapperTest { @ParameterizedTest @MethodSource void test_toAnthropicMessages(List messages, List expectedAnthropicMessages) { - // when List anthropicMessages = toAnthropicMessages(messages); @@ -212,7 +211,7 @@ static Stream test_toAnthropicMessages() { void test_toAnthropicTool(ToolSpecification toolSpecification, AnthropicTool expectedAnthropicTool) { // when - AnthropicTool anthropicTool = toAnthropicTool(toolSpecification); + AnthropicTool anthropicTool = toAnthropicTool(toolSpecification, AnthropicCacheType.NO_CACHE); // then assertThat(anthropicTool).isEqualTo(expectedAnthropicTool); @@ -258,4 +257,4 @@ private static Map mapOf(Map.Entry... entries) { private static Map.Entry entry(K key, V value) { return new AbstractMap.SimpleEntry<>(key, value); } -} \ No newline at end of file +} diff --git a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java index 816573faa1c..558789d7f4b 100644 --- a/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java +++ b/langchain4j-anthropic/src/test/java/dev/langchain4j/model/anthropic/AnthropicStreamingChatModelIT.java @@ -6,10 +6,12 @@ import dev.langchain4j.data.message.ChatMessage; import dev.langchain4j.data.message.ImageContent; import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.TextContent; import dev.langchain4j.data.message.ToolExecutionResultMessage; import dev.langchain4j.data.message.UserMessage; import dev.langchain4j.model.chat.StreamingChatLanguageModel; import dev.langchain4j.model.chat.TestStreamingResponseHandler; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.TokenUsage; import org.jetbrains.annotations.NotNull; @@ -29,7 +31,9 @@ import static dev.langchain4j.data.message.UserMessage.userMessage; import static dev.langchain4j.internal.Utils.readBytes; import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.CAT_IMAGE_URL; +import static dev.langchain4j.model.anthropic.AnthropicChatModelIT.randomString; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_5_SONNET_20240620; +import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_HAIKU_20240307; import static dev.langchain4j.model.anthropic.AnthropicChatModelName.CLAUDE_3_SONNET_20240229; import static dev.langchain4j.model.output.FinishReason.STOP; import static dev.langchain4j.model.output.FinishReason.TOOL_EXECUTION; @@ -165,6 +169,66 @@ void test_all_parameters() { assertThat(response.content().text()).isNotBlank(); } + @Test + void should_cache_system_message() { + + // given + AnthropicStreamingChatModel model = AnthropicStreamingChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_HAIKU_20240307) + .cacheSystemMessages(true) + .logRequests(true) + .logResponses(true) + .build(); + + SystemMessage systemMessage = SystemMessage.from("What types of messages are supported in LangChain?".repeat(172) + randomString(2)); + UserMessage userMessage = new UserMessage(TextContent.from("What types of messages are supported in LangChain?")); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(asList(userMessage, systemMessage), handler); + AnthropicTokenUsage responseAnthropicTokenUsage = (AnthropicTokenUsage) handler.get().tokenUsage(); + + // then + assertThat(responseAnthropicTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(responseAnthropicTokenUsage.cacheReadInputTokens()).isEqualTo(0); + } + + @Test + void should_cache_tools() { + + // given + AnthropicStreamingChatModel model = AnthropicStreamingChatModel.builder() + .apiKey(System.getenv("ANTHROPIC_API_KEY")) + .beta("prompt-caching-2024-07-31") + .modelName(CLAUDE_3_HAIKU_20240307) + .cacheTools(true) + .logRequests(true) + .logResponses(true) + .build(); + + UserMessage userMessage = userMessage("How much is 2+2 and 3+3? Call tools in parallel!"); + + ToolSpecification toolSpecification = ToolSpecification.builder() + .name("calculator") + .description("returns a sum of two numbers".repeat(214) + randomString(2)) + .parameters(JsonObjectSchema.builder() + .addIntegerProperty("first") + .addIntegerProperty("second") + .build()) + .build(); + + // when + TestStreamingResponseHandler handler = new TestStreamingResponseHandler<>(); + model.generate(singletonList(userMessage), List.of(toolSpecification), handler); + AnthropicTokenUsage responseAnthropicTokenUsage = (AnthropicTokenUsage) handler.get().tokenUsage(); + + // then + assertThat(responseAnthropicTokenUsage.cacheCreationInputTokens()).isGreaterThan(0); + assertThat(responseAnthropicTokenUsage.cacheReadInputTokens()).isEqualTo(0); + } + @Test void should_fail_to_create_without_api_key() { @@ -375,4 +439,4 @@ private static void assertTokenUsage(@NotNull TokenUsage tokenUsage) { assertThat(tokenUsage.totalTokenCount()) .isEqualTo(tokenUsage.inputTokenCount() + tokenUsage.outputTokenCount()); } -} \ No newline at end of file +} diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/SystemMessage.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/SystemMessage.java index 61f59b08d63..1627be181ed 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/SystemMessage.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/SystemMessage.java @@ -72,4 +72,4 @@ public static SystemMessage from(String text) { public static SystemMessage systemMessage(String text) { return from(text); } -} +} \ No newline at end of file diff --git a/langchain4j-core/src/main/java/dev/langchain4j/data/message/TextContent.java b/langchain4j-core/src/main/java/dev/langchain4j/data/message/TextContent.java index c617ab7f305..8891d5f8170 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/data/message/TextContent.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/data/message/TextContent.java @@ -62,4 +62,4 @@ public String toString() { public static TextContent from(String text) { return new TextContent(text); } -} +} \ No newline at end of file diff --git a/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java b/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java index c1c81a19da7..8a9bed6fa1a 100644 --- a/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java +++ b/langchain4j-core/src/main/java/dev/langchain4j/model/output/TokenUsage.java @@ -13,6 +13,7 @@ public class TokenUsage { private final Integer outputTokenCount; private final Integer totalTokenCount; + /** * Creates a new {@link TokenUsage} instance with all fields set to null. */