Skip to content

Commit

Permalink
fix(js): add support for tracing generators in input and output (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd authored May 9, 2024
2 parents 90c5ac5 + cca947b commit ff0d76f
Show file tree
Hide file tree
Showing 3 changed files with 599 additions and 19 deletions.
1 change: 1 addition & 0 deletions js/src/run_trees.ts
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ export class RunTree implements BaseRun {
const runUpdate: RunUpdate = {
end_time: this.end_time,
error: this.error,
inputs: this.inputs,
outputs: this.outputs,
parent_run_id: this.parent_run?.id,
reference_example_id: this.reference_example_id,
Expand Down
324 changes: 323 additions & 1 deletion js/src/tests/traceable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ describe("async generators", () => {
});
});

test("ReadableStream", async () => {
test("readable stream", async () => {
const { client, callSpy } = mockClient();

const stream = traceable(
Expand Down Expand Up @@ -407,6 +407,247 @@ describe("async generators", () => {
});
});

describe("deferred input", () => {
test("generator", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function* parrotStream(input: Generator<string>) {
for (const token of input) {
yield token;
}
},
{ client, tracingEnabled: true }
);

const inputGenerator = function* () {
for (const token of "Hello world".split(" ")) {
yield token;
}
};

const tokens: string[] = [];
for await (const token of parrotStream(inputGenerator())) {
tokens.push(token);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: ["Hello", "world"] },
outputs: { outputs: ["Hello", "world"] },
},
},
});
});

test("async generator", async () => {
const { client, callSpy } = mockClient();
const inputStream = async function* inputStream() {
for (const token of "Hello world".split(" ")) {
yield token;
}
};

const parrotStream = traceable(
async function* parrotStream(input: AsyncGenerator<string>) {
for await (const token of input) {
yield token;
}
},
{ client, tracingEnabled: true }
);

const tokens: string[] = [];
for await (const token of parrotStream(inputStream())) {
tokens.push(token);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: ["Hello", "world"] },
outputs: { outputs: ["Hello", "world"] },
},
},
});
});

test("readable stream", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function* parrotStream(input: ReadableStream<string>) {
for await (const token of input) {
yield token;
}
},
{ client, tracingEnabled: true }
);

const readStream = new ReadableStream({
async start(controller) {
for (const token of "Hello world".split(" ")) {
controller.enqueue(token);
}
controller.close();
},
});

const tokens: string[] = [];
for await (const token of parrotStream(readStream)) {
tokens.push(token);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: ["Hello", "world"] },
outputs: { outputs: ["Hello", "world"] },
},
},
});
});

test("readable stream reader", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function* parrotStream(input: ReadableStream<string>) {
const reader = input.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
yield value;
}
} finally {
reader.releaseLock();
}
},
{ client, tracingEnabled: true }
);

const readStream = new ReadableStream({
async start(controller) {
for (const token of "Hello world".split(" ")) {
controller.enqueue(token);
}
controller.close();
},
});

const tokens: string[] = [];
for await (const token of parrotStream(readStream)) {
tokens.push(token);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: ["Hello", "world"] },
outputs: { outputs: ["Hello", "world"] },
},
},
});
});

test("promise", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function* parrotStream(input: Promise<string[]>) {
// eslint-disable-next-line no-instanceof/no-instanceof
if (!(input instanceof Promise)) {
throw new Error("Input must be a promise");
}

for (const token of await input) {
yield token;
}
},
{ client, tracingEnabled: true }
);

const tokens: string[] = [];
for await (const token of parrotStream(
Promise.resolve(["Hello", "world"])
)) {
tokens.push(token);
}

expect(tokens).toEqual(["Hello", "world"]);
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: ["Hello", "world"] },
outputs: { outputs: ["Hello", "world"] },
},
},
});
});

test("promise rejection", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function parrotStream(input: Promise<string[]>) {
return await input;
},
{ client, tracingEnabled: true }
);

await expect(async () => {
await parrotStream(Promise.reject(new Error("Rejected!")));
}).rejects.toThrow("Rejected!");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: { error: {} } },
error: "Error: Rejected!",
},
},
});
});

test("promise rejection, callback handling", async () => {
const { client, callSpy } = mockClient();
const parrotStream = traceable(
async function parrotStream(input: Promise<string[]>) {
return input.then((value) => value);
},
{ client, tracingEnabled: true }
);

await expect(async () => {
await parrotStream(Promise.reject(new Error("Rejected!")));
}).rejects.toThrow("Rejected!");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["parrotStream:0"],
edges: [],
data: {
"parrotStream:0": {
inputs: { input: { error: {} } },
error: "Error: Rejected!",
},
},
});
});
});

describe("langchain", () => {
test.skip("bound", async () => {
const { client, callSpy } = mockClient();
Expand Down Expand Up @@ -442,6 +683,87 @@ describe("langchain", () => {
});
});

describe("generator", () => {
function gatherAll(iterator: Iterator<unknown>) {
const chunks: unknown[] = [];
// eslint-disable-next-line no-constant-condition
while (true) {
const next = iterator.next();
chunks.push(next.value);
if (next.done) break;
}

return chunks;
}

test("yield", async () => {
const { client, callSpy } = mockClient();

function* generator() {
for (let i = 0; i < 3; ++i) yield i;
}

const traced = traceable(generator, { client, tracingEnabled: true });

expect(gatherAll(await traced())).toEqual(gatherAll(generator()));
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["generator:0"],
edges: [],
data: {
"generator:0": {
outputs: { outputs: [0, 1, 2] },
},
},
});
});

test("with return", async () => {
const { client, callSpy } = mockClient();

function* generator() {
for (let i = 0; i < 3; ++i) yield i;
return 3;
}

const traced = traceable(generator, { client, tracingEnabled: true });

expect(gatherAll(await traced())).toEqual(gatherAll(generator()));
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["generator:0"],
edges: [],
data: { "generator:0": { outputs: { outputs: [0, 1, 2, 3] } } },
});
});

test("nested", async () => {
const { client, callSpy } = mockClient();

function* generator() {
function* child() {
for (let i = 0; i < 3; ++i) yield i;
}

for (let i = 0; i < 2; ++i) {
for (const num of child()) yield num;
}

return 3;
}

const traced = traceable(generator, { client, tracingEnabled: true });
expect(gatherAll(await traced())).toEqual(gatherAll(generator()));
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["generator:0"],
edges: [],
data: {
"generator:0": {
outputs: { outputs: [0, 1, 2, 0, 1, 2, 3] },
},
},
});
});
});

test("metadata", async () => {
const { client, callSpy } = mockClient();
const main = traceable(async (): Promise<number> => 42, {
Expand Down
Loading

0 comments on commit ff0d76f

Please sign in to comment.