From 1be142aaf6e27144c6b2b5d819f88f77856512e1 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Thu, 11 Jul 2024 10:19:34 -0700 Subject: [PATCH] core[minor]: Fix caching of complex message types (#6028) * core[minor]: Fix caching of complex message types * chore: lint files * cache -> caches * fix test * cleanup test * skip in mistral/groq * cr --- langchain-core/langchain.config.js | 2 +- .../src/{caches.ts => caches/base.ts} | 16 ++--- .../src/caches/tests/in_memory_cache.test.ts | 40 +++++++++++ langchain-core/src/language_models/base.ts | 4 +- .../src/language_models/chat_models.ts | 2 +- langchain-core/src/language_models/llms.ts | 2 +- .../language_models/tests/chat_models.test.ts | 39 +++++++++++ langchain-core/src/load/import_map.ts | 2 +- .../src/messages/tests/message_utils.test.ts | 68 +++++++++++++++++++ langchain-core/src/messages/utils.ts | 6 +- langchain-core/src/tests/caches.test.ts | 2 +- langchain-core/src/utils/testing/index.ts | 9 ++- .../tests/chat_models.standard.int.test.ts | 10 ++- .../tests/chat_models.standard.int.test.ts | 8 +++ .../src/integration_tests/chat_models.ts | 52 ++++++++++++++ 15 files changed, 244 insertions(+), 18 deletions(-) rename langchain-core/src/{caches.ts => caches/base.ts} (83%) create mode 100644 langchain-core/src/caches/tests/in_memory_cache.test.ts diff --git a/langchain-core/langchain.config.js b/langchain-core/langchain.config.js index 0b4f41da98ba..6c1b4e569be2 100644 --- a/langchain-core/langchain.config.js +++ b/langchain-core/langchain.config.js @@ -13,7 +13,7 @@ export const config = { internals: [/node\:/, /js-tiktoken/, /langsmith/], entrypoints: { agents: "agents", - caches: "caches", + caches: "caches/base", "callbacks/base": "callbacks/base", "callbacks/manager": "callbacks/manager", "callbacks/promises": "callbacks/promises", diff --git a/langchain-core/src/caches.ts b/langchain-core/src/caches/base.ts similarity index 83% rename from langchain-core/src/caches.ts rename to langchain-core/src/caches/base.ts index ac2c88905ff6..388e5f298481 100644 --- a/langchain-core/src/caches.ts +++ b/langchain-core/src/caches/base.ts @@ -1,17 +1,17 @@ -import { insecureHash } from "./utils/hash.js"; -import type { Generation, ChatGeneration } from "./outputs.js"; -import { mapStoredMessageToChatMessage } from "./messages/utils.js"; -import { type StoredGeneration } from "./messages/base.js"; +import { insecureHash } from "../utils/hash.js"; +import type { Generation, ChatGeneration } from "../outputs.js"; +import { mapStoredMessageToChatMessage } from "../messages/utils.js"; +import { type StoredGeneration } from "../messages/base.js"; /** - * This cache key should be consistent across all versions of langchain. - * It is currently NOT consistent across versions of langchain. + * This cache key should be consistent across all versions of LangChain. + * It is currently NOT consistent across versions of LangChain. * * A huge benefit of having a remote cache (like redis) is that you can * access the cache from different processes/machines. The allows you to - * seperate concerns and scale horizontally. + * separate concerns and scale horizontally. * - * TODO: Make cache key consistent across versions of langchain. + * TODO: Make cache key consistent across versions of LangChain. */ export const getCacheKey = (...strings: string[]): string => insecureHash(strings.join("_")); diff --git a/langchain-core/src/caches/tests/in_memory_cache.test.ts b/langchain-core/src/caches/tests/in_memory_cache.test.ts new file mode 100644 index 000000000000..4f2e837e717d --- /dev/null +++ b/langchain-core/src/caches/tests/in_memory_cache.test.ts @@ -0,0 +1,40 @@ +import { MessageContentComplex } from "../../messages/base.js"; +import { InMemoryCache } from "../base.js"; + +test("InMemoryCache works", async () => { + const cache = new InMemoryCache(); + + await cache.update("prompt", "key1", [ + { + text: "text1", + }, + ]); + + const result = await cache.lookup("prompt", "key1"); + expect(result).toBeDefined(); + if (!result) { + return; + } + expect(result[0].text).toBe("text1"); +}); + +test("InMemoryCache works with complex message types", async () => { + const cache = new InMemoryCache(); + + await cache.update("prompt", "key1", [ + { + type: "text", + text: "text1", + }, + ]); + + const result = await cache.lookup("prompt", "key1"); + expect(result).toBeDefined(); + if (!result) { + return; + } + expect(result[0]).toEqual({ + type: "text", + text: "text1", + }); +}); diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 56dd3b9aa80d..8adea3e83b4e 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -1,7 +1,7 @@ import type { Tiktoken, TiktokenModel } from "js-tiktoken/lite"; import { z } from "zod"; -import { type BaseCache, InMemoryCache } from "../caches.js"; +import { type BaseCache, InMemoryCache } from "../caches/base.js"; import { type BasePromptValueInterface, StringPromptValue, @@ -481,7 +481,7 @@ export abstract class BaseLanguageModel< * @param callOptions Call options for the model * @returns A unique cache key. */ - protected _getSerializedCacheKeyParametersForCall( + _getSerializedCacheKeyParametersForCall( // TODO: Fix when we remove the RunnableLambda backwards compatibility shim. { config, ...callOptions }: CallOptions & { config?: RunnableConfig } ): string { diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index f3b900427ba0..0bb1a0f77202 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -32,7 +32,7 @@ import { type Callbacks, } from "../callbacks/manager.js"; import type { RunnableConfig } from "../runnables/config.js"; -import type { BaseCache } from "../caches.js"; +import type { BaseCache } from "../caches/base.js"; import { StructuredToolInterface } from "../tools.js"; import { Runnable, diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index 5f93d6c5ce49..20b0e812deb7 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -23,7 +23,7 @@ import { type BaseLanguageModelParams, } from "./base.js"; import type { RunnableConfig } from "../runnables/config.js"; -import type { BaseCache } from "../caches.js"; +import type { BaseCache } from "../caches/base.js"; import { isStreamEventsHandler } from "../tracers/event_stream.js"; import { isLogStreamHandler } from "../tracers/log_stream.js"; import { concat } from "../utils/stream.js"; diff --git a/langchain-core/src/language_models/tests/chat_models.test.ts b/langchain-core/src/language_models/tests/chat_models.test.ts index 526ce3b49fb9..6ddb97f3d067 100644 --- a/langchain-core/src/language_models/tests/chat_models.test.ts +++ b/langchain-core/src/language_models/tests/chat_models.test.ts @@ -4,6 +4,9 @@ import { test } from "@jest/globals"; import { z } from "zod"; import { zodToJsonSchema } from "zod-to-json-schema"; import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js"; +import { HumanMessage } from "../../messages/human.js"; +import { getBufferString } from "../../messages/utils.js"; +import { AIMessage } from "../../messages/ai.js"; test("Test ChatModel accepts array shorthand for messages", async () => { const model = new FakeChatModel({}); @@ -189,3 +192,39 @@ test("Test ChatModel withStructuredOutput new syntax and includeRaw", async () = // No error console.log(response.parsed); }); + +test("Test ChatModel can cache complex messages", async () => { + const model = new FakeChatModel({ + cache: true, + }); + if (!model.cache) { + throw new Error("Cache not enabled"); + } + + const contentToCache = [ + { + type: "text", + text: "Hello there!", + }, + ]; + const humanMessage = new HumanMessage({ + content: contentToCache, + }); + + const prompt = getBufferString([humanMessage]); + const llmKey = model._getSerializedCacheKeyParametersForCall({}); + + // Invoke model to trigger cache update + await model.invoke([humanMessage]); + + const value = await model.cache.lookup(prompt, llmKey); + expect(value).toBeDefined(); + if (!value) return; + + expect(value[0].text).toEqual(JSON.stringify(contentToCache, null, 2)); + + expect("message" in value[0]).toBeTruthy(); + if (!("message" in value[0])) return; + const cachedMsg = value[0].message as AIMessage; + expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2)); +}); diff --git a/langchain-core/src/load/import_map.ts b/langchain-core/src/load/import_map.ts index 6eb63fe8f386..294f3feca879 100644 --- a/langchain-core/src/load/import_map.ts +++ b/langchain-core/src/load/import_map.ts @@ -1,7 +1,7 @@ // Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually. export * as agents from "../agents.js"; -export * as caches from "../caches.js"; +export * as caches from "../caches/base.js"; export * as callbacks__base from "../callbacks/base.js"; export * as callbacks__manager from "../callbacks/manager.js"; export * as callbacks__promises from "../callbacks/promises.js"; diff --git a/langchain-core/src/messages/tests/message_utils.test.ts b/langchain-core/src/messages/tests/message_utils.test.ts index f17f40e83cb0..2cd179326bde 100644 --- a/langchain-core/src/messages/tests/message_utils.test.ts +++ b/langchain-core/src/messages/tests/message_utils.test.ts @@ -8,6 +8,7 @@ import { AIMessage } from "../ai.js"; import { HumanMessage } from "../human.js"; import { SystemMessage } from "../system.js"; import { BaseMessage } from "../base.js"; +import { getBufferString } from "../utils.js"; describe("filterMessage", () => { const getMessages = () => [ @@ -431,3 +432,70 @@ describe("trimMessages can trim", () => { expect(typeof (trimmedMessages as any).func).toBe("function"); }); }); + +test("getBufferString can handle complex messages", () => { + const messageArr1 = [new HumanMessage("Hello there!")]; + const messageArr2 = [ + new AIMessage({ + content: [ + { + type: "text", + text: "Hello there!", + }, + ], + }), + ]; + const messageArr3 = [ + new HumanMessage({ + content: [ + { + type: "image_url", + image_url: { + url: "https://example.com/image.jpg", + }, + }, + { + type: "image_url", + image_url: "https://example.com/image.jpg", + }, + ], + }), + ]; + + const bufferString1 = getBufferString(messageArr1); + expect(bufferString1).toBe("Human: Hello there!"); + + const bufferString2 = getBufferString(messageArr2); + expect(bufferString2).toBe( + `AI: ${JSON.stringify( + [ + { + type: "text", + text: "Hello there!", + }, + ], + null, + 2 + )}` + ); + + const bufferString3 = getBufferString(messageArr3); + expect(bufferString3).toBe( + `Human: ${JSON.stringify( + [ + { + type: "image_url", + image_url: { + url: "https://example.com/image.jpg", + }, + }, + { + type: "image_url", + image_url: "https://example.com/image.jpg", + }, + ], + null, + 2 + )}` + ); +}); diff --git a/langchain-core/src/messages/utils.ts b/langchain-core/src/messages/utils.ts index 287e79be9c0a..ac3c6b6594d2 100644 --- a/langchain-core/src/messages/utils.ts +++ b/langchain-core/src/messages/utils.ts @@ -82,7 +82,11 @@ export function getBufferString( throw new Error(`Got unsupported message type: ${m._getType()}`); } const nameStr = m.name ? `${m.name}, ` : ""; - string_messages.push(`${role}: ${nameStr}${m.content}`); + const readableContent = + typeof m.content === "string" + ? m.content + : JSON.stringify(m.content, null, 2); + string_messages.push(`${role}: ${nameStr}${readableContent}`); } return string_messages.join("\n"); } diff --git a/langchain-core/src/tests/caches.test.ts b/langchain-core/src/tests/caches.test.ts index dfda89ab456b..6165d7b3913c 100644 --- a/langchain-core/src/tests/caches.test.ts +++ b/langchain-core/src/tests/caches.test.ts @@ -1,6 +1,6 @@ import { test, expect } from "@jest/globals"; -import { InMemoryCache } from "../caches.js"; +import { InMemoryCache } from "../caches/base.js"; test("InMemoryCache", async () => { const cache = new InMemoryCache(); diff --git a/langchain-core/src/utils/testing/index.ts b/langchain-core/src/utils/testing/index.ts index eaa9722a3962..6eaf89177887 100644 --- a/langchain-core/src/utils/testing/index.ts +++ b/langchain-core/src/utils/testing/index.ts @@ -184,7 +184,14 @@ export class FakeChatModel extends BaseChatModel { ], }; } - const text = messages.map((m) => m.content).join("\n"); + const text = messages + .map((m) => { + if (typeof m.content === "string") { + return m.content; + } + return JSON.stringify(m.content, null, 2); + }) + .join("\n"); await runManager?.handleLLMNewToken(text); return { generations: [ diff --git a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts index 4eb55e6cd5c7..9e1a2774771f 100644 --- a/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-groq/src/tests/chat_models.standard.int.test.ts @@ -44,7 +44,15 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests< this.skipTestMessage( "testToolMessageHistoriesListContent", "ChatGroq", - "Not properly implemented." + "Complex message types not properly implemented" + ); + } + + async testCacheComplexMessageTypes() { + this.skipTestMessage( + "testCacheComplexMessageTypes", + "ChatGroq", + "Complex message types not properly implemented" ); } } diff --git a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts index 248b3892f727..a17ff1085925 100644 --- a/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts +++ b/libs/langchain-mistralai/src/tests/chat_models.standard.int.test.ts @@ -23,6 +23,14 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests< functionId: "123456789", }); } + + async testCacheComplexMessageTypes() { + this.skipTestMessage( + "testCacheComplexMessageTypes", + "ChatMistralAI", + "Complex message types not properly implemented" + ); + } } const testClass = new ChatMistralAIStandardIntegrationTests(); diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index ce8ff4c3a96b..77a76038ec83 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -7,6 +7,7 @@ import { HumanMessage, ToolMessage, UsageMetadata, + getBufferString, } from "@langchain/core/messages"; import { z } from "zod"; import { StructuredTool } from "@langchain/core/tools"; @@ -438,6 +439,50 @@ export abstract class ChatModelIntegrationTests< expect(tool_calls[0].name).toBe("math_addition"); } + async testCacheComplexMessageTypes() { + const model = new this.Cls({ + ...this.constructorArgs, + cache: true, + }); + if (!model.cache) { + throw new Error("Cache not enabled"); + } + + const humanMessage = new HumanMessage({ + content: [ + { + type: "text", + text: "Hello there!", + }, + ], + }); + const prompt = getBufferString([humanMessage]); + const llmKey = model._getSerializedCacheKeyParametersForCall({} as any); + + // Invoke the model to trigger a cache update. + await model.invoke([humanMessage]); + const cacheValue = await model.cache.lookup(prompt, llmKey); + + // Ensure only one generation was added to the cache. + expect(cacheValue !== null).toBeTruthy(); + if (!cacheValue) return; + expect(cacheValue).toHaveLength(1); + + expect("message" in cacheValue[0]).toBeTruthy(); + if (!("message" in cacheValue[0])) return; + const cachedMessage = cacheValue[0].message as AIMessage; + + // Invoke the model again with the same prompt, triggering a cache hit. + const result = await model.invoke([humanMessage]); + + expect(result.content).toBe(cacheValue[0].text); + expect(result).toEqual(cachedMessage); + + // Verify a second generation was not added to the cache. + const cacheValue2 = await model.cache.lookup(prompt, llmKey); + expect(cacheValue2).toEqual(cacheValue); + } + /** * Run all unit tests for the chat model. * Each test is wrapped in a try/catch block to prevent the entire test suite from failing. @@ -531,6 +576,13 @@ export abstract class ChatModelIntegrationTests< console.error("testBindToolsWithOpenAIFormattedTools failed", e); } + try { + await this.testCacheComplexMessageTypes(); + } catch (e: any) { + allTestsPassed = false; + console.error("testCacheComplexMessageTypes failed", e); + } + return allTestsPassed; } }