Skip to content

Commit

Permalink
Llms: streaming: cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Apr 23, 2024
1 parent f316b89 commit 5d8084b
Showing 1 changed file with 86 additions and 57 deletions.
143 changes: 86 additions & 57 deletions src/modules/llms/server/llm.server.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { OLLAMA_PATH_CHAT, ollamaAccess, ollamaAccessSchema, ollamaChatCompletio

// OpenAI server imports
import type { OpenAIWire } from './openai/openai.wiretypes';
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router';
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from './openai/openai.router';


// configuration
Expand Down Expand Up @@ -70,74 +70,46 @@ export type ChatStreamingPreambleModelSchema = z.infer<typeof chatStreamingFirst

export async function llmStreamingRelayHandler(req: NextRequest): Promise<Response> {

// inputs - reuse the tRPC schema
// Parse the request
const body = await req.json();
const { access, model, history } = chatStreamingInputSchema.parse(body);
const prettyDialect = serverCapitalizeFirstLetter(access.dialect);

// access/dialect dependent setup:
// - requestAccess: the headers and URL to use for the upstream API call
// - muxingFormat: the format of the event stream (sse or json-nl)
// - vendorStreamParser: the parser to use for the event stream
let upstreamResponse: Response;
let requestAccess: { headers: HeadersInit, url: string } = { headers: {}, url: '' };
let muxingFormat: MuxingFormat = 'sse';
let vendorStreamParser: AIStreamParser;
try {

// prepare the API request data
let body: object;
switch (access.dialect) {
case 'anthropic':
requestAccess = anthropicAccess(access, '/v1/messages');
body = anthropicMessagesPayloadOrThrow(model, history, true);
vendorStreamParser = createStreamParserAnthropicMessages();
break;

case 'gemini':
requestAccess = geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath);
body = geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1);
vendorStreamParser = createStreamParserGemini(model.id.replace('models/', ''));
break;

case 'ollama':
requestAccess = ollamaAccess(access, OLLAMA_PATH_CHAT);
body = ollamaChatCompletionPayload(model, history, true);
muxingFormat = 'json-nl';
vendorStreamParser = createStreamParserOllama();
break;
// Prepare the upstream API request and demuxer/parser
let requestData: ReturnType<typeof _prepareRequestData>;
try {
requestData = _prepareRequestData(access, model, history);
} catch (error: any) {
console.error(`[POST] /api/llms/stream: ${prettyDialect}: prepareRequestData issue:`, safeErrorString(error));
return new NextResponse(`**[Service Issue] ${prettyDialect}**: ${safeErrorString(error) || 'Unknown streaming error'}`, {
status: 422,
});
}

case 'azure':
case 'groq':
case 'lmstudio':
case 'localai':
case 'mistral':
case 'oobabooga':
case 'openai':
case 'openrouter':
case 'perplexity':
case 'togetherai':
requestAccess = openAIAccess(access, model.id, '/v1/chat/completions');
body = openAIChatCompletionPayload(access.dialect, model, history, null, null, 1, true);
vendorStreamParser = createStreamParserOpenAI();
break;
}
// Connect to the upstream (blocking)
let upstreamResponse: Response;
try {

if (SERVER_DEBUG_WIRE)
console.log('-> streaming:', debugGenerateCurlCommand('POST', requestAccess.url, requestAccess.headers, body));
console.log('-> streaming:', debugGenerateCurlCommand('POST', requestData.url, requestData.headers, requestData.body));

// POST to our API route
upstreamResponse = await nonTrpcServerFetchOrThrow(requestAccess.url, 'POST', requestAccess.headers, body);
// [MAY TIMEOUT] on Vercel Edge calls; this times out on long requests to Anthropic, on 2024-04-23.
// The solution would be to return a new response with a 200 status code, and then stream the data
// in a new request, but we'll lose back-pressure and complicates logic.
upstreamResponse = await nonTrpcServerFetchOrThrow(requestData.url, 'POST', requestData.headers, requestData.body);

} catch (error: any) {

// server-side admins message
const capDialect = serverCapitalizeFirstLetter(access.dialect);
const fetchOrVendorError = safeErrorString(error) + (error?.cause ? ' · ' + JSON.stringify(error.cause) : '');
console.error(`[POST] /api/llms/stream: ${capDialect}: fetch issue:`, fetchOrVendorError, requestAccess?.url);
console.error(`[POST] /api/llms/stream: ${capDialect}: fetch issue:`, fetchOrVendorError, requestData?.url);

// client-side users visible message
const statusCode = ((error instanceof ServerFetchError) && (error.statusCode >= 400)) ? error.statusCode : 422;
const devMessage = process.env.NODE_ENV === 'development' ? ` [DEV_URL: ${requestAccess?.url}]` : '';
const devMessage = process.env.NODE_ENV === 'development' ? ` [DEV_URL: ${requestData?.url}]` : '';
return new NextResponse(`**[Service Issue] ${capDialect}**: ${fetchOrVendorError}${devMessage}`, {
status: statusCode,
});
Expand All @@ -152,8 +124,8 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise<Respon
* NOTE: we have not benchmarked to see if there is performance impact by using this approach - we do want to have
* a 'healthy' level of inventory (i.e., pre-buffering) on the pipe to the client.
*/
const transformUpstreamToBigAgiClient = createEventStreamTransformer(
muxingFormat, vendorStreamParser, access.dialect,
const transformUpstreamToBigAgiClient = createUpstreamTransformer(
requestData.vendorMuxingFormat, requestData.vendorStreamParser, access.dialect,
);

const chatResponseStream =
Expand All @@ -171,11 +143,16 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise<Respon

// Event Stream Transformers

/**
* The default demuxer for EventSource upstreams.
*/
const _createDemuxerEventSource: (onParse: EventSourceParseCallback) => EventSourceParser = createEventsourceParser;

/**
* Creates a parser for a 'JSON\n' non-event stream, to be swapped with an EventSource parser.
* Ollama is the only vendor that uses this format.
*/
function createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourceParser {
function _createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourceParser {
let accumulator: string = '';
return {
// feeds a new chunk to the parser - we accumulate in case of partial data, and only execute on full lines
Expand Down Expand Up @@ -206,7 +183,7 @@ function createDemuxerJsonNewline(onParse: EventSourceParseCallback): EventSourc
* Creates a TransformStream that parses events from an EventSource stream using a custom parser.
* @returns {TransformStream<Uint8Array, string>} TransformStream parsing events.
*/
function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream<Uint8Array, Uint8Array> {
function createUpstreamTransformer(muxingFormat: MuxingFormat, vendorTextParser: AIStreamParser, dialectLabel: string): TransformStream<Uint8Array, Uint8Array> {
const textDecoder = new TextDecoder();
const textEncoder = new TextEncoder();
let eventSourceParser: EventSourceParser;
Expand Down Expand Up @@ -255,9 +232,9 @@ function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextPars
};

if (muxingFormat === 'sse')
eventSourceParser = createEventsourceParser(onNewEvent);
eventSourceParser = _createDemuxerEventSource(onNewEvent);
else if (muxingFormat === 'json-nl')
eventSourceParser = createDemuxerJsonNewline(onNewEvent);
eventSourceParser = _createDemuxerJsonNewline(onNewEvent);
},

// stream=true is set because the data is not guaranteed to be final and un-chunked
Expand Down Expand Up @@ -506,4 +483,56 @@ function createStreamParserOpenAI(): AIStreamParser {
const close = !!json.choices[0].finish_reason;
return { text, close };
};
}


function _prepareRequestData(access: ChatStreamingInputSchema['access'], model: OpenAIModelSchema, history: OpenAIHistorySchema): {
headers: HeadersInit;
url: string;
body: object;
vendorMuxingFormat: MuxingFormat;
vendorStreamParser: AIStreamParser;
} {
switch (access.dialect) {
case 'anthropic':
return {
...anthropicAccess(access, '/v1/messages'),
body: anthropicMessagesPayloadOrThrow(model, history, true),
vendorMuxingFormat: 'sse',
vendorStreamParser: createStreamParserAnthropicMessages(),
};

case 'gemini':
return {
...geminiAccess(access, model.id, geminiModelsStreamGenerateContentPath),
body: geminiGenerateContentTextPayload(model, history, access.minSafetyLevel, 1),
vendorMuxingFormat: 'sse',
vendorStreamParser: createStreamParserGemini(model.id.replace('models/', '')),
};

case 'ollama':
return {
...ollamaAccess(access, OLLAMA_PATH_CHAT),
body: ollamaChatCompletionPayload(model, history, true),
vendorMuxingFormat: 'json-nl',
vendorStreamParser: createStreamParserOllama(),
};

case 'azure':
case 'groq':
case 'lmstudio':
case 'localai':
case 'mistral':
case 'oobabooga':
case 'openai':
case 'openrouter':
case 'perplexity':
case 'togetherai':
return {
...openAIAccess(access, model.id, '/v1/chat/completions'),
body: openAIChatCompletionPayload(access.dialect, model, history, null, null, 1, true),
vendorMuxingFormat: 'sse',
vendorStreamParser: createStreamParserOpenAI(),
};
}
}

0 comments on commit 5d8084b

Please sign in to comment.