Skip to content

Commit

Permalink
feat(playground): introducing system prompt banner (#684)
Browse files Browse the repository at this point in the history
* feat: introducing system banner

Signed-off-by: axel7083 <[email protected]>

* fix: prettier&linter

Signed-off-by: axel7083 <[email protected]>

* test: ensuring system prompt proper behaviour

Signed-off-by: axel7083 <[email protected]>

* test: ensuring system prompt banner works as expected

Signed-off-by: axel7083 <[email protected]>

* fix: making clear action delete system prompt

Signed-off-by: axel7083 <[email protected]>

* fix: prettier

Signed-off-by: axel7083 <[email protected]>

* fix: side panel alignment

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Mar 26, 2024
1 parent 4616acd commit 7fa0e15
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 62 deletions.
148 changes: 113 additions & 35 deletions packages/backend/src/managers/playgroundV2Manager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import { expect, test, vi, beforeEach, afterEach } from 'vitest';
import { expect, test, vi, beforeEach, afterEach, describe } from 'vitest';
import OpenAI from 'openai';
import { PlaygroundV2Manager } from './playgroundV2Manager';
import type { Webview } from '@podman-desktop/api';
Expand Down Expand Up @@ -89,7 +89,7 @@ test('submit should throw an error if the server is stopped', async () => {
} as unknown as InferenceServer,
]);

await expect(manager.submit('0', 'dummyUserInput', '')).rejects.toThrowError('Inference server is not running.');
await expect(manager.submit('0', 'dummyUserInput')).rejects.toThrowError('Inference server is not running.');
});

test('submit should throw an error if the server is unhealthy', async () => {
Expand All @@ -109,7 +109,7 @@ test('submit should throw an error if the server is unhealthy', async () => {
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, '', 'tracking-1');
const playgroundId = manager.getPlaygrounds()[0].id;
await expect(manager.submit(playgroundId, 'dummyUserInput', '')).rejects.toThrowError(
await expect(manager.submit(playgroundId, 'dummyUserInput')).rejects.toThrowError(
'Inference server is not healthy, currently status: unhealthy.',
);
});
Expand Down Expand Up @@ -139,36 +139,6 @@ test('create playground should create conversation.', async () => {
expect(conversations.length).toBe(1);
});

test('create playground called with a system prompt should create conversation with a system message.', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
expect(manager.getConversations().length).toBe(0);
await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'a system prompt', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
const conversation = conversations[0];
expect(conversation.messages).toHaveLength(1);
const systemMessage = conversation.messages[0];
expect(systemMessage.role).toEqual('system');
expect(systemMessage.content).toEqual('a system prompt');
});

test('valid submit should create IPlaygroundMessage and notify the webview', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
Expand Down Expand Up @@ -205,7 +175,7 @@ test('valid submit should create IPlaygroundMessage and notify the webview', asy
vi.setSystemTime(date);

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '');
await manager.submit(playgrounds[0].id, 'dummyUserInput');

// Wait for assistant message to be completed
await vi.waitFor(() => {
Expand Down Expand Up @@ -271,7 +241,7 @@ test('submit should send options', async () => {
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput', '', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });
await manager.submit(playgrounds[0].id, 'dummyUserInput', { temperature: 0.123, max_tokens: 45, top_p: 0.345 });

const messages: unknown[] = [
{
Expand Down Expand Up @@ -514,3 +484,111 @@ test('requestCreatePlayground should call createPlayground and createTask, then
state: 'error',
});
});

describe('system prompt', () => {
test('create playground with system prompt should init the conversation with one message', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
models: [
{
id: 'model1',
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'dummySystemPrompt', 'tracking-1');

const conversations = manager.getConversations();
expect(conversations.length).toBe(1);
expect(conversations[0].messages.length).toBe(1);
expect(conversations[0].messages[0].role).toBe('system');
expect(conversations[0].messages[0].content).toBe('dummySystemPrompt');
});

test('set system prompt on non existing conversation should throw an error', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
models: [
{
id: 'model1',
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);

expect(() => {
manager.setSystemPrompt('invalid', 'content');
}).toThrowError('Conversation with id invalid does not exists.');
});

test('set system prompt should overwrite existing system prompt', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
models: [
{
id: 'model1',
},
],
} as unknown as InferenceServer,
]);
const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'dummySystemPrompt', 'tracking-1');

const conversations = manager.getConversations();
manager.setSystemPrompt(conversations[0].id, 'newSystemPrompt');
expect(manager.getConversations()[0].messages[0].content).toBe('newSystemPrompt');
});

test('set system prompt should throw an error if user already submit message', async () => {
vi.mocked(inferenceManagerMock.getServers).mockReturnValue([
{
status: 'running',
health: {
Status: 'healthy',
},
models: [
{
id: 'dummyModelId',
file: {
file: 'dummyModelFile',
},
},
],
connection: {
port: 8888,
},
} as unknown as InferenceServer,
]);
const createMock = vi.fn().mockResolvedValue([]);
vi.mocked(OpenAI).mockReturnValue({
chat: {
completions: {
create: createMock,
},
},
} as unknown as OpenAI);

const manager = new PlaygroundV2Manager(webviewMock, inferenceManagerMock, taskRegistryMock);
await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, undefined, 'tracking-1');

const date = new Date(2000, 1, 1, 13);
vi.setSystemTime(date);

const playgrounds = manager.getPlaygrounds();
await manager.submit(playgrounds[0].id, 'dummyUserInput');

// Wait for assistant message to be completed
await vi.waitFor(() => {
expect(manager.getConversations()[0].messages[1].content).toBeDefined();
});

expect(() => {
manager.setSystemPrompt(manager.getConversations()[0].id, 'newSystemPrompt');
}).toThrowError('Cannot change system prompt on started conversation.');
});
});
42 changes: 35 additions & 7 deletions packages/backend/src/managers/playgroundV2Manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import type { ModelOptions } from '@shared/src/models/IModelOptions';
import type { Stream } from 'openai/streaming';
import { ConversationRegistry } from '../registries/conversationRegistry';
import type { Conversation, PendingChat, SystemPrompt, UserChat } from '@shared/src/models/IPlaygroundMessage';
import { isSystemPrompt } from '@shared/src/models/IPlaygroundMessage';
import type { PlaygroundV2 } from '@shared/src/models/IPlaygroundV2';
import { Publisher } from '../utils/Publisher';
import { Messages } from '@shared/Messages';
Expand Down Expand Up @@ -108,7 +109,7 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di

this.#conversationRegistry.createConversation(id);

if (systemPrompt) {
if (systemPrompt !== undefined && systemPrompt.length > 0) {
this.#conversationRegistry.submit(id, {
content: systemPrompt,
role: 'system',
Expand Down Expand Up @@ -149,12 +150,43 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
return `playground-${++this.#UIDcounter}`;
}

/**
* Given a conversation, update the system prompt.
* If none exists, it will create one, otherwise it will replace the content with the new one
* @param conversationId the conversation id to set the system id
* @param content the new system prompt to use
*/
setSystemPrompt(conversationId: string, content: string | undefined): void {
const conversation = this.#conversationRegistry.get(conversationId);
if (conversation === undefined) throw new Error(`Conversation with id ${conversationId} does not exists.`);

if (conversation.messages.length === 0) {
this.#conversationRegistry.submit(conversationId, {
role: 'system',
content,
timestamp: Date.now(),
id: this.getUniqueId(),
} as SystemPrompt);
this.notify();
} else if (conversation.messages.length === 1 && isSystemPrompt(conversation.messages[0])) {
if (content !== undefined && content.length > 0) {
this.#conversationRegistry.update(conversationId, conversation.messages[0].id, {
content,
});
} else {
this.#conversationRegistry.removeMessage(conversationId, conversation.messages[0].id);
}
} else {
throw new Error('Cannot change system prompt on started conversation.');
}
}

/**
* @param playgroundId
* @param userInput the user input
* @param options the model configuration
*/
async submit(playgroundId: string, userInput: string, systemPrompt: string, options?: ModelOptions): Promise<void> {
async submit(playgroundId: string, userInput: string, options?: ModelOptions): Promise<void> {
const playground = this.#playgrounds.get(playgroundId);
if (playground === undefined) throw new Error('Playground not found.');

Expand Down Expand Up @@ -189,13 +221,9 @@ export class PlaygroundV2Manager extends Publisher<PlaygroundV2[]> implements Di
apiKey: 'dummy',
});

const messages = this.getFormattedMessages(playground.id);
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}
client.chat.completions
.create({
messages,
messages: this.getFormattedMessages(playground.id),
stream: true,
model: modelInfo.file.file,
...options,
Expand Down
16 changes: 16 additions & 0 deletions packages/backend/src/registries/conversationRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ export class ConversationRegistry extends Publisher<Conversation[]> implements D
return `conversation-${++this.#counter}`;
}

/**
* Remove a message from a conversation
* @param conversationId
* @param messageId
*/
removeMessage(conversationId: string, messageId: string) {
const conversation = this.#conversations.get(conversationId);

if (conversation === undefined) {
throw new Error(`conversation with id ${conversationId} does not exist.`);
}

conversation.messages = conversation.messages.filter(message => message.id !== messageId);
this.notify();
}

/**
* Utility method to update a message content in a given conversation
* @param conversationId
Expand Down
13 changes: 6 additions & 7 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,12 @@ export class StudioApiImpl implements StudioAPI {
return this.playgroundV2.getPlaygrounds();
}

submitPlaygroundMessage(
containerId: string,
userInput: string,
systemPrompt: string,
options?: ModelOptions,
): Promise<void> {
return this.playgroundV2.submit(containerId, userInput, systemPrompt, options);
submitPlaygroundMessage(containerId: string, userInput: string, options?: ModelOptions): Promise<void> {
return this.playgroundV2.submit(containerId, userInput, options);
}

async setPlaygroundSystemPrompt(conversationId: string, content: string | undefined): Promise<void> {
this.playgroundV2.setSystemPrompt(conversationId, content);
}

async getPlaygroundConversations(): Promise<Conversation[]> {
Expand Down
6 changes: 3 additions & 3 deletions packages/frontend/src/lib/ContentDetailsLayout.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ const toggle = () => {
<slot name="content" />
</div>
<div class="inline-grid max-lg:order-first">
<div class="lg:my-5 max-lg:w-full max-lg:min-w-full" class:w-[375px]="{open}" class:min-w-[375px]="{open}">
<div class="max-lg:w-full max-lg:min-w-full" class:w-[375px]="{open}" class:min-w-[375px]="{open}">
<div
class:hidden="{!open}"
class:block="{open}"
class="h-fit lg:bg-charcoal-800 lg:rounded-l-md lg:mt-4 lg:py-4 max-lg:block"
class="h-fit lg:bg-charcoal-800 lg:rounded-l-md lg:mt-5 lg:py-4 max-lg:block"
aria-label="{`${detailsLabel} panel`}">
<div class="flex flex-col px-4 space-y-4 mx-auto">
<div class="w-full flex flex-row justify-between max-lg:hidden">
Expand All @@ -31,7 +31,7 @@ const toggle = () => {
<div
class:hidden="{open}"
class:block="{!open}"
class="bg-charcoal-800 mt-4 p-4 rounded-md h-fit max-lg:hidden"
class="bg-charcoal-800 mt-5 p-4 rounded-md h-fit max-lg:hidden"
aria-label="{`toggle ${detailsLabel}`}">
<button on:click="{toggle}" aria-label="{`show ${detailsLabel}`}"
><i class="fas fa-angle-left text-gray-900"></i></button>
Expand Down
Loading

0 comments on commit 7fa0e15

Please sign in to comment.