Skip to content

Commit

Permalink
fix(js): Adds fix for circular references in inputs and outputs (#985)
Browse files Browse the repository at this point in the history
~~This does add a few stringify calls - we could do this more
efficiently if needed but I am not sure it's worth prematurely
optimizing~~

Fixes #962
  • Loading branch information
jacoblee93 authored Sep 9, 2024
1 parent c6f4cad commit b5e583f
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 5 deletions.
11 changes: 6 additions & 5 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import {
parsePromptIdentifier,
} from "./utils/prompts.js";
import { raiseForStatus } from "./utils/error.js";
import { stringifyForTracing } from "./utils/serde.js";
import { _getFetchImplementation } from "./singletons/fetch.js";

export interface ClientConfig {
Expand Down Expand Up @@ -800,7 +801,7 @@ export class Client {
{
method: "POST",
headers,
body: JSON.stringify(mergedRunCreateParams[0]),
body: stringifyForTracing(mergedRunCreateParams[0]),
signal: AbortSignal.timeout(this.timeout_ms),
...this.fetchOptions,
}
Expand Down Expand Up @@ -897,12 +898,12 @@ export class Client {
const batchItems = rawBatch[key].reverse();
let batchItem = batchItems.pop();
while (batchItem !== undefined) {
const stringifiedBatchItem = JSON.stringify(batchItem);
const stringifiedBatchItem = stringifyForTracing(batchItem);
if (
currentBatchSizeBytes > 0 &&
currentBatchSizeBytes + stringifiedBatchItem.length > sizeLimitBytes
) {
await this._postBatchIngestRuns(JSON.stringify(batchChunks));
await this._postBatchIngestRuns(stringifyForTracing(batchChunks));
currentBatchSizeBytes = 0;
batchChunks.post = [];
batchChunks.patch = [];
Expand All @@ -913,7 +914,7 @@ export class Client {
}
}
if (batchChunks.post.length > 0 || batchChunks.patch.length > 0) {
await this._postBatchIngestRuns(JSON.stringify(batchChunks));
await this._postBatchIngestRuns(stringifyForTracing(batchChunks));
}
}

Expand Down Expand Up @@ -975,7 +976,7 @@ export class Client {
{
method: "PATCH",
headers,
body: JSON.stringify(run),
body: stringifyForTracing(run),
signal: AbortSignal.timeout(this.timeout_ms),
...this.fetchOptions,
}
Expand Down
78 changes: 78 additions & 0 deletions js/src/tests/batch_client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { jest } from "@jest/globals";
import { v4 as uuidv4 } from "uuid";
import { Client } from "../client.js";
import { convertToDottedOrderFormat } from "../run_trees.js";
import { CIRCULAR_VALUE_REPLACEMENT_STRING } from "../utils/serde.js";
import { _getFetchImplementation } from "../singletons/fetch.js";

describe("Batch client tracing", () => {
Expand Down Expand Up @@ -511,4 +512,81 @@ describe("Batch client tracing", () => {
expect.objectContaining({ body: expect.any(String) })
);
});

it("Should handle circular values", async () => {
const client = new Client({
apiKey: "test-api-key",
autoBatchTracing: true,
});
const callSpy = jest
.spyOn((client as any).batchIngestCaller, "call")
.mockResolvedValue({
ok: true,
text: () => "",
});
jest
.spyOn(client as any, "batchEndpointIsSupported")
.mockResolvedValue(true);
const projectName = "__test_batch";
const a: Record<string, any> = {};
const b: Record<string, any> = {};
a.b = b;
b.a = a;

const runId = uuidv4();
const dottedOrder = convertToDottedOrderFormat(
new Date().getTime() / 1000,
runId
);
await client.createRun({
id: runId,
project_name: projectName,
name: "test_run",
run_type: "llm",
inputs: a,
trace_id: runId,
dotted_order: dottedOrder,
});

const endTime = Math.floor(new Date().getTime() / 1000);

await client.updateRun(runId, {
outputs: b,
dotted_order: dottedOrder,
trace_id: runId,
end_time: endTime,
});

await new Promise((resolve) => setTimeout(resolve, 100));

const calledRequestParam: any = callSpy.mock.calls[0][2];
expect(JSON.parse(calledRequestParam?.body)).toEqual({
post: [
expect.objectContaining({
id: runId,
run_type: "llm",
inputs: {
b: {
a: {
result: CIRCULAR_VALUE_REPLACEMENT_STRING,
},
},
},
outputs: {
result: CIRCULAR_VALUE_REPLACEMENT_STRING,
},
end_time: endTime,
trace_id: runId,
dotted_order: dottedOrder,
}),
],
patch: [],
});

expect(callSpy).toHaveBeenCalledWith(
_getFetchImplementation(),
"https://api.smith.langchain.com/runs/batch",
expect.objectContaining({ body: expect.any(String) })
);
});
});
36 changes: 36 additions & 0 deletions js/src/tests/traceable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { RunTree, RunTreeConfig } from "../run_trees.js";
import { ROOT, traceable, withRunTree } from "../traceable.js";
import { getAssumedTreeFromCalls } from "./utils/tree.js";
import { mockClient } from "./utils/mock_client.js";
import { CIRCULAR_VALUE_REPLACEMENT_STRING } from "../utils/serde.js";

test("basic traceable implementation", async () => {
const { client, callSpy } = mockClient();
Expand Down Expand Up @@ -70,6 +71,41 @@ test("nested traceable implementation", async () => {
});
});

test("trace circular input and output objects", async () => {
const { client, callSpy } = mockClient();
const a: Record<string, any> = {};
const b: Record<string, any> = {};
a.b = b;
b.a = a;
const llm = traceable(
async function foo(_: any) {
return a;
},
{ client, tracingEnabled: true }
);

await llm(a);

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["foo:0"],
edges: [],
data: {
"foo:0": {
inputs: {
b: {
a: {
result: CIRCULAR_VALUE_REPLACEMENT_STRING,
},
},
},
outputs: {
result: CIRCULAR_VALUE_REPLACEMENT_STRING,
},
},
},
});
});

test("passing run tree manually", async () => {
const { client, callSpy } = mockClient();
const child = traceable(
Expand Down
22 changes: 22 additions & 0 deletions js/src/utils/serde.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
export const CIRCULAR_VALUE_REPLACEMENT_STRING = "[Circular]";

/**
* JSON.stringify version that handles circular references by replacing them
* with an object marking them as such ({ result: "[Circular]" }).
*/
export const stringifyForTracing = (value: any): string => {
const seen = new WeakSet();

const serializer = (_: string, value: any): any => {
if (typeof value === "object" && value !== null) {
if (seen.has(value)) {
return {
result: CIRCULAR_VALUE_REPLACEMENT_STRING,
};
}
seen.add(value);
}
return value;
};
return JSON.stringify(value, serializer);
};

0 comments on commit b5e583f

Please sign in to comment.