From 3d539d8f99c2a94596c094e150dfd14c533786c0 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Fri, 9 Aug 2024 21:09:43 -0700 Subject: [PATCH] Update test --- langgraph/src/channels/base.ts | 2 +- langgraph/src/tests/pregel.test.ts | 509 +++++++++++++---------------- 2 files changed, 231 insertions(+), 280 deletions(-) diff --git a/langgraph/src/channels/base.ts b/langgraph/src/channels/base.ts index 136d293bf..2822bd4c9 100644 --- a/langgraph/src/channels/base.ts +++ b/langgraph/src/channels/base.ts @@ -94,6 +94,6 @@ export function createCheckpoint( channel_values: values, channel_versions: { ...checkpoint.channel_versions }, versions_seen: deepCopy(checkpoint.versions_seen), - pending_sends: [], + pending_sends: checkpoint.pending_sends ?? [], }; } diff --git a/langgraph/src/tests/pregel.test.ts b/langgraph/src/tests/pregel.test.ts index d26282913..67485c84f 100644 --- a/langgraph/src/tests/pregel.test.ts +++ b/langgraph/src/tests/pregel.test.ts @@ -1995,13 +1995,12 @@ describe("StateGraph", () => { }; }; - const graph = new StateGraph(AgentState) + const builder = new StateGraph(AgentState) .addNode("agent", agent) .addNode("tools", toolsNode) .addEdge("__start__", "agent") .addConditionalEdges("agent", shouldContinue) - .addEdge("tools", "agent") - .compile(); + .addEdge("tools", "agent"); const inputMessage = new HumanMessage({ id: "foo", content: "what is weather in sf", @@ -2061,21 +2060,20 @@ describe("StateGraph", () => { content: "answer", }), ]; - const res = await graph.invoke({ + const res = await builder.compile().invoke({ messages: [inputMessage], }); expect(res).toEqual({ messages: expectedOutputMessages, }); - const stream = await graph.stream({ + const stream = await builder.compile().stream({ messages: [inputMessage], }); - const chunks = []; + let chunks = []; for await (const chunk of stream) { chunks.push(chunk); } - console.log(chunks); const nodeOrder = ["agent", "tools", "agent", "tools", "tools", "agent"]; expect(nodeOrder.length).toEqual(chunks.length); expect(chunks).toEqual( @@ -2087,278 +2085,231 @@ describe("StateGraph", () => { }) ); - // app_w_interrupt = workflow.compile( - // checkpointer=MemorySaverAssertImmutable(serde=serde), - // interrupt_after=["agent"], - // ) - // config = {"configurable": {"thread_id": "1"}} - - // assert [ - // c - // for c in app_w_interrupt.stream( - // {"messages": HumanMessage(content="what is weather in sf")}, config - // ) - // ] == [ - // { - // "agent": { - // "messages": AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "query"}, - // }, - // ], - // ) - // } - // }, - // ] - - // assert app_w_interrupt.get_state(config) == StateSnapshot( - // values={ - // "messages": [ - // _AnyIdHumanMessage(content="what is weather in sf"), - // AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "query"}, - // }, - // ], - // ), - // ] - // }, - // next=("tools",), - // config=(app_w_interrupt.checkpointer.get_tuple(config)).config, - // created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], - // metadata={ - // "source": "loop", - // "step": 1, - // "writes": { - // "agent": { - // "messages": AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "query"}, - // }, - // ], - // ) - // } - // }, - // }, - // parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, - // ) - - // # modify ai message - // last_message = (app_w_interrupt.get_state(config)).values["messages"][-1] - // last_message.tool_calls[0]["args"]["query"] = "a different query" - // app_w_interrupt.update_state( - // config, {"messages": last_message, "something_extra": "hi there"} - // ) - - // # message was replaced instead of appended - // assert app_w_interrupt.get_state(config) == StateSnapshot( - // values={ - // "messages": [ - // _AnyIdHumanMessage(content="what is weather in sf"), - // AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "a different query"}, - // }, - // ], - // ), - // ] - // }, - // next=("tools",), - // config=app_w_interrupt.checkpointer.get_tuple(config).config, - // created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], - // metadata={ - // "source": "update", - // "step": 2, - // "writes": { - // "agent": { - // "messages": AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "a different query"}, - // }, - // ], - // ), - // "something_extra": "hi there", - // } - // }, - // }, - // parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, - // ) - - // assert [c for c in app_w_interrupt.stream(None, config)] == [ - // { - // "tools": { - // "messages": ToolMessage( - // content="result for a different query", - // name="search_api", - // id=AnyStr(), - // tool_call_id="tool_call123", - // ) - // } - // }, - // { - // "agent": { - // "messages": AIMessage( - // id="ai2", - // content="", - // tool_calls=[ - // { - // "id": "tool_call234", - // "name": "search_api", - // "args": {"query": "another", "idx": 0}, - // }, - // { - // "id": "tool_call567", - // "name": "search_api", - // "args": {"query": "a third one", "idx": 1}, - // }, - // ], - // ) - // }, - // }, - // ] - - // assert app_w_interrupt.get_state(config) == StateSnapshot( - // values={ - // "messages": [ - // _AnyIdHumanMessage(content="what is weather in sf"), - // AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "a different query"}, - // }, - // ], - // ), - // ToolMessage( - // content="result for a different query", - // name="search_api", - // id=AnyStr(), - // tool_call_id="tool_call123", - // ), - // AIMessage( - // id="ai2", - // content="", - // tool_calls=[ - // { - // "id": "tool_call234", - // "name": "search_api", - // "args": {"query": "another", "idx": 0}, - // }, - // { - // "id": "tool_call567", - // "name": "search_api", - // "args": {"query": "a third one", "idx": 1}, - // }, - // ], - // ), - // ] - // }, - // next=("tools", "tools"), - // config=app_w_interrupt.checkpointer.get_tuple(config).config, - // created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], - // metadata={ - // "source": "loop", - // "step": 4, - // "writes": { - // "agent": { - // "messages": AIMessage( - // id="ai2", - // content="", - // tool_calls=[ - // { - // "id": "tool_call234", - // "name": "search_api", - // "args": {"query": "another", "idx": 0}, - // }, - // { - // "id": "tool_call567", - // "name": "search_api", - // "args": {"query": "a third one", "idx": 1}, - // }, - // ], - // ) - // }, - // }, - // }, - // parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, - // ) - - // app_w_interrupt.update_state( - // config, - // { - // "messages": AIMessage(content="answer", id="ai2"), - // "something_extra": "hi there", - // }, - // ) - - // # replaces message even if object identity is different, as long as id is the same - // assert app_w_interrupt.get_state(config) == StateSnapshot( - // values={ - // "messages": [ - // _AnyIdHumanMessage(content="what is weather in sf"), - // AIMessage( - // id="ai1", - // content="", - // tool_calls=[ - // { - // "id": "tool_call123", - // "name": "search_api", - // "args": {"query": "a different query"}, - // }, - // ], - // ), - // ToolMessage( - // content="result for a different query", - // name="search_api", - // id=AnyStr(), - // tool_call_id="tool_call123", - // ), - // AIMessage(content="answer", id="ai2"), - // ] - // }, - // next=(), - // config=app_w_interrupt.checkpointer.get_tuple(config).config, - // created_at=(app_w_interrupt.checkpointer.get_tuple(config)).checkpoint["ts"], - // metadata={ - // "source": "update", - // "step": 5, - // "writes": { - // "agent": { - // "messages": AIMessage(content="answer", id="ai2"), - // "something_extra": "hi there", - // } - // }, - // }, - // parent_config=[*app_w_interrupt.checkpointer.list(config, limit=2)][-1].config, - // ) + const appWithInterrupt = builder.compile({ + checkpointer: new MemorySaverAssertImmutable(), + interruptAfter: ["agent"], + }); + const config = { configurable: { thread_id: "1" } }; + chunks = []; + for await (const chunk of await appWithInterrupt.stream( + { + messages: [inputMessage], + }, + config + )) { + chunks.push(chunk); + } + expect(chunks).toEqual([ + { + agent: { + messages: expectedOutputMessages[1], + }, + }, + ]); + const appWithInterruptState = await appWithInterrupt.getState(config); + expect(appWithInterruptState).toEqual({ + values: { + messages: expectedOutputMessages.slice(0, 2), + }, + // TODO: Populate, see Python test + next: [], + metadata: { + source: "loop", + step: 1, + writes: { + agent: { + messages: expectedOutputMessages[1], + }, + }, + }, + config: (await appWithInterrupt.checkpointer?.getTuple(config))?.config, + createdAt: (await appWithInterrupt.checkpointer?.getTuple(config)) + ?.checkpoint.ts, + // TODO: Populate, see Python test + parentConfig: undefined, + }); + + // modify ai message + const lastMessage = + appWithInterruptState!.values.messages[ + appWithInterruptState!.values.messages.length - 1 + ]; + lastMessage.tool_calls[0].args.query = "a different query"; + await appWithInterrupt.updateState(config, { + messages: lastMessage, + something_extra: "hi there", + }); + expect(await appWithInterrupt.getState(config)).toEqual({ + values: { + messages: [ + expectedOutputMessages[0], + new AIMessage({ + id: "ai1", + content: "", + tool_calls: [ + { + id: "tool_call123", + name: "search_api", + args: { query: "a different query" }, + type: "tool_call", + }, + ], + }), + ], + }, + // TODO: Populate, see Python test + next: [], + metadata: { + source: "update", + step: 2, + writes: { + agent: { + messages: new AIMessage({ + id: "ai1", + content: "", + tool_calls: [ + { + id: "tool_call123", + name: "search_api", + args: { query: "a different query" }, + type: "tool_call", + }, + ], + }), + something_extra: "hi there", + }, + }, + }, + config: (await appWithInterrupt.checkpointer?.getTuple(config))?.config, + createdAt: (await appWithInterrupt.checkpointer?.getTuple(config)) + ?.checkpoint.ts, + // TODO: Populate, see Python test + parentConfig: undefined, + }); + + chunks = []; + for await (const chunk of await appWithInterrupt.stream(null, config)) { + chunks.push(chunk); + } + expect(chunks).toEqual([ + { + tools: { + messages: new ToolMessage({ + id: "abc", + content: "result for a different query", + name: "search_api", + tool_call_id: "tool_call123", + }), + }, + }, + { + agent: { + messages: expectedOutputMessages[3], + }, + }, + ]); + + expect(await appWithInterrupt.getState(config)).toEqual({ + values: { + messages: [ + expectedOutputMessages[0], + new AIMessage({ + id: "ai1", + content: "", + tool_calls: [ + { + id: "tool_call123", + name: "search_api", + args: { query: "a different query" }, + type: "tool_call", + }, + ], + }), + new ToolMessage({ + id: "abc", + content: "result for a different query", + name: "search_api", + tool_call_id: "tool_call123", + }), + expectedOutputMessages[3], + ], + }, + // TODO: Populate, see Python test + next: [], + metadata: { + source: "loop", + step: 4, + writes: { + agent: { + messages: expectedOutputMessages[3], + }, + }, + }, + createdAt: (await appWithInterrupt.checkpointer?.getTuple(config)) + ?.checkpoint.ts, + config: (await appWithInterrupt.checkpointer?.getTuple(config))?.config, + // TODO: Populate, see Python test + parentConfig: undefined, + }); + + // replaces message even if object identity is different, as long as id is the same + await appWithInterrupt.updateState(config, { + messages: new AIMessage({ + id: "ai2", + content: "answer", + }), + something_extra: "hi there", + }); + + expect(await appWithInterrupt.getState(config)).toEqual({ + values: { + messages: [ + expectedOutputMessages[0], + new AIMessage({ + id: "ai1", + content: "", + tool_calls: [ + { + id: "tool_call123", + name: "search_api", + args: { query: "a different query" }, + type: "tool_call", + }, + ], + }), + new ToolMessage({ + id: "abc", + content: "result for a different query", + name: "search_api", + tool_call_id: "tool_call123", + }), + new AIMessage({ + content: "answer", + id: "ai2", + }), + ], + }, + // TODO: Populate, see Python test + next: [], + metadata: { + source: "update", + step: 5, + writes: { + agent: { + messages: new AIMessage({ + content: "answer", + id: "ai2", + }), + something_extra: "hi there", + }, + }, + }, + createdAt: (await appWithInterrupt.checkpointer?.getTuple(config)) + ?.checkpoint.ts, + config: (await appWithInterrupt.checkpointer?.getTuple(config))?.config, + // TODO: Populate, see Python test + parentConfig: undefined, + }); }); });