From 3c8fedce68aec3c6e5788f82522377510413d6a2 Mon Sep 17 00:00:00 2001 From: Enrico Ros Date: Fri, 7 Jun 2024 12:38:21 -0700 Subject: [PATCH] Highlight issues with chatGenerateWithFunctions --- .../llms/server/openai/openai.router.ts | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/modules/llms/server/openai/openai.router.ts b/src/modules/llms/server/openai/openai.router.ts index fedc124a3..76cb96392 100644 --- a/src/modules/llms/server/openai/openai.router.ts +++ b/src/modules/llms/server/openai/openai.router.ts @@ -108,7 +108,7 @@ export const llmOpenAIRouter = createTRPCRouter({ // [Azure]: use an older 'deployments' API to enumerate the models, and a modified OpenAI id to description mapping if (access.dialect === 'azure') { - const azureModels = await openaiGET(access, `/openai/deployments?api-version=2023-03-15-preview`); + const azureModels = await openaiGETOrThrow(access, `/openai/deployments?api-version=2023-03-15-preview`); const wireAzureListDeploymentsSchema = z.object({ data: z.array(z.object({ @@ -146,7 +146,7 @@ export const llmOpenAIRouter = createTRPCRouter({ // [non-Azure]: fetch openAI-style for all but Azure (will be then used in each dialect) - const openAIWireModelsResponse = await openaiGET(access, '/v1/models'); + const openAIWireModelsResponse = await openaiGETOrThrow(access, '/v1/models'); // [Together] missing the .data property if (access.dialect === 'togetherai') @@ -271,13 +271,18 @@ export const llmOpenAIRouter = createTRPCRouter({ const isFunctionsCall = !!functions && functions.length > 0; const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false); - const wireCompletions = await openaiPOST( + const wireCompletions = await openaiPOSTOrThrow( access, model.id, completionsBody, '/v1/chat/completions', ); // expect a single output - if (wireCompletions?.choices?.length !== 1) - throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` }); + if (wireCompletions?.choices?.length !== 1) { + console.error(`[POST] llmOpenAI.chatGenerateWithFunctions: ${access.dialect}: unexpected output${forceFunctionName ? ` (fn: ${forceFunctionName})` : ''}:`, wireCompletions?.choices?.length); + throw new TRPCError({ + code: 'UNPROCESSABLE_CONTENT', + message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}`, + }); + } let { message, finish_reason } = wireCompletions.choices[0]; // LocalAI hack/workaround, until https://github.com/go-skynet/LocalAI/issues/788 is fixed @@ -318,7 +323,7 @@ export const llmOpenAIRouter = createTRPCRouter({ delete requestBody.response_format; // create 1 image (dall-e-3 won't support more than 1, so better transfer the burden to the client) - const wireOpenAICreateImageOutput = await openaiPOST( + const wireOpenAICreateImageOutput = await openaiPOSTOrThrow( access, null, requestBody, '/v1/images/generations', ); @@ -340,7 +345,7 @@ export const llmOpenAIRouter = createTRPCRouter({ .mutation(async ({ input: { access, text } }): Promise => { try { - return await openaiPOST(access, null, { + return await openaiPOSTOrThrow(access, null, { input: text, model: 'text-moderation-latest', }, '/v1/moderations'); @@ -361,7 +366,7 @@ export const llmOpenAIRouter = createTRPCRouter({ dialectLocalAI_galleryModelsAvailable: publicProcedure .input(listModelsInputSchema) .query(async ({ input: { access } }) => { - const wireLocalAIModelsAvailable = await openaiGET(access, '/models/available'); + const wireLocalAIModelsAvailable = await openaiGETOrThrow(access, '/models/available'); return wireLocalAIModelsAvailableOutputSchema.parse(wireLocalAIModelsAvailable); }), @@ -374,7 +379,7 @@ export const llmOpenAIRouter = createTRPCRouter({ })) .mutation(async ({ input: { access, galleryName, modelName } }) => { const galleryModelId = `${galleryName}@${modelName}`; - const wireLocalAIModelApply = await openaiPOST(access, null, { id: galleryModelId }, '/models/apply'); + const wireLocalAIModelApply = await openaiPOSTOrThrow(access, null, { id: galleryModelId }, '/models/apply'); return wilreLocalAIModelsApplyOutputSchema.parse(wireLocalAIModelApply); }), @@ -385,7 +390,7 @@ export const llmOpenAIRouter = createTRPCRouter({ jobId: z.string(), })) .query(async ({ input: { access, jobId } }) => { - const wireLocalAIModelsJobs = await openaiGET(access, `/models/jobs/${jobId}`); + const wireLocalAIModelsJobs = await openaiGETOrThrow(access, `/models/jobs/${jobId}`); return wireLocalAIModelsListOutputSchema.parse(wireLocalAIModelsJobs); }), @@ -623,12 +628,12 @@ export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: Open }; } -async function openaiGET(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { +async function openaiGETOrThrow(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise { const { headers, url } = openAIAccess(access, null, apiPath); return await fetchJsonOrTRPCError(url, 'GET', headers, undefined, `OpenAI/${access.dialect}`); } -async function openaiPOST(access: OpenAIAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise { +async function openaiPOSTOrThrow(access: OpenAIAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise { const { headers, url } = openAIAccess(access, modelRefId, apiPath); return await fetchJsonOrTRPCError(url, 'POST', headers, body, `OpenAI/${access.dialect}`); }