Skip to content

Commit

Permalink
fix(handoff): support streaming returned generators (#708)
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd authored May 21, 2024
2 parents bc9b65f + ed34634 commit 7cf3074
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 61 deletions.
40 changes: 38 additions & 2 deletions js/src/langchain.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { CallbackManager } from "@langchain/core/callbacks/manager";
import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain";
import { Runnable, RunnableConfig } from "@langchain/core/runnables";
import {
Runnable,
RunnableConfig,
patchConfig,
getCallbackManagerForConfig,
} from "@langchain/core/runnables";

import { RunTree } from "./run_trees.js";
import { Run } from "./schemas.js";
Expand All @@ -9,6 +14,7 @@ import {
getCurrentRunTree,
isTraceableFunction,
} from "./traceable.js";
import { isAsyncIterable, isIteratorLike } from "./utils/asserts.js";

/**
* Converts the current run tree active within a traceable-wrapped function
Expand Down Expand Up @@ -113,7 +119,37 @@ export class RunnableTraceable<RunInput, RunOutput> extends Runnable<

async invoke(input: RunInput, options?: Partial<RunnableConfig>) {
const [config] = this._getOptionsList(options ?? {}, 1);
return (await this.func(config, input)) as RunOutput;
const callbacks = await getCallbackManagerForConfig(config);

return (await this.func(
patchConfig(config, { callbacks }),
input
)) as RunOutput;
}

async *_streamIterator(
input: RunInput,
options?: Partial<RunnableConfig>
): AsyncGenerator<RunOutput> {
const result = await this.invoke(input, options);

if (isAsyncIterable(result)) {
for await (const item of result) {
yield item as RunOutput;
}
return;
}

if (isIteratorLike(result)) {
while (true) {
const state: IteratorResult<unknown> = result.next();
if (state.done) break;
yield state.value as RunOutput;
}
return;
}

yield result;
}

static from(func: AnyTraceableFunction) {
Expand Down
17 changes: 14 additions & 3 deletions js/src/run_trees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,20 @@ export class RunTree implements BaseRun {
tracingEnabled = tracingEnabled || !!langChainTracer;
}

if (!parentRun) {
return new RunTree({
client,
tracingEnabled,
project_name: projectName,
name: props.name,
tags: props.tags,
metadata: props.metadata,
});
}

const parentRunTree = new RunTree({
name: parentRun?.name ?? "<parent>",
id: parentRun?.id,
name: parentRun.name,
id: parentRun.id,
client,
tracingEnabled,
project_name: projectName,
Expand All @@ -198,7 +209,7 @@ export class RunTree implements BaseRun {
});

return parentRunTree.createChild({
name: props?.name ?? "<lambda>",
name: props.name,
tags: props.tags,
metadata: props.metadata,
});
Expand Down
124 changes: 122 additions & 2 deletions js/src/tests/traceable_langchain.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ describe("to langchain", () => {
});
});

test("to langchain stream", async () => {
test("stream", async () => {
const { client, callSpy } = mockClient();

const main = traceable(
Expand Down Expand Up @@ -100,7 +100,7 @@ describe("to langchain", () => {
});
});

test("to langchain batch", async () => {
test("batch", async () => {
const { client, callSpy } = mockClient();

const main = traceable(
Expand Down Expand Up @@ -191,6 +191,126 @@ describe("to traceable", () => {
],
});
});

test("array stream", async () => {
const { client, callSpy } = mockClient();

const source = RunnableTraceable.from(
traceable(function (input: { text: string }) {
return input.text.split(" ");
})
);

const tokens: unknown[] = [];
for await (const chunk of await source.stream(
{ text: "Hello world" },
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore client might be of different type
{ callbacks: [new LangChainTracer({ client })] }
)) {
tokens.push(chunk);
}

expect(tokens).toEqual([["Hello", "world"]]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["<lambda>:0"],
edges: [],
});
});

test("generator stream", async () => {
const { client, callSpy } = mockClient();

const source = RunnableTraceable.from(
traceable(function* (input: { text: string }) {
const chunks = input.text.split(" ");
for (const chunk of chunks) {
yield chunk;
}
})
);

const tokens: unknown[] = [];
for await (const chunk of await source.stream(
{ text: "Hello world" },
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore client might be of different type
{ callbacks: [new LangChainTracer({ client })] }
)) {
tokens.push(chunk);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["<lambda>:0"],
edges: [],
});
});

test("readable stream", async () => {
const { client, callSpy } = mockClient();

const source = RunnableTraceable.from(
traceable(async function (input: { text: string }) {
const readStream = new ReadableStream({
async pull(controller) {
for (const item of input.text.split(" ")) {
controller.enqueue(item);
}
controller.close();
},
});

return readStream;
})
);

const tokens: unknown[] = [];
for await (const chunk of await source.stream(
{ text: "Hello world" },
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore client might be of different type
{ callbacks: [new LangChainTracer({ client })] }
)) {
tokens.push(chunk);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["<lambda>:0"],
edges: [],
});
});

test("async generator stream", async () => {
const { client, callSpy } = mockClient();
const source = RunnableTraceable.from(
traceable(async function* (input: { text: string }) {
const chunks = input.text.split(" ");
for (const chunk of chunks) {
yield chunk;
}
})
);

const tokens: unknown[] = [];
for await (const chunk of await source.stream(
{ text: "Hello world" },
{
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore client might be of different type
callbacks: [new LangChainTracer({ client })],
}
)) {
tokens.push(chunk);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["<lambda>:0"],
edges: [],
});
});
});

test("explicit nested", async () => {
Expand Down
64 changes: 10 additions & 54 deletions js/src/traceable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,20 @@ import {
AsyncLocalStorageProviderSingleton,
} from "./singletons/traceable.js";
import { TraceableFunction } from "./singletons/types.js";
import {
isKVMap,
isReadableStream,
isAsyncIterable,
isIteratorLike,
isThenable,
isGenerator,
isPromiseMethod,
} from "./utils/asserts.js";

AsyncLocalStorageProviderSingleton.initializeGlobalInstance(
new AsyncLocalStorage<RunTree | undefined>()
);

function isPromiseMethod(
x: string | symbol
): x is "then" | "catch" | "finally" {
if (x === "then" || x === "catch" || x === "finally") {
return true;
}
return false;
}

function isKVMap(x: unknown): x is Record<string, unknown> {
if (typeof x !== "object" || x == null) {
return false;
}

const prototype = Object.getPrototypeOf(x);
return (
(prototype === null ||
prototype === Object.prototype ||
Object.getPrototypeOf(prototype) === null) &&
!(Symbol.toStringTag in x) &&
!(Symbol.iterator in x)
);
}

const isAsyncIterable = (x: unknown): x is AsyncIterable<unknown> =>
x != null &&
typeof x === "object" &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
typeof (x as any)[Symbol.asyncIterator] === "function";

const isIteratorLike = (x: unknown): x is Iterator<unknown> =>
x != null &&
typeof x === "object" &&
"next" in x &&
typeof x.next === "function";

const GeneratorFunction = function* () {}.constructor;

const isGenerator = (x: unknown): x is Generator =>
// eslint-disable-next-line no-instanceof/no-instanceof
x != null && typeof x === "function" && x instanceof GeneratorFunction;

const isThenable = (x: unknown): x is Promise<unknown> =>
x != null &&
typeof x === "object" &&
"then" in x &&
typeof x.then === "function";

const isReadableStream = (x: unknown): x is ReadableStream =>
x != null &&
typeof x === "object" &&
"getReader" in x &&
typeof x.getReader === "function";

const handleRunInputs = (rawInputs: unknown[]): KVMap => {
const firstInput = rawInputs[0];

Expand All @@ -83,6 +38,7 @@ const handleRunInputs = (rawInputs: unknown[]): KVMap => {
if (rawInputs.length > 1) {
return { args: rawInputs };
}

if (isKVMap(firstInput)) {
return firstInput;
}
Expand Down
51 changes: 51 additions & 0 deletions js/src/utils/asserts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
export function isPromiseMethod(
x: string | symbol
): x is "then" | "catch" | "finally" {
if (x === "then" || x === "catch" || x === "finally") {
return true;
}
return false;
}

export function isKVMap(x: unknown): x is Record<string, unknown> {
if (typeof x !== "object" || x == null) {
return false;
}

const prototype = Object.getPrototypeOf(x);
return (
(prototype === null ||
prototype === Object.prototype ||
Object.getPrototypeOf(prototype) === null) &&
!(Symbol.toStringTag in x) &&
!(Symbol.iterator in x)
);
}
export const isAsyncIterable = (x: unknown): x is AsyncIterable<unknown> =>
x != null &&
typeof x === "object" &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
typeof (x as any)[Symbol.asyncIterator] === "function";

export const isIteratorLike = (x: unknown): x is Iterator<unknown> =>
x != null &&
typeof x === "object" &&
"next" in x &&
typeof x.next === "function";

const GeneratorFunction = function* () {}.constructor;
export const isGenerator = (x: unknown): x is Generator =>
// eslint-disable-next-line no-instanceof/no-instanceof
x != null && typeof x === "function" && x instanceof GeneratorFunction;

export const isThenable = (x: unknown): x is Promise<unknown> =>
x != null &&
typeof x === "object" &&
"then" in x &&
typeof x.then === "function";

export const isReadableStream = (x: unknown): x is ReadableStream =>
x != null &&
typeof x === "object" &&
"getReader" in x &&
typeof x.getReader === "function";

0 comments on commit 7cf3074

Please sign in to comment.