Skip to content

Commit

Permalink
core[patch]: Fix double streaming issue when streamEvents is called d…
Browse files Browse the repository at this point in the history
…irectly on chat models/LLMs (#6155)

* Fix double streaming issue when streamEvents is called directly on chat models/LLMs

* Fix lint, add docs

* Fix format
  • Loading branch information
jacoblee93 authored Jul 19, 2024
1 parent 0ff34f9 commit c116ee1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 53 deletions.
2 changes: 2 additions & 0 deletions docs/core_docs/docs/how_to/streaming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,8 @@
"| on_prompt_start | [template_name] | | {\"question\": \"hello\"} | |\n",
"| on_prompt_end | [template_name] | | {\"question\": \"hello\"} | ChatPromptValue(messages: [SystemMessage, ...]) |\n",
"\n",
"`streamEvents` will also emit dispatched custom events in `v2`. Please see [this guide](/docs/how_to/callbacks_custom_events/) for more.\n",
"\n",
"### Chat Model\n",
"\n",
"Let's start off by looking at the events produced by a chat model."
Expand Down
52 changes: 49 additions & 3 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,11 @@ export abstract class Runnable<
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_llm_end | [model name] | | 'Hello human!' | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_start | format_docs | | | |
* | on_chain_start | some_runnable | | | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_stream | format_docs | "hello world!, goodbye world!" | | |
* | on_chain_stream | some_runnable | "hello world!, goodbye world!" | | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_chain_end | format_docs | | [Document(...)] | "hello world!, goodbye world!" |
* | on_chain_end | some_runnable | | [Document(...)] | "hello world!, goodbye world!" |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
Expand All @@ -780,6 +780,52 @@ export abstract class Runnable<
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
* | on_prompt_end | [template_name] | | {"question": "hello"} | ChatPromptValue(messages: [SystemMessage, ...]) |
* +----------------------+------------------+---------------------------------+-----------------------------------------------+-------------------------------------------------+
*
* The "on_chain_*" events are the default for Runnables that don't fit one of the above categories.
*
* In addition to the standard events above, users can also dispatch custom events.
*
* Custom events will be only be surfaced with in the `v2` version of the API!
*
* A custom event has following format:
*
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
* | Attribute | Type | Description |
* +===========+======+===========================================================================================================+
* | name | str | A user defined name for the event. |
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
* | data | Any | The data associated with the event. This can be anything, though we suggest making it JSON serializable. |
* +-----------+------+-----------------------------------------------------------------------------------------------------------+
*
* Here's an example:
* @example
* ```ts
* import { RunnableLambda } from "@langchain/core/runnables";
* import { dispatchCustomEvent } from "@langchain/core/callbacks/dispatch";
* // Use this import for web environments that don't support "async_hooks"
* // and manually pass config to child runs.
* // import { dispatchCustomEvent } from "@langchain/core/callbacks/dispatch/web";
*
* const slowThing = RunnableLambda.from(async (someInput: string) => {
* // Placeholder for some slow operation
* await new Promise((resolve) => setTimeout(resolve, 100));
* await dispatchCustomEvent("progress_event", {
* message: "Finished step 1 of 2",
* });
* await new Promise((resolve) => setTimeout(resolve, 100));
* return "Done";
* });
*
* const eventStream = await slowThing.streamEvents("hello world", {
* version: "v2",
* });
*
* for await (const event of eventStream) {
* if (event.event === "on_custom_event") {
* console.log(event);
* }
* }
* ```
*/
streamEvents(
input: RunInput,
Expand Down
102 changes: 54 additions & 48 deletions langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,60 @@ test("Runnable streamEvents method", async () => {
]);
});

test("Runnable streamEvents method on a chat model", async () => {
const model = new FakeListChatModel({
responses: ["abc"],
});

const events = [];
const eventStream = await model.streamEvents("hello", { version: "v2" });
for await (const event of eventStream) {
events.push(event);
}
expect(events).toMatchObject([
{
data: { input: "hello" },
event: "on_chat_model_start",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "a" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "b" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { chunk: new AIMessageChunk({ content: "c" }) },
event: "on_chat_model_stream",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
{
data: { output: new AIMessageChunk({ content: "abc" }) },
event: "on_chat_model_end",
name: "FakeListChatModel",
metadata: expect.any(Object),
run_id: expect.any(String),
tags: [],
},
]);
});

test("Runnable streamEvents method with three runnables", async () => {
const r = RunnableLambda.from(reverse);

Expand Down Expand Up @@ -599,18 +653,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "h",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -625,18 +667,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "e",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -651,18 +681,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "y",
},
},
{
event: "on_llm_stream",
data: {
Expand All @@ -677,18 +695,6 @@ test("Runnable streamEvents method with llm", async () => {
a: "b",
},
},
{
event: "on_llm_stream",
run_id: expect.any(String),
name: "my_model",
tags: ["my_model"],
metadata: {
a: "b",
},
data: {
chunk: "!",
},
},
{
event: "on_llm_end",
data: {
Expand Down
17 changes: 15 additions & 2 deletions langchain-core/src/tracers/event_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ export class EventStreamCallbackHandler extends BaseTracer {
yield firstChunk.value;
return;
}
// Match format from handlers below
function _formatOutputChunk(eventType: string, data: unknown) {
if (eventType === "llm" && typeof data === "string") {
return new GenerationChunk({ text: data });
}
return data;
}
let tappedPromise = this.tappedPromises.get(runId);
// if we are the first to tap, issue stream events
if (tappedPromise === undefined) {
Expand All @@ -264,7 +271,9 @@ export class EventStreamCallbackHandler extends BaseTracer {
await this.send(
{
...event,
data: { chunk: firstChunk.value },
data: {
chunk: _formatOutputChunk(runInfo.runType, firstChunk.value),
},
},
runInfo
);
Expand All @@ -276,7 +285,7 @@ export class EventStreamCallbackHandler extends BaseTracer {
{
...event,
data: {
chunk,
chunk: _formatOutputChunk(runInfo.runType, chunk),
},
},
runInfo
Expand Down Expand Up @@ -354,6 +363,10 @@ export class EventStreamCallbackHandler extends BaseTracer {
if (runInfo === undefined) {
throw new Error(`onLLMNewToken: Run ID ${run.id} not found in run map.`);
}
// Top-level streaming events are covered by tapOutputIterable
if (run.parent_run_id === undefined) {
return;
}
if (runInfo.runType === "chat_model") {
eventName = "on_chat_model_stream";
if (kwargs?.chunk === undefined) {
Expand Down

0 comments on commit c116ee1

Please sign in to comment.