Skip to content

Commit

Permalink
fix: manage playground queries state
Browse files Browse the repository at this point in the history
Signed-off-by: Philippe Martin <[email protected]>
  • Loading branch information
feloy committed Jan 17, 2024
1 parent 874f138 commit b897e9b
Show file tree
Hide file tree
Showing 12 changed files with 439 additions and 67 deletions.
142 changes: 101 additions & 41 deletions packages/backend/src/playground.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { provider, containerEngine, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api';
import { provider, containerEngine, type Webview, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api';
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';
import type { ModelResponse } from '@shared/src/models/IModelResponse';

import path from 'node:path';
import * as http from 'node:http';
import { getFreePort } from './utils/ports';
import { QueryState } from '@shared/models/IPlaygroundQueryState';
import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages';

const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1';

Expand All @@ -15,30 +18,51 @@ function findFirstProvider(): ProviderContainerConnection | undefined {
return engines.length > 0 ? engines[0] : undefined;
}

export interface PlaygroundState {
containerId: string;
port: number;
}

export class PlayGroundManager {
private queryIdCounter = 0;

private playgrounds: Map<string, PlaygroundState>;
private queries: Map<number, QueryState>;

constructor(private webview: Webview) {
this.playgrounds = new Map<string, PlaygroundState>();
this.queries = new Map<number, QueryState>();
}

async selectImage(connection: ProviderContainerConnection, image: string): Promise<ImageInfo | undefined> {
const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image));
return images.length > 0 ? images[0] : undefined;
}

async startPlayground(modelId: string, modelPath: string): Promise<string> {
// TODO(feloy) remove previous query from state?

if (this.playgrounds.has(modelId)) {
throw new Error('model is already running');
}
const connection = findFirstProvider();
if (!connection) {
throw new Error('Unable to find an engine to start playground');
}

let image = await this.selectImage(connection, LOCALAI_IMAGE);
if (!image) {
await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => {});
await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => { });
image = await this.selectImage(connection, LOCALAI_IMAGE);
if (!image) {
throw new Error(`Unable to find ${LOCALAI_IMAGE} image`);
}
}
const freePort = await getFreePort();
const result = await containerEngine.createContainer(image.engineId, {
Image: image.Id,
Detach: true,
ExposedPorts: { '9000': {} },
ExposedPorts: { ['' + freePort]: {} },
HostConfig: {
AutoRemove: true,
Mounts: [
Expand All @@ -51,13 +75,17 @@ export class PlayGroundManager {
PortBindings: {
'8080/tcp': [
{
HostPort: '9000',
},
],
},
HostPort: '' + freePort
}
]
}
},
Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'],
});
this.playgrounds.set(modelId, {
containerId: result.id,
port: freePort,
});
return result.id;
}

Expand All @@ -69,42 +97,74 @@ export class PlayGroundManager {
return containerEngine.stopContainer(connection.providerId, playgroundId);
}

async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<ModelResponse> {
return new Promise(resolve => {
const post_data = JSON.stringify({
model: modelInfo.file,
prompt: prompt,
temperature: 0.7,
});
async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<number> {
const state = this.playgrounds.get(modelInfo.id);
if (!state) {
throw new Error('model is not running');
}

const post_options: http.RequestOptions = {
host: 'localhost',
port: '9000',
path: '/v1/completions',
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};

const post_req = http.request(post_options, function (res) {
res.setEncoding('utf8');
const chunks = [];
res.on('data', data => chunks.push(data));
res.on('end', () => {
const resBody = chunks.join();
const result = JSON.parse(resBody);
console.log('result', result);
switch (res.headers['content-type']) {
case 'application/json':
resolve(result as ModelResponse);
break;
}
});
const query = {
id: this.getNextQueryId(),
modelId: modelInfo.id,
prompt: prompt,
} as QueryState;

let post_data = JSON.stringify({
"model": modelInfo.file,
"prompt": prompt,
"temperature": 0.7
});

let post_options: http.RequestOptions = {
host: 'localhost',
port: '' + state.port,
path: '/v1/completions',
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
};

let post_req = http.request(post_options, res => {
res.setEncoding('utf8');
const chunks = [];
res.on('data', (data) => chunks.push(data));
res.on('end', () => {
let resBody = chunks.join();
switch (res.headers['content-type']) {
case 'application/json':
const result = JSON.parse(resBody);
console.log('result', result);
const q = this.queries.get(query.id);
if (!q) {
throw new Error('query not found in state');
}
q.response = result as ModelResponse;
this.queries.set(query.id, q);
this.sendState();
break;
}
});
// post the data
post_req.write(post_data);
post_req.end();
});
// post the data
post_req.write(post_data);
post_req.end();

this.queries.set(query.id, query);
this.sendState();
return query.id;
}

getNextQueryId() {
return ++this.queryIdCounter;
}
getState(): QueryState[] {
return Array.from(this.queries.values());
}
sendState() {
this.webview.postMessage({
id: MSG_NEW_PLAYGROUND_QUERIES_STATE,
body: this.getState(),
});
}
}
8 changes: 6 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { TaskRegistry } from './registries/TaskRegistry';
import type { Task } from '@shared/src/models/ITask';
import * as path from 'node:path';
import type { ModelResponse } from '@shared/src/models/IModelResponse';
import type { PlayGroundManager } from './playground';
import * as podmanDesktopApi from '@podman-desktop/api';
import { QueryState } from '@shared/models/IPlaygroundQueryState';

export const RECENT_CATEGORY_ID = 'recent-category';

Expand Down Expand Up @@ -115,11 +115,15 @@ export class StudioApiImpl implements StudioAPI {
await this.playgroundManager.startPlayground(modelId, modelPath);
}

askPlayground(modelId: string, prompt: string): Promise<ModelResponse> {
askPlayground(modelId: string, prompt: string): Promise<number> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
return this.playgroundManager.askPlayground(localModelInfo[0], prompt);
}

async getPlaygroundStates(): Promise<QueryState[]> {
return this.playgroundManager.getState();
}
}
2 changes: 1 addition & 1 deletion packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export class Studio {

constructor(readonly extensionContext: ExtensionContext) {
this.#extensionContext = extensionContext;
this.playgroundManager = new PlayGroundManager();
}

public async activate(): Promise<void> {
Expand Down Expand Up @@ -93,6 +92,7 @@ export class Studio {
const taskRegistry = new TaskRegistry();
const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry);
const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
this.studioApi = new StudioApiImpl(applicationManager, recipeStatusRegistry, taskRegistry, this.playgroundManager);
// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down
66 changes: 66 additions & 0 deletions packages/backend/src/utils/ports.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**********************************************************************
* Copyright (C) 2022 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import * as net from 'net';

/**
* Find a free port starting from the given port
*/
export async function getFreePort(port = 0): Promise<number> {
if (port < 1024) {
port = 9000;
}
let isFree = false;
while (!isFree) {
isFree = await isFreePort(port);
if (!isFree) {
port++;
}
}

return port;
}

/**
* Find a free port range
*/
export async function getFreePortRange(rangeSize: number): Promise<string> {
let port = 9000;
let startPort = port;

do {
if (await isFreePort(port)) {
++port;
} else {
++port;
startPort = port;
}
} while (port + 1 - startPort <= rangeSize);

return `${startPort}-${port - 1}`;
}

export function isFreePort(port: number): Promise<boolean> {
const server = net.createServer();
return new Promise((resolve, reject) =>
server
.on('error', (error: NodeJS.ErrnoException) => (error.code === 'EADDRINUSE' ? resolve(false) : reject(error)))
.on('listening', () => server.close(() => resolve(true)))
.listen(port, '127.0.0.1'),
);
}
Loading

0 comments on commit b897e9b

Please sign in to comment.