Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): Clear inherited config for called callbacks #7174

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 2 additions & 42 deletions langchain-core/src/callbacks/promises.ts
Original file line number Diff line number Diff line change
@@ -1,43 +1,3 @@
import PQueueMod from "p-queue";
import { awaitAllCallbacks, consumeCallback } from "../singletons/callbacks.js";

let queue: typeof import("p-queue")["default"]["prototype"];

/**
* Creates a queue using the p-queue library. The queue is configured to
* auto-start and has a concurrency of 1, meaning it will process tasks
* one at a time.
*/
function createQueue() {
const PQueue = "default" in PQueueMod ? PQueueMod.default : PQueueMod;
return new PQueue({
autoStart: true,
concurrency: 1,
});
}

/**
* Consume a promise, either adding it to the queue or waiting for it to resolve
* @param promiseFn Promise to consume
* @param wait Whether to wait for the promise to resolve or resolve immediately
*/
export async function consumeCallback<T>(
promiseFn: () => Promise<T> | T | void,
wait: boolean
): Promise<void> {
if (wait === true) {
await promiseFn();
} else {
if (typeof queue === "undefined") {
queue = createQueue();
}
void queue.add(promiseFn);
}
}

/**
* Waits for all promises in the queue to resolve. If the queue is
* undefined, it immediately resolves a promise.
*/
export function awaitAllCallbacks(): Promise<void> {
return typeof queue !== "undefined" ? queue.onIdle() : Promise.resolve();
}
export { awaitAllCallbacks, consumeCallback };
38 changes: 38 additions & 0 deletions langchain-core/src/callbacks/tests/callbacks.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable no-promise-executor-return */
import { test, expect } from "@jest/globals";
import * as uuid from "uuid";
import { AsyncLocalStorage } from "node:async_hooks";
import { CallbackManager } from "../manager.js";
import { BaseCallbackHandler, type BaseCallbackHandlerInput } from "../base.js";
import type { Serialized } from "../../load/serializable.js";
Expand All @@ -10,6 +11,8 @@ import type { AgentAction, AgentFinish } from "../../agents.js";
import { BaseMessage, HumanMessage } from "../../messages/index.js";
import type { LLMResult } from "../../outputs.js";
import { RunnableLambda } from "../../runnables/base.js";
import { AsyncLocalStorageProviderSingleton } from "../../singletons/index.js";
import { awaitAllCallbacks } from "../promises.js";

class FakeCallbackHandler extends BaseCallbackHandler {
name = `fake-${uuid.v4()}`;
Expand Down Expand Up @@ -536,3 +539,38 @@ test("chain should still run if a normal callback handler throws an error", asyn
);
expect(res).toEqual("hello world");
});

test("runnables in callbacks should be root runs", async () => {
AsyncLocalStorageProviderSingleton.initializeGlobalInstance(
new AsyncLocalStorage()
);
const nestedChain = RunnableLambda.from(async () => {
const subRun = RunnableLambda.from(async () => "hello world");
return await subRun.invoke({ foo: "bar" });
});
let error;
let finalInputs;
const res = await nestedChain.invoke(
{},
{
callbacks: [
{
handleChainStart: (_chain, inputs) => {
finalInputs = inputs;
try {
expect(
AsyncLocalStorageProviderSingleton.getRunnableConfig()
).toEqual(undefined);
} catch (e) {
error = e;
}
},
},
],
}
);
await awaitAllCallbacks();
expect(res).toEqual("hello world");
expect(error).toBe(undefined);
expect(finalInputs).toEqual({ foo: "bar" });
});
20 changes: 20 additions & 0 deletions langchain-core/src/singletons/async_local_storage/globals.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
export interface AsyncLocalStorageInterface {
getStore: () => any | undefined;

run: <T>(store: any, callback: () => T) => T;

enterWith: (store: any) => void;
}

export const TRACING_ALS_KEY = Symbol.for("ls:tracing_async_local_storage");

export const setGlobalAsyncLocalStorageInstance = (
instance: AsyncLocalStorageInterface
) => {
(globalThis as any)[TRACING_ALS_KEY] = instance;
};

export const getGlobalAsyncLocalStorageInstance = () => {
return (globalThis as any)[TRACING_ALS_KEY];
};
98 changes: 98 additions & 0 deletions langchain-core/src/singletons/async_local_storage/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { RunTree } from "langsmith";
import {
AsyncLocalStorageInterface,
getGlobalAsyncLocalStorageInstance,
setGlobalAsyncLocalStorageInstance,
} from "./globals.js";
import { CallbackManager } from "../../callbacks/manager.js";
import { LangChainTracer } from "../../tracers/tracer_langchain.js";

export class MockAsyncLocalStorage implements AsyncLocalStorageInterface {
getStore(): any {
return undefined;
}

run<T>(_store: any, callback: () => T): T {
return callback();
}

enterWith(_store: any) {
return undefined;
}
}

const mockAsyncLocalStorage = new MockAsyncLocalStorage();

const LC_CHILD_KEY = Symbol.for("lc:child_config");

export const _CONTEXT_VARIABLES_KEY = Symbol.for("lc:context_variables");

class AsyncLocalStorageProvider {
getInstance(): AsyncLocalStorageInterface {
return getGlobalAsyncLocalStorageInstance() ?? mockAsyncLocalStorage;
}

getRunnableConfig() {
const storage = this.getInstance();
// this has the runnable config
// which means that we should also have an instance of a LangChainTracer
// with the run map prepopulated
return storage.getStore()?.extra?.[LC_CHILD_KEY];
}

runWithConfig<T>(
config: any,
callback: () => T,
avoidCreatingRootRunTree?: boolean
): T {
const callbackManager = CallbackManager._configureSync(
config?.callbacks,
undefined,
config?.tags,
undefined,
config?.metadata
);
const storage = this.getInstance();
const previousValue = storage.getStore();
const parentRunId = callbackManager?.getParentRunId();

const langChainTracer = callbackManager?.handlers?.find(
(handler) => handler?.name === "langchain_tracer"
) as LangChainTracer | undefined;

let runTree;
if (langChainTracer && parentRunId) {
runTree = langChainTracer.convertToRunTree(parentRunId);
} else if (!avoidCreatingRootRunTree) {
runTree = new RunTree({
name: "<runnable_lambda>",
tracingEnabled: false,
});
}

if (runTree) {
runTree.extra = { ...runTree.extra, [LC_CHILD_KEY]: config };
}

if (
previousValue !== undefined &&
previousValue[_CONTEXT_VARIABLES_KEY] !== undefined
) {
(runTree as any)[_CONTEXT_VARIABLES_KEY] =
previousValue[_CONTEXT_VARIABLES_KEY];
}

return storage.run(runTree, callback);
}

initializeGlobalInstance(instance: AsyncLocalStorageInterface) {
if (getGlobalAsyncLocalStorageInstance() === undefined) {
setGlobalAsyncLocalStorageInstance(instance);
}
}
}

const AsyncLocalStorageProviderSingleton = new AsyncLocalStorageProvider();

export { AsyncLocalStorageProviderSingleton, type AsyncLocalStorageInterface };
67 changes: 67 additions & 0 deletions langchain-core/src/singletons/callbacks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import PQueueMod from "p-queue";
import { getGlobalAsyncLocalStorageInstance } from "./async_local_storage/globals.js";

let queue: typeof import("p-queue")["default"]["prototype"];

/**
* Creates a queue using the p-queue library. The queue is configured to
* auto-start and has a concurrency of 1, meaning it will process tasks
* one at a time.
*/
function createQueue() {
const PQueue: any = "default" in PQueueMod ? PQueueMod.default : PQueueMod;
return new PQueue({
autoStart: true,
concurrency: 1,
});
}

export function getQueue() {
if (typeof queue === "undefined") {
queue = createQueue();
}
return queue;
}

/**
* Consume a promise, either adding it to the queue or waiting for it to resolve
* @param promiseFn Promise to consume
* @param wait Whether to wait for the promise to resolve or resolve immediately
*/
export async function consumeCallback<T>(
promiseFn: () => Promise<T> | T | void,
wait: boolean
): Promise<void> {
if (wait === true) {
// Clear config since callbacks are not part of the root run
// Avoid using global singleton due to circuluar dependency issues
if (getGlobalAsyncLocalStorageInstance() !== undefined) {
await getGlobalAsyncLocalStorageInstance().run(undefined, async () =>
promiseFn()
);
} else {
await promiseFn();
}
} else {
queue = getQueue();
void queue.add(async () => {
if (getGlobalAsyncLocalStorageInstance() !== undefined) {
await getGlobalAsyncLocalStorageInstance().run(undefined, async () =>
promiseFn()
);
} else {
await promiseFn();
}
});
}
}

/**
* Waits for all promises in the queue to resolve. If the queue is
* undefined, it immediately resolves a promise.
*/
export function awaitAllCallbacks(): Promise<void> {
return typeof queue !== "undefined" ? queue.onIdle() : Promise.resolve();
}
Loading
Loading