Skip to content

Commit

Permalink
Llms: fix Streaming timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Apr 23, 2024
1 parent cc0ac5a commit 2f8e879
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
23 changes: 18 additions & 5 deletions src/modules/llms/server/llm.server.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,18 @@ const chatStreamingInputSchema = z.object({
});
export type ChatStreamingInputSchema = z.infer<typeof chatStreamingInputSchema>;

// the purpose is to send something out even before the upstream stream starts, so that we keep the connection up
const chatStreamingStartOutputPacketSchema = z.object({
type: z.enum(['start']),
});
export type ChatStreamingPreambleStartSchema = z.infer<typeof chatStreamingStartOutputPacketSchema>;

// the purpose is to have a first packet that contains the model name, so that the client can display it
// this is a hack until we have a better streaming format
const chatStreamingFirstOutputPacketSchema = z.object({
model: z.string(),
});
export type ChatStreamingFirstOutputPacketSchema = z.infer<typeof chatStreamingFirstOutputPacketSchema>;
export type ChatStreamingPreambleModelSchema = z.infer<typeof chatStreamingFirstOutputPacketSchema>;


export async function llmStreamingRelayHandler(req: NextRequest): Promise<Response> {
Expand Down Expand Up @@ -147,6 +155,7 @@ export async function llmStreamingRelayHandler(req: NextRequest): Promise<Respon
const transformUpstreamToBigAgiClient = createEventStreamTransformer(
muxingFormat, vendorStreamParser, access.dialect,
);

const chatResponseStream =
(upstreamResponse.body || createEmptyReadableStream())
.pipeThrough(transformUpstreamToBigAgiClient);
Expand Down Expand Up @@ -206,6 +215,10 @@ function createEventStreamTransformer(muxingFormat: MuxingFormat, vendorTextPars
return new TransformStream({
start: async (controller): Promise<void> => {

// Send initial packet indicating the start of the stream
const startPacket: ChatStreamingPreambleStartSchema = { type: 'start' };
controller.enqueue(textEncoder.encode(JSON.stringify(startPacket)));

// only used for debugging
let debugLastMs: number | null = null;

Expand Down Expand Up @@ -293,7 +306,7 @@ function createStreamParserAnthropicMessages(): AIStreamParser {
responseMessage = anthropicWireMessagesResponseSchema.parse(message);
// hack: prepend the model name to the first packet
if (firstMessage) {
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: responseMessage.model };
const firstPacket: ChatStreamingPreambleModelSchema = { model: responseMessage.model };
text = JSON.stringify(firstPacket);
}
break;
Expand Down Expand Up @@ -408,7 +421,7 @@ function createStreamParserGemini(modelName: string): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun) {
hasBegun = true;
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: modelName };
const firstPacket: ChatStreamingPreambleModelSchema = { model: modelName };
text = JSON.stringify(firstPacket) + text;
}

Expand Down Expand Up @@ -444,7 +457,7 @@ function createStreamParserOllama(): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun && chunk.model) {
hasBegun = true;
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: chunk.model };
const firstPacket: ChatStreamingPreambleModelSchema = { model: chunk.model };
text = JSON.stringify(firstPacket) + text;
}

Expand Down Expand Up @@ -485,7 +498,7 @@ function createStreamParserOpenAI(): AIStreamParser {
// hack: prepend the model name to the first packet
if (!hasBegun) {
hasBegun = true;
const firstPacket: ChatStreamingFirstOutputPacketSchema = { model: json.model };
const firstPacket: ChatStreamingPreambleModelSchema = { model: json.model };
text = JSON.stringify(firstPacket) + text;
}

Expand Down
50 changes: 35 additions & 15 deletions src/modules/llms/vendors/unifiedStreamingClient.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { apiAsync } from '~/common/util/trpc.client';
import { frontendSideFetch } from '~/common/util/clientFetchers';

import type { ChatStreamingFirstOutputPacketSchema, ChatStreamingInputSchema } from '../server/llm.server.streaming';
import type { ChatStreamingInputSchema, ChatStreamingPreambleModelSchema, ChatStreamingPreambleStartSchema } from '../server/llm.server.streaming';
import type { DLLMId } from '../store-llms';
import type { VChatFunctionIn, VChatMessageIn } from '../llm.client';

Expand Down Expand Up @@ -58,6 +58,7 @@ export async function unifiedStreamingClient<TSourceSetup = unknown, TLLMOptions
};

// connect to the server-side streaming endpoint
const timeFetch = performance.now();
const response = await frontendSideFetch('/api/llms/stream', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
Expand All @@ -75,7 +76,8 @@ export async function unifiedStreamingClient<TSourceSetup = unknown, TLLMOptions

// loop forever until the read is done, or the abort controller is triggered
let incrementalText = '';
let parsedFirstPacket = false;
let parsedPreambleStart = false;
let parsedPreableModel = false;
while (true) {
const { value, done } = await responseReader.read();

Expand All @@ -88,21 +90,39 @@ export async function unifiedStreamingClient<TSourceSetup = unknown, TLLMOptions

incrementalText += textDecoder.decode(value, { stream: true });

// (streaming workaround) there may be a JSON object at the beginning of the message,
// injected by us to transmit the model name
if (!parsedFirstPacket && incrementalText.startsWith('{')) {
// we have two packets with a serialized flat json object at the start; this is side data, before the text flow starts
while ((!parsedPreambleStart || !parsedPreableModel) && incrementalText.startsWith('{')) {

// extract a complete JSON object, if present
const endOfJson = incrementalText.indexOf('}');
if (endOfJson === -1)
continue;
const json = incrementalText.substring(0, endOfJson + 1);
if (endOfJson === -1) break;
const jsonString = incrementalText.substring(0, endOfJson + 1);
incrementalText = incrementalText.substring(endOfJson + 1);
parsedFirstPacket = true;
try {
const parsed: ChatStreamingFirstOutputPacketSchema = JSON.parse(json);
onUpdate({ originLLM: parsed.model }, false);
} catch (e) {
// error parsing JSON, ignore
console.log('unifiedStreamingClient: error parsing JSON:', e);

// first packet: preamble to let the Vercel edge function go over time
if (!parsedPreambleStart) {
parsedPreambleStart = true;
try {
const parsed: ChatStreamingPreambleStartSchema = JSON.parse(jsonString);
if (parsed.type !== 'start')
console.log('unifiedStreamingClient: unexpected preamble type:', parsed?.type, 'time:', performance.now() - timeFetch);
} catch (e) {
// error parsing JSON, ignore
console.log('unifiedStreamingClient: error parsing start JSON:', e);
}
continue;
}

// second packet: the model name
if (!parsedPreableModel) {
parsedPreableModel = true;
try {
const parsed: ChatStreamingPreambleModelSchema = JSON.parse(jsonString);
onUpdate({ originLLM: parsed.model }, false);
} catch (e) {
// error parsing JSON, ignore
console.log('unifiedStreamingClient: error parsing model JSON:', e);
}
}
}

Expand Down

0 comments on commit 2f8e879

Please sign in to comment.