From f0657e540dc277d2cb0acddb438feb8cc7010e95 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 17 May 2024 03:07:18 +0200 Subject: [PATCH 1/3] fix(handoff): support streaming returned generators --- js/src/langchain.ts | 44 +++++++++++- js/src/run_trees.ts | 11 +++ js/src/tests/traceable_langchain.test.ts | 89 +++++++++++++++++++++++- js/src/traceable.ts | 64 +++-------------- js/src/utils/asserts.ts | 51 ++++++++++++++ 5 files changed, 201 insertions(+), 58 deletions(-) create mode 100644 js/src/utils/asserts.ts diff --git a/js/src/langchain.ts b/js/src/langchain.ts index 505e05424..fc51d561f 100644 --- a/js/src/langchain.ts +++ b/js/src/langchain.ts @@ -1,6 +1,11 @@ import { CallbackManager } from "@langchain/core/callbacks/manager"; import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain"; -import { Runnable, RunnableConfig } from "@langchain/core/runnables"; +import { + Runnable, + RunnableConfig, + patchConfig, + getCallbackManagerForConfig, +} from "@langchain/core/runnables"; import { RunTree } from "./run_trees.js"; import { Run } from "./schemas.js"; @@ -9,6 +14,7 @@ import { getCurrentRunTree, isTraceableFunction, } from "./traceable.js"; +import { isAsyncIterable, isIteratorLike } from "./utils/asserts.js"; /** * Converts the current run tree active within a traceable-wrapped function @@ -113,7 +119,41 @@ export class RunnableTraceable extends Runnable< async invoke(input: RunInput, options?: Partial) { const [config] = this._getOptionsList(options ?? {}, 1); - return (await this.func(config, input)) as RunOutput; + + const callbacks = await getCallbackManagerForConfig(config); + + return (await this.func( + patchConfig(config, { callbacks }), + input + )) as RunOutput; + } + + async *_streamIterator( + input: RunInput, + options?: Partial + ): AsyncGenerator { + const result = await this.invoke(input, options); + + if (isAsyncIterable(result)) { + const iterator = result[Symbol.asyncIterator](); + while (true) { + const { done, value } = await iterator.next(); + if (done) break; + yield value as RunOutput; + } + return; + } + + if (isIteratorLike(result)) { + while (true) { + const state: IteratorResult = result.next(); + if (state.done) break; + yield state.value as RunOutput; + } + return; + } + + yield result; } static from(func: AnyTraceableFunction) { diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 91c812928..639925b3c 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -188,6 +188,17 @@ export class RunTree implements BaseRun { tracingEnabled = tracingEnabled || !!langChainTracer; } + if (!parentRun) { + return new RunTree({ + name: props.name, + client, + tracingEnabled: isTracingEnabled(), + project_name: projectName, + tags: props.tags, + metadata: props.metadata, + }); + } + const parentRunTree = new RunTree({ name: parentRun?.name ?? "", id: parentRun?.id, diff --git a/js/src/tests/traceable_langchain.test.ts b/js/src/tests/traceable_langchain.test.ts index b3056ac53..779d4d7bf 100644 --- a/js/src/tests/traceable_langchain.test.ts +++ b/js/src/tests/traceable_langchain.test.ts @@ -63,7 +63,7 @@ describe("to langchain", () => { }); }); - test("to langchain stream", async () => { + test("stream", async () => { const { client, callSpy } = mockClient(); const main = traceable( @@ -100,7 +100,7 @@ describe("to langchain", () => { }); }); - test("to langchain batch", async () => { + test("batch", async () => { const { client, callSpy } = mockClient(); const main = traceable( @@ -191,6 +191,91 @@ describe("to traceable", () => { ], }); }); + + test("array stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(function (input: { text: string }) { + return input.text.split(" "); + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual([["Hello", "world"]]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + + test("generator stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(function* (input: { text: string }) { + const chunks = input.text.split(" "); + for (const chunk of chunks) { + yield chunk; + } + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + + test("async generator stream", async () => { + const { client, callSpy } = mockClient(); + const source = RunnableTraceable.from( + traceable(async function* (input: { text: string }) { + const chunks = input.text.split(" "); + for (const chunk of chunks) { + yield chunk; + } + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + callbacks: [new LangChainTracer({ client })], + } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); }); test("explicit nested", async () => { diff --git a/js/src/traceable.ts b/js/src/traceable.ts index 82f86b616..173f1cba3 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -14,65 +14,20 @@ import { AsyncLocalStorageProviderSingleton, } from "./singletons/traceable.js"; import { TraceableFunction } from "./singletons/types.js"; +import { + isKVMap, + isReadableStream, + isAsyncIterable, + isIteratorLike, + isThenable, + isGenerator, + isPromiseMethod, +} from "./utils/asserts.js"; AsyncLocalStorageProviderSingleton.initializeGlobalInstance( new AsyncLocalStorage() ); -function isPromiseMethod( - x: string | symbol -): x is "then" | "catch" | "finally" { - if (x === "then" || x === "catch" || x === "finally") { - return true; - } - return false; -} - -function isKVMap(x: unknown): x is Record { - if (typeof x !== "object" || x == null) { - return false; - } - - const prototype = Object.getPrototypeOf(x); - return ( - (prototype === null || - prototype === Object.prototype || - Object.getPrototypeOf(prototype) === null) && - !(Symbol.toStringTag in x) && - !(Symbol.iterator in x) - ); -} - -const isAsyncIterable = (x: unknown): x is AsyncIterable => - x != null && - typeof x === "object" && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - typeof (x as any)[Symbol.asyncIterator] === "function"; - -const isIteratorLike = (x: unknown): x is Iterator => - x != null && - typeof x === "object" && - "next" in x && - typeof x.next === "function"; - -const GeneratorFunction = function* () {}.constructor; - -const isGenerator = (x: unknown): x is Generator => - // eslint-disable-next-line no-instanceof/no-instanceof - x != null && typeof x === "function" && x instanceof GeneratorFunction; - -const isThenable = (x: unknown): x is Promise => - x != null && - typeof x === "object" && - "then" in x && - typeof x.then === "function"; - -const isReadableStream = (x: unknown): x is ReadableStream => - x != null && - typeof x === "object" && - "getReader" in x && - typeof x.getReader === "function"; - const handleRunInputs = (rawInputs: unknown[]): KVMap => { const firstInput = rawInputs[0]; @@ -83,6 +38,7 @@ const handleRunInputs = (rawInputs: unknown[]): KVMap => { if (rawInputs.length > 1) { return { args: rawInputs }; } + if (isKVMap(firstInput)) { return firstInput; } diff --git a/js/src/utils/asserts.ts b/js/src/utils/asserts.ts new file mode 100644 index 000000000..55bc260db --- /dev/null +++ b/js/src/utils/asserts.ts @@ -0,0 +1,51 @@ +export function isPromiseMethod( + x: string | symbol +): x is "then" | "catch" | "finally" { + if (x === "then" || x === "catch" || x === "finally") { + return true; + } + return false; +} + +export function isKVMap(x: unknown): x is Record { + if (typeof x !== "object" || x == null) { + return false; + } + + const prototype = Object.getPrototypeOf(x); + return ( + (prototype === null || + prototype === Object.prototype || + Object.getPrototypeOf(prototype) === null) && + !(Symbol.toStringTag in x) && + !(Symbol.iterator in x) + ); +} +export const isAsyncIterable = (x: unknown): x is AsyncIterable => + x != null && + typeof x === "object" && + // eslint-disable-next-line @typescript-eslint/no-explicit-any + typeof (x as any)[Symbol.asyncIterator] === "function"; + +export const isIteratorLike = (x: unknown): x is Iterator => + x != null && + typeof x === "object" && + "next" in x && + typeof x.next === "function"; + +const GeneratorFunction = function* () {}.constructor; +export const isGenerator = (x: unknown): x is Generator => + // eslint-disable-next-line no-instanceof/no-instanceof + x != null && typeof x === "function" && x instanceof GeneratorFunction; + +export const isThenable = (x: unknown): x is Promise => + x != null && + typeof x === "object" && + "then" in x && + typeof x.then === "function"; + +export const isReadableStream = (x: unknown): x is ReadableStream => + x != null && + typeof x === "object" && + "getReader" in x && + typeof x.getReader === "function"; From 65a80d6bfe1094fd1765b0eb433522857069d16d Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Fri, 17 May 2024 04:01:19 +0200 Subject: [PATCH 2/3] Fix tests --- js/src/langchain.ts | 1 - js/src/run_trees.ts | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/js/src/langchain.ts b/js/src/langchain.ts index fc51d561f..9aa0825dc 100644 --- a/js/src/langchain.ts +++ b/js/src/langchain.ts @@ -119,7 +119,6 @@ export class RunnableTraceable extends Runnable< async invoke(input: RunInput, options?: Partial) { const [config] = this._getOptionsList(options ?? {}, 1); - const callbacks = await getCallbackManagerForConfig(config); return (await this.func( diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index 639925b3c..4ecf82173 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -190,18 +190,18 @@ export class RunTree implements BaseRun { if (!parentRun) { return new RunTree({ - name: props.name, client, - tracingEnabled: isTracingEnabled(), + tracingEnabled, project_name: projectName, + name: props.name, tags: props.tags, metadata: props.metadata, }); } const parentRunTree = new RunTree({ - name: parentRun?.name ?? "", - id: parentRun?.id, + name: parentRun.name, + id: parentRun.id, client, tracingEnabled, project_name: projectName, @@ -217,7 +217,7 @@ export class RunTree implements BaseRun { }); return parentRunTree.createChild({ - name: props?.name ?? "", + name: props.name, tags: props.tags, metadata: props.metadata, }); From ed3463438b0e2d6fb1bc65257d035938620bf47d Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Tue, 21 May 2024 18:04:22 +0200 Subject: [PATCH 3/3] Use for await, make sure readable stream works as well --- js/src/langchain.ts | 7 ++--- js/src/tests/traceable_langchain.test.ts | 35 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/js/src/langchain.ts b/js/src/langchain.ts index 9aa0825dc..a3a4de845 100644 --- a/js/src/langchain.ts +++ b/js/src/langchain.ts @@ -134,11 +134,8 @@ export class RunnableTraceable extends Runnable< const result = await this.invoke(input, options); if (isAsyncIterable(result)) { - const iterator = result[Symbol.asyncIterator](); - while (true) { - const { done, value } = await iterator.next(); - if (done) break; - yield value as RunOutput; + for await (const item of result) { + yield item as RunOutput; } return; } diff --git a/js/src/tests/traceable_langchain.test.ts b/js/src/tests/traceable_langchain.test.ts index 779d4d7bf..c5d03e027 100644 --- a/js/src/tests/traceable_langchain.test.ts +++ b/js/src/tests/traceable_langchain.test.ts @@ -247,6 +247,41 @@ describe("to traceable", () => { }); }); + test("readable stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(async function (input: { text: string }) { + const readStream = new ReadableStream({ + async pull(controller) { + for (const item of input.text.split(" ")) { + controller.enqueue(item); + } + controller.close(); + }, + }); + + return readStream; + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + test("async generator stream", async () => { const { client, callSpy } = mockClient(); const source = RunnableTraceable.from(