Skip to content

Commit

Permalink
Highlight issues with chatGenerateWithFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Jun 7, 2024
1 parent 1744b5b commit 3c8fedc
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions src/modules/llms/server/openai/openai.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ export const llmOpenAIRouter = createTRPCRouter({

// [Azure]: use an older 'deployments' API to enumerate the models, and a modified OpenAI id to description mapping
if (access.dialect === 'azure') {
const azureModels = await openaiGET(access, `/openai/deployments?api-version=2023-03-15-preview`);
const azureModels = await openaiGETOrThrow(access, `/openai/deployments?api-version=2023-03-15-preview`);

const wireAzureListDeploymentsSchema = z.object({
data: z.array(z.object({
Expand Down Expand Up @@ -146,7 +146,7 @@ export const llmOpenAIRouter = createTRPCRouter({


// [non-Azure]: fetch openAI-style for all but Azure (will be then used in each dialect)
const openAIWireModelsResponse = await openaiGET<OpenAIWire.Models.Response>(access, '/v1/models');
const openAIWireModelsResponse = await openaiGETOrThrow<OpenAIWire.Models.Response>(access, '/v1/models');

// [Together] missing the .data property
if (access.dialect === 'togetherai')
Expand Down Expand Up @@ -271,13 +271,18 @@ export const llmOpenAIRouter = createTRPCRouter({
const isFunctionsCall = !!functions && functions.length > 0;

const completionsBody = openAIChatCompletionPayload(access.dialect, model, history, isFunctionsCall ? functions : null, forceFunctionName ?? null, 1, false);
const wireCompletions = await openaiPOST<OpenAIWire.ChatCompletion.Response, OpenAIWire.ChatCompletion.Request>(
const wireCompletions = await openaiPOSTOrThrow<OpenAIWire.ChatCompletion.Response, OpenAIWire.ChatCompletion.Request>(
access, model.id, completionsBody, '/v1/chat/completions',
);

// expect a single output
if (wireCompletions?.choices?.length !== 1)
throw new TRPCError({ code: 'INTERNAL_SERVER_ERROR', message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}` });
if (wireCompletions?.choices?.length !== 1) {
console.error(`[POST] llmOpenAI.chatGenerateWithFunctions: ${access.dialect}: unexpected output${forceFunctionName ? ` (fn: ${forceFunctionName})` : ''}:`, wireCompletions?.choices?.length);
throw new TRPCError({
code: 'UNPROCESSABLE_CONTENT',
message: `[OpenAI Issue] Expected 1 completion, got ${wireCompletions?.choices?.length}`,
});
}
let { message, finish_reason } = wireCompletions.choices[0];

// LocalAI hack/workaround, until https://github.com/go-skynet/LocalAI/issues/788 is fixed
Expand Down Expand Up @@ -318,7 +323,7 @@ export const llmOpenAIRouter = createTRPCRouter({
delete requestBody.response_format;

// create 1 image (dall-e-3 won't support more than 1, so better transfer the burden to the client)
const wireOpenAICreateImageOutput = await openaiPOST<WireOpenAICreateImageOutput, WireOpenAICreateImageRequest>(
const wireOpenAICreateImageOutput = await openaiPOSTOrThrow<WireOpenAICreateImageOutput, WireOpenAICreateImageRequest>(
access, null, requestBody, '/v1/images/generations',
);

Expand All @@ -340,7 +345,7 @@ export const llmOpenAIRouter = createTRPCRouter({
.mutation(async ({ input: { access, text } }): Promise<OpenAIWire.Moderation.Response> => {
try {

return await openaiPOST<OpenAIWire.Moderation.Response, OpenAIWire.Moderation.Request>(access, null, {
return await openaiPOSTOrThrow<OpenAIWire.Moderation.Response, OpenAIWire.Moderation.Request>(access, null, {
input: text,
model: 'text-moderation-latest',
}, '/v1/moderations');
Expand All @@ -361,7 +366,7 @@ export const llmOpenAIRouter = createTRPCRouter({
dialectLocalAI_galleryModelsAvailable: publicProcedure
.input(listModelsInputSchema)
.query(async ({ input: { access } }) => {
const wireLocalAIModelsAvailable = await openaiGET(access, '/models/available');
const wireLocalAIModelsAvailable = await openaiGETOrThrow(access, '/models/available');
return wireLocalAIModelsAvailableOutputSchema.parse(wireLocalAIModelsAvailable);
}),

Expand All @@ -374,7 +379,7 @@ export const llmOpenAIRouter = createTRPCRouter({
}))
.mutation(async ({ input: { access, galleryName, modelName } }) => {
const galleryModelId = `${galleryName}@${modelName}`;
const wireLocalAIModelApply = await openaiPOST(access, null, { id: galleryModelId }, '/models/apply');
const wireLocalAIModelApply = await openaiPOSTOrThrow(access, null, { id: galleryModelId }, '/models/apply');
return wilreLocalAIModelsApplyOutputSchema.parse(wireLocalAIModelApply);
}),

Expand All @@ -385,7 +390,7 @@ export const llmOpenAIRouter = createTRPCRouter({
jobId: z.string(),
}))
.query(async ({ input: { access, jobId } }) => {
const wireLocalAIModelsJobs = await openaiGET(access, `/models/jobs/${jobId}`);
const wireLocalAIModelsJobs = await openaiGETOrThrow(access, `/models/jobs/${jobId}`);
return wireLocalAIModelsListOutputSchema.parse(wireLocalAIModelsJobs);
}),

Expand Down Expand Up @@ -623,12 +628,12 @@ export function openAIChatCompletionPayload(dialect: OpenAIDialects, model: Open
};
}

async function openaiGET<TOut extends object>(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
async function openaiGETOrThrow<TOut extends object>(access: OpenAIAccessSchema, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = openAIAccess(access, null, apiPath);
return await fetchJsonOrTRPCError<TOut>(url, 'GET', headers, undefined, `OpenAI/${access.dialect}`);
}

async function openaiPOST<TOut extends object, TPostBody extends object>(access: OpenAIAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
async function openaiPOSTOrThrow<TOut extends object, TPostBody extends object>(access: OpenAIAccessSchema, modelRefId: string | null, body: TPostBody, apiPath: string /*, signal?: AbortSignal*/): Promise<TOut> {
const { headers, url } = openAIAccess(access, modelRefId, apiPath);
return await fetchJsonOrTRPCError<TOut, TPostBody>(url, 'POST', headers, body, `OpenAI/${access.dialect}`);
}
Expand Down

0 comments on commit 3c8fedc

Please sign in to comment.