From 5daa8eec93e22f6b7978e74525a4cc31a337f12f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 2 Aug 2024 16:44:16 -0700 Subject: [PATCH] core: Add signal/timeout options to RunnableConfig (#6305) * core: Add signal/timeout options to RunnableConfig - Handled by all built-in runnables - Handled by all utility methods in base runnable, which should propagate to basically all runnables * Lint * Fix build * Refactor race logic into a util * Relax typing in tests * Fix types * Formatting * Fix type * Fix type * Fix runnable map, start adding tests * Adds test * More robust fix * Ignore thrown errors after aborting a signal * Adds test cases, fix streaming for generators * Remove redundant test * Fix --------- Co-authored-by: jacoblee93 --- langchain-core/src/language_models/base.ts | 12 -- .../src/language_models/chat_models.ts | 19 ++- langchain-core/src/language_models/llms.ts | 19 ++- langchain-core/src/runnables/base.ts | 51 ++++-- langchain-core/src/runnables/config.ts | 35 ++++ langchain-core/src/runnables/remote.ts | 91 +++++++---- .../tests/runnable_stream_events_v2.test.ts | 21 +++ .../src/runnables/tests/signal.test.ts | 154 ++++++++++++++++++ langchain-core/src/runnables/types.ts | 12 ++ langchain-core/src/utils/signal.ts | 26 +++ langchain-core/src/utils/stream.ts | 20 ++- langchain-core/src/utils/testing/index.ts | 15 +- .../src/integration_tests/chat_models.ts | 123 +++++--------- 13 files changed, 436 insertions(+), 162 deletions(-) create mode 100644 langchain-core/src/runnables/tests/signal.test.ts create mode 100644 langchain-core/src/utils/signal.ts diff --git a/langchain-core/src/language_models/base.ts b/langchain-core/src/language_models/base.ts index 8adea3e83b4e..0e8af1bc32bf 100644 --- a/langchain-core/src/language_models/base.ts +++ b/langchain-core/src/language_models/base.ts @@ -207,18 +207,6 @@ export interface BaseLanguageModelCallOptions extends RunnableConfig { * If not provided, the default stop tokens for the model will be used. */ stop?: string[]; - - /** - * Timeout for this call in milliseconds. - */ - timeout?: number; - - /** - * Abort signal for this call. - * If provided, the call will be aborted when the signal is aborted. - * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal - */ - signal?: AbortSignal; } export interface FunctionDefinition { diff --git a/langchain-core/src/language_models/chat_models.ts b/langchain-core/src/language_models/chat_models.ts index 89eca0b0329b..bb952e43d918 100644 --- a/langchain-core/src/language_models/chat_models.ts +++ b/langchain-core/src/language_models/chat_models.ts @@ -132,9 +132,10 @@ export abstract class BaseChatModel< // TODO: Fix the parameter order on the next minor version. OutputMessageType extends BaseMessageChunk = BaseMessageChunk > extends BaseLanguageModel { + // Backwards compatibility since fields have been moved to RunnableConfig declare ParsedCallOptions: Omit< CallOptions, - keyof RunnableConfig & "timeout" + Exclude >; // Only ever instantiated in main LangChain @@ -148,14 +149,13 @@ export abstract class BaseChatModel< ...llmOutputs: LLMResult["llmOutput"][] ): LLMResult["llmOutput"]; - protected _separateRunnableConfigFromCallOptions( + protected _separateRunnableConfigFromCallOptionsCompat( options?: Partial ): [RunnableConfig, this["ParsedCallOptions"]] { + // For backwards compat, keep `signal` in both runnableConfig and callOptions const [runnableConfig, callOptions] = super._separateRunnableConfigFromCallOptions(options); - if (callOptions?.timeout && !callOptions.signal) { - callOptions.signal = AbortSignal.timeout(callOptions.timeout); - } + (callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal; return [runnableConfig, callOptions as this["ParsedCallOptions"]]; } @@ -221,7 +221,7 @@ export abstract class BaseChatModel< const prompt = BaseChatModel._convertInputToPromptValue(input); const messages = prompt.toChatMessages(); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(options); + this._separateRunnableConfigFromCallOptionsCompat(options); const inheritableMetadata = { ...runnableConfig.metadata, @@ -572,7 +572,7 @@ export abstract class BaseChatModel< ); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(parsedOptions); + this._separateRunnableConfigFromCallOptionsCompat(parsedOptions); runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks; if (!this.cache) { @@ -580,8 +580,9 @@ export abstract class BaseChatModel< } const { cache } = this; - const llmStringKey = - this._getSerializedCacheKeyParametersForCall(callOptions); + const llmStringKey = this._getSerializedCacheKeyParametersForCall( + callOptions as CallOptions + ); const { generations, missingPromptIndices } = await this._generateCached({ messages: baseMessages, diff --git a/langchain-core/src/language_models/llms.ts b/langchain-core/src/language_models/llms.ts index ef990198e1df..3aeb2a879bdc 100644 --- a/langchain-core/src/language_models/llms.ts +++ b/langchain-core/src/language_models/llms.ts @@ -49,9 +49,10 @@ export interface BaseLLMCallOptions extends BaseLanguageModelCallOptions {} export abstract class BaseLLM< CallOptions extends BaseLLMCallOptions = BaseLLMCallOptions > extends BaseLanguageModel { + // Backwards compatibility since fields have been moved to RunnableConfig declare ParsedCallOptions: Omit< CallOptions, - keyof RunnableConfig & "timeout" + Exclude >; // Only ever instantiated in main LangChain @@ -91,14 +92,13 @@ export abstract class BaseLLM< throw new Error("Not implemented."); } - protected _separateRunnableConfigFromCallOptions( + protected _separateRunnableConfigFromCallOptionsCompat( options?: Partial ): [RunnableConfig, this["ParsedCallOptions"]] { + // For backwards compat, keep `signal` in both runnableConfig and callOptions const [runnableConfig, callOptions] = super._separateRunnableConfigFromCallOptions(options); - if (callOptions?.timeout && !callOptions.signal) { - callOptions.signal = AbortSignal.timeout(callOptions.timeout); - } + (callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal; return [runnableConfig, callOptions as this["ParsedCallOptions"]]; } @@ -114,7 +114,7 @@ export abstract class BaseLLM< } else { const prompt = BaseLLM._convertInputToPromptValue(input); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(options); + this._separateRunnableConfigFromCallOptionsCompat(options); const callbackManager_ = await CallbackManager.configure( runnableConfig.callbacks, this.callbacks, @@ -455,7 +455,7 @@ export abstract class BaseLLM< } const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptions(parsedOptions); + this._separateRunnableConfigFromCallOptionsCompat(parsedOptions); runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks; if (!this.cache) { @@ -463,8 +463,9 @@ export abstract class BaseLLM< } const { cache } = this; - const llmStringKey = - this._getSerializedCacheKeyParametersForCall(callOptions); + const llmStringKey = this._getSerializedCacheKeyParametersForCall( + callOptions as CallOptions + ); const { generations, missingPromptIndices } = await this._generateCached({ prompts, cache, diff --git a/langchain-core/src/runnables/base.ts b/langchain-core/src/runnables/base.ts index 4cb65f379f22..7d1f836ab4df 100644 --- a/langchain-core/src/runnables/base.ts +++ b/langchain-core/src/runnables/base.ts @@ -30,6 +30,7 @@ import { pipeGeneratorWithSetup, AsyncGeneratorWithSetup, } from "../utils/stream.js"; +import { raceWithSignal } from "../utils/signal.js"; import { DEFAULT_RECURSION_LIMIT, RunnableConfig, @@ -339,6 +340,8 @@ export abstract class Runnable< recursionLimit: options.recursionLimit, maxConcurrency: options.maxConcurrency, runId: options.runId, + timeout: options.timeout, + signal: options.signal, }); } const callOptions = { ...(options as Partial) }; @@ -350,6 +353,8 @@ export abstract class Runnable< delete callOptions.recursionLimit; delete callOptions.maxConcurrency; delete callOptions.runId; + delete callOptions.timeout; + delete callOptions.signal; return [runnableConfig, callOptions]; } @@ -378,7 +383,8 @@ export abstract class Runnable< delete config.runId; let output; try { - output = await func.call(this, input, config, runManager); + const promise = func.call(this, input, config, runManager); + output = await raceWithSignal(promise, options?.signal); } catch (e) { await runManager?.handleChainError(e); throw e; @@ -430,13 +436,14 @@ export abstract class Runnable< ); let outputs: (RunOutput | Error)[]; try { - outputs = await func.call( + const promise = func.call( this, inputs, optionsList, runManagers, batchOptions ); + outputs = await raceWithSignal(promise, optionsList?.[0]?.signal); } catch (e) { await Promise.all( runManagers.map((runManager) => runManager?.handleChainError(e)) @@ -509,6 +516,7 @@ export abstract class Runnable< undefined, config.runName ?? this.getName() ), + options?.signal, config ); delete config.runId; @@ -1750,14 +1758,18 @@ export class RunnableSequence< const initialSteps = [this.first, ...this.middle]; for (let i = 0; i < initialSteps.length; i += 1) { const step = initialSteps[i]; - nextStepInput = await step.invoke( + const promise = step.invoke( nextStepInput, patchConfig(config, { callbacks: runManager?.getChild(`seq:step:${i + 1}`), }) ); + nextStepInput = await raceWithSignal(promise, options?.signal); } // TypeScript can't detect that the last output of the sequence returns RunOutput, so call it out of the loop here + if (options?.signal?.aborted) { + throw new Error("Aborted"); + } finalOutput = await this.last.invoke( nextStepInput, patchConfig(config, { @@ -1819,7 +1831,7 @@ export class RunnableSequence< try { for (let i = 0; i < this.steps.length; i += 1) { const step = this.steps[i]; - nextStepInputs = await step.batch( + const promise = step.batch( nextStepInputs, runManagers.map((runManager, j) => { const childRunManager = runManager?.getChild(`seq:step:${i + 1}`); @@ -1827,6 +1839,7 @@ export class RunnableSequence< }), batchOptions ); + nextStepInputs = await raceWithSignal(promise, configList[0]?.signal); } } catch (e) { await Promise.all( @@ -1880,6 +1893,7 @@ export class RunnableSequence< ); } for await (const chunk of finalGenerator) { + options?.signal?.throwIfAborted(); yield chunk; if (concatSupported) { if (finalOutput === undefined) { @@ -2058,16 +2072,17 @@ export class RunnableMap< // eslint-disable-next-line @typescript-eslint/no-explicit-any const output: Record = {}; try { - await Promise.all( - Object.entries(this.steps).map(async ([key, runnable]) => { + const promises = Object.entries(this.steps).map( + async ([key, runnable]) => { output[key] = await runnable.invoke( input, patchConfig(config, { callbacks: runManager?.getChild(`map:key:${key}`), }) ); - }) + } ); + await raceWithSignal(Promise.all(promises), options?.signal); } catch (e) { await runManager?.handleChainError(e); throw e; @@ -2101,7 +2116,11 @@ export class RunnableMap< // starting new iterations as needed, // until all iterators are done while (tasks.size) { - const { key, result, gen } = await Promise.race(tasks.values()); + const promise = Promise.race(tasks.values()); + const { key, result, gen } = await raceWithSignal( + promise, + options?.signal + ); tasks.delete(key); if (!result.done) { yield { [key]: result.value } as unknown as RunOutput; @@ -2172,21 +2191,24 @@ export class RunnableTraceable extends Runnable< async invoke(input: RunInput, options?: Partial) { const [config] = this._getOptionsList(options ?? {}, 1); const callbacks = await getCallbackManagerForConfig(config); - - return (await this.func( + const promise = this.func( patchConfig(config, { callbacks }), input - )) as RunOutput; + ) as Promise; + + return raceWithSignal(promise, config?.signal); } async *_streamIterator( input: RunInput, options?: Partial ): AsyncGenerator { + const [config] = this._getOptionsList(options ?? {}, 1); const result = await this.invoke(input, options); if (isAsyncIterable(result)) { for await (const item of result) { + config?.signal?.throwIfAborted(); yield item as RunOutput; } return; @@ -2194,6 +2216,7 @@ export class RunnableTraceable extends Runnable< if (isIterator(result)) { while (true) { + config?.signal?.throwIfAborted(); const state: IteratorResult = result.next(); if (state.done) break; yield state.value as RunOutput; @@ -2320,6 +2343,7 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); if (finalOutput === undefined) { finalOutput = chunk as RunOutput; } else { @@ -2339,6 +2363,7 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); if (finalOutput === undefined) { finalOutput = chunk as RunOutput; } else { @@ -2423,10 +2448,12 @@ export class RunnableLambda extends Runnable< childConfig, output )) { + config?.signal?.throwIfAborted(); yield chunk as RunOutput; } } else if (isIterableIterator(output)) { for (const chunk of consumeIteratorInContext(childConfig, output)) { + config?.signal?.throwIfAborted(); yield chunk as RunOutput; } } else { @@ -2517,6 +2544,7 @@ export class RunnableWithFallbacks extends Runnable< ); let firstError; for (const runnable of this.runnables()) { + config?.signal?.throwIfAborted(); try { const output = await runnable.invoke( input, @@ -2586,6 +2614,7 @@ export class RunnableWithFallbacks extends Runnable< // eslint-disable-next-line @typescript-eslint/no-explicit-any let firstError: any; for (const runnable of this.runnables()) { + configList[0].signal?.throwIfAborted(); try { const outputs = await runnable.batch( inputs, diff --git a/langchain-core/src/runnables/config.ts b/langchain-core/src/runnables/config.ts index 409d556eac8d..8fa9a244ee3d 100644 --- a/langchain-core/src/runnables/config.ts +++ b/langchain-core/src/runnables/config.ts @@ -31,6 +31,26 @@ export function mergeConfigs( copy[key] = [...new Set(baseKeys.concat(options[key] ?? []))]; } else if (key === "configurable") { copy[key] = { ...copy[key], ...options[key] }; + } else if (key === "timeout") { + if (copy.timeout === undefined) { + copy.timeout = options.timeout; + } else if (options.timeout !== undefined) { + copy.timeout = Math.min(copy.timeout, options.timeout); + } + } else if (key === "signal") { + if (copy.signal === undefined) { + copy.signal = options.signal; + } else if (options.signal !== undefined) { + if ("any" in AbortSignal) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + copy.signal = (AbortSignal as any).any([ + copy.signal, + options.signal, + ]); + } else { + copy.signal = options.signal; + } + } } else if (key === "callbacks") { const baseCallbacks = copy.callbacks; const providedCallbacks = options.callbacks; @@ -155,6 +175,21 @@ export function ensureConfig( } } } + if (empty.timeout !== undefined) { + if (empty.timeout <= 0) { + throw new Error("Timeout must be a positive number"); + } + const timeoutSignal = AbortSignal.timeout(empty.timeout); + if (empty.signal !== undefined) { + if ("any" in AbortSignal) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + empty.signal = (AbortSignal as any).any([empty.signal, timeoutSignal]); + } + } else { + empty.signal = timeoutSignal; + } + delete empty.timeout; + } return empty as CallOptions; } diff --git a/langchain-core/src/runnables/remote.ts b/langchain-core/src/runnables/remote.ts index 9ecd597556f9..dc08c731a501 100644 --- a/langchain-core/src/runnables/remote.ts +++ b/langchain-core/src/runnables/remote.ts @@ -214,11 +214,12 @@ function deserialize(str: string): RunOutput { return revive(obj); } -function removeCallbacks( +function removeCallbacksAndSignal( options?: RunnableConfig -): Omit { +): Omit { const rest = { ...options }; delete rest.callbacks; + delete rest.signal; return rest; } @@ -276,7 +277,7 @@ export class RemoteRunnable< this.options = options; } - private async post(path: string, body: Body) { + private async post(path: string, body: Body, signal?: AbortSignal) { return fetch(`${this.url}${path}`, { method: "POST", body: JSON.stringify(serialize(body)), @@ -284,7 +285,7 @@ export class RemoteRunnable< "Content-Type": "application/json", ...this.options?.headers, }, - signal: AbortSignal.timeout(this.options?.timeout ?? 60000), + signal: signal ?? AbortSignal.timeout(this.options?.timeout ?? 60000), }); } @@ -299,11 +300,15 @@ export class RemoteRunnable< input: RunInput; config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; - }>("/invoke", { - input, - config: removeCallbacks(config), - kwargs: kwargs ?? {}, - }); + }>( + "/invoke", + { + input, + config: removeCallbacksAndSignal(config), + kwargs: kwargs ?? {}, + }, + config.signal + ); if (!response.ok) { throw new Error(`${response.status} Error: ${await response.text()}`); } @@ -347,13 +352,17 @@ export class RemoteRunnable< inputs: RunInput[]; config?: (RunnableConfig & RunnableBatchOptions)[]; kwargs?: Omit, keyof RunnableConfig>[]; - }>("/batch", { - inputs, - config: (configs ?? []) - .map(removeCallbacks) - .map((config) => ({ ...config, ...batchOptions })), - kwargs, - }); + }>( + "/batch", + { + inputs, + config: (configs ?? []) + .map(removeCallbacksAndSignal) + .map((config) => ({ ...config, ...batchOptions })), + kwargs, + }, + options?.[0]?.signal + ); if (!response.ok) { throw new Error(`${response.status} Error: ${await response.text()}`); } @@ -422,11 +431,15 @@ export class RemoteRunnable< input: RunInput; config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; - }>("/stream", { - input, - config: removeCallbacks(config), - kwargs, - }); + }>( + "/stream", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + }, + config.signal + ); if (!response.ok) { const json = await response.json(); const error = new Error( @@ -502,13 +515,17 @@ export class RemoteRunnable< config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; diff: false; - }>("/stream_log", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); + }>( + "/stream_log", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }, + config.signal + ); const { body, ok } = response; if (!ok) { throw new Error(`${response.status} Error: ${await response.text()}`); @@ -574,13 +591,17 @@ export class RemoteRunnable< config?: RunnableConfig; kwargs?: Omit, keyof RunnableConfig>; diff: false; - }>("/stream_events", { - input, - config: removeCallbacks(config), - kwargs, - ...camelCaseStreamOptions, - diff: false, - }); + }>( + "/stream_events", + { + input, + config: removeCallbacksAndSignal(config), + kwargs, + ...camelCaseStreamOptions, + diff: false, + }, + config.signal + ); const { body, ok } = response; if (!ok) { throw new Error(`${response.status} Error: ${await response.text()}`); diff --git a/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts b/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts index 17a39f9dde29..d1fc1ea7fc65 100644 --- a/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts +++ b/langchain-core/src/runnables/tests/runnable_stream_events_v2.test.ts @@ -2120,3 +2120,24 @@ test("Runnable streamEvents method with text/event-stream encoding", async () => expect(decoder.decode(events[3])).toEqual("event: end\n\n"); }); + +test("Runnable streamEvents method should respect passed signal", async () => { + const r = RunnableLambda.from(reverse); + + const chain = r + .withConfig({ runName: "1" }) + .pipe(r.withConfig({ runName: "2" })) + .pipe(r.withConfig({ runName: "3" })); + + const controller = new AbortController(); + const eventStream = await chain.streamEvents("hello", { + version: "v2", + signal: controller.signal, + }); + await expect(async () => { + for await (const _ of eventStream) { + // Abort after the first chunk + controller.abort(); + } + }).rejects.toThrowError(); +}); diff --git a/langchain-core/src/runnables/tests/signal.test.ts b/langchain-core/src/runnables/tests/signal.test.ts new file mode 100644 index 000000000000..7413ea3794cf --- /dev/null +++ b/langchain-core/src/runnables/tests/signal.test.ts @@ -0,0 +1,154 @@ +/* eslint-disable no-promise-executor-return */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { test, describe, expect } from "@jest/globals"; +import { + Runnable, + RunnableLambda, + RunnableMap, + RunnablePassthrough, + RunnableSequence, + RunnableWithMessageHistory, +} from "../index.js"; +import { + FakeChatMessageHistory, + FakeListChatModel, +} from "../../utils/testing/index.js"; + +const chatModel = new FakeListChatModel({ responses: ["hey"], sleep: 500 }); + +const TEST_CASES = { + map: { + runnable: RunnableMap.from({ + question: new RunnablePassthrough(), + context: async () => { + await new Promise((resolve) => setTimeout(resolve, 500)); + return "SOME STUFF"; + }, + }), + input: "testing", + }, + binding: { + runnable: RunnableLambda.from( + () => new Promise((resolve) => setTimeout(resolve, 500)) + ), + input: "testing", + }, + fallbacks: { + runnable: chatModel + .bind({ thrownErrorString: "expected" }) + .withFallbacks({ fallbacks: [chatModel] }), + input: "testing", + skipStream: true, + }, + sequence: { + runnable: RunnableSequence.from([ + RunnablePassthrough.assign({ + test: () => chatModel, + }), + () => {}, + ]), + input: { question: "testing" }, + }, + lambda: { + runnable: RunnableLambda.from( + () => new Promise((resolve) => setTimeout(resolve, 500)) + ), + input: {}, + }, + history: { + runnable: new RunnableWithMessageHistory({ + runnable: chatModel, + config: {}, + getMessageHistory: () => new FakeChatMessageHistory(), + }), + input: "testing", + }, +}; + +describe.each(Object.keys(TEST_CASES))("Test runnable %s", (name) => { + const { + runnable, + input, + skipStream, + }: { runnable: Runnable; input: any; skipStream?: boolean } = + TEST_CASES[name as keyof typeof TEST_CASES]; + test("Test invoke with signal", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.invoke(input, { + signal: controller.signal, + }), + new Promise((resolve) => { + controller.abort(); + resolve(); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test invoke with signal with a delay", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.invoke(input, { + signal: controller.signal, + }), + new Promise((resolve) => { + setTimeout(() => { + controller.abort(); + resolve(); + }, 250); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test stream with signal", async () => { + if (skipStream) { + return; + } + const controller = new AbortController(); + await expect(async () => { + const stream = await runnable.stream(input, { + signal: controller.signal, + }); + for await (const _ of stream) { + controller.abort(); + } + }).rejects.toThrowError(); + }); + + test("Test batch with signal", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.batch([input, input], { + signal: controller.signal, + }), + new Promise((resolve) => { + controller.abort(); + resolve(); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test batch with signal with a delay", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.batch([input, input], { + signal: controller.signal, + }), + new Promise((resolve) => { + setTimeout(() => { + controller.abort(); + resolve(); + }, 250); + }), + ]); + }).rejects.toThrowError(); + }); +}); diff --git a/langchain-core/src/runnables/types.ts b/langchain-core/src/runnables/types.ts index 569e8aa26c0e..e7ddfa8c3852 100644 --- a/langchain-core/src/runnables/types.ts +++ b/langchain-core/src/runnables/types.ts @@ -89,4 +89,16 @@ export interface RunnableConfig extends BaseCallbackConfig { /** Maximum number of parallel calls to make. */ maxConcurrency?: number; + + /** + * Timeout for this call in milliseconds. + */ + timeout?: number; + + /** + * Abort signal for this call. + * If provided, the call will be aborted when the signal is aborted. + * @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal + */ + signal?: AbortSignal; } diff --git a/langchain-core/src/utils/signal.ts b/langchain-core/src/utils/signal.ts new file mode 100644 index 000000000000..7ccb554429cf --- /dev/null +++ b/langchain-core/src/utils/signal.ts @@ -0,0 +1,26 @@ +export async function raceWithSignal( + promise: Promise, + signal?: AbortSignal +): Promise { + if (signal === undefined) { + return promise; + } + return Promise.race([ + promise.catch((err) => { + if (!signal?.aborted) { + throw err; + } else { + return undefined as T; + } + }), + new Promise((_, reject) => { + signal.addEventListener("abort", () => { + reject(new Error("Aborted")); + }); + // Must be here inside the promise to avoid a race condition + if (signal.aborted) { + reject(new Error("Aborted")); + } + }), + ]); +} diff --git a/langchain-core/src/utils/stream.ts b/langchain-core/src/utils/stream.ts index 234cec3b900f..91a9810e2d25 100644 --- a/langchain-core/src/utils/stream.ts +++ b/langchain-core/src/utils/stream.ts @@ -1,4 +1,5 @@ import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js"; +import { raceWithSignal } from "./signal.js"; // Make this a type to override ReadableStream's async iterator type in case // the popular web-streams-polyfill is imported - the supplied types @@ -186,6 +187,8 @@ export class AsyncGeneratorWithSetup< public config?: unknown; + public signal?: AbortSignal; + private firstResult: Promise>; private firstResultUsed = false; @@ -194,9 +197,12 @@ export class AsyncGeneratorWithSetup< generator: AsyncGenerator; startSetup?: () => Promise; config?: unknown; + signal?: AbortSignal; }) { this.generator = params.generator; this.config = params.config; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.signal = params.signal ?? (this.config as any)?.signal; // setup is a promise that resolves only after the first iterator value // is available. this is useful when setup of several piped generators // needs to happen in logical order, ie. in the order in which input to @@ -218,6 +224,8 @@ export class AsyncGeneratorWithSetup< } async next(...args: [] | [TNext]): Promise> { + this.signal?.throwIfAborted(); + if (!this.firstResultUsed) { this.firstResultUsed = true; return this.firstResult; @@ -225,9 +233,13 @@ export class AsyncGeneratorWithSetup< return AsyncLocalStorageProviderSingleton.runWithConfig( this.config, - async () => { - return this.generator.next(...args); - }, + this.signal + ? async () => { + return raceWithSignal(this.generator.next(...args), this.signal); + } + : async () => { + return this.generator.next(...args); + }, true ); } @@ -264,11 +276,13 @@ export async function pipeGeneratorWithSetup< ) => AsyncGenerator, generator: AsyncGenerator, startSetup: () => Promise, + signal: AbortSignal | undefined, ...args: A ) { const gen = new AsyncGeneratorWithSetup({ generator, startSetup, + signal, }); const setup = await gen.setup; return { output: to(gen, setup, ...args), setup }; diff --git a/langchain-core/src/utils/testing/index.ts b/langchain-core/src/utils/testing/index.ts index 65d197f6c23e..f14629794293 100644 --- a/langchain-core/src/utils/testing/index.ts +++ b/langchain-core/src/utils/testing/index.ts @@ -15,6 +15,7 @@ import { import { Document } from "../../documents/document.js"; import { BaseChatModel, + BaseChatModelCallOptions, BaseChatModelParams, } from "../../language_models/chat_models.js"; import { BaseLLMParams, LLM } from "../../language_models/llms.js"; @@ -324,6 +325,10 @@ export interface FakeChatInput extends BaseChatModelParams { emitCustomEvent?: boolean; } +export interface FakeListChatModelCallOptions extends BaseChatModelCallOptions { + thrownErrorString?: string; +} + /** * A fake Chat Model that returns a predefined list of responses. It can be used * for testing purposes. @@ -344,7 +349,7 @@ export interface FakeChatInput extends BaseChatModelParams { * console.log({ secondResponse }); * ``` */ -export class FakeListChatModel extends BaseChatModel { +export class FakeListChatModel extends BaseChatModel { static lc_name() { return "FakeListChatModel"; } @@ -378,6 +383,9 @@ export class FakeListChatModel extends BaseChatModel { runManager?: CallbackManagerForLLMRun ): Promise { await this._sleepIfRequested(); + if (options?.thrownErrorString) { + throw new Error(options.thrownErrorString); + } if (this.emitCustomEvent) { await runManager?.handleCustomEvent("some_test_event", { someval: true, @@ -408,7 +416,7 @@ export class FakeListChatModel extends BaseChatModel { async *_streamResponseChunks( _messages: BaseMessage[], - _options: this["ParsedCallOptions"], + options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { const response = this._currentResponse(); @@ -421,6 +429,9 @@ export class FakeListChatModel extends BaseChatModel { for await (const text of response) { await this._sleepIfRequested(); + if (options?.thrownErrorString) { + throw new Error(options.thrownErrorString); + } const chunk = this._createResponseChunk(text); yield chunk; void runManager?.handleLLMNewToken(text); diff --git a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts index 8fff150f1cf5..016a85f2810a 100644 --- a/libs/langchain-standard-tests/src/integration_tests/chat_models.ts +++ b/libs/langchain-standard-tests/src/integration_tests/chat_models.ts @@ -1,3 +1,5 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + import { expect } from "@jest/globals"; import { BaseChatModelCallOptions } from "@langchain/core/language_models/chat_models"; import { @@ -112,12 +114,10 @@ export abstract class ChatModelIntegrationTests< * 1. The result is defined and is an instance of the correct type. * 2. The content of the response is a non-empty string. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testInvoke( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testInvoke(callOptions?: any) { // Create a new instance of the chat model const chatModel = new this.Cls(this.constructorArgs); @@ -147,12 +147,10 @@ export abstract class ChatModelIntegrationTests< * 2. The content of each token is a string. * 3. The total number of characters streamed is greater than zero. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testStream( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testStream(callOptions?: any) { const chatModel = new this.Cls(this.constructorArgs); let numChars = 0; @@ -183,12 +181,10 @@ export abstract class ChatModelIntegrationTests< * 2. The number of results matches the number of inputs. * 3. Each result is of the correct type and has non-empty content. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testBatch( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testBatch(callOptions?: any) { const chatModel = new this.Cls(this.constructorArgs); // Process two simple prompts in batch @@ -229,12 +225,10 @@ export abstract class ChatModelIntegrationTests< * * Finally, it verifies the final chunk's `event.data.output` field * matches the concatenated content of all `on_chat_model_stream` events. - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testStreamEvents( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testStreamEvents(callOptions?: any) { const chatModel = new this.Cls(this.constructorArgs); const stream = chatModel.streamEvents("Hello", { @@ -300,12 +294,10 @@ export abstract class ChatModelIntegrationTests< * 1. The result is defined and is an instance of the correct response type. * 2. The content of the response is a non-empty string. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testConversation( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testConversation(callOptions?: any) { // Create a new instance of the chat model const chatModel = new this.Cls(this.constructorArgs); @@ -343,12 +335,10 @@ export abstract class ChatModelIntegrationTests< * 3. The `usage_metadata` field contains `input_tokens`, `output_tokens`, and `total_tokens`, * all of which are numbers. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testUsageMetadata( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testUsageMetadata(callOptions?: any) { // Create a new instance of the chat model const chatModel = new this.Cls(this.constructorArgs); @@ -393,12 +383,10 @@ export abstract class ChatModelIntegrationTests< * 3. The `usage_metadata` field contains `input_tokens`, `output_tokens`, and `total_tokens`, * all of which are numbers. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testUsageMetadataStreaming( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testUsageMetadataStreaming(callOptions?: any) { const chatModel = new this.Cls(this.constructorArgs); let finalChunks: AIMessageChunk | undefined; @@ -451,12 +439,10 @@ export abstract class ChatModelIntegrationTests< * This test ensures that the model can correctly process and respond to complex message * histories that include tool calls with string-based content structures. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testToolMessageHistoriesStringContent( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testToolMessageHistoriesStringContent(callOptions?: any) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); @@ -522,11 +508,9 @@ export abstract class ChatModelIntegrationTests< * This test ensures that the model can correctly process and respond to complex message * histories that include tool calls with list-based content structures. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. */ - async testToolMessageHistoriesListContent( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testToolMessageHistoriesListContent(callOptions?: any) { if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); return; @@ -602,12 +586,10 @@ export abstract class ChatModelIntegrationTests< * the patterns demonstrated in few-shot examples, particularly when those * examples involve tool usage. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testStructuredFewShotExamples( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testStructuredFewShotExamples(callOptions?: any) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); @@ -667,12 +649,10 @@ export abstract class ChatModelIntegrationTests< * This test is crucial for ensuring that the model can generate responses * in a specific format, which is useful for tasks requiring structured data output. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testWithStructuredOutput( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testWithStructuredOutput(callOptions?: any) { // Skip the test if the model doesn't support structured output if (!this.chatModelHasStructuredOutput) { console.log("Test requires withStructuredOutput. Skipping..."); @@ -726,12 +706,10 @@ export abstract class ChatModelIntegrationTests< * This test is crucial for ensuring that the model can generate responses in a specific format * while also providing access to the original, unprocessed model output. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testWithStructuredOutputIncludeRaw( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testWithStructuredOutputIncludeRaw(callOptions?: any) { // Skip the test if the model doesn't support structured output if (!this.chatModelHasStructuredOutput) { console.log("Test requires withStructuredOutput. Skipping..."); @@ -788,12 +766,10 @@ export abstract class ChatModelIntegrationTests< * This test is crucial for ensuring compatibility with OpenAI's function * calling format, which is a common standard in AI tool integration. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testBindToolsWithOpenAIFormattedTools( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testBindToolsWithOpenAIFormattedTools(callOptions?: any) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); @@ -855,12 +831,10 @@ export abstract class ChatModelIntegrationTests< * from Runnable objects, which provides a flexible way to integrate * custom logic into the model's tool-calling capabilities. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testBindToolsWithRunnableToolLike( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testBindToolsWithRunnableToolLike(callOptions?: any) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); @@ -923,12 +897,10 @@ export abstract class ChatModelIntegrationTests< * This test is crucial for ensuring that the caching mechanism works correctly * with various message structures, maintaining consistency and efficiency. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testCacheComplexMessageTypes( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testCacheComplexMessageTypes(callOptions?: any) { // Create a new instance of the chat model with caching enabled const model = new this.Cls({ ...this.constructorArgs, @@ -987,12 +959,10 @@ export abstract class ChatModelIntegrationTests< * 3. The usage metadata is present in the streamed result. * 4. Both input and output tokens are present and greater than zero in the usage metadata. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testStreamTokensWithToolCalls( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testStreamTokensWithToolCalls(callOptions?: any) { const model = new this.Cls(this.constructorArgs); if (!model.bindTools) { throw new Error("bindTools is undefined"); @@ -1053,12 +1023,10 @@ export abstract class ChatModelIntegrationTests< * 5. Send a followup request including the tool call and response. * 6. Verify the model generates a non-empty final response. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testModelCanUseToolUseAIMessage( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testModelCanUseToolUseAIMessage(callOptions?: any) { if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); return; @@ -1147,12 +1115,10 @@ export abstract class ChatModelIntegrationTests< * 5. Stream a followup request including the tool call and response. * 6. Verify the model generates a non-empty final streamed response. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testModelCanUseToolUseAIMessageWithStreaming( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testModelCanUseToolUseAIMessageWithStreaming(callOptions?: any) { if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); return; @@ -1253,12 +1219,10 @@ export abstract class ChatModelIntegrationTests< * This test is particularly important for ensuring compatibility with APIs * that may not accept JSON schemas with unknown object fields (e.g., Google's API). * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * These options will be applied to the model at runtime. */ - async testInvokeMoreComplexTools( - callOptions?: InstanceType["ParsedCallOptions"] - ) { + async testInvokeMoreComplexTools(callOptions?: any) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping..."); @@ -1333,13 +1297,10 @@ Extraction path: {extractionPath}`, * It ensures that the model can correctly process and respond to prompts requiring multiple tool calls, * both in streaming and non-streaming contexts, and can handle message histories with parallel tool calls. * - * @param {InstanceType["ParsedCallOptions"] | undefined} callOptions Optional call options to pass to the model. + * @param {any | undefined} callOptions Optional call options to pass to the model. * @param {boolean} onlyVerifyHistory If true, only verifies the message history test. */ - async testParallelToolCalling( - callOptions?: InstanceType["ParsedCallOptions"], - onlyVerifyHistory = false - ) { + async testParallelToolCalling(callOptions?: any, onlyVerifyHistory = false) { // Skip the test if the model doesn't support tool calling if (!this.chatModelHasToolCalling) { console.log("Test requires tool calling. Skipping...");