From 15b73dae9e176aa52041f2a5a33845e2a61dcc40 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 6 May 2024 15:16:26 +0200 Subject: [PATCH 1/6] fix(js): add support for tracing generators --- js/src/tests/traceable.test.ts | 81 ++++++++++++++++++++++++++++++++++ js/src/traceable.ts | 66 +++++++++++++++++++++++---- 2 files changed, 139 insertions(+), 8 deletions(-) diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index 8e8a72f82..cc5a8a38a 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -441,3 +441,84 @@ describe("langchain", () => { }); }); }); + +describe("generator", () => { + function gatherAll(iterator: Iterator) { + const chunks: unknown[] = []; + // eslint-disable-next-line no-constant-condition + while (true) { + const next = iterator.next(); + chunks.push(next.value); + if (next.done) break; + } + + return chunks; + } + + test("yield", async () => { + const { client, callSpy } = mockClient(); + + function* generator() { + for (let i = 0; i < 3; ++i) yield i; + } + + const traced = traceable(generator, { client, tracingEnabled: true }); + + expect(gatherAll(await traced())).toEqual(gatherAll(generator())); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["generator:0"], + edges: [], + data: { + "generator:0": { + outputs: { outputs: [0, 1, 2] }, + }, + }, + }); + }); + + test("with return", async () => { + const { client, callSpy } = mockClient(); + + function* generator() { + for (let i = 0; i < 3; ++i) yield i; + return 3; + } + + const traced = traceable(generator, { client, tracingEnabled: true }); + + expect(gatherAll(await traced())).toEqual(gatherAll(generator())); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["generator:0"], + edges: [], + data: { "generator:0": { outputs: { outputs: [0, 1, 2, 3] } } }, + }); + }); + + test("nested", async () => { + const { client, callSpy } = mockClient(); + + function* generator() { + function* child() { + for (let i = 0; i < 3; ++i) yield i; + } + + for (let i = 0; i < 2; ++i) { + for (const num of child()) yield num; + } + + return 3; + } + + const traced = traceable(generator, { client, tracingEnabled: true }); + expect(gatherAll(await traced())).toEqual(gatherAll(generator())); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["generator:0"], + edges: [], + data: { + "generator:0": { + outputs: { outputs: [0, 1, 2, 0, 1, 2, 3] }, + }, + }, + }); + }); +}); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index 32acadd04..30f85f4fb 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -104,6 +104,18 @@ const isAsyncIterable = (x: unknown): x is AsyncIterable => // eslint-disable-next-line @typescript-eslint/no-explicit-any typeof (x as any)[Symbol.asyncIterator] === "function"; +const GeneratorFunction = function* () {}.constructor; + +const isIteratorLike = (x: unknown): x is Iterator => + x != null && + typeof x === "object" && + "next" in x && + typeof x.next === "function"; + +const isGenerator = (x: unknown): x is Generator => + // eslint-disable-next-line no-instanceof/no-instanceof + x != null && typeof x === "function" && x instanceof GeneratorFunction; + const tracingIsEnabled = (tracingEnabled?: boolean): boolean => { if (tracingEnabled !== undefined) { return tracingEnabled; @@ -350,6 +362,18 @@ export function traceable any>( await currentRunTree?.patchRun(); } + function gatherAll(iterator: Iterator) { + const chunks: IteratorResult[] = []; + // eslint-disable-next-line no-constant-condition + while (true) { + const next = iterator.next(); + chunks.push(next); + if (next.done) break; + } + + return chunks; + } + let returnValue: unknown; try { returnValue = wrappedFunc(...rawInputs); @@ -371,14 +395,40 @@ export function traceable any>( return resolve( wrapAsyncGeneratorForTracing(rawOutput, snapshot) ); - } else { - try { - await currentRunTree?.end(handleRunOutputs(rawOutput)); - await handleEnd(); - } finally { - // eslint-disable-next-line no-unsafe-finally - return rawOutput; - } + } + + if (isGenerator(wrappedFunc) && isIteratorLike(rawOutput)) { + const chunks = gatherAll(rawOutput); + + await currentRunTree?.end( + handleRunOutputs( + await handleChunks( + chunks.reduce((memo, { value, done }) => { + if (!done || typeof value !== "undefined") { + memo.push(value); + } + + return memo; + }, []) + ) + ) + ); + await handleEnd(); + + return (function* () { + for (const ret of chunks) { + if (ret.done) return ret.value; + yield ret.value; + } + })(); + } + + try { + await currentRunTree?.end(handleRunOutputs(rawOutput)); + await handleEnd(); + } finally { + // eslint-disable-next-line no-unsafe-finally + return rawOutput; } }, async (error: unknown) => { From c45e7c6eadf993228fa3159f14b1d06818d14e27 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 6 May 2024 23:34:54 +0200 Subject: [PATCH 2/6] Send inputs even when patching run --- js/src/run_trees.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index f89ef16e0..267b22c5e 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -326,6 +326,7 @@ export class RunTree implements BaseRun { const runUpdate: RunUpdate = { end_time: this.end_time, error: this.error, + inputs: this.inputs, outputs: this.outputs, parent_run_id: this.parent_run?.id, reference_example_id: this.reference_example_id, From cff6cd09d11b05bbd4a2107a5144c6d4cb7bc71c Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 6 May 2024 23:35:26 +0200 Subject: [PATCH 3/6] Collect deferred values during invocation via proxy --- js/src/tests/traceable.test.ts | 108 +++++++++++++++++ js/src/traceable.ts | 204 +++++++++++++++++++++++++++++++-- 2 files changed, 302 insertions(+), 10 deletions(-) diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index cc5a8a38a..46cefbd4b 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -407,6 +407,114 @@ describe("async generators", () => { }); }); +describe("deferred input", () => { + test("generator", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function* parrotStream(input: Generator) { + for (const token of input) { + yield token; + } + }, + { client, tracingEnabled: true } + ); + + const inputGenerator = function* () { + for (const token of "Hello world".split(" ")) { + yield token; + } + }; + + const tokens: string[] = []; + for await (const token of parrotStream(inputGenerator())) { + tokens.push(token); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: ["Hello", "world"] }, + outputs: { outputs: ["Hello", "world"] }, + }, + }, + }); + }); + + test("async generator", async () => { + const { client, callSpy } = mockClient(); + const inputStream = async function* inputStream() { + for (const token of "Hello world".split(" ")) { + yield token; + } + }; + + const parrotStream = traceable( + async function* parrotStream(input: AsyncGenerator) { + for await (const token of input) { + yield token; + } + }, + { client, tracingEnabled: true } + ); + + const tokens: string[] = []; + for await (const token of parrotStream(inputStream())) { + tokens.push(token); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: ["Hello", "world"] }, + outputs: { outputs: ["Hello", "world"] }, + }, + }, + }); + }); + + test("promise", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function* parrotStream(input: Promise) { + // eslint-disable-next-line no-instanceof/no-instanceof + if (!(input instanceof Promise)) { + throw new Error("Input must be a promise"); + } + + for (const token of await input) { + yield token; + } + }, + { client, tracingEnabled: true } + ); + + const tokens: string[] = []; + for await (const token of parrotStream( + Promise.resolve(["Hello", "world"]) + )) { + tokens.push(token); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: ["Hello", "world"] }, + outputs: { outputs: ["Hello", "world"] }, + }, + }, + }); + }); +}); + describe("langchain", () => { test.skip("bound", async () => { const { client, callSpy } = mockClient(); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index 30f85f4fb..f2fe53e8d 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -116,6 +116,12 @@ const isGenerator = (x: unknown): x is Generator => // eslint-disable-next-line no-instanceof/no-instanceof x != null && typeof x === "function" && x instanceof GeneratorFunction; +const isThenable = (x: unknown): x is Promise => + x != null && + typeof x === "object" && + "then" in x && + typeof x.then === "function"; + const tracingIsEnabled = (tracingEnabled?: boolean): boolean => { if (tracingEnabled !== undefined) { return tracingEnabled; @@ -133,6 +139,7 @@ const tracingIsEnabled = (tracingEnabled?: boolean): boolean => { const handleRunInputs = (rawInputs: unknown[]): KVMap => { const firstInput = rawInputs[0]; + if (firstInput == null) { return {}; } @@ -167,6 +174,169 @@ const getTracingRunTree = ( return runTree; }; +// idea: store the state of the promise outside +// but only when the promise is "consumed" +const getSerializablePromise = (arg: Promise) => { + const proxyState: { + current: ["resolve", unknown] | ["reject", unknown] | undefined; + } = { current: undefined }; + + const promiseProxy = new Proxy(arg, { + get(target, prop, receiver) { + if (prop === "then") { + const boundThen = arg[prop].bind(arg); + return ( + resolve: (value: unknown) => unknown, + reject?: (error: unknown) => unknown + ) => { + return boundThen( + (value) => { + proxyState.current = ["resolve", value]; + return resolve(value); + }, + (error) => { + proxyState.current = ["reject", error]; + return reject?.(error); + } + ); + }; + } + + if (prop === "catch") { + const boundCatch = arg[prop].bind(arg); + return (reject: (error: unknown) => unknown) => { + return boundCatch((error) => { + proxyState.current = ["reject", error]; + return reject(error); + }); + }; + } + + if (prop === "toJSON") { + return () => { + if (!proxyState.current) return undefined; + const [type, value] = proxyState.current ?? []; + if (type === "resolve") return value; + return { error: value }; + }; + } + + return Reflect.get(target, prop, receiver); + }, + }); + + return promiseProxy as Promise & { toJSON: () => unknown }; +}; + +// attempt to +const convertSerializableArg = (arg: unknown): unknown => { + if (isAsyncIterable(arg)) { + const proxyState: { + current: (Promise> & { + toJSON: () => unknown; + })[]; + } = { current: [] }; + + return new Proxy(arg, { + get(target, prop, receiver) { + if (prop === Symbol.asyncIterator) { + return () => { + const boundIterator = arg[Symbol.asyncIterator].bind(arg); + const iterator = boundIterator(); + + return new Proxy(iterator, { + get(target, prop, receiver) { + if (prop === "next" || prop === "return" || prop === "throw") { + const bound = iterator.next.bind(iterator); + + return ( + ...args: Parameters< + Exclude< + AsyncIterator["next" | "return" | "throw"], + undefined + > + > + ) => { + // @ts-expect-error TS cannot infer the argument types for the bound function + const wrapped = getSerializablePromise(bound(...args)); + proxyState.current.push(wrapped); + return wrapped; + }; + } + + if (prop === "return" || prop === "throw") { + return iterator.next.bind(iterator); + } + + return Reflect.get(target, prop, receiver); + }, + }); + }; + } + + if (prop === "toJSON") { + return () => { + const onlyNexts = proxyState.current; + const serialized = onlyNexts.map( + (next) => next.toJSON() as IteratorResult + ); + + const chunks = serialized.reduce((memo, next) => { + if (next.value) memo.push(next.value); + return memo; + }, []); + + return chunks; + }; + } + + return Reflect.get(target, prop, receiver); + }, + }); + } + + if (!Array.isArray(arg) && isIteratorLike(arg)) { + const proxyState: Array> = []; + + return new Proxy(arg, { + get(target, prop, receiver) { + if (prop === "next" || prop === "return" || prop === "throw") { + const bound = arg[prop]?.bind(arg); + return ( + ...args: Parameters< + Exclude["next" | "return" | "throw"], undefined> + > + ) => { + // @ts-expect-error TS cannot infer the argument types for the bound function + const next = bound?.(...args); + if (next != null) proxyState.push(next); + return next; + }; + } + + if (prop === "toJSON") { + return () => { + const chunks = proxyState.reduce((memo, next) => { + if (next.value) memo.push(next.value); + return memo; + }, []); + + return chunks; + }; + } + + return Reflect.get(target, prop, receiver); + }, + }); + } + + if (isThenable(arg)) { + return getSerializablePromise(arg); + } + + return arg; +}; + /** * Higher-order function that takes function as input and returns a * "TraceableFunction" - a wrapped version of the input that @@ -247,8 +417,14 @@ export function traceable any>( }; } + // TODO: deal with possible nested promises and async iterables + const processedArgs = args as unknown as Inputs; + for (let i = 0; i < processedArgs.length; i++) { + processedArgs[i] = convertSerializableArg(processedArgs[i]); + } + const [currentRunTree, rawInputs] = ((): [RunTree | undefined, Inputs] => { - const [firstArg, ...restArgs] = args; + const [firstArg, ...restArgs] = processedArgs; // used for handoff between LangChain.JS and traceable functions if (isRunnableConfigLike(firstArg)) { @@ -289,16 +465,19 @@ export function traceable any>( const prevRunFromStore = asyncLocalStorage.getStore(); if (prevRunFromStore) { return [ - getTracingRunTree(prevRunFromStore.createChild(ensuredConfig), args), - args as Inputs, + getTracingRunTree( + prevRunFromStore.createChild(ensuredConfig), + processedArgs + ), + processedArgs as Inputs, ]; } const currentRunTree = getTracingRunTree( new RunTree(ensuredConfig), - args + processedArgs ); - return [currentRunTree, args as Inputs]; + return [currentRunTree, processedArgs as Inputs]; })(); return asyncLocalStorage.run(currentRunTree, () => { @@ -490,12 +669,17 @@ export function isTraceableFunction( } function isKVMap(x: unknown): x is Record { + if (typeof x !== "object" || x == null) { + return false; + } + + const prototype = Object.getPrototypeOf(x); return ( - typeof x === "object" && - x != null && - !Array.isArray(x) && - // eslint-disable-next-line no-instanceof/no-instanceof - !(x instanceof Date) + (prototype === null || + prototype === Object.prototype || + Object.getPrototypeOf(prototype) === null) && + !(Symbol.toStringTag in x) && + !(Symbol.iterator in x) ); } From e5d37694eca91d43e0c41936cea2f6645c6989d8 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Mon, 6 May 2024 23:49:06 +0200 Subject: [PATCH 4/6] Handle ReadableStream in input --- js/src/tests/traceable.test.ts | 85 +++++++++++++++++++++++++++++++++- js/src/traceable.ts | 27 +++++++++-- 2 files changed, 108 insertions(+), 4 deletions(-) diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index 46cefbd4b..a92ce523c 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -370,7 +370,7 @@ describe("async generators", () => { }); }); - test("ReadableStream", async () => { + test("readable stream", async () => { const { client, callSpy } = mockClient(); const stream = traceable( @@ -478,6 +478,89 @@ describe("deferred input", () => { }); }); + test("readable stream", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function* parrotStream(input: ReadableStream) { + for await (const token of input) { + yield token; + } + }, + { client, tracingEnabled: true } + ); + + const readStream = new ReadableStream({ + async start(controller) { + for (const token of "Hello world".split(" ")) { + controller.enqueue(token); + } + controller.close(); + }, + }); + + const tokens: string[] = []; + for await (const token of parrotStream(readStream)) { + tokens.push(token); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: ["Hello", "world"] }, + outputs: { outputs: ["Hello", "world"] }, + }, + }, + }); + }); + + test("readable stream reader", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function* parrotStream(input: ReadableStream) { + const reader = input.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + yield value; + } + } finally { + reader.releaseLock(); + } + }, + { client, tracingEnabled: true } + ); + + const readStream = new ReadableStream({ + async start(controller) { + for (const token of "Hello world".split(" ")) { + controller.enqueue(token); + } + controller.close(); + }, + }); + + const tokens: string[] = []; + for await (const token of parrotStream(readStream)) { + tokens.push(token); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: ["Hello", "world"] }, + outputs: { outputs: ["Hello", "world"] }, + }, + }, + }); + }); + test("promise", async () => { const { client, callSpy } = mockClient(); const parrotStream = traceable( diff --git a/js/src/traceable.ts b/js/src/traceable.ts index f2fe53e8d..cda8b046c 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -122,6 +122,12 @@ const isThenable = (x: unknown): x is Promise => "then" in x && typeof x.then === "function"; +const isReadableStream = (x: unknown): x is ReadableStream => + x != null && + typeof x === "object" && + "getReader" in x && + typeof x.getReader === "function"; + const tracingIsEnabled = (tracingEnabled?: boolean): boolean => { if (tracingEnabled !== undefined) { return tracingEnabled; @@ -228,8 +234,23 @@ const getSerializablePromise = (arg: Promise) => { return promiseProxy as Promise & { toJSON: () => unknown }; }; -// attempt to const convertSerializableArg = (arg: unknown): unknown => { + if (isReadableStream(arg)) { + const proxyState: unknown[] = []; + const transform = new TransformStream({ + start: () => void 0, + transform: (chunk, controller) => { + proxyState.push(chunk); + controller.enqueue(chunk); + }, + flush: () => void 0, + }); + + const pipeThrough = arg.pipeThrough(transform); + Object.assign(pipeThrough, { toJSON: () => proxyState }); + return pipeThrough; + } + if (isAsyncIterable(arg)) { const proxyState: { current: (Promise> & { @@ -278,11 +299,11 @@ const convertSerializableArg = (arg: unknown): unknown => { return () => { const onlyNexts = proxyState.current; const serialized = onlyNexts.map( - (next) => next.toJSON() as IteratorResult + (next) => next.toJSON() as IteratorResult | undefined ); const chunks = serialized.reduce((memo, next) => { - if (next.value) memo.push(next.value); + if (next?.value) memo.push(next.value); return memo; }, []); From 4c19b0533861061c09c8d2c866429b669f3b1a0d Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Tue, 7 May 2024 00:32:27 +0200 Subject: [PATCH 5/6] Handle promise rejection by using callbacks --- js/src/tests/traceable.test.ts | 50 ++++++++++++++++++++++++++++++++++ js/src/traceable.ts | 4 ++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/js/src/tests/traceable.test.ts b/js/src/tests/traceable.test.ts index a92ce523c..6a7810cd2 100644 --- a/js/src/tests/traceable.test.ts +++ b/js/src/tests/traceable.test.ts @@ -596,6 +596,56 @@ describe("deferred input", () => { }, }); }); + + test("promise rejection", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function parrotStream(input: Promise) { + return await input; + }, + { client, tracingEnabled: true } + ); + + await expect(async () => { + await parrotStream(Promise.reject(new Error("Rejected!"))); + }).rejects.toThrow("Rejected!"); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: { error: {} } }, + error: "Error: Rejected!", + }, + }, + }); + }); + + test("promise rejection, callback handling", async () => { + const { client, callSpy } = mockClient(); + const parrotStream = traceable( + async function parrotStream(input: Promise) { + return input.then((value) => value); + }, + { client, tracingEnabled: true } + ); + + await expect(async () => { + await parrotStream(Promise.reject(new Error("Rejected!"))); + }).rejects.toThrow("Rejected!"); + + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: ["parrotStream:0"], + edges: [], + data: { + "parrotStream:0": { + inputs: { input: { error: {} } }, + error: "Error: Rejected!", + }, + }, + }); + }); }); describe("langchain", () => { diff --git a/js/src/traceable.ts b/js/src/traceable.ts index cda8b046c..f689c9824 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -193,7 +193,9 @@ const getSerializablePromise = (arg: Promise) => { const boundThen = arg[prop].bind(arg); return ( resolve: (value: unknown) => unknown, - reject?: (error: unknown) => unknown + reject: (error: unknown) => unknown = (x) => { + throw x; + } ) => { return boundThen( (value) => { From ce38b30750eb647589b5c372d15ff184fe0bc6c6 Mon Sep 17 00:00:00 2001 From: Tat Dat Duong Date: Tue, 7 May 2024 00:32:56 +0200 Subject: [PATCH 6/6] No need to make reject optional --- js/src/traceable.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/src/traceable.ts b/js/src/traceable.ts index f689c9824..cfec78c76 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -204,7 +204,7 @@ const getSerializablePromise = (arg: Promise) => { }, (error) => { proxyState.current = ["reject", error]; - return reject?.(error); + return reject(error); } ); };