From 2f8e879976384a93b333afdf8df21d9fa3af9265 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 23 Apr 2024 01:45:27 -0700 Subject: [PATCH] Llms: fix Streaming timeouts --- .../llms/server/llm.server.streaming.ts | 23 +++++++-- .../llms/vendors/unifiedStreamingClient.ts | 50 +++++++++++++------ 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index f4e3b4d82..1d056e6ae 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -54,10 +54,18 @@ const chatStreamingInputSchema = z.object({ }); export type ChatStreamingInputSchema = z.infer; +// the purpose is to send something out even before the upstream stream starts, so that we keep the connection up +const chatStreamingStartOutputPacketSchema = z.object({ + type: z.enum(['start']), +}); +export type ChatStreamingPreambleStartSchema = z.infer; + +// the purpose is to have a first packet that contains the model name, so that the client can display it +// this is a hack until we have a better streaming format const chatStreamingFirstOutputPacketSchema = z.object({ model: z.string(), }); -export type ChatStreamingFirstOutputPacketSchema = z.infer; +export type ChatStreamingPreambleModelSchema = z.infer; export async function llmStreamingRelayHandler(req: NextRequest): Promise { @@ -147,6 +155,7 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise => { + // Send initial packet indicating the start of the stream + const startPacket: ChatStreamingPreambleStartSchema = { type: 'start' }; + controller.enqueue(textEncoder.encode(JSON.stringify(startPacket))); + // only used for debugging let debugLastMs: number | null = null; @@ -293,7 +306,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser { responseMessage = anthropicWireMessagesResponseSchema.parse(message); // hack: prepend the model name to the first packet if (firstMessage) { - const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: responseMessage.model }; + const firstPacket: ChatStreamingPreambleModelSchema = { model: responseMessage.model }; text = JSON.stringify(firstPacket); } break; @@ -408,7 +421,7 @@ function createStreamParserGemini(modelName: string): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: modelName }; + const firstPacket: ChatStreamingPreambleModelSchema = { model: modelName }; text = JSON.stringify(firstPacket) + text; } @@ -444,7 +457,7 @@ function createStreamParserOllama(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun && chunk.model) { hasBegun = true; - const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model }; + const firstPacket: ChatStreamingPreambleModelSchema = { model: chunk.model }; text = JSON.stringify(firstPacket) + text; } @@ -485,7 +498,7 @@ function createStreamParserOpenAI(): AIStreamParser { // hack: prepend the model name to the first packet if (!hasBegun) { hasBegun = true; - const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model }; + const firstPacket: ChatStreamingPreambleModelSchema = { model: json.model }; text = JSON.stringify(firstPacket) + text; } diff --git a/src/modules/llms/vendors/unifiedStreamingClient.ts b/src/modules/llms/vendors/unifiedStreamingClient.ts index a245fde91..07094baa6 100644 --- a/src/modules/llms/vendors/unifiedStreamingClient.ts +++ b/src/modules/llms/vendors/unifiedStreamingClient.ts @@ -1,7 +1,7 @@ import { apiAsync } from '~/common/util/trpc.client'; import { frontendSideFetch } from '~/common/util/clientFetchers'; -import type { ChatStreamingFirstOutputPacketSchema, ChatStreamingInputSchema } from '../server/llm.server.streaming'; +import type { ChatStreamingInputSchema, ChatStreamingPreambleModelSchema, ChatStreamingPreambleStartSchema } from '../server/llm.server.streaming'; import type { DLLMId } from '../store-llms'; import type { VChatFunctionIn, VChatMessageIn } from '../llm.client'; @@ -58,6 +58,7 @@ export async function unifiedStreamingClient