Skip to content

Commit

Permalink
core: Add signal/timeout options to RunnableConfig (#6305)
Browse files Browse the repository at this point in the history
* core: Add signal/timeout options to RunnableConfig

- Handled by all built-in runnables
- Handled by all utility methods in base runnable, which should propagate to basically all runnables

* Lint

* Fix build

* Refactor race logic into a util

* Relax typing in tests

* Fix types

* Formatting

* Fix type

* Fix type

* Fix runnable map, start adding tests

* Adds test

* More robust fix

* Ignore thrown errors after aborting a signal

* Adds test cases, fix streaming for generators

* Remove redundant test

* Fix

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
nfcampos and jacoblee93 authored Aug 2, 2024
1 parent 5f4323d commit 5daa8ee
Show file tree
Hide file tree
Showing 13 changed files with 436 additions and 162 deletions.
12 changes: 0 additions & 12 deletions langchain-core/src/language_models/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,6 @@ export interface BaseLanguageModelCallOptions extends RunnableConfig {
* If not provided, the default stop tokens for the model will be used.
*/
stop?: string[];

/**
* Timeout for this call in milliseconds.
*/
timeout?: number;

/**
* Abort signal for this call.
* If provided, the call will be aborted when the signal is aborted.
* @see https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal
*/
signal?: AbortSignal;
}

export interface FunctionDefinition {
Expand Down
19 changes: 10 additions & 9 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ export abstract class BaseChatModel<
// TODO: Fix the parameter order on the next minor version.
OutputMessageType extends BaseMessageChunk = BaseMessageChunk
> extends BaseLanguageModel<OutputMessageType, CallOptions> {
// Backwards compatibility since fields have been moved to RunnableConfig
declare ParsedCallOptions: Omit<
CallOptions,
keyof RunnableConfig & "timeout"
Exclude<keyof RunnableConfig, "signal" | "timeout" | "maxConcurrency">
>;

// Only ever instantiated in main LangChain
Expand All @@ -148,14 +149,13 @@ export abstract class BaseChatModel<
...llmOutputs: LLMResult["llmOutput"][]
): LLMResult["llmOutput"];

protected _separateRunnableConfigFromCallOptions(
protected _separateRunnableConfigFromCallOptionsCompat(
options?: Partial<CallOptions>
): [RunnableConfig, this["ParsedCallOptions"]] {
// For backwards compat, keep `signal` in both runnableConfig and callOptions
const [runnableConfig, callOptions] =
super._separateRunnableConfigFromCallOptions(options);
if (callOptions?.timeout && !callOptions.signal) {
callOptions.signal = AbortSignal.timeout(callOptions.timeout);
}
(callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal;
return [runnableConfig, callOptions as this["ParsedCallOptions"]];
}

Expand Down Expand Up @@ -221,7 +221,7 @@ export abstract class BaseChatModel<
const prompt = BaseChatModel._convertInputToPromptValue(input);
const messages = prompt.toChatMessages();
const [runnableConfig, callOptions] =
this._separateRunnableConfigFromCallOptions(options);
this._separateRunnableConfigFromCallOptionsCompat(options);

const inheritableMetadata = {
...runnableConfig.metadata,
Expand Down Expand Up @@ -572,16 +572,17 @@ export abstract class BaseChatModel<
);

const [runnableConfig, callOptions] =
this._separateRunnableConfigFromCallOptions(parsedOptions);
this._separateRunnableConfigFromCallOptionsCompat(parsedOptions);
runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks;

if (!this.cache) {
return this._generateUncached(baseMessages, callOptions, runnableConfig);
}

const { cache } = this;
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);
const llmStringKey = this._getSerializedCacheKeyParametersForCall(
callOptions as CallOptions
);

const { generations, missingPromptIndices } = await this._generateCached({
messages: baseMessages,
Expand Down
19 changes: 10 additions & 9 deletions langchain-core/src/language_models/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ export interface BaseLLMCallOptions extends BaseLanguageModelCallOptions {}
export abstract class BaseLLM<
CallOptions extends BaseLLMCallOptions = BaseLLMCallOptions
> extends BaseLanguageModel<string, CallOptions> {
// Backwards compatibility since fields have been moved to RunnableConfig
declare ParsedCallOptions: Omit<
CallOptions,
keyof RunnableConfig & "timeout"
Exclude<keyof RunnableConfig, "signal" | "timeout" | "maxConcurrency">
>;

// Only ever instantiated in main LangChain
Expand Down Expand Up @@ -91,14 +92,13 @@ export abstract class BaseLLM<
throw new Error("Not implemented.");
}

protected _separateRunnableConfigFromCallOptions(
protected _separateRunnableConfigFromCallOptionsCompat(
options?: Partial<CallOptions>
): [RunnableConfig, this["ParsedCallOptions"]] {
// For backwards compat, keep `signal` in both runnableConfig and callOptions
const [runnableConfig, callOptions] =
super._separateRunnableConfigFromCallOptions(options);
if (callOptions?.timeout && !callOptions.signal) {
callOptions.signal = AbortSignal.timeout(callOptions.timeout);
}
(callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal;
return [runnableConfig, callOptions as this["ParsedCallOptions"]];
}

Expand All @@ -114,7 +114,7 @@ export abstract class BaseLLM<
} else {
const prompt = BaseLLM._convertInputToPromptValue(input);
const [runnableConfig, callOptions] =
this._separateRunnableConfigFromCallOptions(options);
this._separateRunnableConfigFromCallOptionsCompat(options);
const callbackManager_ = await CallbackManager.configure(
runnableConfig.callbacks,
this.callbacks,
Expand Down Expand Up @@ -455,16 +455,17 @@ export abstract class BaseLLM<
}

const [runnableConfig, callOptions] =
this._separateRunnableConfigFromCallOptions(parsedOptions);
this._separateRunnableConfigFromCallOptionsCompat(parsedOptions);
runnableConfig.callbacks = runnableConfig.callbacks ?? callbacks;

if (!this.cache) {
return this._generateUncached(prompts, callOptions, runnableConfig);
}

const { cache } = this;
const llmStringKey =
this._getSerializedCacheKeyParametersForCall(callOptions);
const llmStringKey = this._getSerializedCacheKeyParametersForCall(
callOptions as CallOptions
);
const { generations, missingPromptIndices } = await this._generateCached({
prompts,
cache,
Expand Down
51 changes: 40 additions & 11 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
pipeGeneratorWithSetup,
AsyncGeneratorWithSetup,
} from "../utils/stream.js";
import { raceWithSignal } from "../utils/signal.js";
import {
DEFAULT_RECURSION_LIMIT,
RunnableConfig,
Expand Down Expand Up @@ -339,6 +340,8 @@ export abstract class Runnable<
recursionLimit: options.recursionLimit,
maxConcurrency: options.maxConcurrency,
runId: options.runId,
timeout: options.timeout,
signal: options.signal,
});
}
const callOptions = { ...(options as Partial<CallOptions>) };
Expand All @@ -350,6 +353,8 @@ export abstract class Runnable<
delete callOptions.recursionLimit;
delete callOptions.maxConcurrency;
delete callOptions.runId;
delete callOptions.timeout;
delete callOptions.signal;
return [runnableConfig, callOptions];
}

Expand Down Expand Up @@ -378,7 +383,8 @@ export abstract class Runnable<
delete config.runId;
let output;
try {
output = await func.call(this, input, config, runManager);
const promise = func.call(this, input, config, runManager);
output = await raceWithSignal(promise, options?.signal);
} catch (e) {
await runManager?.handleChainError(e);
throw e;
Expand Down Expand Up @@ -430,13 +436,14 @@ export abstract class Runnable<
);
let outputs: (RunOutput | Error)[];
try {
outputs = await func.call(
const promise = func.call(
this,
inputs,
optionsList,
runManagers,
batchOptions
);
outputs = await raceWithSignal(promise, optionsList?.[0]?.signal);
} catch (e) {
await Promise.all(
runManagers.map((runManager) => runManager?.handleChainError(e))
Expand Down Expand Up @@ -509,6 +516,7 @@ export abstract class Runnable<
undefined,
config.runName ?? this.getName()
),
options?.signal,
config
);
delete config.runId;
Expand Down Expand Up @@ -1750,14 +1758,18 @@ export class RunnableSequence<
const initialSteps = [this.first, ...this.middle];
for (let i = 0; i < initialSteps.length; i += 1) {
const step = initialSteps[i];
nextStepInput = await step.invoke(
const promise = step.invoke(
nextStepInput,
patchConfig(config, {
callbacks: runManager?.getChild(`seq:step:${i + 1}`),
})
);
nextStepInput = await raceWithSignal(promise, options?.signal);
}
// TypeScript can't detect that the last output of the sequence returns RunOutput, so call it out of the loop here
if (options?.signal?.aborted) {
throw new Error("Aborted");
}
finalOutput = await this.last.invoke(
nextStepInput,
patchConfig(config, {
Expand Down Expand Up @@ -1819,14 +1831,15 @@ export class RunnableSequence<
try {
for (let i = 0; i < this.steps.length; i += 1) {
const step = this.steps[i];
nextStepInputs = await step.batch(
const promise = step.batch(
nextStepInputs,
runManagers.map((runManager, j) => {
const childRunManager = runManager?.getChild(`seq:step:${i + 1}`);
return patchConfig(configList[j], { callbacks: childRunManager });
}),
batchOptions
);
nextStepInputs = await raceWithSignal(promise, configList[0]?.signal);
}
} catch (e) {
await Promise.all(
Expand Down Expand Up @@ -1880,6 +1893,7 @@ export class RunnableSequence<
);
}
for await (const chunk of finalGenerator) {
options?.signal?.throwIfAborted();
yield chunk;
if (concatSupported) {
if (finalOutput === undefined) {
Expand Down Expand Up @@ -2058,16 +2072,17 @@ export class RunnableMap<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const output: Record<string, any> = {};
try {
await Promise.all(
Object.entries(this.steps).map(async ([key, runnable]) => {
const promises = Object.entries(this.steps).map(
async ([key, runnable]) => {
output[key] = await runnable.invoke(
input,
patchConfig(config, {
callbacks: runManager?.getChild(`map:key:${key}`),
})
);
})
}
);
await raceWithSignal(Promise.all(promises), options?.signal);
} catch (e) {
await runManager?.handleChainError(e);
throw e;
Expand Down Expand Up @@ -2101,7 +2116,11 @@ export class RunnableMap<
// starting new iterations as needed,
// until all iterators are done
while (tasks.size) {
const { key, result, gen } = await Promise.race(tasks.values());
const promise = Promise.race(tasks.values());
const { key, result, gen } = await raceWithSignal(
promise,
options?.signal
);
tasks.delete(key);
if (!result.done) {
yield { [key]: result.value } as unknown as RunOutput;
Expand Down Expand Up @@ -2172,28 +2191,32 @@ export class RunnableTraceable<RunInput, RunOutput> extends Runnable<
async invoke(input: RunInput, options?: Partial<RunnableConfig>) {
const [config] = this._getOptionsList(options ?? {}, 1);
const callbacks = await getCallbackManagerForConfig(config);

return (await this.func(
const promise = this.func(
patchConfig(config, { callbacks }),
input
)) as RunOutput;
) as Promise<RunOutput>;

return raceWithSignal(promise, config?.signal);
}

async *_streamIterator(
input: RunInput,
options?: Partial<RunnableConfig>
): AsyncGenerator<RunOutput> {
const [config] = this._getOptionsList(options ?? {}, 1);
const result = await this.invoke(input, options);

if (isAsyncIterable(result)) {
for await (const item of result) {
config?.signal?.throwIfAborted();
yield item as RunOutput;
}
return;
}

if (isIterator(result)) {
while (true) {
config?.signal?.throwIfAborted();
const state: IteratorResult<unknown> = result.next();
if (state.done) break;
yield state.value as RunOutput;
Expand Down Expand Up @@ -2320,6 +2343,7 @@ export class RunnableLambda<RunInput, RunOutput> extends Runnable<
childConfig,
output
)) {
config?.signal?.throwIfAborted();
if (finalOutput === undefined) {
finalOutput = chunk as RunOutput;
} else {
Expand All @@ -2339,6 +2363,7 @@ export class RunnableLambda<RunInput, RunOutput> extends Runnable<
childConfig,
output
)) {
config?.signal?.throwIfAborted();
if (finalOutput === undefined) {
finalOutput = chunk as RunOutput;
} else {
Expand Down Expand Up @@ -2423,10 +2448,12 @@ export class RunnableLambda<RunInput, RunOutput> extends Runnable<
childConfig,
output
)) {
config?.signal?.throwIfAborted();
yield chunk as RunOutput;
}
} else if (isIterableIterator(output)) {
for (const chunk of consumeIteratorInContext(childConfig, output)) {
config?.signal?.throwIfAborted();
yield chunk as RunOutput;
}
} else {
Expand Down Expand Up @@ -2517,6 +2544,7 @@ export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<
);
let firstError;
for (const runnable of this.runnables()) {
config?.signal?.throwIfAborted();
try {
const output = await runnable.invoke(
input,
Expand Down Expand Up @@ -2586,6 +2614,7 @@ export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let firstError: any;
for (const runnable of this.runnables()) {
configList[0].signal?.throwIfAborted();
try {
const outputs = await runnable.batch(
inputs,
Expand Down
Loading

0 comments on commit 5daa8ee

Please sign in to comment.