Skip to content

Commit

Permalink
Adds test cases, fix streaming for generators
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 2, 2024
1 parent 632ad31 commit a44c38a
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 3 deletions.
153 changes: 153 additions & 0 deletions langchain-core/src/runnables/tests/signal.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/* eslint-disable no-promise-executor-return */
/* eslint-disable @typescript-eslint/no-explicit-any */

import {
Runnable,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnableSequence,
RunnableWithMessageHistory,
} from "../index.js";
import {
FakeChatMessageHistory,
FakeListChatModel,
} from "../../utils/testing/index.js";

const chatModel = new FakeListChatModel({ responses: ["hey"], sleep: 500 });

const TEST_CASES = {
map: {
runnable: RunnableMap.from({
question: new RunnablePassthrough(),
context: async () => {
await new Promise((resolve) => setTimeout(resolve, 500));
return "SOME STUFF";
},
}),
input: "testing",
},
binding: {
runnable: RunnableLambda.from(
() => new Promise((resolve) => setTimeout(resolve, 500))
),
input: "testing",
},
fallbacks: {
runnable: chatModel
.bind({ thrownErrorString: "expected" })
.withFallbacks({ fallbacks: [chatModel] }),
input: "testing",
skipStream: true,
},
sequence: {
runnable: RunnableSequence.from([
RunnablePassthrough.assign({
test: () => chatModel,
}),
() => {},
]),
input: { question: "testing" },
},
lambda: {
runnable: RunnableLambda.from(
() => new Promise((resolve) => setTimeout(resolve, 500))
),
input: {},
},
history: {
runnable: new RunnableWithMessageHistory({
runnable: chatModel,
config: {},
getMessageHistory: () => new FakeChatMessageHistory(),
}),
input: "testing",
},
};

describe.each(Object.keys(TEST_CASES))("Test runnable %s", (name) => {
const {
runnable,
input,
skipStream,
}: { runnable: Runnable; input: any; skipStream?: boolean } =
TEST_CASES[name as keyof typeof TEST_CASES];
test("Test invoke with signal", async () => {
await expect(async () => {
const controller = new AbortController();
await Promise.all([
runnable.invoke(input, {
signal: controller.signal,
}),
new Promise<void>((resolve) => {
controller.abort();
resolve();
}),
]);
}).rejects.toThrowError();
});

test("Test invoke with signal with a delay", async () => {
await expect(async () => {
const controller = new AbortController();
await Promise.all([
runnable.invoke(input, {
signal: controller.signal,
}),
new Promise<void>((resolve) => {
setTimeout(() => {
controller.abort();
resolve();
}, 250);
}),
]);
}).rejects.toThrowError();
});

test("Test stream with signal", async () => {
if (skipStream) {
return;
}
const controller = new AbortController();
await expect(async () => {
const stream = await runnable.stream(input, {
signal: controller.signal,
});
for await (const _ of stream) {
controller.abort();
}
}).rejects.toThrowError();
});

test("Test batch with signal", async () => {
await expect(async () => {
const controller = new AbortController();
await Promise.all([
runnable.batch([input, input], {
signal: controller.signal,
}),
new Promise<void>((resolve) => {
controller.abort();
resolve();
}),
]);
}).rejects.toThrowError();
});

test("Test batch with signal with a delay", async () => {
await expect(async () => {
const controller = new AbortController();
await Promise.all([
runnable.batch([input, input], {
signal: controller.signal,
}),
new Promise<void>((resolve) => {
setTimeout(() => {
controller.abort();
resolve();
}, 250);
}),
]);
}).rejects.toThrowError();
});
});
3 changes: 2 additions & 1 deletion langchain-core/src/utils/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ export class AsyncGeneratorWithSetup<
}) {
this.generator = params.generator;
this.config = params.config;
this.signal = params.signal;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
this.signal = params.signal ?? (this.config as any)?.signal;
// setup is a promise that resolves only after the first iterator value
// is available. this is useful when setup of several piped generators
// needs to happen in logical order, ie. in the order in which input to
Expand Down
15 changes: 13 additions & 2 deletions langchain-core/src/utils/testing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
import { Document } from "../../documents/document.js";
import {
BaseChatModel,
BaseChatModelCallOptions,
BaseChatModelParams,
} from "../../language_models/chat_models.js";
import { BaseLLMParams, LLM } from "../../language_models/llms.js";
Expand Down Expand Up @@ -324,6 +325,10 @@ export interface FakeChatInput extends BaseChatModelParams {
emitCustomEvent?: boolean;
}

export interface FakeListChatModelCallOptions extends BaseChatModelCallOptions {
thrownErrorString?: string;
}

/**
* A fake Chat Model that returns a predefined list of responses. It can be used
* for testing purposes.
Expand All @@ -344,7 +349,7 @@ export interface FakeChatInput extends BaseChatModelParams {
* console.log({ secondResponse });
* ```
*/
export class FakeListChatModel extends BaseChatModel {
export class FakeListChatModel extends BaseChatModel<FakeListChatModelCallOptions> {
static lc_name() {
return "FakeListChatModel";
}
Expand Down Expand Up @@ -378,6 +383,9 @@ export class FakeListChatModel extends BaseChatModel {
runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
await this._sleepIfRequested();
if (options?.thrownErrorString) {
throw new Error(options.thrownErrorString);
}
if (this.emitCustomEvent) {
await runManager?.handleCustomEvent("some_test_event", {
someval: true,
Expand Down Expand Up @@ -408,7 +416,7 @@ export class FakeListChatModel extends BaseChatModel {

async *_streamResponseChunks(
_messages: BaseMessage[],
_options: this["ParsedCallOptions"],
options: this["ParsedCallOptions"],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const response = this._currentResponse();
Expand All @@ -421,6 +429,9 @@ export class FakeListChatModel extends BaseChatModel {

for await (const text of response) {
await this._sleepIfRequested();
if (options?.thrownErrorString) {
throw new Error(options.thrownErrorString);
}
const chunk = this._createResponseChunk(text);
yield chunk;
void runManager?.handleLLMNewToken(text);
Expand Down

0 comments on commit a44c38a

Please sign in to comment.