Skip to content

Commit

Permalink
checkpoint[patch]: Avoid deeply serializing LangChain classes (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Aug 27, 2024
1 parent d3c4494 commit 1806051
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 172 deletions.
75 changes: 57 additions & 18 deletions libs/checkpoint/src/serde/jsonplus.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,60 @@
import { load } from "@langchain/core/load";
import { SerializerProtocol } from "./base.js";

function isLangChainSerializable(value: Record<string, unknown>) {
return (
typeof value.lc_serializable === "boolean" && Array.isArray(value.lc_id)
);
}

function isLangChainSerializedObject(value: Record<string, unknown>) {
return (
value !== null &&
value.lc === 1 &&
value.type === "constructor" &&
Array.isArray(value.id)
);
}

const _serialize = (value: any, seen = new WeakSet()): string => {
const defaultValue = _default("", value);

if (defaultValue === null) {
return "null";
} else if (typeof defaultValue === "string") {
return JSON.stringify(defaultValue);
} else if (
typeof defaultValue === "number" ||
typeof defaultValue === "boolean"
) {
return defaultValue.toString();
} else if (typeof defaultValue === "object") {
if (seen.has(defaultValue)) {
throw new TypeError("Circular reference detected");
}
seen.add(defaultValue);

if (Array.isArray(defaultValue)) {
const result = `[${defaultValue
.map((item) => _serialize(item, seen))
.join(",")}]`;
seen.delete(defaultValue);
return result;
} else if (isLangChainSerializable(defaultValue)) {
return JSON.stringify(defaultValue);
} else {
const entries = Object.entries(defaultValue).map(
([k, v]) => `${JSON.stringify(k)}:${_serialize(v, seen)}`
);
const result = `{${entries.join(",")}}`;
seen.delete(defaultValue);
return result;
}
}
// Only be reached for functions or symbols
return JSON.stringify(defaultValue);
};

async function _reviver(value: any): Promise<any> {
if (value && typeof value === "object") {
if (value.lc === 2 && value.type === "undefined") {
Expand Down Expand Up @@ -40,7 +94,7 @@ async function _reviver(value: any): Promise<any> {
} catch (error) {
return value;
}
} else if (value.lc === 1) {
} else if (isLangChainSerializedObject(value)) {
return load(JSON.stringify(value));
} else if (Array.isArray(value)) {
return Promise.all(value.map((item) => _reviver(item)));
Expand Down Expand Up @@ -93,23 +147,8 @@ function _default(_key: string, obj: any): any {

export class JsonPlusSerializer implements SerializerProtocol {
protected _dumps(obj: any): Uint8Array {
const jsonString = JSON.stringify(obj, (key, value) => {
if (value && typeof value === "object") {
if (Array.isArray(value)) {
// Handle arrays
return value.map((item) => _default(key, item));
} else {
// Handle objects
const serialized: any = {};
for (const [k, v] of Object.entries(value)) {
serialized[k] = _default(k, v);
}
return serialized;
}
}
return _default(key, value);
});
return new TextEncoder().encode(jsonString);
const encoder = new TextEncoder();
return encoder.encode(_serialize(obj));
}

dumpsTyped(obj: any): [string, Uint8Array] {
Expand Down
54 changes: 54 additions & 0 deletions libs/checkpoint/src/serde/tests/jsonplus.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,44 @@ import { AIMessage, HumanMessage } from "@langchain/core/messages";
import { uuid6 } from "../../id.js";
import { JsonPlusSerializer } from "../jsonplus.js";

const messageWithToolCall = new AIMessage({
content: "",
tool_calls: [
{
name: "current_weather_sf",
args: {
input: "",
},
type: "tool_call",
id: "call_Co6nrPmiAdWWZQHCNdEZUjTe",
},
],
invalid_tool_calls: [],
additional_kwargs: {
function_call: undefined,
tool_calls: [
{
id: "call_Co6nrPmiAdWWZQHCNdEZUjTe",
type: "function",
function: {
name: "current_weather_sf",
arguments: '{"input":""}',
},
},
],
},
response_metadata: {
tokenUsage: {
completionTokens: 15,
promptTokens: 84,
totalTokens: 99,
},
finish_reason: "tool_calls",
system_fingerprint: "fp_a2ff031fb5",
},
id: "chatcmpl-A0s8Rd97RnFo6xMlYgpJDDfV8J1cl",
});

const complexValue = {
number: 1,
id: uuid6(-1),
Expand All @@ -14,6 +52,7 @@ const complexValue = {
]),
regex: /foo*/gi,
message: new AIMessage("test message"),
messageWithToolCall,
array: [
new Error("nestedfoo"),
5,
Expand All @@ -40,6 +79,7 @@ const VALUES = [
["empty string", ""],
["simple string", "foobar"],
["various data types", complexValue],
["an AIMessage with a tool call", messageWithToolCall],
] satisfies [string, unknown][];

it.each(VALUES)(
Expand All @@ -51,3 +91,17 @@ it.each(VALUES)(
expect(deserialized).toEqual(value);
}
);

it("Should throw an error for circular JSON inputs", async () => {
const a: Record<string, unknown> = {};
const b: Record<string, unknown> = {};
a.b = b;
b.a = a;

const circular = {
a,
b,
};
const serde = new JsonPlusSerializer();
expect(() => serde.dumpsTyped(circular)).toThrow();
});
Loading

0 comments on commit 1806051

Please sign in to comment.