From ea0c5a8be4f47ce8ad6a02f49fbe45468f243b79 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 18 Nov 2024 15:53:46 -0800 Subject: [PATCH] fix(openai): Support o1 streaming (#7229) --- libs/langchain-openai/src/chat_models.ts | 14 ---------- .../src/tests/chat_models.int.test.ts | 28 +++++++++++++++++++ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/libs/langchain-openai/src/chat_models.ts b/libs/langchain-openai/src/chat_models.ts index 1db33c8728ab..8d2145fc8be7 100644 --- a/libs/langchain-openai/src/chat_models.ts +++ b/libs/langchain-openai/src/chat_models.ts @@ -14,7 +14,6 @@ import { ToolMessageChunk, OpenAIToolCall, isAIMessage, - convertToChunk, UsageMetadata, } from "@langchain/core/messages"; import { @@ -1360,19 +1359,6 @@ export class ChatOpenAI< options: this["ParsedCallOptions"], runManager?: CallbackManagerForLLMRun ): AsyncGenerator { - if (this.model.includes("o1-")) { - console.warn( - "[WARNING]: OpenAI o1 models do not yet support token-level streaming. Streaming will yield single chunk." - ); - const result = await this._generate(messages, options, runManager); - const messageChunk = convertToChunk(result.generations[0].message); - yield new ChatGenerationChunk({ - message: messageChunk, - text: - typeof messageChunk.content === "string" ? messageChunk.content : "", - }); - return; - } const messagesMapped: OpenAICompletionParam[] = _convertMessagesToOpenAIParams(messages); const params = { diff --git a/libs/langchain-openai/src/tests/chat_models.int.test.ts b/libs/langchain-openai/src/tests/chat_models.int.test.ts index a5bed0811e61..be5f8c7d7a90 100644 --- a/libs/langchain-openai/src/tests/chat_models.int.test.ts +++ b/libs/langchain-openai/src/tests/chat_models.int.test.ts @@ -1166,3 +1166,31 @@ describe("Audio output", () => { ).toBeGreaterThan(1); }); }); + +test("Can stream o1 requests", async () => { + const model = new ChatOpenAI({ + model: "o1-mini", + }); + const stream = await model.stream( + "Write me a very simple hello world program in Python. Ensure it is wrapped in a function called 'hello_world' and has descriptive comments." + ); + let finalMsg: AIMessageChunk | undefined; + let numChunks = 0; + for await (const chunk of stream) { + finalMsg = finalMsg ? concat(finalMsg, chunk) : chunk; + numChunks += 1; + } + + expect(finalMsg).toBeTruthy(); + if (!finalMsg) { + throw new Error("No final message found"); + } + if (typeof finalMsg.content === "string") { + expect(finalMsg.content.length).toBeGreaterThan(10); + } else { + expect(finalMsg.content.length).toBeGreaterThanOrEqual(1); + } + + // A + expect(numChunks).toBeGreaterThan(3); +});