From 5d8084b6507e6134cc5d9e48220ac280f4482ba9 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Tue, 23 Apr 2024 05:07:55 -0700 Subject: [PATCH] Llms: streaming: cleanups --- .../llms/server/llm.server.streaming.ts | 143 +++++++++++------- 1 file changed, 86 insertions(+), 57 deletions(-) diff --git a/src/modules/llms/server/llm.server.streaming.ts b/src/modules/llms/server/llm.server.streaming.ts index 1d056e6ae..ba1f7002c 100644 --- a/src/modules/llms/server/llm.server.streaming.ts +++ b/src/modules/llms/server/llm.server.streaming.ts @@ -19,7 +19,7 @@ import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletio // OpenAI server imports import type { OpenAIWire } from './openai/openai.wiretypes'; -import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router'; +import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from './openai/openai.router'; // configuration @@ -70,74 +70,46 @@ export type ChatStreamingPreambleModelSchema = z.infer { - // inputs - reuse the tRPC schema + // Parse the request const body = await req.json(); const { access, model, history } = chatStreamingInputSchema.parse(body); + const prettyDialect = serverCapitalizeFirstLetter(access.dialect); - // access/dialect dependent setup: - // - requestAccess: the headers and URL to use for the upstream API call - // - muxingFormat: the format of the event stream (sse or json-nl) - // - vendorStreamParser: the parser to use for the event stream - let upstreamResponse: Response; - let requestAccess: { headers: HeadersInit, url: string } = { headers: {}, url: '' }; - let muxingFormat: MuxingFormat = 'sse'; - let vendorStreamParser: AIStreamParser; - try { - - // prepare the API request data - let body: object; - switch (access.dialect) { - case 'anthropic': - requestAccess = anthropicAccess(access, '/v1/messages'); - body = anthropicMessagesPayloadOrThrow(model, history, true); - vendorStreamParser = createStreamParserAnthropicMessages(); - break; - case 'gemini': - requestAccess = geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath); - body = geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1); - vendorStreamParser = createStreamParserGemini(model.id.replace('models/', '')); - break; - - case 'ollama': - requestAccess = ollamaAccess(access, OLLAMA_PATH_CHAT); - body = ollamaChatCompletionPayload(model, history, true); - muxingFormat = 'json-nl'; - vendorStreamParser = createStreamParserOllama(); - break; + // Prepare the upstream API request and demuxer/parser + let requestData: ReturnType; + try { + requestData = _prepareRequestData(access, model, history); + } catch (error: any) { + console.error(`[POST] /api/llms/stream: ${prettyDialect}: prepareRequestData issue:`, safeErrorString(error)); + return new NextResponse(`**[Service Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown streaming error'}`, { + status: 422, + }); + } - case 'azure': - case 'groq': - case 'lmstudio': - case 'localai': - case 'mistral': - case 'oobabooga': - case 'openai': - case 'openrouter': - case 'perplexity': - case 'togetherai': - requestAccess = openAIAccess(access, model.id, '/v1/chat/completions'); - body = openAIChatCompletionPayload(access.dialect, model, history, null, null, 1, true); - vendorStreamParser = createStreamParserOpenAI(); - break; - } + // Connect to the upstream (blocking) + let upstreamResponse: Response; + try { if (SERVER_DEBUG_WIRE) - console.log('-> streaming:', debugGenerateCurlCommand('POST', requestAccess.url, requestAccess.headers, body)); + console.log('-> streaming:', debugGenerateCurlCommand('POST', requestData.url, requestData.headers, requestData.body)); // POST to our API route - upstreamResponse = await nonTrpcServerFetchOrThrow(requestAccess.url, 'POST', requestAccess.headers, body); + // [MAY TIMEOUT] on Vercel Edge calls; this times out on long requests to Anthropic, on 2024-04-23. + // The solution would be to return a new response with a 200 status code, and then stream the data + // in a new request, but we'll lose back-pressure and complicates logic. + upstreamResponse = await nonTrpcServerFetchOrThrow(requestData.url, 'POST', requestData.headers, requestData.body); } catch (error: any) { // server-side admins message const capDialect = serverCapitalizeFirstLetter(access.dialect); const fetchOrVendorError = safeErrorString(error) + (error?.cause ? ' ยท ' + JSON.stringify(error.cause) : ''); - console.error(`[POST] /api/llms/stream: ${capDialect}: fetch issue:`, fetchOrVendorError, requestAccess?.url); + console.error(`[POST] /api/llms/stream: ${capDialect}: fetch issue:`, fetchOrVendorError, requestData?.url); // client-side users visible message const statusCode = ((error instanceof ServerFetchError) && (error.statusCode >= 400)) ? error.statusCode : 422; - const devMessage = process.env.NODE_ENV === 'development' ? ` [DEV_URL: ${requestAccess?.url}]` : ''; + const devMessage = process.env.NODE_ENV === 'development' ? ` [DEV_URL: ${requestData?.url}]` : ''; return new NextResponse(`**[Service Issue] ${capDialect}**: ${fetchOrVendorError}${devMessage}`, { status: statusCode, }); @@ -152,8 +124,8 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise EventSourceParser = createEventsourceParser; + /** * Creates a parser for a 'JSON\n' non-event stream, to be swapped with an EventSource parser. * Ollama is the only vendor that uses this format. */ -function createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourceParser { +function _createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourceParser { let accumulator: string = ''; return { // feeds a new chunk to the parser - we accumulate in case of partial data, and only execute on full lines @@ -206,7 +183,7 @@ function createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourc * Creates a TransformStream that parses events from an EventSource stream using a custom parser. * @returns {TransformStream} TransformStream parsing events. */ -function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { +function createUpstreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream { const textDecoder = new TextDecoder(); const textEncoder = new TextEncoder(); let eventSourceParser: EventSourceParser; @@ -255,9 +232,9 @@ function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextPars }; if (muxingFormat === 'sse') - eventSourceParser = createEventsourceParser(onNewEvent); + eventSourceParser = _createDemuxerEventSource(onNewEvent); else if (muxingFormat === 'json-nl') - eventSourceParser = createDemuxerJsonNewline(onNewEvent); + eventSourceParser = _createDemuxerJsonNewline(onNewEvent); }, // stream=true is set because the data is not guaranteed to be final and un-chunked @@ -506,4 +483,56 @@ function createStreamParserOpenAI(): AIStreamParser { const close = !!json.choices[0].finish_reason; return { text, close }; }; +} + + +function _prepareRequestData(access: ChatStreamingInputSchema['access'], model: OpenAIModelSchema, history: OpenAIHistorySchema): { + headers: HeadersInit; + url: string; + body: object; + vendorMuxingFormat: MuxingFormat; + vendorStreamParser: AIStreamParser; +} { + switch (access.dialect) { + case 'anthropic': + return { + ...anthropicAccess(access, '/v1/messages'), + body: anthropicMessagesPayloadOrThrow(model, history, true), + vendorMuxingFormat: 'sse', + vendorStreamParser: createStreamParserAnthropicMessages(), + }; + + case 'gemini': + return { + ...geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath), + body: geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1), + vendorMuxingFormat: 'sse', + vendorStreamParser: createStreamParserGemini(model.id.replace('models/', '')), + }; + + case 'ollama': + return { + ...ollamaAccess(access, OLLAMA_PATH_CHAT), + body: ollamaChatCompletionPayload(model, history, true), + vendorMuxingFormat: 'json-nl', + vendorStreamParser: createStreamParserOllama(), + }; + + case 'azure': + case 'groq': + case 'lmstudio': + case 'localai': + case 'mistral': + case 'oobabooga': + case 'openai': + case 'openrouter': + case 'perplexity': + case 'togetherai': + return { + ...openAIAccess(access, model.id, '/v1/chat/completions'), + body: openAIChatCompletionPayload(access.dialect, model, history, null, null, 1, true), + vendorMuxingFormat: 'sse', + vendorStreamParser: createStreamParserOpenAI(), + }; + } } \ No newline at end of file