From 5ba9e2e9b169dd6b175537e7f1825995cefae44c Mon Sep 17 00:00:00 2001 From: darkskygit Date: Mon, 27 May 2024 09:57:39 +0000 Subject: [PATCH] fix: choose provider correctly (#7081) fix no provider error in caption generate action --- .../migration.sql | 8 +++ packages/backend/server/schema.prisma | 2 +- .../server/src/plugins/copilot/controller.ts | 66 +++++++++++++------ .../server/src/plugins/copilot/prompt.ts | 4 +- 4 files changed, 57 insertions(+), 23 deletions(-) create mode 100644 packages/backend/server/migrations/20240527095524_fix_prompt_schema/migration.sql diff --git a/packages/backend/server/migrations/20240527095524_fix_prompt_schema/migration.sql b/packages/backend/server/migrations/20240527095524_fix_prompt_schema/migration.sql new file mode 100644 index 0000000000000..72673a92b3744 --- /dev/null +++ b/packages/backend/server/migrations/20240527095524_fix_prompt_schema/migration.sql @@ -0,0 +1,8 @@ +/* + Warnings: + + - Made the column `model` on table `ai_prompts_metadata` required. This step will fail if there are existing NULL values in that column. + +*/ +-- AlterTable +ALTER TABLE "ai_prompts_metadata" ALTER COLUMN "model" SET NOT NULL; diff --git a/packages/backend/server/schema.prisma b/packages/backend/server/schema.prisma index 3312f795ec81c..0f1e2801ce0fd 100644 --- a/packages/backend/server/schema.prisma +++ b/packages/backend/server/schema.prisma @@ -455,7 +455,7 @@ model AiPrompt { // an mark identifying which view to use to display the session // it is only used in the frontend and does not affect the backend action String? @db.VarChar - model String? @db.VarChar + model String @db.VarChar createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) messages AiPromptMessage[] diff --git a/packages/backend/server/src/plugins/copilot/controller.ts b/packages/backend/server/src/plugins/copilot/controller.ts index 097fc28d79dd3..0b6ec84dc7cc7 100644 --- a/packages/backend/server/src/plugins/copilot/controller.ts +++ b/packages/backend/server/src/plugins/copilot/controller.ts @@ -34,7 +34,11 @@ import { Config } from '../../fundamentals'; import { CopilotProviderService } from './providers'; import { ChatSession, ChatSessionService } from './session'; import { CopilotStorage } from './storage'; -import { CopilotCapability } from './types'; +import { + CopilotCapability, + CopilotImageToTextProvider, + CopilotTextToTextProvider, +} from './types'; export interface ChatEvent { type: 'attachment' | 'message' | 'error'; @@ -71,7 +75,7 @@ export class CopilotController { const ret: CheckResult = { model: session.model }; - if (messageId) { + if (messageId && typeof messageId === 'string') { const message = await session.getMessageById(messageId); ret.hasAttachment = Array.isArray(message.attachments) && !!message.attachments.length; @@ -80,6 +84,34 @@ export class CopilotController { return ret; } + private async chooseTextProvider( + userId: string, + sessionId: string, + messageId?: string + ): Promise { + const { hasAttachment, model } = await this.checkRequest( + userId, + sessionId, + messageId + ); + let provider = await this.provider.getProviderByCapability( + CopilotCapability.TextToText, + model + ); + // fallback to image to text if text to text is not available + if (!provider && hasAttachment) { + provider = await this.provider.getProviderByCapability( + CopilotCapability.ImageToText, + model + ); + } + if (!provider) { + throw new InternalServerErrorException('No provider available'); + } + + return provider; + } + private async appendSessionMessage( sessionId: string, messageId?: string @@ -139,18 +171,15 @@ export class CopilotController { @Param('sessionId') sessionId: string, @Query() params: Record ): Promise { - const { model } = await this.checkRequest(user.id, sessionId); - const provider = await this.provider.getProviderByCapability( - CopilotCapability.TextToText, - model - ); - if (!provider) { - throw new InternalServerErrorException('No provider available'); - } - const messageId = Array.isArray(params.messageId) ? params.messageId[0] : params.messageId; + const provider = await this.chooseTextProvider( + user.id, + sessionId, + messageId + ); + const session = await this.appendSessionMessage(sessionId, messageId); try { @@ -187,18 +216,15 @@ export class CopilotController { @Query() params: Record ): Promise> { try { - const { model } = await this.checkRequest(user.id, sessionId); - const provider = await this.provider.getProviderByCapability( - CopilotCapability.TextToText, - model - ); - if (!provider) { - throw new InternalServerErrorException('No provider available'); - } - const messageId = Array.isArray(params.messageId) ? params.messageId[0] : params.messageId; + const provider = await this.chooseTextProvider( + user.id, + sessionId, + messageId + ); + const session = await this.appendSessionMessage(sessionId, messageId); delete params.messageId; diff --git a/packages/backend/server/src/plugins/copilot/prompt.ts b/packages/backend/server/src/plugins/copilot/prompt.ts index a513b0ae9bbab..d23ece0ca334a 100644 --- a/packages/backend/server/src/plugins/copilot/prompt.ts +++ b/packages/backend/server/src/plugins/copilot/prompt.ts @@ -42,7 +42,7 @@ export class ChatPrompt { return new ChatPrompt( options.name, options.action || undefined, - options.model || undefined, + options.model, options.messages ); } @@ -50,7 +50,7 @@ export class ChatPrompt { constructor( public readonly name: string, public readonly action: string | undefined, - public readonly model: string | undefined, + public readonly model: string, private readonly messages: PromptMessage[] ) { this.encoder = getTokenEncoder(model);