Skip to content

Commit

Permalink
feat: add upscaler & bg remover (#6967)
Browse files Browse the repository at this point in the history
  • Loading branch information
darkskygit committed May 16, 2024
1 parent f37bbb0 commit a3f3d09
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 6 deletions.
17 changes: 17 additions & 0 deletions packages/backend/server/src/data/migrations/utils/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,23 @@ export const prompts: Prompt[] = [
model: 'fast-turbo-diffusion',
messages: [],
},
{
name: 'debug:action:fal-upscaler',
action: 'image',
model: 'clarity-upscaler',
messages: [
{
role: 'user',
content: 'best quality, 8K resolution, highres, clarity, {{content}}',
},
],
},
{
name: 'debug:action:fal-remove-bg',
action: 'image',
model: 'imageutils/rembg',
messages: [],
},
{
name: 'Summary',
action: 'Summary',
Expand Down
1 change: 1 addition & 0 deletions packages/backend/server/src/plugins/copilot/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export class CopilotController {
if (err instanceof HttpException) {
ret.status = err.getStatus();
}
return ret;
}
return err;
}
Expand Down
10 changes: 8 additions & 2 deletions packages/backend/server/src/plugins/copilot/providers/fal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export type FalConfig = {
};

export type FalResponse = {
detail: Array<{ msg: string }>;
detail: Array<{ msg: string }> | string;
images: Array<{ url: string }>;
};

Expand All @@ -32,6 +32,8 @@ export class FalProvider
'fast-turbo-diffusion',
// image to image
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',
];

constructor(private readonly config: FalConfig) {
Expand Down Expand Up @@ -87,7 +89,11 @@ export class FalProvider
}).then(res => res.json())) as FalResponse;

if (!data.images?.length) {
const error = data.detail?.[0]?.msg;
const error = Array.isArray(data.detail)
? data.detail[0]?.msg
: typeof data.detail === 'string'
? data.detail
: '';
throw new Error(
error ? `Invalid message: ${error}` : 'No images generated'
);
Expand Down
11 changes: 11 additions & 0 deletions packages/backend/server/src/plugins/copilot/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ export function registerCopilotProvider<
});
}

export function unregisterCopilotProvider(type: CopilotProviderType) {
COPILOT_PROVIDER.delete(type);
ASSERT_CONFIG.delete(type);
for (const providers of PROVIDER_CAPABILITY_MAP.values()) {
const index = providers.indexOf(type);
if (index !== -1) {
providers.splice(index, 1);
}
}
}

/// Asserts that the config is valid for any registered providers
export function assertProvidersConfigs(config: Config) {
return (
Expand Down
57 changes: 56 additions & 1 deletion packages/backend/server/tests/copilot.e2e.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import Sinon from 'sinon';

import { AuthService } from '../src/core/auth';
import { WorkspaceModule } from '../src/core/workspaces';
import { prompts } from '../src/data/migrations/utils/prompts';
import { ConfigModule } from '../src/fundamentals/config';
import { CopilotModule } from '../src/plugins/copilot';
import { PromptService } from '../src/plugins/copilot/prompt';
import {
CopilotProviderService,
FalProvider,
OpenAIProvider,
registerCopilotProvider,
unregisterCopilotProvider,
} from '../src/plugins/copilot/providers';
import { CopilotStorage } from '../src/plugins/copilot/storage';
import {
Expand Down Expand Up @@ -80,11 +84,17 @@ test.beforeEach(async t => {
const user = await signUp(app, 'test', '[email protected]', '123456');
token = user.token.token;

unregisterCopilotProvider(OpenAIProvider.type);
unregisterCopilotProvider(FalProvider.type);
registerCopilotProvider(MockCopilotTestProvider);

await prompt.set(promptName, 'test', [
{ role: 'system', content: 'hello {{word}}' },
]);

for (const p of prompts) {
await prompt.set(p.name, p.model, p.messages);
}
});

test.afterEach.always(async t => {
Expand Down Expand Up @@ -218,7 +228,7 @@ test('should be able to chat with api', async t => {
t.is(
ret3,
textToEventStream(
['https://example.com/image.jpg'],
['https://example.com/test.jpg', 'generate text to text stream'],
messageId,
'attachment'
),
Expand All @@ -228,6 +238,51 @@ test('should be able to chat with api', async t => {
Sinon.restore();
});

test('should be able to chat with special image model', async t => {
const { app, storage } = t.context;

Sinon.stub(storage, 'handleRemoteLink').resolvesArg(2);

const { id } = await createWorkspace(app, token);

const testWithModel = async (promptName: string, finalPrompt: string) => {
const model = prompts.find(p => p.name === promptName)?.model;
const sessionId = await createCopilotSession(
app,
token,
id,
randomUUID(),
promptName
);
const messageId = await createCopilotMessage(
app,
token,
sessionId,
'some-tag',
[`https://example.com/${promptName}.jpg`]
);
const ret3 = await chatWithImages(app, token, sessionId, messageId);
t.is(
ret3,
textToEventStream(
[`https://example.com/${model}.jpg`, finalPrompt],
messageId,
'attachment'
),
'should be able to chat with images'
);
};

await testWithModel('debug:action:fal-sd15', 'some-tag');
await testWithModel(
'debug:action:fal-upscaler',
'best quality, 8K resolution, highres, clarity, some-tag'
);
await testWithModel('debug:action:fal-remove-bg', 'some-tag');

Sinon.restore();
});

test('should be able to retry with api', async t => {
const { app, storage } = t.context;

Expand Down
13 changes: 10 additions & 3 deletions packages/backend/server/tests/utils/copilot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ export class MockCopilotTestProvider
CopilotImageToImageProvider,
CopilotImageToTextProvider
{
override readonly availableModels = ['test'];
override readonly availableModels = [
'test',
'fast-turbo-diffusion',
'lcm-sd15-i2i',
'clarity-upscaler',
'imageutils/rembg',
];
static override readonly capabilities = [
CopilotCapability.TextToText,
CopilotCapability.TextToEmbedding,
Expand Down Expand Up @@ -107,7 +113,7 @@ export class MockCopilotTestProvider
// ====== text to image ======
override async generateImages(
messages: PromptMessage[],
_model: string = 'test',
model: string = 'test',
_options: {
signal?: AbortSignal;
user?: string;
Expand All @@ -118,7 +124,8 @@ export class MockCopilotTestProvider
throw new Error('Prompt is required');
}

return ['https://example.com/image.jpg'];
// just let test case can easily verify the final prompt
return [`https://example.com/${model}.jpg`, prompt];
}

override async *generateImagesStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ export const promptKeys = [
'debug:action:vision4',
'debug:action:dalle3',
'debug:action:fal-sd15',
'debug:action:fal-upscaler',
'debug:action:fal-rembg',
'chat:gpt4',
'Summary',
'Summary the webpage',
Expand Down

0 comments on commit a3f3d09

Please sign in to comment.