Skip to content

Commit

Permalink
fix: choose provider correctly (#7081)
Browse files Browse the repository at this point in the history
fix no provider error in caption generate action
  • Loading branch information
darkskygit committed May 27, 2024
1 parent 50dcce8 commit e1bece4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
66 changes: 46 additions & 20 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ import { Config } from '../../fundamentals';
import { CopilotProviderService } from './providers';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { CopilotCapability } from './types';
import {
CopilotCapability,
CopilotImageToTextProvider,
CopilotTextToTextProvider,
} from './types';

export interface ChatEvent {
type: 'attachment' | 'message' | 'error';
Expand Down Expand Up @@ -71,7 +75,7 @@ export class CopilotController {

const ret: CheckResult = { model: session.model };

if (messageId) {
if (messageId && typeof messageId === 'string') {
const message = await session.getMessageById(messageId);
ret.hasAttachment =
Array.isArray(message.attachments) && !!message.attachments.length;
Expand All @@ -80,6 +84,34 @@ export class CopilotController {
return ret;
}

private async chooseTextProvider(
userId: string,
sessionId: string,
messageId?: string
): Promise<CopilotTextToTextProvider | CopilotImageToTextProvider> {
const { hasAttachment, model } = await this.checkRequest(
userId,
sessionId,
messageId
);
let provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
// fallback to image to text if text to text is not available
if (!provider && hasAttachment) {
provider = await this.provider.getProviderByCapability(
CopilotCapability.ImageToText,
model
);
}
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

return provider;
}

private async appendSessionMessage(
sessionId: string,
messageId?: string
Expand Down Expand Up @@ -139,18 +171,15 @@ export class CopilotController {
@Param('sessionId') sessionId: string,
@Query() params: Record<string, string | string[]>
): Promise<string> {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);

const session = await this.appendSessionMessage(sessionId, messageId);

try {
Expand Down Expand Up @@ -187,18 +216,15 @@ export class CopilotController {
@Query() params: Record<string, string>
): Promise<Observable<ChatEvent>> {
try {
const { model } = await this.checkRequest(user.id, sessionId);
const provider = await this.provider.getProviderByCapability(
CopilotCapability.TextToText,
model
);
if (!provider) {
throw new InternalServerErrorException('No provider available');
}

const messageId = Array.isArray(params.messageId)
? params.messageId[0]
: params.messageId;
const provider = await this.chooseTextProvider(
user.id,
sessionId,
messageId
);

const session = await this.appendSessionMessage(sessionId, messageId);
delete params.messageId;

Expand Down
7 changes: 5 additions & 2 deletions packages/backend/server/src/plugins/copilot/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,21 @@ export class ChatPrompt {
messages: PromptMessage[];
}
) {
if (!options.model) {
throw new Error('Model field is required');
}
return new ChatPrompt(
options.name,
options.action || undefined,
options.model || undefined,
options.model,
options.messages
);
}

constructor(
public readonly name: string,
public readonly action: string | undefined,
public readonly model: string | undefined,
public readonly model: string,
private readonly messages: PromptMessage[]
) {
this.encoder = getTokenEncoder(model);
Expand Down

0 comments on commit e1bece4

Please sign in to comment.