Skip to content

Commit

Permalink
fix: manage playground queries state
Browse files Browse the repository at this point in the history
  • Loading branch information
feloy committed Jan 16, 2024
1 parent 76286b4 commit abe7e18
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 47 deletions.
129 changes: 91 additions & 38 deletions packages/backend/src/playground.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { provider, containerEngine, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api';
import { provider, containerEngine, type Webview, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api';
import { LocalModelInfo } from '@shared/models/ILocalModelInfo';
import { ModelResponse } from '@shared/models/IModelResponse';

import path from 'node:path';
import * as http from 'node:http';
import { getFreePort } from './utils/ports';
import { QueryState } from '@shared/models/IPlaygroundQueryState';
import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages';

const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1';

Expand All @@ -15,30 +18,51 @@ function findFirstProvider(): ProviderContainerConnection | undefined {
return engines.length > 0 ? engines[0] : undefined;
}

export interface PlaygroundState {
containerId: string;
port: number;
}

export class PlayGroundManager {
private queryIdCounter = 0;

private playgrounds: Map<string, PlaygroundState>;
private queries: Map<number, QueryState>;

constructor(private webview: Webview) {
this.playgrounds = new Map<string, PlaygroundState>();
this.queries = new Map<number, QueryState>();
}

async selectImage(connection: ProviderContainerConnection, image: string): Promise<ImageInfo | undefined> {
const images = (await containerEngine.listImages()).filter(im => im.RepoTags && im.RepoTags.some(tag => tag === image));
return images.length > 0 ? images[0] : undefined;
}

async startPlayground(modelId: string, modelPath: string): Promise<string> {
// TODO(feloy) remove previous query from state?

if (this.playgrounds.has(modelId)) {
throw new Error('model is already running');
}
const connection = findFirstProvider();
if (!connection) {
throw new Error('Unable to find an engine to start playground');
}

let image = await this.selectImage(connection, LOCALAI_IMAGE);
if (!image) {
await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => {});
await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => { });
image = await this.selectImage(connection, LOCALAI_IMAGE);
if (!image) {
throw new Error(`Unable to find ${LOCALAI_IMAGE} image`);
}
}
const freePort = await getFreePort();
const result = await containerEngine.createContainer(image.engineId, {
Image: image.Id,
Detach: true,
ExposedPorts: { '9000': {} },
ExposedPorts: { ['' + freePort]: {} },
HostConfig: {
AutoRemove: true,
Mounts: [
Expand All @@ -51,13 +75,17 @@ export class PlayGroundManager {
PortBindings: {
'8080/tcp': [
{
HostPort: '9000'
HostPort: '' + freePort
}
]
}
},
Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'],
});
this.playgrounds.set(modelId, {
containerId: result.id,
port: freePort,
});
return result.id;
}

Expand All @@ -69,42 +97,67 @@ export class PlayGroundManager {
return containerEngine.stopContainer(connection.providerId, playgroundId);
}

async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<ModelResponse> {
return new Promise(resolve => {
let post_data = JSON.stringify({
"model": modelInfo.file,
"prompt": prompt,
"temperature": 0.7
});
async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<number> {
const state = this.playgrounds.get(modelInfo.id);
if (!state) {
throw new Error('model is not running');
}

let post_options: http.RequestOptions = {
host: 'localhost',
port: '9000',
path: '/v1/completions',
method: 'POST',
headers: {
'Content-Type': 'application/json'
const query = {
id: this.getNextQueryId(),
modelId: modelInfo.id,
prompt: prompt,
} as QueryState;

let post_data = JSON.stringify({
"model": modelInfo.file,
"prompt": prompt,
"temperature": 0.7
});

let post_options: http.RequestOptions = {
host: 'localhost',
port: '' + state.port,
path: '/v1/completions',
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
};

let post_req = http.request(post_options, res => {
res.setEncoding('utf8');
const chunks = [];
res.on('data', (data) => chunks.push(data));
res.on('end', () => {
let resBody = chunks.join();
switch (res.headers['content-type']) {
case 'application/json':
const result = JSON.parse(resBody);
console.log('result', result);
const q = this.queries.get(query.id);
if (!q) {
throw new Error('query not found in state');
}
q.response = result as ModelResponse;
this.queries.set(query.id, q);
this.webview.postMessage({
id: MSG_NEW_PLAYGROUND_QUERIES_STATE,
body: Array.from(this.queries.values()),
});
break;
}
};

let post_req = http.request(post_options, function (res) {
res.setEncoding('utf8');
const chunks = [];
res.on('data', (data) => chunks.push(data));
res.on('end', () => {
let resBody = chunks.join();
switch (res.headers['content-type']) {
case 'application/json':
const result = JSON.parse(resBody);
console.log('result', result);
resolve(result as ModelResponse);
break;
}
});
});
// post the data
post_req.write(post_data);
post_req.end();
});
});
// post the data
post_req.write(post_data);
post_req.end();

this.queries.set(query.id, query);
return query.id;
}

getNextQueryId() {
return ++this.queryIdCounter;
}
}
2 changes: 1 addition & 1 deletion packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ export class StudioApiImpl implements StudioAPI {
await this.playgroundManager.startPlayground(modelId, modelPath);
}

askPlayground(modelId: string, prompt: string): Promise<ModelResponse> {
askPlayground(modelId: string, prompt: string): Promise<number> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
Expand Down
6 changes: 4 additions & 2 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ export class Studio {

constructor(readonly extensionContext: ExtensionContext) {
this.#extensionContext = extensionContext;
this.playgroundManager = new PlayGroundManager();
}

public async activate(): Promise<void> {
Expand Down Expand Up @@ -99,7 +98,10 @@ export class Studio {
gitManager,
recipeStatusRegistry,
this.#extensionContext,
)
);
this.#panel.webview
this.playgroundManager = new PlayGroundManager(this.#panel.webview);

this.studioApi = new StudioApiImpl(
applicationManager,
recipeStatusRegistry,
Expand Down
66 changes: 66 additions & 0 deletions packages/backend/src/utils/ports.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**********************************************************************
* Copyright (C) 2022 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import * as net from 'net';

/**
* Find a free port starting from the given port
*/
export async function getFreePort(port = 0): Promise<number> {
if (port < 1024) {
port = 9000;
}
let isFree = false;
while (!isFree) {
isFree = await isFreePort(port);
if (!isFree) {
port++;
}
}

return port;
}

/**
* Find a free port range
*/
export async function getFreePortRange(rangeSize: number): Promise<string> {
let port = 9000;
let startPort = port;

do {
if (await isFreePort(port)) {
++port;
} else {
++port;
startPort = port;
}
} while (port + 1 - startPort <= rangeSize);

return `${startPort}-${port - 1}`;
}

export function isFreePort(port: number): Promise<boolean> {
const server = net.createServer();
return new Promise((resolve, reject) =>
server
.on('error', (error: NodeJS.ErrnoException) => (error.code === 'EADDRINUSE' ? resolve(false) : reject(error)))
.on('listening', () => server.close(() => resolve(true)))
.listen(port, '127.0.0.1'),
);
}
11 changes: 6 additions & 5 deletions packages/frontend/src/pages/ModelPlayground.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
}
inProgress = true;
result = undefined;
const res = await studioClient.askPlayground(model.id, prompt)
inProgress = false;
if (res.choices.length) {
result = res.choices[0];
}
const queryId = await studioClient.askPlayground(model.id, prompt);
console.log('==> queryId', queryId);
// inProgress = false;
// if (res.choices.length) {
// result = res.choices[0];
// }
}
</script>

Expand Down
6 changes: 6 additions & 0 deletions packages/frontend/src/stores/playground-queries.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import type { Readable } from 'svelte/store';
import { readable } from 'svelte/store';
import type { QueryState } from '@shared/models/IPlaygroundQueryState';

export const playgroundQueries: Readable<QueryState[]> = readable<QueryState[]>([], (set) => {
});
1 change: 1 addition & 0 deletions packages/shared/Messages.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state';
2 changes: 1 addition & 1 deletion packages/shared/StudioAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export abstract class StudioAPI {
abstract getLocalModels(): Promise<ModelInfo[]>;

abstract startPlayground(modelId: string): Promise<void>;
abstract askPlayground(modelId: string, prompt: string): Promise<ModelResponse>;
abstract askPlayground(modelId: string, prompt: string): Promise<number>;

/**
* Get task by label
Expand Down
8 changes: 8 additions & 0 deletions packages/shared/models/IPlaygroundQueryState.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import { ModelResponse } from "./IModelResponse";

export interface QueryState {
id: number;
modelId: string;
prompt: string;
response?: ModelResponse;
}

0 comments on commit abe7e18

Please sign in to comment.