From abe7e18b146fa6ab83867a592ed4122da41b7d15 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Tue, 16 Jan 2024 15:06:06 +0100 Subject: [PATCH] fix: manage playground queries state --- packages/backend/src/playground.ts | 129 ++++++++++++------ packages/backend/src/studio-api-impl.ts | 2 +- packages/backend/src/studio.ts | 6 +- packages/backend/src/utils/ports.ts | 66 +++++++++ .../frontend/src/pages/ModelPlayground.svelte | 11 +- .../frontend/src/stores/playground-queries.ts | 6 + packages/shared/Messages.ts | 1 + packages/shared/StudioAPI.ts | 2 +- .../shared/models/IPlaygroundQueryState.ts | 8 ++ 9 files changed, 184 insertions(+), 47 deletions(-) create mode 100644 packages/backend/src/utils/ports.ts create mode 100644 packages/frontend/src/stores/playground-queries.ts create mode 100644 packages/shared/Messages.ts create mode 100644 packages/shared/models/IPlaygroundQueryState.ts diff --git a/packages/backend/src/playground.ts b/packages/backend/src/playground.ts index 612e4f31f..847de2efe 100644 --- a/packages/backend/src/playground.ts +++ b/packages/backend/src/playground.ts @@ -1,9 +1,12 @@ -import { provider, containerEngine, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api'; +import { provider, containerEngine, type Webview, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api'; import { LocalModelInfo } from '@shared/models/ILocalModelInfo'; import { ModelResponse } from '@shared/models/IModelResponse'; import path from 'node:path'; import * as http from 'node:http'; +import { getFreePort } from './utils/ports'; +import { QueryState } from '@shared/models/IPlaygroundQueryState'; +import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages'; const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1'; @@ -15,13 +18,33 @@ function findFirstProvider(): ProviderContainerConnection | undefined { return engines.length > 0 ? engines[0] : undefined; } +export interface PlaygroundState { + containerId: string; + port: number; +} + export class PlayGroundManager { + private queryIdCounter = 0; + + private playgrounds: Map; + private queries: Map; + + constructor(private webview: Webview) { + this.playgrounds = new Map(); + this.queries = new Map(); + } + async selectImage(connection: ProviderContainerConnection, image: string): Promise { const images = (await containerEngine.listImages()).filter(im => im.RepoTags && im.RepoTags.some(tag => tag === image)); return images.length > 0 ? images[0] : undefined; } async startPlayground(modelId: string, modelPath: string): Promise { + // TODO(feloy) remove previous query from state? + + if (this.playgrounds.has(modelId)) { + throw new Error('model is already running'); + } const connection = findFirstProvider(); if (!connection) { throw new Error('Unable to find an engine to start playground'); @@ -29,16 +52,17 @@ export class PlayGroundManager { let image = await this.selectImage(connection, LOCALAI_IMAGE); if (!image) { - await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => {}); + await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => { }); image = await this.selectImage(connection, LOCALAI_IMAGE); if (!image) { throw new Error(`Unable to find ${LOCALAI_IMAGE} image`); } } + const freePort = await getFreePort(); const result = await containerEngine.createContainer(image.engineId, { Image: image.Id, Detach: true, - ExposedPorts: { '9000': {} }, + ExposedPorts: { ['' + freePort]: {} }, HostConfig: { AutoRemove: true, Mounts: [ @@ -51,13 +75,17 @@ export class PlayGroundManager { PortBindings: { '8080/tcp': [ { - HostPort: '9000' + HostPort: '' + freePort } ] } }, Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], }); + this.playgrounds.set(modelId, { + containerId: result.id, + port: freePort, + }); return result.id; } @@ -69,42 +97,67 @@ export class PlayGroundManager { return containerEngine.stopContainer(connection.providerId, playgroundId); } - async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { - return new Promise(resolve => { - let post_data = JSON.stringify({ - "model": modelInfo.file, - "prompt": prompt, - "temperature": 0.7 - }); + async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { + const state = this.playgrounds.get(modelInfo.id); + if (!state) { + throw new Error('model is not running'); + } - let post_options: http.RequestOptions = { - host: 'localhost', - port: '9000', - path: '/v1/completions', - method: 'POST', - headers: { - 'Content-Type': 'application/json' + const query = { + id: this.getNextQueryId(), + modelId: modelInfo.id, + prompt: prompt, + } as QueryState; + + let post_data = JSON.stringify({ + "model": modelInfo.file, + "prompt": prompt, + "temperature": 0.7 + }); + + let post_options: http.RequestOptions = { + host: 'localhost', + port: '' + state.port, + path: '/v1/completions', + method: 'POST', + headers: { + 'Content-Type': 'application/json' + } + }; + + let post_req = http.request(post_options, res => { + res.setEncoding('utf8'); + const chunks = []; + res.on('data', (data) => chunks.push(data)); + res.on('end', () => { + let resBody = chunks.join(); + switch (res.headers['content-type']) { + case 'application/json': + const result = JSON.parse(resBody); + console.log('result', result); + const q = this.queries.get(query.id); + if (!q) { + throw new Error('query not found in state'); + } + q.response = result as ModelResponse; + this.queries.set(query.id, q); + this.webview.postMessage({ + id: MSG_NEW_PLAYGROUND_QUERIES_STATE, + body: Array.from(this.queries.values()), + }); + break; } - }; - - let post_req = http.request(post_options, function (res) { - res.setEncoding('utf8'); - const chunks = []; - res.on('data', (data) => chunks.push(data)); - res.on('end', () => { - let resBody = chunks.join(); - switch (res.headers['content-type']) { - case 'application/json': - const result = JSON.parse(resBody); - console.log('result', result); - resolve(result as ModelResponse); - break; - } - }); }); - // post the data - post_req.write(post_data); - post_req.end(); - }); + }); + // post the data + post_req.write(post_data); + post_req.end(); + + this.queries.set(query.id, query); + return query.id; + } + + getNextQueryId() { + return ++this.queryIdCounter; } } diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 1ae2fec62..634d8e56f 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -101,7 +101,7 @@ export class StudioApiImpl implements StudioAPI { await this.playgroundManager.startPlayground(modelId, modelPath); } - askPlayground(modelId: string, prompt: string): Promise { + askPlayground(modelId: string, prompt: string): Promise { const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index eb13c103b..7b8d378c5 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -42,7 +42,6 @@ export class Studio { constructor(readonly extensionContext: ExtensionContext) { this.#extensionContext = extensionContext; - this.playgroundManager = new PlayGroundManager(); } public async activate(): Promise { @@ -99,7 +98,10 @@ export class Studio { gitManager, recipeStatusRegistry, this.#extensionContext, - ) + ); + this.#panel.webview + this.playgroundManager = new PlayGroundManager(this.#panel.webview); + this.studioApi = new StudioApiImpl( applicationManager, recipeStatusRegistry, diff --git a/packages/backend/src/utils/ports.ts b/packages/backend/src/utils/ports.ts new file mode 100644 index 000000000..91d12473a --- /dev/null +++ b/packages/backend/src/utils/ports.ts @@ -0,0 +1,66 @@ +/********************************************************************** + * Copyright (C) 2022 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import * as net from 'net'; + +/** + * Find a free port starting from the given port + */ +export async function getFreePort(port = 0): Promise { + if (port < 1024) { + port = 9000; + } + let isFree = false; + while (!isFree) { + isFree = await isFreePort(port); + if (!isFree) { + port++; + } + } + + return port; +} + +/** + * Find a free port range + */ +export async function getFreePortRange(rangeSize: number): Promise { + let port = 9000; + let startPort = port; + + do { + if (await isFreePort(port)) { + ++port; + } else { + ++port; + startPort = port; + } + } while (port + 1 - startPort <= rangeSize); + + return `${startPort}-${port - 1}`; +} + +export function isFreePort(port: number): Promise { + const server = net.createServer(); + return new Promise((resolve, reject) => + server + .on('error', (error: NodeJS.ErrnoException) => (error.code === 'EADDRINUSE' ? resolve(false) : reject(error))) + .on('listening', () => server.close(() => resolve(true))) + .listen(port, '127.0.0.1'), + ); +} diff --git a/packages/frontend/src/pages/ModelPlayground.svelte b/packages/frontend/src/pages/ModelPlayground.svelte index e873e35e9..0f64b89f0 100644 --- a/packages/frontend/src/pages/ModelPlayground.svelte +++ b/packages/frontend/src/pages/ModelPlayground.svelte @@ -23,11 +23,12 @@ } inProgress = true; result = undefined; - const res = await studioClient.askPlayground(model.id, prompt) - inProgress = false; - if (res.choices.length) { - result = res.choices[0]; - } + const queryId = await studioClient.askPlayground(model.id, prompt); + console.log('==> queryId', queryId); +// inProgress = false; +// if (res.choices.length) { +// result = res.choices[0]; +// } } diff --git a/packages/frontend/src/stores/playground-queries.ts b/packages/frontend/src/stores/playground-queries.ts new file mode 100644 index 000000000..3bae8cf55 --- /dev/null +++ b/packages/frontend/src/stores/playground-queries.ts @@ -0,0 +1,6 @@ +import type { Readable } from 'svelte/store'; +import { readable } from 'svelte/store'; +import type { QueryState } from '@shared/models/IPlaygroundQueryState'; + +export const playgroundQueries: Readable = readable([], (set) => { +}); diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts new file mode 100644 index 000000000..433b1c142 --- /dev/null +++ b/packages/shared/Messages.ts @@ -0,0 +1 @@ +export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state'; diff --git a/packages/shared/StudioAPI.ts b/packages/shared/StudioAPI.ts index b76fa4f1c..ec23896ae 100644 --- a/packages/shared/StudioAPI.ts +++ b/packages/shared/StudioAPI.ts @@ -22,7 +22,7 @@ export abstract class StudioAPI { abstract getLocalModels(): Promise; abstract startPlayground(modelId: string): Promise; - abstract askPlayground(modelId: string, prompt: string): Promise; + abstract askPlayground(modelId: string, prompt: string): Promise; /** * Get task by label diff --git a/packages/shared/models/IPlaygroundQueryState.ts b/packages/shared/models/IPlaygroundQueryState.ts new file mode 100644 index 000000000..4f28f708a --- /dev/null +++ b/packages/shared/models/IPlaygroundQueryState.ts @@ -0,0 +1,8 @@ +import { ModelResponse } from "./IModelResponse"; + +export interface QueryState { + id: number; + modelId: string; + prompt: string; + response?: ModelResponse; +}