diff --git a/js/src/langchain.ts b/js/src/langchain.ts index 505e05424..a3a4de845 100644 --- a/js/src/langchain.ts +++ b/js/src/langchain.ts @@ -1,6 +1,11 @@ import { CallbackManager } from "@langchain/core/callbacks/manager"; import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain"; -import { Runnable, RunnableConfig } from "@langchain/core/runnables"; +import { + Runnable, + RunnableConfig, + patchConfig, + getCallbackManagerForConfig, +} from "@langchain/core/runnables"; import { RunTree } from "./run_trees.js"; import { Run } from "./schemas.js"; @@ -9,6 +14,7 @@ import { getCurrentRunTree, isTraceableFunction, } from "./traceable.js"; +import { isAsyncIterable, isIteratorLike } from "./utils/asserts.js"; /** * Converts the current run tree active within a traceable-wrapped function @@ -113,7 +119,37 @@ export class RunnableTraceable extends Runnable< async invoke(input: RunInput, options?: Partial) { const [config] = this._getOptionsList(options ?? {}, 1); - return (await this.func(config, input)) as RunOutput; + const callbacks = await getCallbackManagerForConfig(config); + + return (await this.func( + patchConfig(config, { callbacks }), + input + )) as RunOutput; + } + + async *_streamIterator( + input: RunInput, + options?: Partial + ): AsyncGenerator { + const result = await this.invoke(input, options); + + if (isAsyncIterable(result)) { + for await (const item of result) { + yield item as RunOutput; + } + return; + } + + if (isIteratorLike(result)) { + while (true) { + const state: IteratorResult = result.next(); + if (state.done) break; + yield state.value as RunOutput; + } + return; + } + + yield result; } static from(func: AnyTraceableFunction) { diff --git a/js/src/run_trees.ts b/js/src/run_trees.ts index c88c4c193..1c877bc5a 100644 --- a/js/src/run_trees.ts +++ b/js/src/run_trees.ts @@ -180,9 +180,20 @@ export class RunTree implements BaseRun { tracingEnabled = tracingEnabled || !!langChainTracer; } + if (!parentRun) { + return new RunTree({ + client, + tracingEnabled, + project_name: projectName, + name: props.name, + tags: props.tags, + metadata: props.metadata, + }); + } + const parentRunTree = new RunTree({ - name: parentRun?.name ?? "", - id: parentRun?.id, + name: parentRun.name, + id: parentRun.id, client, tracingEnabled, project_name: projectName, @@ -198,7 +209,7 @@ export class RunTree implements BaseRun { }); return parentRunTree.createChild({ - name: props?.name ?? "", + name: props.name, tags: props.tags, metadata: props.metadata, }); diff --git a/js/src/tests/traceable_langchain.test.ts b/js/src/tests/traceable_langchain.test.ts index b3056ac53..c5d03e027 100644 --- a/js/src/tests/traceable_langchain.test.ts +++ b/js/src/tests/traceable_langchain.test.ts @@ -63,7 +63,7 @@ describe("to langchain", () => { }); }); - test("to langchain stream", async () => { + test("stream", async () => { const { client, callSpy } = mockClient(); const main = traceable( @@ -100,7 +100,7 @@ describe("to langchain", () => { }); }); - test("to langchain batch", async () => { + test("batch", async () => { const { client, callSpy } = mockClient(); const main = traceable( @@ -191,6 +191,126 @@ describe("to traceable", () => { ], }); }); + + test("array stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(function (input: { text: string }) { + return input.text.split(" "); + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual([["Hello", "world"]]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + + test("generator stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(function* (input: { text: string }) { + const chunks = input.text.split(" "); + for (const chunk of chunks) { + yield chunk; + } + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + + test("readable stream", async () => { + const { client, callSpy } = mockClient(); + + const source = RunnableTraceable.from( + traceable(async function (input: { text: string }) { + const readStream = new ReadableStream({ + async pull(controller) { + for (const item of input.text.split(" ")) { + controller.enqueue(item); + } + controller.close(); + }, + }); + + return readStream; + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + { callbacks: [new LangChainTracer({ client })] } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); + + test("async generator stream", async () => { + const { client, callSpy } = mockClient(); + const source = RunnableTraceable.from( + traceable(async function* (input: { text: string }) { + const chunks = input.text.split(" "); + for (const chunk of chunks) { + yield chunk; + } + }) + ); + + const tokens: unknown[] = []; + for await (const chunk of await source.stream( + { text: "Hello world" }, + { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore client might be of different type + callbacks: [new LangChainTracer({ client })], + } + )) { + tokens.push(chunk); + } + + expect(tokens).toEqual(["Hello", "world"]); + expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({ + nodes: [":0"], + edges: [], + }); + }); }); test("explicit nested", async () => { diff --git a/js/src/traceable.ts b/js/src/traceable.ts index 82f86b616..173f1cba3 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -14,65 +14,20 @@ import { AsyncLocalStorageProviderSingleton, } from "./singletons/traceable.js"; import { TraceableFunction } from "./singletons/types.js"; +import { + isKVMap, + isReadableStream, + isAsyncIterable, + isIteratorLike, + isThenable, + isGenerator, + isPromiseMethod, +} from "./utils/asserts.js"; AsyncLocalStorageProviderSingleton.initializeGlobalInstance( new AsyncLocalStorage() ); -function isPromiseMethod( - x: string | symbol -): x is "then" | "catch" | "finally" { - if (x === "then" || x === "catch" || x === "finally") { - return true; - } - return false; -} - -function isKVMap(x: unknown): x is Record { - if (typeof x !== "object" || x == null) { - return false; - } - - const prototype = Object.getPrototypeOf(x); - return ( - (prototype === null || - prototype === Object.prototype || - Object.getPrototypeOf(prototype) === null) && - !(Symbol.toStringTag in x) && - !(Symbol.iterator in x) - ); -} - -const isAsyncIterable = (x: unknown): x is AsyncIterable => - x != null && - typeof x === "object" && - // eslint-disable-next-line @typescript-eslint/no-explicit-any - typeof (x as any)[Symbol.asyncIterator] === "function"; - -const isIteratorLike = (x: unknown): x is Iterator => - x != null && - typeof x === "object" && - "next" in x && - typeof x.next === "function"; - -const GeneratorFunction = function* () {}.constructor; - -const isGenerator = (x: unknown): x is Generator => - // eslint-disable-next-line no-instanceof/no-instanceof - x != null && typeof x === "function" && x instanceof GeneratorFunction; - -const isThenable = (x: unknown): x is Promise => - x != null && - typeof x === "object" && - "then" in x && - typeof x.then === "function"; - -const isReadableStream = (x: unknown): x is ReadableStream => - x != null && - typeof x === "object" && - "getReader" in x && - typeof x.getReader === "function"; - const handleRunInputs = (rawInputs: unknown[]): KVMap => { const firstInput = rawInputs[0]; @@ -83,6 +38,7 @@ const handleRunInputs = (rawInputs: unknown[]): KVMap => { if (rawInputs.length > 1) { return { args: rawInputs }; } + if (isKVMap(firstInput)) { return firstInput; } diff --git a/js/src/utils/asserts.ts b/js/src/utils/asserts.ts new file mode 100644 index 000000000..55bc260db --- /dev/null +++ b/js/src/utils/asserts.ts @@ -0,0 +1,51 @@ +export function isPromiseMethod( + x: string | symbol +): x is "then" | "catch" | "finally" { + if (x === "then" || x === "catch" || x === "finally") { + return true; + } + return false; +} + +export function isKVMap(x: unknown): x is Record { + if (typeof x !== "object" || x == null) { + return false; + } + + const prototype = Object.getPrototypeOf(x); + return ( + (prototype === null || + prototype === Object.prototype || + Object.getPrototypeOf(prototype) === null) && + !(Symbol.toStringTag in x) && + !(Symbol.iterator in x) + ); +} +export const isAsyncIterable = (x: unknown): x is AsyncIterable => + x != null && + typeof x === "object" && + // eslint-disable-next-line @typescript-eslint/no-explicit-any + typeof (x as any)[Symbol.asyncIterator] === "function"; + +export const isIteratorLike = (x: unknown): x is Iterator => + x != null && + typeof x === "object" && + "next" in x && + typeof x.next === "function"; + +const GeneratorFunction = function* () {}.constructor; +export const isGenerator = (x: unknown): x is Generator => + // eslint-disable-next-line no-instanceof/no-instanceof + x != null && typeof x === "function" && x instanceof GeneratorFunction; + +export const isThenable = (x: unknown): x is Promise => + x != null && + typeof x === "object" && + "then" in x && + typeof x.then === "function"; + +export const isReadableStream = (x: unknown): x is ReadableStream => + x != null && + typeof x === "object" && + "getReader" in x && + typeof x.getReader === "function";