Skip to content

Commit

Permalink
core[minor]: Fix caching of complex message types (#6028)
Browse files Browse the repository at this point in the history
* core[minor]: Fix caching of complex message types

* chore: lint files

* cache -> caches

* fix test

* cleanup test

* skip in mistral/groq

* cr
  • Loading branch information
bracesproul authored Jul 11, 2024
1 parent 281198d commit 1be142a
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 18 deletions.
2 changes: 1 addition & 1 deletion langchain-core/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const config = {
internals: [/node\:/, /js-tiktoken/, /langsmith/],
entrypoints: {
agents: "agents",
caches: "caches",
caches: "caches/base",
"callbacks/base": "callbacks/base",
"callbacks/manager": "callbacks/manager",
"callbacks/promises": "callbacks/promises",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { insecureHash } from "./utils/hash.js";
import type { Generation, ChatGeneration } from "./outputs.js";
import { mapStoredMessageToChatMessage } from "./messages/utils.js";
import { type StoredGeneration } from "./messages/base.js";
import { insecureHash } from "../utils/hash.js";
import type { Generation, ChatGeneration } from "../outputs.js";
import { mapStoredMessageToChatMessage } from "../messages/utils.js";
import { type StoredGeneration } from "../messages/base.js";

/**
* This cache key should be consistent across all versions of langchain.
* It is currently NOT consistent across versions of langchain.
* This cache key should be consistent across all versions of LangChain.
* It is currently NOT consistent across versions of LangChain.
*
* A huge benefit of having a remote cache (like redis) is that you can
* access the cache from different processes/machines. The allows you to
* seperate concerns and scale horizontally.
* separate concerns and scale horizontally.
*
* TODO: Make cache key consistent across versions of langchain.
* TODO: Make cache key consistent across versions of LangChain.
*/
export const getCacheKey = (...strings: string[]): string =>
insecureHash(strings.join("_"));
Expand Down
40 changes: 40 additions & 0 deletions langchain-core/src/caches/tests/in_memory_cache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import { MessageContentComplex } from "../../messages/base.js";
import { InMemoryCache } from "../base.js";

test("InMemoryCache works", async () => {
const cache = new InMemoryCache();

await cache.update("prompt", "key1", [
{
text: "text1",
},
]);

const result = await cache.lookup("prompt", "key1");
expect(result).toBeDefined();
if (!result) {
return;
}
expect(result[0].text).toBe("text1");
});

test("InMemoryCache works with complex message types", async () => {
const cache = new InMemoryCache<MessageContentComplex[]>();

await cache.update("prompt", "key1", [
{
type: "text",
text: "text1",
},
]);

const result = await cache.lookup("prompt", "key1");
expect(result).toBeDefined();
if (!result) {
return;
}
expect(result[0]).toEqual({
type: "text",
text: "text1",
});
});
4 changes: 2 additions & 2 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Tiktoken, TiktokenModel } from "js-tiktoken/lite";

import { z } from "zod";
import { type BaseCache, InMemoryCache } from "../caches.js";
import { type BaseCache, InMemoryCache } from "../caches/base.js";
import {
type BasePromptValueInterface,
StringPromptValue,
Expand Down Expand Up @@ -481,7 +481,7 @@ export abstract class BaseLanguageModel<
* @param callOptions Call options for the model
* @returns A unique cache key.
*/
protected _getSerializedCacheKeyParametersForCall(
_getSerializedCacheKeyParametersForCall(
// TODO: Fix when we remove the RunnableLambda backwards compatibility shim.
{ config, ...callOptions }: CallOptions & { config?: RunnableConfig }
): string {
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {
type Callbacks,
} from "../callbacks/manager.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import type { BaseCache } from "../caches/base.js";
import { StructuredToolInterface } from "../tools.js";
import {
Runnable,
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
type BaseLanguageModelParams,
} from "./base.js";
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import type { BaseCache } from "../caches/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
Expand Down
39 changes: 39 additions & 0 deletions langchain-core/src/language_models/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { test } from "@jest/globals";
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import { FakeChatModel, FakeListChatModel } from "../../utils/testing/index.js";
import { HumanMessage } from "../../messages/human.js";
import { getBufferString } from "../../messages/utils.js";
import { AIMessage } from "../../messages/ai.js";

test("Test ChatModel accepts array shorthand for messages", async () => {
const model = new FakeChatModel({});
Expand Down Expand Up @@ -189,3 +192,39 @@ test("Test ChatModel withStructuredOutput new syntax and includeRaw", async () =
// No error
console.log(response.parsed);
});

test("Test ChatModel can cache complex messages", async () => {
const model = new FakeChatModel({
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const contentToCache = [
{
type: "text",
text: "Hello there!",
},
];
const humanMessage = new HumanMessage({
content: contentToCache,
});

const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({});

// Invoke model to trigger cache update
await model.invoke([humanMessage]);

const value = await model.cache.lookup(prompt, llmKey);
expect(value).toBeDefined();
if (!value) return;

expect(value[0].text).toEqual(JSON.stringify(contentToCache, null, 2));

expect("message" in value[0]).toBeTruthy();
if (!("message" in value[0])) return;
const cachedMsg = value[0].message as AIMessage;
expect(cachedMsg.content).toEqual(JSON.stringify(contentToCache, null, 2));
});
2 changes: 1 addition & 1 deletion langchain-core/src/load/import_map.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Auto-generated by `scripts/create-entrypoints.js`. Do not edit manually.

export * as agents from "../agents.js";
export * as caches from "../caches.js";
export * as caches from "../caches/base.js";
export * as callbacks__base from "../callbacks/base.js";
export * as callbacks__manager from "../callbacks/manager.js";
export * as callbacks__promises from "../callbacks/promises.js";
Expand Down
68 changes: 68 additions & 0 deletions langchain-core/src/messages/tests/message_utils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AIMessage } from "../ai.js";
import { HumanMessage } from "../human.js";
import { SystemMessage } from "../system.js";
import { BaseMessage } from "../base.js";
import { getBufferString } from "../utils.js";

describe("filterMessage", () => {
const getMessages = () => [
Expand Down Expand Up @@ -431,3 +432,70 @@ describe("trimMessages can trim", () => {
expect(typeof (trimmedMessages as any).func).toBe("function");
});
});

test("getBufferString can handle complex messages", () => {
const messageArr1 = [new HumanMessage("Hello there!")];
const messageArr2 = [
new AIMessage({
content: [
{
type: "text",
text: "Hello there!",
},
],
}),
];
const messageArr3 = [
new HumanMessage({
content: [
{
type: "image_url",
image_url: {
url: "https://example.com/image.jpg",
},
},
{
type: "image_url",
image_url: "https://example.com/image.jpg",
},
],
}),
];

const bufferString1 = getBufferString(messageArr1);
expect(bufferString1).toBe("Human: Hello there!");

const bufferString2 = getBufferString(messageArr2);
expect(bufferString2).toBe(
`AI: ${JSON.stringify(
[
{
type: "text",
text: "Hello there!",
},
],
null,
2
)}`
);

const bufferString3 = getBufferString(messageArr3);
expect(bufferString3).toBe(
`Human: ${JSON.stringify(
[
{
type: "image_url",
image_url: {
url: "https://example.com/image.jpg",
},
},
{
type: "image_url",
image_url: "https://example.com/image.jpg",
},
],
null,
2
)}`
);
});
6 changes: 5 additions & 1 deletion langchain-core/src/messages/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@ export function getBufferString(
throw new Error(`Got unsupported message type: ${m._getType()}`);
}
const nameStr = m.name ? `${m.name}, ` : "";
string_messages.push(`${role}: ${nameStr}${m.content}`);
const readableContent =
typeof m.content === "string"
? m.content
: JSON.stringify(m.content, null, 2);
string_messages.push(`${role}: ${nameStr}${readableContent}`);
}
return string_messages.join("\n");
}
Expand Down
2 changes: 1 addition & 1 deletion langchain-core/src/tests/caches.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { test, expect } from "@jest/globals";

import { InMemoryCache } from "../caches.js";
import { InMemoryCache } from "../caches/base.js";

test("InMemoryCache", async () => {
const cache = new InMemoryCache();
Expand Down
9 changes: 8 additions & 1 deletion langchain-core/src/utils/testing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,14 @@ export class FakeChatModel extends BaseChatModel {
],
};
}
const text = messages.map((m) => m.content).join("\n");
const text = messages
.map((m) => {
if (typeof m.content === "string") {
return m.content;
}
return JSON.stringify(m.content, null, 2);
})
.join("\n");
await runManager?.handleLLMNewToken(text);
return {
generations: [
Expand Down
10 changes: 9 additions & 1 deletion libs/langchain-groq/src/tests/chat_models.standard.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,15 @@ class ChatGroqStandardIntegrationTests extends ChatModelIntegrationTests<
this.skipTestMessage(
"testToolMessageHistoriesListContent",
"ChatGroq",
"Not properly implemented."
"Complex message types not properly implemented"
);
}

async testCacheComplexMessageTypes() {
this.skipTestMessage(
"testCacheComplexMessageTypes",
"ChatGroq",
"Complex message types not properly implemented"
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class ChatMistralAIStandardIntegrationTests extends ChatModelIntegrationTests<
functionId: "123456789",
});
}

async testCacheComplexMessageTypes() {
this.skipTestMessage(
"testCacheComplexMessageTypes",
"ChatMistralAI",
"Complex message types not properly implemented"
);
}
}

const testClass = new ChatMistralAIStandardIntegrationTests();
Expand Down
52 changes: 52 additions & 0 deletions libs/langchain-standard-tests/src/integration_tests/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
HumanMessage,
ToolMessage,
UsageMetadata,
getBufferString,
} from "@langchain/core/messages";
import { z } from "zod";
import { StructuredTool } from "@langchain/core/tools";
Expand Down Expand Up @@ -438,6 +439,50 @@ export abstract class ChatModelIntegrationTests<
expect(tool_calls[0].name).toBe("math_addition");
}

async testCacheComplexMessageTypes() {
const model = new this.Cls({
...this.constructorArgs,
cache: true,
});
if (!model.cache) {
throw new Error("Cache not enabled");
}

const humanMessage = new HumanMessage({
content: [
{
type: "text",
text: "Hello there!",
},
],
});
const prompt = getBufferString([humanMessage]);
const llmKey = model._getSerializedCacheKeyParametersForCall({} as any);

// Invoke the model to trigger a cache update.
await model.invoke([humanMessage]);
const cacheValue = await model.cache.lookup(prompt, llmKey);

// Ensure only one generation was added to the cache.
expect(cacheValue !== null).toBeTruthy();
if (!cacheValue) return;
expect(cacheValue).toHaveLength(1);

expect("message" in cacheValue[0]).toBeTruthy();
if (!("message" in cacheValue[0])) return;
const cachedMessage = cacheValue[0].message as AIMessage;

// Invoke the model again with the same prompt, triggering a cache hit.
const result = await model.invoke([humanMessage]);

expect(result.content).toBe(cacheValue[0].text);
expect(result).toEqual(cachedMessage);

// Verify a second generation was not added to the cache.
const cacheValue2 = await model.cache.lookup(prompt, llmKey);
expect(cacheValue2).toEqual(cacheValue);
}

/**
* Run all unit tests for the chat model.
* Each test is wrapped in a try/catch block to prevent the entire test suite from failing.
Expand Down Expand Up @@ -531,6 +576,13 @@ export abstract class ChatModelIntegrationTests<
console.error("testBindToolsWithOpenAIFormattedTools failed", e);
}

try {
await this.testCacheComplexMessageTypes();
} catch (e: any) {
allTestsPassed = false;
console.error("testCacheComplexMessageTypes failed", e);
}

return allTestsPassed;
}
}

0 comments on commit 1be142a

Please sign in to comment.