From 80dac02dccc3081eab2ef34d819d07adfc7685d9 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Tue, 16 Jan 2024 15:06:06 +0100 Subject: [PATCH] fix: manage playground queries state Signed-off-by: Philippe Martin --- packages/backend/src/playground.ts | 135 +++++++++++----- packages/backend/src/studio-api-impl.ts | 8 +- packages/backend/src/studio.ts | 2 +- packages/backend/src/utils/ports.ts | 66 ++++++++ .../src/pages/ModelPlayground.spec.ts | 146 ++++++++++++++++++ .../frontend/src/pages/ModelPlayground.svelte | 46 +++++- .../frontend/src/stores/playground-queries.ts | 18 +++ packages/frontend/src/utils/client.ts | 2 +- packages/shared/Messages.ts | 1 + packages/shared/models/IRecipe.ts | 11 -- packages/shared/src/MessageProxy.ts | 61 ++++++-- packages/shared/src/StudioAPI.ts | 9 +- .../src/models/IPlaygroundQueryState.ts | 8 + 13 files changed, 442 insertions(+), 71 deletions(-) create mode 100644 packages/backend/src/utils/ports.ts create mode 100644 packages/frontend/src/pages/ModelPlayground.spec.ts create mode 100644 packages/frontend/src/stores/playground-queries.ts create mode 100644 packages/shared/Messages.ts delete mode 100644 packages/shared/models/IRecipe.ts create mode 100644 packages/shared/src/models/IPlaygroundQueryState.ts diff --git a/packages/backend/src/playground.ts b/packages/backend/src/playground.ts index 0669b049d..214c9349d 100644 --- a/packages/backend/src/playground.ts +++ b/packages/backend/src/playground.ts @@ -1,9 +1,18 @@ -import { provider, containerEngine, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api'; +import { + provider, + containerEngine, + type Webview, + type ProviderContainerConnection, + type ImageInfo, +} from '@podman-desktop/api'; import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import type { ModelResponse } from '@shared/src/models/IModelResponse'; import path from 'node:path'; import * as http from 'node:http'; +import { getFreePort } from './utils/ports'; +import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; +import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages'; const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1'; @@ -15,13 +24,33 @@ function findFirstProvider(): ProviderContainerConnection | undefined { return engines.length > 0 ? engines[0] : undefined; } +export interface PlaygroundState { + containerId: string; + port: number; +} + export class PlayGroundManager { + private queryIdCounter = 0; + + private playgrounds: Map; + private queries: Map; + + constructor(private webview: Webview) { + this.playgrounds = new Map(); + this.queries = new Map(); + } + async selectImage(connection: ProviderContainerConnection, image: string): Promise { const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image)); return images.length > 0 ? images[0] : undefined; } async startPlayground(modelId: string, modelPath: string): Promise { + // TODO(feloy) remove previous query from state? + + if (this.playgrounds.has(modelId)) { + throw new Error('model is already running'); + } const connection = findFirstProvider(); if (!connection) { throw new Error('Unable to find an engine to start playground'); @@ -35,10 +64,11 @@ export class PlayGroundManager { throw new Error(`Unable to find ${LOCALAI_IMAGE} image`); } } + const freePort = await getFreePort(); const result = await containerEngine.createContainer(image.engineId, { Image: image.Id, Detach: true, - ExposedPorts: { '9000': {} }, + ExposedPorts: { ['' + freePort]: {} }, HostConfig: { AutoRemove: true, Mounts: [ @@ -51,13 +81,17 @@ export class PlayGroundManager { PortBindings: { '8080/tcp': [ { - HostPort: '9000', + HostPort: '' + freePort, }, ], }, }, Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], }); + this.playgrounds.set(modelId, { + containerId: result.id, + port: freePort, + }); return result.id; } @@ -69,42 +103,73 @@ export class PlayGroundManager { return containerEngine.stopContainer(connection.providerId, playgroundId); } - async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { - return new Promise(resolve => { - const post_data = JSON.stringify({ - model: modelInfo.file, - prompt: prompt, - temperature: 0.7, - }); + async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { + const state = this.playgrounds.get(modelInfo.id); + if (!state) { + throw new Error('model is not running'); + } - const post_options: http.RequestOptions = { - host: 'localhost', - port: '9000', - path: '/v1/completions', - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - }; - - const post_req = http.request(post_options, function (res) { - res.setEncoding('utf8'); - const chunks = []; - res.on('data', data => chunks.push(data)); - res.on('end', () => { - const resBody = chunks.join(); + const query = { + id: this.getNextQueryId(), + modelId: modelInfo.id, + prompt: prompt, + } as QueryState; + + const post_data = JSON.stringify({ + model: modelInfo.file, + prompt: prompt, + temperature: 0.7, + }); + + const post_options: http.RequestOptions = { + host: 'localhost', + port: '' + state.port, + path: '/v1/completions', + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + }; + + const post_req = http.request(post_options, res => { + res.setEncoding('utf8'); + const chunks = []; + res.on('data', data => chunks.push(data)); + res.on('end', () => { + const resBody = chunks.join(); + if (res.headers['content-type'] === 'application/json') { const result = JSON.parse(resBody); - console.log('result', result); - switch (res.headers['content-type']) { - case 'application/json': - resolve(result as ModelResponse); - break; + const q = this.queries.get(query.id); + if (!q) { + throw new Error('query not found in state'); } - }); + q.response = result as ModelResponse; + this.queries.set(query.id, q); + this.sendState().catch((err: unknown) => { + console.error('playground: unable to send the response to the frontend', err); + }); + } }); - // post the data - post_req.write(post_data); - post_req.end(); + }); + // post the data + post_req.write(post_data); + post_req.end(); + + this.queries.set(query.id, query); + await this.sendState(); + return query.id; + } + + getNextQueryId() { + return ++this.queryIdCounter; + } + getState(): QueryState[] { + return Array.from(this.queries.values()); + } + async sendState() { + await this.webview.postMessage({ + id: MSG_NEW_PLAYGROUND_QUERIES_STATE, + body: this.getState(), }); } } diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 58cfad47a..79e8a576a 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -10,9 +10,9 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { TaskRegistry } from './registries/TaskRegistry'; import type { Task } from '@shared/src/models/ITask'; import * as path from 'node:path'; -import type { ModelResponse } from '@shared/src/models/IModelResponse'; import type { PlayGroundManager } from './playground'; import * as podmanDesktopApi from '@podman-desktop/api'; +import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; export const RECENT_CATEGORY_ID = 'recent-category'; @@ -115,11 +115,15 @@ export class StudioApiImpl implements StudioAPI { await this.playgroundManager.startPlayground(modelId, modelPath); } - askPlayground(modelId: string, prompt: string): Promise { + askPlayground(modelId: string, prompt: string): Promise { const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); } return this.playgroundManager.askPlayground(localModelInfo[0], prompt); } + + async getPlaygroundStates(): Promise { + return this.playgroundManager.getState(); + } } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 1cc5ed91d..a03c47dfb 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -39,7 +39,6 @@ export class Studio { constructor(readonly extensionContext: ExtensionContext) { this.#extensionContext = extensionContext; - this.playgroundManager = new PlayGroundManager(); } public async activate(): Promise { @@ -93,6 +92,7 @@ export class Studio { const taskRegistry = new TaskRegistry(); const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry); const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext); + this.playgroundManager = new PlayGroundManager(this.#panel.webview); this.studioApi = new StudioApiImpl(applicationManager, recipeStatusRegistry, taskRegistry, this.playgroundManager); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi); diff --git a/packages/backend/src/utils/ports.ts b/packages/backend/src/utils/ports.ts new file mode 100644 index 000000000..91d12473a --- /dev/null +++ b/packages/backend/src/utils/ports.ts @@ -0,0 +1,66 @@ +/********************************************************************** + * Copyright (C) 2022 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import * as net from 'net'; + +/** + * Find a free port starting from the given port + */ +export async function getFreePort(port = 0): Promise { + if (port < 1024) { + port = 9000; + } + let isFree = false; + while (!isFree) { + isFree = await isFreePort(port); + if (!isFree) { + port++; + } + } + + return port; +} + +/** + * Find a free port range + */ +export async function getFreePortRange(rangeSize: number): Promise { + let port = 9000; + let startPort = port; + + do { + if (await isFreePort(port)) { + ++port; + } else { + ++port; + startPort = port; + } + } while (port + 1 - startPort <= rangeSize); + + return `${startPort}-${port - 1}`; +} + +export function isFreePort(port: number): Promise { + const server = net.createServer(); + return new Promise((resolve, reject) => + server + .on('error', (error: NodeJS.ErrnoException) => (error.code === 'EADDRINUSE' ? resolve(false) : reject(error))) + .on('listening', () => server.close(() => resolve(true))) + .listen(port, '127.0.0.1'), + ); +} diff --git a/packages/frontend/src/pages/ModelPlayground.spec.ts b/packages/frontend/src/pages/ModelPlayground.spec.ts new file mode 100644 index 000000000..dd450ca02 --- /dev/null +++ b/packages/frontend/src/pages/ModelPlayground.spec.ts @@ -0,0 +1,146 @@ +import '@testing-library/jest-dom/vitest'; +import { vi, test, expect, beforeEach } from 'vitest'; +import { screen, fireEvent, render } from '@testing-library/svelte'; +import ModelPlayground from './ModelPlayground.svelte'; +import type { ModelInfo } from '@shared/src/models/IModelInfo'; +import userEvent from '@testing-library/user-event'; + +const mocks = vi.hoisted(() => { + return { + startPlaygroundMock: vi.fn(), + askPlaygroundMock: vi.fn(), + playgroundQueriesSubscribeMock: vi.fn(), + playgroundQueriesMock: { + subscribe: (f: (msg: any) => void) => { + f(mocks.playgroundQueriesSubscribeMock()); + return () => {}; + }, + }, + }; +}); + +vi.mock('../utils/client', async () => { + return { + studioClient: { + startPlayground: mocks.startPlaygroundMock, + askPlayground: mocks.askPlaygroundMock, + askPlaygroundQueries: () => {}, + }, + rpcBrowser: { + subscribe: () => { + return { + unsubscribe: () => {}, + }; + }, + }, + }; +}); + +vi.mock('../stores/playground-queries', async () => { + return { + playgroundQueries: mocks.playgroundQueriesMock, + }; +}); + +beforeEach(() => { + vi.clearAllMocks(); +}); + +test('should start playground at init time and call askPlayground when button clicked', async () => { + mocks.playgroundQueriesSubscribeMock.mockReturnValue([]); + render(ModelPlayground, { + model: { + id: 'model1', + name: 'Model 1', + description: 'A description', + hw: 'CPU', + registry: 'Hugging Face', + popularity: 3, + license: '?', + url: 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + } as ModelInfo, + }); + await new Promise(resolve => setTimeout(resolve, 200)); + + expect(mocks.startPlaygroundMock).toHaveBeenCalledOnce(); + + const prompt = screen.getByPlaceholderText('Type your prompt here'); + expect(prompt).toBeInTheDocument(); + const user = userEvent.setup(); + user.type(prompt, 'what is it?'); + + const send = screen.getByRole('button', { name: 'Send Request' }); + expect(send).toBeInTheDocument(); + + expect(mocks.askPlaygroundMock).not.toHaveBeenCalled(); + await fireEvent.click(send); + expect(mocks.askPlaygroundMock).toHaveBeenCalledOnce(); +}); + +test('should display query without response', async () => { + mocks.playgroundQueriesSubscribeMock.mockReturnValue([ + { + id: 1, + modelId: 'model1', + prompt: 'what is 1+1?', + }, + ]); + render(ModelPlayground, { + model: { + id: 'model1', + name: 'Model 1', + description: 'A description', + hw: 'CPU', + registry: 'Hugging Face', + popularity: 3, + license: '?', + url: 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + } as ModelInfo, + }); + await new Promise(resolve => setTimeout(resolve, 200)); + + const prompt = screen.getByPlaceholderText('Type your prompt here'); + expect(prompt).toBeInTheDocument(); + expect(prompt).toHaveValue('what is 1+1?'); + + const response = screen.queryByRole('textbox', { name: 'response' }); + expect(response).not.toBeInTheDocument(); +}); + +test('should display query without response', async () => { + mocks.playgroundQueriesSubscribeMock.mockReturnValue([ + { + id: 1, + modelId: 'model1', + prompt: 'what is 1+1?', + response: { + choices: [ + { + text: 'The response is 2', + }, + ], + }, + }, + ]); + render(ModelPlayground, { + model: { + id: 'model1', + name: 'Model 1', + description: 'A description', + hw: 'CPU', + registry: 'Hugging Face', + popularity: 3, + license: '?', + url: 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + } as ModelInfo, + }); + await new Promise(resolve => setTimeout(resolve, 200)); + + const prompt = screen.getByPlaceholderText('Type your prompt here'); + expect(prompt).toBeInTheDocument(); + expect(prompt).toHaveValue('what is 1+1?'); + + const response = screen.queryByRole('textbox', { name: 'response' }); + expect(response).toBeInTheDocument(); + expect(response).toHaveValue('The response is 2'); +}); diff --git a/packages/frontend/src/pages/ModelPlayground.svelte b/packages/frontend/src/pages/ModelPlayground.svelte index e873e35e9..ab4811b62 100644 --- a/packages/frontend/src/pages/ModelPlayground.svelte +++ b/packages/frontend/src/pages/ModelPlayground.svelte @@ -4,9 +4,12 @@ import Button from '../lib/button/Button.svelte'; import { onMount } from 'svelte'; import { studioClient } from '../utils/client'; + import { playgroundQueries } from '../stores/playground-queries'; + import type { QueryState } from '@shared/models/IPlaygroundQueryState'; export let model: ModelInfo | undefined; let prompt = ''; + let queryId: number; let result: ModelResponseChoice | undefined = undefined; let inProgress = false; @@ -15,25 +18,57 @@ return; } studioClient.startPlayground(model.id); + + const unsubscribe = playgroundQueries.subscribe((queries: QueryState[]) => { + if (queryId === -1) { + return; + } + let myQuery = queries.find(q => q.id === queryId); + if (!myQuery) { + myQuery = queries.findLast(q => q.modelId === model?.id); + } + if (!myQuery) { + return; + } + displayQuery(myQuery); + }); + + return () => { + unsubscribe(); + }; }); + function displayQuery(query: QueryState) { + if (query.response) { + inProgress = false; + prompt = query.prompt; + if (query.response?.choices.length) { + result = query.response?.choices[0]; + } + } else { + inProgress = true; + prompt = query.prompt; + queryId = query.id; + } + } + async function askPlayground() { if (!model) { return; } inProgress = true; result = undefined; - const res = await studioClient.askPlayground(model.id, prompt) - inProgress = false; - if (res.choices.length) { - result = res.choices[0]; - } + // do not display anything before we get a response from askPlayground + // (we can receive a new queryState before the new QueryId) + queryId = -1; + queryId = await studioClient.askPlayground(model.id, prompt); }
Prompt