diff --git a/libs/checkpoint/src/serde/jsonplus.ts b/libs/checkpoint/src/serde/jsonplus.ts index eaa82c52..c2b54883 100644 --- a/libs/checkpoint/src/serde/jsonplus.ts +++ b/libs/checkpoint/src/serde/jsonplus.ts @@ -3,6 +3,55 @@ import { load } from "@langchain/core/load"; import { SerializerProtocol } from "./base.js"; +function isLangChainSerializable(value: Record) { + return ( + typeof value.lc_serializable === "boolean" && Array.isArray(value.lc_id) + ); +} + +function isLangChainSerializedObject(value: Record) { + return ( + value !== null && + value.lc === 1 && + value.type === "constructor" && + Array.isArray(value.id) + ); +} + +const _serialize = (value: any, seen = new WeakSet()): string => { + const defaultValue = _default("", value); + + if (defaultValue === null) return "null"; + if (typeof defaultValue === "string") return JSON.stringify(defaultValue); + if (typeof defaultValue === "number" || typeof defaultValue === "boolean") + return defaultValue.toString(); + if (typeof defaultValue === "object") { + if (seen.has(defaultValue)) { + throw new TypeError("Circular reference detected"); + } + seen.add(defaultValue); + + if (Array.isArray(defaultValue)) { + const result = `[${defaultValue + .map((item) => _serialize(item, seen)) + .join(",")}]`; + seen.delete(defaultValue); + return result; + } else if (isLangChainSerializable(defaultValue)) { + return JSON.stringify(defaultValue); + } else { + const entries = Object.entries(defaultValue).map( + ([k, v]) => `${JSON.stringify(k)}:${_serialize(v, seen)}` + ); + const result = `{${entries.join(",")}}`; + seen.delete(defaultValue); + return result; + } + } + // Only be reached for functions or symbols + return JSON.stringify(defaultValue); +}; + async function _reviver(value: any): Promise { if (value && typeof value === "object") { if (value.lc === 2 && value.type === "undefined") { @@ -40,7 +89,7 @@ async function _reviver(value: any): Promise { } catch (error) { return value; } - } else if (value.lc === 1) { + } else if (isLangChainSerializedObject(value)) { return load(JSON.stringify(value)); } else if (Array.isArray(value)) { return Promise.all(value.map((item) => _reviver(item))); @@ -93,23 +142,8 @@ function _default(_key: string, obj: any): any { export class JsonPlusSerializer implements SerializerProtocol { protected _dumps(obj: any): Uint8Array { - const jsonString = JSON.stringify(obj, (key, value) => { - if (value && typeof value === "object") { - if (Array.isArray(value)) { - // Handle arrays - return value.map((item) => _default(key, item)); - } else { - // Handle objects - const serialized: any = {}; - for (const [k, v] of Object.entries(value)) { - serialized[k] = _default(k, v); - } - return serialized; - } - } - return _default(key, value); - }); - return new TextEncoder().encode(jsonString); + const encoder = new TextEncoder(); + return encoder.encode(_serialize(obj)); } dumpsTyped(obj: any): [string, Uint8Array] { diff --git a/libs/checkpoint/src/serde/tests/jsonplus.test.ts b/libs/checkpoint/src/serde/tests/jsonplus.test.ts index 70bb4ee5..56df9237 100644 --- a/libs/checkpoint/src/serde/tests/jsonplus.test.ts +++ b/libs/checkpoint/src/serde/tests/jsonplus.test.ts @@ -3,6 +3,44 @@ import { AIMessage, HumanMessage } from "@langchain/core/messages"; import { uuid6 } from "../../id.js"; import { JsonPlusSerializer } from "../jsonplus.js"; +const messageWithToolCall = new AIMessage({ + content: "", + tool_calls: [ + { + name: "current_weather_sf", + args: { + input: "", + }, + type: "tool_call", + id: "call_Co6nrPmiAdWWZQHCNdEZUjTe", + }, + ], + invalid_tool_calls: [], + additional_kwargs: { + function_call: undefined, + tool_calls: [ + { + id: "call_Co6nrPmiAdWWZQHCNdEZUjTe", + type: "function", + function: { + name: "current_weather_sf", + arguments: '{"input":""}', + }, + }, + ], + }, + response_metadata: { + tokenUsage: { + completionTokens: 15, + promptTokens: 84, + totalTokens: 99, + }, + finish_reason: "tool_calls", + system_fingerprint: "fp_a2ff031fb5", + }, + id: "chatcmpl-A0s8Rd97RnFo6xMlYgpJDDfV8J1cl", +}); + const complexValue = { number: 1, id: uuid6(-1), @@ -14,6 +52,7 @@ const complexValue = { ]), regex: /foo*/gi, message: new AIMessage("test message"), + messageWithToolCall, array: [ new Error("nestedfoo"), 5, @@ -40,6 +79,7 @@ const VALUES = [ ["empty string", ""], ["simple string", "foobar"], ["various data types", complexValue], + ["an AIMessage with a tool call", messageWithToolCall], ] satisfies [string, unknown][]; it.each(VALUES)( @@ -51,3 +91,17 @@ it.each(VALUES)( expect(deserialized).toEqual(value); } ); + +it("Should throw an error for circular JSON inputs", async () => { + const a: Record = {}; + const b: Record = {}; + a.b = b; + b.a = a; + + const circular = { + a, + b, + }; + const serde = new JsonPlusSerializer(); + expect(() => serde.dumpsTyped(circular)).toThrow(); +}); diff --git a/libs/langgraph/src/tests/prebuilt.int.test.ts b/libs/langgraph/src/tests/prebuilt.int.test.ts index bb75975e..d2745bda 100644 --- a/libs/langgraph/src/tests/prebuilt.int.test.ts +++ b/libs/langgraph/src/tests/prebuilt.int.test.ts @@ -3,14 +3,10 @@ import { it, beforeAll, describe, expect } from "@jest/globals"; import { Tool } from "@langchain/core/tools"; import { ChatOpenAI } from "@langchain/openai"; -import { BaseMessage, HumanMessage } from "@langchain/core/messages"; -import { RunnableLambda } from "@langchain/core/runnables"; -import { z } from "zod"; -import { - createReactAgent, - createFunctionCallingExecutor, -} from "../prebuilt/index.js"; +import { HumanMessage } from "@langchain/core/messages"; +import { createReactAgent } from "../prebuilt/index.js"; import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js"; +import { MemorySaverAssertImmutable } from "./utils.js"; // Tracing slows down the tests beforeAll(() => { @@ -23,144 +19,40 @@ beforeAll(() => { initializeAsyncLocalStorageSingleton(); }); -describe("createFunctionCallingExecutor", () => { - it("can call a function", async () => { - const weatherResponse = `Not too cold, not too hot 😎`; - const model = new ChatOpenAI(); - class SanFranciscoWeatherTool extends Tool { - name = "current_weather"; - - description = "Get the current weather report for San Francisco, CA"; +describe("createReactAgent", () => { + const weatherResponse = `Not too cold, not too hot 😎`; + class SanFranciscoWeatherTool extends Tool { + name = "current_weather_sf"; - constructor() { - super(); - } + description = "Get the current weather report for San Francisco, CA"; - async _call(_: string): Promise { - return weatherResponse; - } + constructor() { + super(); } - const tools = [new SanFranciscoWeatherTool()]; - - const functionsAgentExecutor = createFunctionCallingExecutor({ - model, - tools, - }); - - const response = await functionsAgentExecutor.invoke({ - messages: [new HumanMessage("What's the weather like in SF?")], - }); - - // It needs at least one human message, one AI and one function message. - expect(response.messages.length > 3).toBe(true); - const firstFunctionMessage = (response.messages as Array).find( - (message) => message._getType() === "function" - ); - expect(firstFunctionMessage).toBeDefined(); - expect(firstFunctionMessage?.content).toBe(weatherResponse); - }); - - it("can stream a function", async () => { - const weatherResponse = `Not too cold, not too hot 😎`; - const model = new ChatOpenAI({ - streaming: true, - }); - class SanFranciscoWeatherTool extends Tool { - name = "current_weather"; - - description = "Get the current weather report for San Francisco, CA"; - constructor() { - super(); - } - - async _call(_: string): Promise { - return weatherResponse; - } + async _call(_: string): Promise { + return weatherResponse; } - const tools = [new SanFranciscoWeatherTool()]; + } + class NewYorkWeatherTool extends Tool { + name = "current_weather_ny"; - const functionsAgentExecutor = createFunctionCallingExecutor({ - model, - tools, - }); + description = "Get the current weather report for New York City, NY"; - const stream = await functionsAgentExecutor.stream( - { - messages: [new HumanMessage("What's the weather like in SF?")], - }, - { streamMode: "values" } - ); - const fullResponse = []; - for await (const item of stream) { - fullResponse.push(item); + constructor() { + super(); } - // human -> agent -> action -> agent - expect(fullResponse.length).toEqual(4); - - const endState = fullResponse[fullResponse.length - 1]; - // 1 human, 2 llm calls, 1 function call. - expect(endState.messages.length).toEqual(4); - const functionCall = endState.messages.find( - (message: BaseMessage) => message._getType() === "function" - ); - expect(functionCall.content).toBe(weatherResponse); - }); - - it("can accept RunnableToolLike tools", async () => { - const weatherResponse = `Not too cold, not too hot 😎`; - const model = new ChatOpenAI(); - - const sfWeatherTool = RunnableLambda.from(async (_) => weatherResponse); - const tools = [ - sfWeatherTool.asTool({ - name: "current_weather", - description: "Get the current weather report for San Francisco, CA", - schema: z.object({ - location: z.string(), - }), - }), - ]; - - const functionsAgentExecutor = createFunctionCallingExecutor({ - model, - tools, - }); - - const response = await functionsAgentExecutor.invoke({ - messages: [new HumanMessage("What's the weather like in SF?")], - }); - - // It needs at least one human message, one AI and one function message. - expect(response.messages.length > 3).toBe(true); - const firstFunctionMessage = (response.messages as Array).find( - (message) => message._getType() === "function" - ); - expect(firstFunctionMessage).toBeDefined(); - expect(firstFunctionMessage?.content).toBe(weatherResponse); - }); -}); - -describe("createReactAgent", () => { - it("can call a tool", async () => { - const weatherResponse = `Not too cold, not too hot 😎`; - const model = new ChatOpenAI(); - class SanFranciscoWeatherTool extends Tool { - name = "current_weather"; - - description = "Get the current weather report for San Francisco, CA"; - - constructor() { - super(); - } - - async _call(_: string): Promise { - return weatherResponse; - } + async _call(_: string): Promise { + return weatherResponse; } - const tools = [new SanFranciscoWeatherTool()]; + } + const tools = [new SanFranciscoWeatherTool(), new NewYorkWeatherTool()]; + it("can call a tool", async () => { + const model = new ChatOpenAI({ + model: "gpt-4o", + }); const reactAgent = createReactAgent({ llm: model, tools }); const response = await reactAgent.invoke({ @@ -174,40 +66,31 @@ describe("createReactAgent", () => { expect(lastMessage.content.toLowerCase()).toContain("not too cold"); }); - it("can stream a tool call", async () => { - const weatherResponse = `Not too cold, not too hot 😎`; + it("can stream a tool call with a checkpointer", async () => { const model = new ChatOpenAI({ - streaming: true, + model: "gpt-4o", }); - class SanFranciscoWeatherTool extends Tool { - name = "current_weather"; - - description = "Get the current weather report for San Francisco, CA"; - constructor() { - super(); - } - - async _call(_: string): Promise { - return weatherResponse; - } - } - const tools = [new SanFranciscoWeatherTool()]; + const checkpointer = new MemorySaverAssertImmutable(); - const reactAgent = createReactAgent({ llm: model, tools }); + const reactAgent = createReactAgent({ + llm: model, + tools, + checkpointSaver: checkpointer, + }); const stream = await reactAgent.stream( { messages: [new HumanMessage("What's the weather like in SF?")], }, - { streamMode: "values" } + { configurable: { thread_id: "foo" }, streamMode: "values" } ); const fullResponse = []; for await (const item of stream) { fullResponse.push(item); } - // human -> agent -> action -> agent + // human -> agent -> tool -> agent expect(fullResponse.length).toEqual(4); const endState = fullResponse[fullResponse.length - 1]; // 1 human, 2 ai, 1 tool. @@ -216,5 +99,24 @@ describe("createReactAgent", () => { const lastMessage = endState.messages[endState.messages.length - 1]; expect(lastMessage._getType()).toBe("ai"); expect(lastMessage.content.toLowerCase()).toContain("not too cold"); + const stream2 = await reactAgent.stream( + { + messages: [new HumanMessage("What about NYC?")], + }, + { configurable: { thread_id: "foo" }, streamMode: "values" } + ); + const fullResponse2 = []; + for await (const item of stream2) { + fullResponse2.push(item); + } + // human -> agent -> tool -> agent + expect(fullResponse2.length).toEqual(4); + const endState2 = fullResponse2[fullResponse2.length - 1]; + // 2 human, 4 ai, 2 tool. + expect(endState2.messages.length).toEqual(8); + + const lastMessage2 = endState.messages[endState.messages.length - 1]; + expect(lastMessage2._getType()).toBe("ai"); + expect(lastMessage2.content.toLowerCase()).toContain("not too cold"); }); }); diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 9c48494e..b624679c 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -2108,7 +2108,7 @@ describe("StateGraph", () => { ], } ); - await new Promise((resolve) => setTimeout(resolve, 100)); + await new Promise((resolve) => setTimeout(resolve, 200)); expect(result).toEqual({ input: "what is the weather in sf?", agentOutcome: {