Skip to content

Commit

Permalink
Refactor race logic into a util
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 1, 2024
1 parent 52e38d8 commit 5cb5cfc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 69 deletions.
70 changes: 10 additions & 60 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
pipeGeneratorWithSetup,
AsyncGeneratorWithSetup,
} from "../utils/stream.js";
import { raceWithSignal } from "../utils/signal.js";
import {
DEFAULT_RECURSION_LIMIT,
RunnableConfig,
Expand Down Expand Up @@ -383,16 +384,7 @@ export abstract class Runnable<
let output;
try {
const promise = func.call(this, input, config, runManager);
output = options?.signal
? await Promise.race([
promise,
new Promise<never>((_, reject) => {
options.signal?.addEventListener("abort", () => {
reject(new Error("AbortError"));
});
}),
])
: await promise;
output = await raceWithSignal(promise, options?.signal);
} catch (e) {
await runManager?.handleChainError(e);
throw e;
Expand Down Expand Up @@ -451,16 +443,7 @@ export abstract class Runnable<
runManagers,
batchOptions
);
outputs = optionsList?.[0]?.signal
? await Promise.race([
promise,
new Promise<never>((_, reject) => {
optionsList?.[0]?.signal?.addEventListener("abort", () => {
reject(new Error("AbortError"));
});
}),
])
: await promise;
outputs = await raceWithSignal(promise, optionsList?.[0]?.signal);
} catch (e) {
await Promise.all(
runManagers.map((runManager) => runManager?.handleChainError(e))
Expand Down Expand Up @@ -1781,16 +1764,7 @@ export class RunnableSequence<
callbacks: runManager?.getChild(`seq:step:${i + 1}`),
})
);
nextStepInput = options?.signal
? await Promise.race([
promise,
new Promise<never>((_, reject) => {
options.signal?.addEventListener("abort", () =>
reject(new Error("Aborted"))
);
}),
])
: await promise;
nextStepInput = await raceWithSignal(promise, options?.signal);
}
// TypeScript can't detect that the last output of the sequence returns RunOutput, so call it out of the loop here
if (options?.signal?.aborted) {
Expand Down Expand Up @@ -1865,16 +1839,7 @@ export class RunnableSequence<
}),
batchOptions
);
nextStepInputs = configList[0]?.signal
? await Promise.race([
promise,
new Promise<never>((_, reject) => {
configList[0]?.signal?.addEventListener("abort", () =>
reject(new Error("Aborted"))
);
}),
])
: await promise;
nextStepInputs = await raceWithSignal(promise, configList[0]?.signal);
}
} catch (e) {
await Promise.all(
Expand Down Expand Up @@ -2161,16 +2126,10 @@ export class RunnableMap<
// until all iterators are done
while (tasks.size) {
const promise = Promise.race(tasks.values());
const { key, result, gen } = options?.signal
? await Promise.race([
promise,
new Promise<never>((_, reject) => {
options.signal?.addEventListener("abort", () =>
reject(new Error("Aborted"))
);
}),
])
: await promise;
const { key, result, gen } = await raceWithSignal(
promise,
options?.signal
);
tasks.delete(key);
if (!result.done) {
yield { [key]: result.value } as unknown as RunOutput;
Expand Down Expand Up @@ -2246,16 +2205,7 @@ export class RunnableTraceable<RunInput, RunOutput> extends Runnable<
input
) as Promise<RunOutput>;

return config?.signal
? Promise.race([
promise,
new Promise<never>((_, reject) => {
config.signal?.addEventListener("abort", () =>
reject(new Error("Aborted"))
);
}),
])
: await promise;
return raceWithSignal(promise, config?.signal);
}

async *_streamIterator(
Expand Down
5 changes: 4 additions & 1 deletion langchain-core/src/runnables/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ export function mergeConfigs<CallOptions extends RunnableConfig>(
} else if (options.signal !== undefined) {
if ("any" in AbortSignal) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
copy.signal = (AbortSignal as any).any([copy.signal, options.signal]);
copy.signal = (AbortSignal as any).any([
copy.signal,
options.signal,
]);
} else {
copy.signal = options.signal;
}
Expand Down
17 changes: 17 additions & 0 deletions langchain-core/src/utils/signal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
export async function raceWithSignal<T>(
promise: Promise<T>,
signal?: AbortSignal
): Promise<T> {
if (signal === undefined) {
return promise;
}
if (signal.aborted) {
throw new Error("AbortError");
}
return Promise.race([
promise,
new Promise<never>((_, reject) => {
signal.addEventListener("abort", () => reject(new Error("Aborted")));
}),
]);
}
10 changes: 2 additions & 8 deletions langchain-core/src/utils/stream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js";
import { raceWithSignal } from "./signal.js";

// Make this a type to override ReadableStream's async iterator type in case
// the popular web-streams-polyfill is imported - the supplied types
Expand Down Expand Up @@ -233,14 +234,7 @@ export class AsyncGeneratorWithSetup<
this.config,
this.signal
? async () => {
return Promise.race([
this.generator.next(...args),
new Promise<never>((_resolve, reject) => {
this.signal?.addEventListener("abort", () => {
reject(new Error("Aborted"));
});
}),
]);
return raceWithSignal(this.generator.next(...args), this.signal);
}
: async () => {
return this.generator.next(...args);
Expand Down

0 comments on commit 5cb5cfc

Please sign in to comment.