From 84d889d9e2b162b6e2dcb084edfbe1d5380e849c Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Mon, 16 Dec 2024 11:44:25 -0800 Subject: [PATCH] Fix chat model streaming for streamMode messages --- examples/package.json | 2 +- libs/langgraph/package.json | 2 +- libs/langgraph/src/pregel/messages.ts | 4 +++ libs/langgraph/src/tests/pregel.test.ts | 38 +++++++++++++++++++++++++ yarn.lock | 12 ++++---- 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/examples/package.json b/examples/package.json index 9e1c7c64..01e53692 100644 --- a/examples/package.json +++ b/examples/package.json @@ -8,7 +8,7 @@ "devDependencies": { "@langchain/anthropic": "^0.3.5", "@langchain/community": "^0.3.9", - "@langchain/core": "^0.3.23", + "@langchain/core": "^0.3.24", "@langchain/groq": "^0.1.2", "@langchain/langgraph": "workspace:*", "@langchain/mistralai": "^0.1.1", diff --git a/libs/langgraph/package.json b/libs/langgraph/package.json index ce81840b..210ec387 100644 --- a/libs/langgraph/package.json +++ b/libs/langgraph/package.json @@ -43,7 +43,7 @@ "@jest/globals": "^29.5.0", "@langchain/anthropic": "^0.3.5", "@langchain/community": "^0.3.9", - "@langchain/core": "^0.3.23", + "@langchain/core": "^0.3.24", "@langchain/langgraph-checkpoint-postgres": "workspace:*", "@langchain/langgraph-checkpoint-sqlite": "workspace:*", "@langchain/openai": "^0.3.11", diff --git a/libs/langgraph/src/pregel/messages.ts b/libs/langgraph/src/pregel/messages.ts index 33ca5720..e9d3ce01 100644 --- a/libs/langgraph/src/pregel/messages.ts +++ b/libs/langgraph/src/pregel/messages.ts @@ -31,6 +31,8 @@ function isChatGenerationChunk(x: unknown): x is ChatGenerationChunk { * A callback handler that implements stream_mode=messages. * Collects messages from (1) chat model stream events and (2) node outputs. */ +// TODO: Make this import and explicitly implement the +// CallbackHandlerPrefersStreaming interface once we drop support for core 0.2 export class StreamMessagesHandler extends BaseCallbackHandler { name = "StreamMessagesHandler"; @@ -42,6 +44,8 @@ export class StreamMessagesHandler extends BaseCallbackHandler { emittedChatModelRunIds: Record = {}; + lc_prefer_streaming = true; + constructor(streamFn: (streamChunk: StreamChunk) => void) { super(); this.streamFn = streamFn; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index b3cd61fb..59b08c71 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -3087,6 +3087,44 @@ graph TD; }); }); + it("Supports automatic streaming with streamMode messages", async () => { + const llm = new FakeChatModel({ + responses: [ + new AIMessage({ + id: "ai1", + content: "foobar", + }), + ], + }); + + const StateAnnotation = Annotation.Root({ + question: Annotation, + answer: Annotation, + }); + + const generate = async (state: typeof StateAnnotation.State) => { + const response = await llm.invoke(state.question); + return { answer: response.content }; + }; + + // Compile application and test + const graph = new StateGraph(StateAnnotation) + .addNode("generate", generate) + .addEdge("__start__", "generate") + .addEdge("generate", "__end__") + .compile(); + + const inputs = { question: "How are you?" }; + + const stream = await graph.stream(inputs, { streamMode: "messages" }); + + const aiMessageChunks = []; + for await (const [message] of stream) { + aiMessageChunks.push(message); + } + expect(aiMessageChunks.length).toBeGreaterThan(1); + }); + it("State graph packets", async () => { const AgentState = Annotation.Root({ messages: Annotation({ diff --git a/yarn.lock b/yarn.lock index e3570734..ee4245de 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1599,9 +1599,9 @@ __metadata: languageName: node linkType: hard -"@langchain/core@npm:^0.3.23": - version: 0.3.23 - resolution: "@langchain/core@npm:0.3.23" +"@langchain/core@npm:^0.3.24": + version: 0.3.24 + resolution: "@langchain/core@npm:0.3.24" dependencies: "@cfworker/json-schema": ^4.0.2 ansi-styles: ^5.0.0 @@ -1615,7 +1615,7 @@ __metadata: uuid: ^10.0.0 zod: ^3.22.4 zod-to-json-schema: ^3.22.3 - checksum: b8cb67c2201fb44feb6136ee0ab097217a760e918d6f33e8cb0152095c960ed9102605a23227b014127f57eadaa0a8aaf62b238557c18b4ef111feb8faf360cf + checksum: c2986e7ed8b7b869e27d633a14cd00d6a4777004ea59f4f70e99fc6b9db4d7e87d687aa8ad84e03684ea5053ea5f4b454c44716092401dc5cf8fd1b8d5cfe9d1 languageName: node linkType: hard @@ -1851,7 +1851,7 @@ __metadata: "@jest/globals": ^29.5.0 "@langchain/anthropic": ^0.3.5 "@langchain/community": ^0.3.9 - "@langchain/core": ^0.3.23 + "@langchain/core": ^0.3.24 "@langchain/langgraph-checkpoint": ~0.0.13 "@langchain/langgraph-checkpoint-postgres": "workspace:*" "@langchain/langgraph-checkpoint-sqlite": "workspace:*" @@ -6649,7 +6649,7 @@ __metadata: dependencies: "@langchain/anthropic": ^0.3.5 "@langchain/community": ^0.3.9 - "@langchain/core": ^0.3.23 + "@langchain/core": ^0.3.24 "@langchain/groq": ^0.1.2 "@langchain/langgraph": "workspace:*" "@langchain/mistralai": ^0.1.1