From 5969ea8d874811039f9bf390ec4df7e64b93633c Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Mon, 16 Dec 2024 12:00:18 -0800 Subject: [PATCH] fix(langgraph): Fix chat model streaming for streamMode messages (#745) --- examples/package.json | 2 +- libs/langgraph/package.json | 2 +- libs/langgraph/src/pregel/messages.ts | 4 + libs/langgraph/src/tests/pregel.test.ts | 137 ++++++++++-------------- yarn.lock | 12 +-- 5 files changed, 70 insertions(+), 87 deletions(-) diff --git a/examples/package.json b/examples/package.json index 9e1c7c64..01e53692 100644 --- a/examples/package.json +++ b/examples/package.json @@ -8,7 +8,7 @@ "devDependencies": { "@langchain/anthropic": "^0.3.5", "@langchain/community": "^0.3.9", - "@langchain/core": "^0.3.23", + "@langchain/core": "^0.3.24", "@langchain/groq": "^0.1.2", "@langchain/langgraph": "workspace:*", "@langchain/mistralai": "^0.1.1", diff --git a/libs/langgraph/package.json b/libs/langgraph/package.json index ce81840b..210ec387 100644 --- a/libs/langgraph/package.json +++ b/libs/langgraph/package.json @@ -43,7 +43,7 @@ "@jest/globals": "^29.5.0", "@langchain/anthropic": "^0.3.5", "@langchain/community": "^0.3.9", - "@langchain/core": "^0.3.23", + "@langchain/core": "^0.3.24", "@langchain/langgraph-checkpoint-postgres": "workspace:*", "@langchain/langgraph-checkpoint-sqlite": "workspace:*", "@langchain/openai": "^0.3.11", diff --git a/libs/langgraph/src/pregel/messages.ts b/libs/langgraph/src/pregel/messages.ts index 33ca5720..e9d3ce01 100644 --- a/libs/langgraph/src/pregel/messages.ts +++ b/libs/langgraph/src/pregel/messages.ts @@ -31,6 +31,8 @@ function isChatGenerationChunk(x: unknown): x is ChatGenerationChunk { * A callback handler that implements stream_mode=messages. * Collects messages from (1) chat model stream events and (2) node outputs. */ +// TODO: Make this import and explicitly implement the +// CallbackHandlerPrefersStreaming interface once we drop support for core 0.2 export class StreamMessagesHandler extends BaseCallbackHandler { name = "StreamMessagesHandler"; @@ -42,6 +44,8 @@ export class StreamMessagesHandler extends BaseCallbackHandler { emittedChatModelRunIds: Record = {}; + lc_prefer_streaming = true; + constructor(streamFn: (streamChunk: StreamChunk) => void) { super(); this.streamFn = streamFn; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index b3cd61fb..410d5e34 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3087,6 +3087,44 @@ graph TD; }); }); + it("Supports automatic streaming with streamMode messages", async () => { + const llm = new FakeChatModel({ + responses: [ + new AIMessage({ + id: "ai1", + content: "foobar", + }), + ], + }); + + const StateAnnotation = Annotation.Root({ + question: Annotation, + answer: Annotation, + }); + + const generate = async (state: typeof StateAnnotation.State) => { + const response = await llm.invoke(state.question); + return { answer: response.content }; + }; + + // Compile application and test + const graph = new StateGraph(StateAnnotation) + .addNode("generate", generate) + .addEdge("__start__", "generate") + .addEdge("generate", "__end__") + .compile(); + + const inputs = { question: "How are you?" }; + + const stream = await graph.stream(inputs, { streamMode: "messages" }); + + const aiMessageChunks = []; + for await (const [message] of stream) { + aiMessageChunks.push(message); + } + expect(aiMessageChunks.length).toBeGreaterThan(1); + }); + it("State graph packets", async () => { const AgentState = Annotation.Root({ messages: Annotation({ @@ -8310,11 +8348,11 @@ graph TD; it("should work with streamMode messages and custom from within a subgraph", async () => { const child = new StateGraph(MessagesAnnotation) .addNode("c_one", () => ({ - messages: [new HumanMessage("foo"), new AIMessage("bar")], + messages: [new HumanMessage("f"), new AIMessage("b")], })) .addNode("c_two", async (_, config) => { const model = new FakeChatModel({ - responses: [new AIMessage("123"), new AIMessage("baz")], + responses: [new AIMessage("1"), new AIMessage("2")], }).withConfig({ tags: ["c_two_chat_model"] }); const stream = await model.stream("yo", { ...config, @@ -8336,7 +8374,7 @@ graph TD; const parent = new StateGraph(MessagesAnnotation) .addNode("p_one", async (_, config) => { const toolExecutor = RunnableLambda.from(async () => { - return [new ToolMessage({ content: "qux", tool_call_id: "test" })]; + return [new ToolMessage({ content: "q", tool_call_id: "test" })]; }); config.writer?.({ from: "parent", @@ -8348,7 +8386,7 @@ graph TD; .addNode("p_two", child.compile()) .addNode("p_three", async (_, config) => { const model = new FakeChatModel({ - responses: [new AIMessage("parent")], + responses: [new AIMessage("x")], }); await model.invoke("hey", config); return { messages: [] }; @@ -8369,7 +8407,7 @@ graph TD; [ new _AnyIdToolMessage({ tool_call_id: "test", - content: "qux", + content: "q", }), { langgraph_step: 1, @@ -8386,7 +8424,7 @@ graph TD; ], [ new _AnyIdHumanMessage({ - content: "foo", + content: "f", }), { langgraph_step: 1, @@ -8403,7 +8441,7 @@ graph TD; ], [ new _AnyIdAIMessage({ - content: "bar", + content: "b", }), { langgraph_step: 1, @@ -8455,12 +8493,11 @@ graph TD; ls_provider: "FakeChatModel", ls_stop: undefined, tags: ["c_two_chat_model"], - name: "c_two_chat_model_stream", }, ], [ new _AnyIdAIMessageChunk({ - content: "3", + content: "2", }), { langgraph_step: 2, @@ -8471,35 +8508,14 @@ graph TD; __pregel_resuming: false, __pregel_task_id: expect.any(String), checkpoint_ns: expect.stringMatching(/^p_two:/), - ls_model_type: "chat", - ls_provider: "FakeChatModel", + name: "c_two", + tags: ["graph:step:2"], ls_stop: undefined, - tags: ["c_two_chat_model"], - name: "c_two_chat_model_stream", }, ], [ - new _AnyIdAIMessage({ - content: "baz", - }), - { - langgraph_step: 2, - langgraph_node: "c_two", - langgraph_triggers: ["c_one"], - langgraph_path: [PULL, "c_two"], - langgraph_checkpoint_ns: expect.stringMatching(/^p_two:.*\|c_two:.*/), - __pregel_resuming: false, - __pregel_task_id: expect.any(String), - checkpoint_ns: expect.stringMatching(/^p_two:/), - ls_model_type: "chat", - ls_provider: "FakeChatModel", - ls_stop: undefined, - tags: ["c_two_chat_model"], - }, - ], - [ - new _AnyIdAIMessage({ - content: "parent", + new _AnyIdAIMessageChunk({ + content: "x", }), { langgraph_step: 3, @@ -8530,14 +8546,6 @@ graph TD; content: "1", from: "subgraph", }, - { - content: "2", - from: "subgraph", - }, - { - content: "3", - from: "subgraph", - }, ]); const streamedCombinedEvents: StateSnapshot[] = await gatherIterator( @@ -8554,7 +8562,7 @@ graph TD; [ new _AnyIdToolMessage({ tool_call_id: "test", - content: "qux", + content: "q", }), { langgraph_step: 1, @@ -8574,7 +8582,7 @@ graph TD; "messages", [ new _AnyIdHumanMessage({ - content: "foo", + content: "f", }), { langgraph_step: 1, @@ -8595,7 +8603,7 @@ graph TD; "messages", [ new _AnyIdAIMessage({ - content: "bar", + content: "b", }), { langgraph_step: 1, @@ -8657,41 +8665,14 @@ graph TD; ls_provider: "FakeChatModel", ls_stop: undefined, tags: ["c_two_chat_model"], - name: "c_two_chat_model_stream", }, ], ], - ["custom", { from: "subgraph", content: "2" }], [ "messages", [ new _AnyIdAIMessageChunk({ - content: "3", - }), - { - langgraph_step: 2, - langgraph_node: "c_two", - langgraph_triggers: ["c_one"], - langgraph_path: [PULL, "c_two"], - langgraph_checkpoint_ns: - expect.stringMatching(/^p_two:.*\|c_two:.*/), - __pregel_resuming: false, - __pregel_task_id: expect.any(String), - checkpoint_ns: expect.stringMatching(/^p_two:/), - ls_model_type: "chat", - ls_provider: "FakeChatModel", - ls_stop: undefined, - tags: ["c_two_chat_model"], - name: "c_two_chat_model_stream", - }, - ], - ], - ["custom", { from: "subgraph", content: "3" }], - [ - "messages", - [ - new _AnyIdAIMessage({ - content: "baz", + content: "2", }), { langgraph_step: 2, @@ -8703,18 +8684,16 @@ graph TD; __pregel_resuming: false, __pregel_task_id: expect.any(String), checkpoint_ns: expect.stringMatching(/^p_two:/), - ls_model_type: "chat", - ls_provider: "FakeChatModel", - ls_stop: undefined, - tags: ["c_two_chat_model"], + tags: ["graph:step:2"], + name: "c_two", }, ], ], [ "messages", [ - new _AnyIdAIMessage({ - content: "parent", + new _AnyIdAIMessageChunk({ + content: "x", }), { langgraph_step: 3, diff --git a/yarn.lock b/yarn.lock index e3570734..ee4245de 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1599,9 +1599,9 @@ __metadata: languageName: node linkType: hard -"@langchain/core@npm:^0.3.23": - version: 0.3.23 - resolution: "@langchain/core@npm:0.3.23" +"@langchain/core@npm:^0.3.24": + version: 0.3.24 + resolution: "@langchain/core@npm:0.3.24" dependencies: "@cfworker/json-schema": ^4.0.2 ansi-styles: ^5.0.0 @@ -1615,7 +1615,7 @@ __metadata: uuid: ^10.0.0 zod: ^3.22.4 zod-to-json-schema: ^3.22.3 - checksum: b8cb67c2201fb44feb6136ee0ab097217a760e918d6f33e8cb0152095c960ed9102605a23227b014127f57eadaa0a8aaf62b238557c18b4ef111feb8faf360cf + checksum: c2986e7ed8b7b869e27d633a14cd00d6a4777004ea59f4f70e99fc6b9db4d7e87d687aa8ad84e03684ea5053ea5f4b454c44716092401dc5cf8fd1b8d5cfe9d1 languageName: node linkType: hard @@ -1851,7 +1851,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/anthropic": ^0.3.5 "@langchain/community": ^0.3.9 - "@langchain/core": ^0.3.23 + "@langchain/core": ^0.3.24 "@langchain/langgraph-checkpoint": ~0.0.13 "@langchain/langgraph-checkpoint-postgres": "workspace:*" "@langchain/langgraph-checkpoint-sqlite": "workspace:*" @@ -6649,7 +6649,7 @@ __metadata: dependencies: "@langchain/anthropic": ^0.3.5 "@langchain/community": ^0.3.9 - "@langchain/core": ^0.3.23 + "@langchain/core": ^0.3.24 "@langchain/groq": ^0.1.2 "@langchain/langgraph": "workspace:*" "@langchain/mistralai": ^0.1.1