Skip to content

Commit

Permalink
fix(core): Clear inherited config for called callbacks (#7174)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Nov 11, 2024
1 parent 05e5813 commit 831f9de
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 143 deletions.
44 changes: 2 additions & 42 deletions langchain-core/src/callbacks/promises.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,3 @@
import PQueueMod from "p-queue";
import { awaitAllCallbacks, consumeCallback } from "../singletons/callbacks.js";

let queue: typeof import("p-queue")["default"]["prototype"];

/**
* Creates a queue using the p-queue library. The queue is configured to
* auto-start and has a concurrency of 1, meaning it will process tasks
* one at a time.
*/
function createQueue() {
const PQueue = "default" in PQueueMod ? PQueueMod.default : PQueueMod;
return new PQueue({
autoStart: true,
concurrency: 1,
});
}

/**
* Consume a promise, either adding it to the queue or waiting for it to resolve
* @param promiseFn Promise to consume
* @param wait Whether to wait for the promise to resolve or resolve immediately
*/
export async function consumeCallback<T>(
promiseFn: () => Promise<T> | T | void,
wait: boolean
): Promise<void> {
if (wait === true) {
await promiseFn();
} else {
if (typeof queue === "undefined") {
queue = createQueue();
}
void queue.add(promiseFn);
}
}

/**
* Waits for all promises in the queue to resolve. If the queue is
* undefined, it immediately resolves a promise.
*/
export function awaitAllCallbacks(): Promise<void> {
return typeof queue !== "undefined" ? queue.onIdle() : Promise.resolve();
}
export { awaitAllCallbacks, consumeCallback };
38 changes: 38 additions & 0 deletions langchain-core/src/callbacks/tests/callbacks.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable no-promise-executor-return */
import { test, expect } from "@jest/globals";
import * as uuid from "uuid";
import { AsyncLocalStorage } from "node:async_hooks";
import { CallbackManager } from "../manager.js";
import { BaseCallbackHandler, type BaseCallbackHandlerInput } from "../base.js";
import type { Serialized } from "../../load/serializable.js";
Expand All @@ -10,6 +11,8 @@ import type { AgentAction, AgentFinish } from "../../agents.js";
import { BaseMessage, HumanMessage } from "../../messages/index.js";
import type { LLMResult } from "../../outputs.js";
import { RunnableLambda } from "../../runnables/base.js";
import { AsyncLocalStorageProviderSingleton } from "../../singletons/index.js";
import { awaitAllCallbacks } from "../promises.js";

class FakeCallbackHandler extends BaseCallbackHandler {
name = `fake-${uuid.v4()}`;
Expand Down Expand Up @@ -536,3 +539,38 @@ test("chain should still run if a normal callback handler throws an error", asyn
);
expect(res).toEqual("hello world");
});

test("runnables in callbacks should be root runs", async () => {
AsyncLocalStorageProviderSingleton.initializeGlobalInstance(
new AsyncLocalStorage()
);
const nestedChain = RunnableLambda.from(async () => {
const subRun = RunnableLambda.from(async () => "hello world");
return await subRun.invoke({ foo: "bar" });
});
let error;
let finalInputs;
const res = await nestedChain.invoke(
{},
{
callbacks: [
{
handleChainStart: (_chain, inputs) => {
finalInputs = inputs;
try {
expect(
AsyncLocalStorageProviderSingleton.getRunnableConfig()
).toEqual(undefined);
} catch (e) {
error = e;
}
},
},
],
}
);
await awaitAllCallbacks();
expect(res).toEqual("hello world");
expect(error).toBe(undefined);
expect(finalInputs).toEqual({ foo: "bar" });
});
20 changes: 20 additions & 0 deletions langchain-core/src/singletons/async_local_storage/globals.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
export interface AsyncLocalStorageInterface {
getStore: () => any | undefined;

run: <T>(store: any, callback: () => T) => T;

enterWith: (store: any) => void;
}

export const TRACING_ALS_KEY = Symbol.for("ls:tracing_async_local_storage");

export const setGlobalAsyncLocalStorageInstance = (
instance: AsyncLocalStorageInterface
) => {
(globalThis as any)[TRACING_ALS_KEY] = instance;
};

export const getGlobalAsyncLocalStorageInstance = () => {
return (globalThis as any)[TRACING_ALS_KEY];
};
98 changes: 98 additions & 0 deletions langchain-core/src/singletons/async_local_storage/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { RunTree } from "langsmith";
import {
AsyncLocalStorageInterface,
getGlobalAsyncLocalStorageInstance,
setGlobalAsyncLocalStorageInstance,
} from "./globals.js";
import { CallbackManager } from "../../callbacks/manager.js";
import { LangChainTracer } from "../../tracers/tracer_langchain.js";

export class MockAsyncLocalStorage implements AsyncLocalStorageInterface {
getStore(): any {
return undefined;
}

run<T>(_store: any, callback: () => T): T {
return callback();
}

enterWith(_store: any) {
return undefined;
}
}

const mockAsyncLocalStorage = new MockAsyncLocalStorage();

const LC_CHILD_KEY = Symbol.for("lc:child_config");

export const _CONTEXT_VARIABLES_KEY = Symbol.for("lc:context_variables");

class AsyncLocalStorageProvider {
getInstance(): AsyncLocalStorageInterface {
return getGlobalAsyncLocalStorageInstance() ?? mockAsyncLocalStorage;
}

getRunnableConfig() {
const storage = this.getInstance();
// this has the runnable config
// which means that we should also have an instance of a LangChainTracer
// with the run map prepopulated
return storage.getStore()?.extra?.[LC_CHILD_KEY];
}

runWithConfig<T>(
config: any,
callback: () => T,
avoidCreatingRootRunTree?: boolean
): T {
const callbackManager = CallbackManager._configureSync(
config?.callbacks,
undefined,
config?.tags,
undefined,
config?.metadata
);
const storage = this.getInstance();
const previousValue = storage.getStore();
const parentRunId = callbackManager?.getParentRunId();

const langChainTracer = callbackManager?.handlers?.find(
(handler) => handler?.name === "langchain_tracer"
) as LangChainTracer | undefined;

let runTree;
if (langChainTracer && parentRunId) {
runTree = langChainTracer.convertToRunTree(parentRunId);
} else if (!avoidCreatingRootRunTree) {
runTree = new RunTree({
name: "<runnable_lambda>",
tracingEnabled: false,
});
}

if (runTree) {
runTree.extra = { ...runTree.extra, [LC_CHILD_KEY]: config };
}

if (
previousValue !== undefined &&
previousValue[_CONTEXT_VARIABLES_KEY] !== undefined
) {
(runTree as any)[_CONTEXT_VARIABLES_KEY] =
previousValue[_CONTEXT_VARIABLES_KEY];
}

return storage.run(runTree, callback);
}

initializeGlobalInstance(instance: AsyncLocalStorageInterface) {
if (getGlobalAsyncLocalStorageInstance() === undefined) {
setGlobalAsyncLocalStorageInstance(instance);
}
}
}

const AsyncLocalStorageProviderSingleton = new AsyncLocalStorageProvider();

export { AsyncLocalStorageProviderSingleton, type AsyncLocalStorageInterface };
67 changes: 67 additions & 0 deletions langchain-core/src/singletons/callbacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import PQueueMod from "p-queue";
import { getGlobalAsyncLocalStorageInstance } from "./async_local_storage/globals.js";

let queue: typeof import("p-queue")["default"]["prototype"];

/**
* Creates a queue using the p-queue library. The queue is configured to
* auto-start and has a concurrency of 1, meaning it will process tasks
* one at a time.
*/
function createQueue() {
const PQueue: any = "default" in PQueueMod ? PQueueMod.default : PQueueMod;
return new PQueue({
autoStart: true,
concurrency: 1,
});
}

export function getQueue() {
if (typeof queue === "undefined") {
queue = createQueue();
}
return queue;
}

/**
* Consume a promise, either adding it to the queue or waiting for it to resolve
* @param promiseFn Promise to consume
* @param wait Whether to wait for the promise to resolve or resolve immediately
*/
export async function consumeCallback<T>(
promiseFn: () => Promise<T> | T | void,
wait: boolean
): Promise<void> {
if (wait === true) {
// Clear config since callbacks are not part of the root run
// Avoid using global singleton due to circuluar dependency issues
if (getGlobalAsyncLocalStorageInstance() !== undefined) {
await getGlobalAsyncLocalStorageInstance().run(undefined, async () =>
promiseFn()
);
} else {
await promiseFn();
}
} else {
queue = getQueue();
void queue.add(async () => {
if (getGlobalAsyncLocalStorageInstance() !== undefined) {
await getGlobalAsyncLocalStorageInstance().run(undefined, async () =>
promiseFn()
);
} else {
await promiseFn();
}
});
}
}

/**
* Waits for all promises in the queue to resolve. If the queue is
* undefined, it immediately resolves a promise.
*/
export function awaitAllCallbacks(): Promise<void> {
return typeof queue !== "undefined" ? queue.onIdle() : Promise.resolve();
}
Loading

0 comments on commit 831f9de

Please sign in to comment.