Skip to content

Commit

Permalink
fix(langgraph): Fix chat model streaming for streamMode messages (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Dec 16, 2024
1 parent d19858a commit 5969ea8
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 87 deletions.
2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"devDependencies": {
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.23",
"@langchain/core": "^0.3.24",
"@langchain/groq": "^0.1.2",
"@langchain/langgraph": "workspace:*",
"@langchain/mistralai": "^0.1.1",
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"@jest/globals": "^29.5.0",
"@langchain/anthropic": "^0.3.5",
"@langchain/community": "^0.3.9",
"@langchain/core": "^0.3.23",
"@langchain/core": "^0.3.24",
"@langchain/langgraph-checkpoint-postgres": "workspace:*",
"@langchain/langgraph-checkpoint-sqlite": "workspace:*",
"@langchain/openai": "^0.3.11",
Expand Down
4 changes: 4 additions & 0 deletions libs/langgraph/src/pregel/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ function isChatGenerationChunk(x: unknown): x is ChatGenerationChunk {
* A callback handler that implements stream_mode=messages.
* Collects messages from (1) chat model stream events and (2) node outputs.
*/
// TODO: Make this import and explicitly implement the
// CallbackHandlerPrefersStreaming interface once we drop support for core 0.2
export class StreamMessagesHandler extends BaseCallbackHandler {
name = "StreamMessagesHandler";

Expand All @@ -42,6 +44,8 @@ export class StreamMessagesHandler extends BaseCallbackHandler {

emittedChatModelRunIds: Record<string, boolean> = {};

lc_prefer_streaming = true;

constructor(streamFn: (streamChunk: StreamChunk) => void) {
super();
this.streamFn = streamFn;
Expand Down
137 changes: 58 additions & 79 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,44 @@ graph TD;
});
});

it("Supports automatic streaming with streamMode messages", async () => {
const llm = new FakeChatModel({
responses: [
new AIMessage({
id: "ai1",
content: "foobar",
}),
],
});

const StateAnnotation = Annotation.Root({
question: Annotation<string>,
answer: Annotation<string>,
});

const generate = async (state: typeof StateAnnotation.State) => {
const response = await llm.invoke(state.question);
return { answer: response.content };
};

// Compile application and test
const graph = new StateGraph(StateAnnotation)
.addNode("generate", generate)
.addEdge("__start__", "generate")
.addEdge("generate", "__end__")
.compile();

const inputs = { question: "How are you?" };

const stream = await graph.stream(inputs, { streamMode: "messages" });

const aiMessageChunks = [];
for await (const [message] of stream) {
aiMessageChunks.push(message);
}
expect(aiMessageChunks.length).toBeGreaterThan(1);
});

it("State graph packets", async () => {
const AgentState = Annotation.Root({
messages: Annotation({
Expand Down Expand Up @@ -8310,11 +8348,11 @@ graph TD;
it("should work with streamMode messages and custom from within a subgraph", async () => {
const child = new StateGraph(MessagesAnnotation)
.addNode("c_one", () => ({
messages: [new HumanMessage("foo"), new AIMessage("bar")],
messages: [new HumanMessage("f"), new AIMessage("b")],
}))
.addNode("c_two", async (_, config) => {
const model = new FakeChatModel({
responses: [new AIMessage("123"), new AIMessage("baz")],
responses: [new AIMessage("1"), new AIMessage("2")],
}).withConfig({ tags: ["c_two_chat_model"] });
const stream = await model.stream("yo", {
...config,
Expand All @@ -8336,7 +8374,7 @@ graph TD;
const parent = new StateGraph(MessagesAnnotation)
.addNode("p_one", async (_, config) => {
const toolExecutor = RunnableLambda.from(async () => {
return [new ToolMessage({ content: "qux", tool_call_id: "test" })];
return [new ToolMessage({ content: "q", tool_call_id: "test" })];
});
config.writer?.({
from: "parent",
Expand All @@ -8348,7 +8386,7 @@ graph TD;
.addNode("p_two", child.compile())
.addNode("p_three", async (_, config) => {
const model = new FakeChatModel({
responses: [new AIMessage("parent")],
responses: [new AIMessage("x")],
});
await model.invoke("hey", config);
return { messages: [] };
Expand All @@ -8369,7 +8407,7 @@ graph TD;
[
new _AnyIdToolMessage({
tool_call_id: "test",
content: "qux",
content: "q",
}),
{
langgraph_step: 1,
Expand All @@ -8386,7 +8424,7 @@ graph TD;
],
[
new _AnyIdHumanMessage({
content: "foo",
content: "f",
}),
{
langgraph_step: 1,
Expand All @@ -8403,7 +8441,7 @@ graph TD;
],
[
new _AnyIdAIMessage({
content: "bar",
content: "b",
}),
{
langgraph_step: 1,
Expand Down Expand Up @@ -8455,12 +8493,11 @@ graph TD;
ls_provider: "FakeChatModel",
ls_stop: undefined,
tags: ["c_two_chat_model"],
name: "c_two_chat_model_stream",
},
],
[
new _AnyIdAIMessageChunk({
content: "3",
content: "2",
}),
{
langgraph_step: 2,
Expand All @@ -8471,35 +8508,14 @@ graph TD;
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
ls_model_type: "chat",
ls_provider: "FakeChatModel",
name: "c_two",
tags: ["graph:step:2"],
ls_stop: undefined,
tags: ["c_two_chat_model"],
name: "c_two_chat_model_stream",
},
],
[
new _AnyIdAIMessage({
content: "baz",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns: expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
ls_model_type: "chat",
ls_provider: "FakeChatModel",
ls_stop: undefined,
tags: ["c_two_chat_model"],
},
],
[
new _AnyIdAIMessage({
content: "parent",
new _AnyIdAIMessageChunk({
content: "x",
}),
{
langgraph_step: 3,
Expand Down Expand Up @@ -8530,14 +8546,6 @@ graph TD;
content: "1",
from: "subgraph",
},
{
content: "2",
from: "subgraph",
},
{
content: "3",
from: "subgraph",
},
]);

const streamedCombinedEvents: StateSnapshot[] = await gatherIterator(
Expand All @@ -8554,7 +8562,7 @@ graph TD;
[
new _AnyIdToolMessage({
tool_call_id: "test",
content: "qux",
content: "q",
}),
{
langgraph_step: 1,
Expand All @@ -8574,7 +8582,7 @@ graph TD;
"messages",
[
new _AnyIdHumanMessage({
content: "foo",
content: "f",
}),
{
langgraph_step: 1,
Expand All @@ -8595,7 +8603,7 @@ graph TD;
"messages",
[
new _AnyIdAIMessage({
content: "bar",
content: "b",
}),
{
langgraph_step: 1,
Expand Down Expand Up @@ -8657,41 +8665,14 @@ graph TD;
ls_provider: "FakeChatModel",
ls_stop: undefined,
tags: ["c_two_chat_model"],
name: "c_two_chat_model_stream",
},
],
],
["custom", { from: "subgraph", content: "2" }],
[
"messages",
[
new _AnyIdAIMessageChunk({
content: "3",
}),
{
langgraph_step: 2,
langgraph_node: "c_two",
langgraph_triggers: ["c_one"],
langgraph_path: [PULL, "c_two"],
langgraph_checkpoint_ns:
expect.stringMatching(/^p_two:.*\|c_two:.*/),
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
ls_model_type: "chat",
ls_provider: "FakeChatModel",
ls_stop: undefined,
tags: ["c_two_chat_model"],
name: "c_two_chat_model_stream",
},
],
],
["custom", { from: "subgraph", content: "3" }],
[
"messages",
[
new _AnyIdAIMessage({
content: "baz",
content: "2",
}),
{
langgraph_step: 2,
Expand All @@ -8703,18 +8684,16 @@ graph TD;
__pregel_resuming: false,
__pregel_task_id: expect.any(String),
checkpoint_ns: expect.stringMatching(/^p_two:/),
ls_model_type: "chat",
ls_provider: "FakeChatModel",
ls_stop: undefined,
tags: ["c_two_chat_model"],
tags: ["graph:step:2"],
name: "c_two",
},
],
],
[
"messages",
[
new _AnyIdAIMessage({
content: "parent",
new _AnyIdAIMessageChunk({
content: "x",
}),
{
langgraph_step: 3,
Expand Down
12 changes: 6 additions & 6 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1599,9 +1599,9 @@ __metadata:
languageName: node
linkType: hard

"@langchain/core@npm:^0.3.23":
version: 0.3.23
resolution: "@langchain/core@npm:0.3.23"
"@langchain/core@npm:^0.3.24":
version: 0.3.24
resolution: "@langchain/core@npm:0.3.24"
dependencies:
"@cfworker/json-schema": ^4.0.2
ansi-styles: ^5.0.0
Expand All @@ -1615,7 +1615,7 @@ __metadata:
uuid: ^10.0.0
zod: ^3.22.4
zod-to-json-schema: ^3.22.3
checksum: b8cb67c2201fb44feb6136ee0ab097217a760e918d6f33e8cb0152095c960ed9102605a23227b014127f57eadaa0a8aaf62b238557c18b4ef111feb8faf360cf
checksum: c2986e7ed8b7b869e27d633a14cd00d6a4777004ea59f4f70e99fc6b9db4d7e87d687aa8ad84e03684ea5053ea5f4b454c44716092401dc5cf8fd1b8d5cfe9d1
languageName: node
linkType: hard

Expand Down Expand Up @@ -1851,7 +1851,7 @@ __metadata:
"@jest/globals": ^29.5.0
"@langchain/anthropic": ^0.3.5
"@langchain/community": ^0.3.9
"@langchain/core": ^0.3.23
"@langchain/core": ^0.3.24
"@langchain/langgraph-checkpoint": ~0.0.13
"@langchain/langgraph-checkpoint-postgres": "workspace:*"
"@langchain/langgraph-checkpoint-sqlite": "workspace:*"
Expand Down Expand Up @@ -6649,7 +6649,7 @@ __metadata:
dependencies:
"@langchain/anthropic": ^0.3.5
"@langchain/community": ^0.3.9
"@langchain/core": ^0.3.23
"@langchain/core": ^0.3.24
"@langchain/groq": ^0.1.2
"@langchain/langgraph": "workspace:*"
"@langchain/mistralai": ^0.1.1
Expand Down

0 comments on commit 5969ea8

Please sign in to comment.