diff --git a/langchain-core/src/runnables/tests/signal.test.ts b/langchain-core/src/runnables/tests/signal.test.ts new file mode 100644 index 000000000000..565fc3267b18 --- /dev/null +++ b/langchain-core/src/runnables/tests/signal.test.ts @@ -0,0 +1,153 @@ +/* eslint-disable no-promise-executor-return */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { + Runnable, + RunnableLambda, + RunnableMap, + RunnablePassthrough, + RunnableSequence, + RunnableWithMessageHistory, +} from "../index.js"; +import { + FakeChatMessageHistory, + FakeListChatModel, +} from "../../utils/testing/index.js"; + +const chatModel = new FakeListChatModel({ responses: ["hey"], sleep: 500 }); + +const TEST_CASES = { + map: { + runnable: RunnableMap.from({ + question: new RunnablePassthrough(), + context: async () => { + await new Promise((resolve) => setTimeout(resolve, 500)); + return "SOME STUFF"; + }, + }), + input: "testing", + }, + binding: { + runnable: RunnableLambda.from( + () => new Promise((resolve) => setTimeout(resolve, 500)) + ), + input: "testing", + }, + fallbacks: { + runnable: chatModel + .bind({ thrownErrorString: "expected" }) + .withFallbacks({ fallbacks: [chatModel] }), + input: "testing", + skipStream: true, + }, + sequence: { + runnable: RunnableSequence.from([ + RunnablePassthrough.assign({ + test: () => chatModel, + }), + () => {}, + ]), + input: { question: "testing" }, + }, + lambda: { + runnable: RunnableLambda.from( + () => new Promise((resolve) => setTimeout(resolve, 500)) + ), + input: {}, + }, + history: { + runnable: new RunnableWithMessageHistory({ + runnable: chatModel, + config: {}, + getMessageHistory: () => new FakeChatMessageHistory(), + }), + input: "testing", + }, +}; + +describe.each(Object.keys(TEST_CASES))("Test runnable %s", (name) => { + const { + runnable, + input, + skipStream, + }: { runnable: Runnable; input: any; skipStream?: boolean } = + TEST_CASES[name as keyof typeof TEST_CASES]; + test("Test invoke with signal", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.invoke(input, { + signal: controller.signal, + }), + new Promise((resolve) => { + controller.abort(); + resolve(); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test invoke with signal with a delay", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.invoke(input, { + signal: controller.signal, + }), + new Promise((resolve) => { + setTimeout(() => { + controller.abort(); + resolve(); + }, 250); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test stream with signal", async () => { + if (skipStream) { + return; + } + const controller = new AbortController(); + await expect(async () => { + const stream = await runnable.stream(input, { + signal: controller.signal, + }); + for await (const _ of stream) { + controller.abort(); + } + }).rejects.toThrowError(); + }); + + test("Test batch with signal", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.batch([input, input], { + signal: controller.signal, + }), + new Promise((resolve) => { + controller.abort(); + resolve(); + }), + ]); + }).rejects.toThrowError(); + }); + + test("Test batch with signal with a delay", async () => { + await expect(async () => { + const controller = new AbortController(); + await Promise.all([ + runnable.batch([input, input], { + signal: controller.signal, + }), + new Promise((resolve) => { + setTimeout(() => { + controller.abort(); + resolve(); + }, 250); + }), + ]); + }).rejects.toThrowError(); + }); +}); diff --git a/langchain-core/src/utils/stream.ts b/langchain-core/src/utils/stream.ts index f40e997d23fb..91a9810e2d25 100644 --- a/langchain-core/src/utils/stream.ts +++ b/langchain-core/src/utils/stream.ts @@ -201,7 +201,8 @@ export class AsyncGeneratorWithSetup< }) { this.generator = params.generator; this.config = params.config; - this.signal = params.signal; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.signal = params.signal ?? (this.config as any)?.signal; // setup is a promise that resolves only after the first iterator value // is available. this is useful when setup of several piped generators // needs to happen in logical order, ie. in the order in which input to diff --git a/langchain-core/src/utils/testing/index.ts b/langchain-core/src/utils/testing/index.ts index 65d197f6c23e..f14629794293 100644 --- a/langchain-core/src/utils/testing/index.ts +++ b/langchain-core/src/utils/testing/index.ts @@ -15,6 +15,7 @@ import { import { Document } from "../../documents/document.js"; import { BaseChatModel, + BaseChatModelCallOptions, BaseChatModelParams, } from "../../language_models/chat_models.js"; import { BaseLLMParams, LLM } from "../../language_models/llms.js"; @@ -324,6 +325,10 @@ export interface FakeChatInput extends BaseChatModelParams { emitCustomEvent?: boolean; } +export interface FakeListChatModelCallOptions extends BaseChatModelCallOptions { + thrownErrorString?: string; +} + /** * A fake Chat Model that returns a predefined list of responses. It can be used * for testing purposes. @@ -344,7 +349,7 @@ export interface FakeChatInput extends BaseChatModelParams { * console.log({ secondResponse }); * ``` */ -export class FakeListChatModel extends BaseChatModel { +export class FakeListChatModel extends BaseChatModel { static lc_name() { return "FakeListChatModel"; } @@ -378,6 +383,9 @@ export class FakeListChatModel extends BaseChatModel { runManager?: CallbackManagerForLLMRun ): Promise { await this._sleepIfRequested(); + if (options?.thrownErrorString) { + throw new Error(options.thrownErrorString); + } if (this.emitCustomEvent) { await runManager?.handleCustomEvent("some_test_event", { someval: true, @@ -408,7 +416,7 @@ export class FakeListChatModel extends BaseChatModel { async *_streamResponseChunks( _messages: BaseMessage[], - _options: this["ParsedCallOptions"], + options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { const response = this._currentResponse(); @@ -421,6 +429,9 @@ export class FakeListChatModel extends BaseChatModel { for await (const text of response) { await this._sleepIfRequested(); + if (options?.thrownErrorString) { + throw new Error(options.thrownErrorString); + } const chunk = this._createResponseChunk(text); yield chunk; void runManager?.handleLLMNewToken(text);