Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding support for starting / stopping playground #93

Merged
merged 9 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions packages/backend/src/managers/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import path from 'node:path';
import * as http from 'node:http';
import { getFreePort } from '../utils/ports';
import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';
import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages';
import { MSG_NEW_PLAYGROUND_QUERIES_STATE, MSG_PLAYGROUNDS_STATE_UPDATE } from '@shared/Messages';
import type { PlaygroundState, PlaygroundStatus } from '@shared/src/models/IPlaygroundState';

// TODO: this should not be hardcoded
const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1';
Expand All @@ -43,14 +44,10 @@ function findFirstProvider(): ProviderContainerConnection | undefined {
return engines.length > 0 ? engines[0] : undefined;
}

export interface PlaygroundState {
containerId: string;
port: number;
}

export class PlayGroundManager {
private queryIdCounter = 0;

// Dict modelId => state
private playgrounds: Map<string, PlaygroundState>;
private queries: Map<number, QueryState>;

Expand All @@ -64,14 +61,44 @@ export class PlayGroundManager {
return images.length > 0 ? images[0] : undefined;
}

setPlaygroundStatus(modelId: string, status: PlaygroundStatus) {
return this.updatePlaygroundState(modelId, {
modelId: modelId,
...(this.playgrounds.get(modelId) || {}),
status: status,
});
}

updatePlaygroundState(modelId: string, state: PlaygroundState) {
this.playgrounds.set(modelId, state);
return this.webview.postMessage({
id: MSG_PLAYGROUNDS_STATE_UPDATE,
body: this.getPlaygroundsState(),
});
}

async startPlayground(modelId: string, modelPath: string): Promise<string> {
// TODO(feloy) remove previous query from state?

if (this.playgrounds.has(modelId)) {
throw new Error('model is already running');
// TODO: check manually if the contains has a matching state
switch (this.playgrounds.get(modelId).status) {
case 'running':
throw new Error('playground is already running');
case 'starting':
case 'stopping':
throw new Error('playground is transitioning');
case 'error':
case 'none':
case 'stopped':
break;
}
}

await this.setPlaygroundStatus(modelId, 'starting');

const connection = findFirstProvider();
if (!connection) {
await this.setPlaygroundStatus(modelId, 'error');
throw new Error('Unable to find an engine to start playground');
}

Expand All @@ -80,9 +107,11 @@ export class PlayGroundManager {
await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => {});
image = await this.selectImage(connection, LOCALAI_IMAGE);
if (!image) {
await this.setPlaygroundStatus(modelId, 'error');
throw new Error(`Unable to find ${LOCALAI_IMAGE} image`);
}
}

const freePort = await getFreePort();
const result = await containerEngine.createContainer(image.engineId, {
Image: image.Id,
Expand All @@ -107,24 +136,41 @@ export class PlayGroundManager {
},
Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'],
});
this.playgrounds.set(modelId, {
containerId: result.id,
port: freePort,

await this.updatePlaygroundState(modelId, {
container: {
containerId: result.id,
port: freePort,
engineId: image.engineId,
},
status: 'running',
modelId,
});

return result.id;
}

async stopPlayground(playgroundId: string): Promise<void> {
const connection = findFirstProvider();
if (!connection) {
throw new Error('Unable to find an engine to start playground');
async stopPlayground(modelId: string): Promise<void> {
const state = this.playgrounds.get(modelId);
if (state?.container === undefined) {
throw new Error('model is not running');
}
return containerEngine.stopContainer(connection.providerId, playgroundId);
await this.setPlaygroundStatus(modelId, 'stopping');
// We do not await since it can take a lot of time
containerEngine
.stopContainer(state.container.engineId, state.container.containerId)
.then(async () => {
await this.setPlaygroundStatus(modelId, 'stopped');
})
.catch(async (error: unknown) => {
console.error(error);
await this.setPlaygroundStatus(modelId, 'error');
});
}

async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<number> {
const state = this.playgrounds.get(modelInfo.id);
if (!state) {
if (state?.container === undefined) {
throw new Error('model is not running');
}

Expand All @@ -142,7 +188,7 @@ export class PlayGroundManager {

const post_options: http.RequestOptions = {
host: 'localhost',
port: '' + state.port,
port: '' + state.container.port,
path: '/v1/completions',
method: 'POST',
headers: {
Expand All @@ -164,7 +210,7 @@ export class PlayGroundManager {
}
q.response = result as ModelResponse;
this.queries.set(query.id, q);
this.sendState().catch((err: unknown) => {
this.sendQueriesState().catch((err: unknown) => {
console.error('playground: unable to send the response to the frontend', err);
});
}
Expand All @@ -175,20 +221,25 @@ export class PlayGroundManager {
post_req.end();

this.queries.set(query.id, query);
await this.sendState();
await this.sendQueriesState();
return query.id;
}

getNextQueryId() {
return ++this.queryIdCounter;
}
getState(): QueryState[] {
getQueriesState(): QueryState[] {
return Array.from(this.queries.values());
}
async sendState() {

getPlaygroundsState(): PlaygroundState[] {
return Array.from(this.playgrounds.values());
}

async sendQueriesState() {
await this.webview.postMessage({
id: MSG_NEW_PLAYGROUND_QUERIES_STATE,
body: this.getState(),
body: this.getQueriesState(),
});
}
}
54 changes: 15 additions & 39 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
***********************************************************************/

import type { StudioAPI } from '@shared/src/StudioAPI';
import type { Category } from '@shared/src/models/ICategory';
import type { Recipe } from '@shared/src/models/IRecipe';
import type { ApplicationManager } from './managers/applicationManager';
import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import type { RecipeStatus } from '@shared/src/models/IRecipeStatus';
Expand All @@ -32,8 +30,7 @@ import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';
import * as path from 'node:path';
import type { CatalogManager } from './managers/catalogManager';
import type { Catalog } from '@shared/src/models/ICatalog';

export const RECENT_CATEGORY_ID = 'recent-category';
import type { PlaygroundState } from '@shared/src/models/IPlaygroundState';

export class StudioApiImpl implements StudioAPI {
constructor(
Expand All @@ -56,28 +53,6 @@ export class StudioApiImpl implements StudioAPI {
return this.recipeStatusRegistry.getStatus(recipeId);
}

async getRecentRecipes(): Promise<Recipe[]> {
return []; // no recent implementation for now
}

async getCategories(): Promise<Category[]> {
return this.catalogManager.getCategories();
}

async getRecipesByCategory(categoryId: string): Promise<Recipe[]> {
if (categoryId === RECENT_CATEGORY_ID) return this.getRecentRecipes();

// TODO: move logic to catalog manager
return this.catalogManager.getRecipes().filter(recipe => recipe.categories.includes(categoryId));
}

async getRecipeById(recipeId: string): Promise<Recipe> {
// TODO: move logic to catalog manager
const recipe = this.catalogManager.getRecipes().find(recipe => recipe.id === recipeId);
if (recipe) return recipe;
throw new Error('Not found');
}

async getModelById(modelId: string): Promise<ModelInfo> {
// TODO: move logic to catalog manager
const model = this.catalogManager.getModels().find(m => modelId === m.id);
Expand All @@ -87,18 +62,9 @@ export class StudioApiImpl implements StudioAPI {
return model;
}

async getModelsByIds(ids: string[]): Promise<ModelInfo[]> {
// TODO: move logic to catalog manager
return this.catalogManager.getModels().filter(m => ids.includes(m.id)) ?? [];
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async searchRecipes(_query: string): Promise<Recipe[]> {
return []; // todo: not implemented
}

async pullApplication(recipeId: string): Promise<void> {
const recipe: Recipe = await this.getRecipeById(recipeId);
const recipe = this.catalogManager.getRecipes().find(recipe => recipe.id === recipeId);
if (!recipe) throw new Error('Not found');

// the user should have selected one model, we use the first one for the moment
const modelId = recipe.models[0];
Expand All @@ -122,16 +88,22 @@ export class StudioApiImpl implements StudioAPI {
}

async startPlayground(modelId: string): Promise<void> {
// TODO: improve the following
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}

// TODO: we need to stop doing that.
const modelPath = path.resolve(this.applicationManager.appUserDirectory, 'models', modelId, localModelInfo[0].file);

await this.playgroundManager.startPlayground(modelId, modelPath);
}

async stopPlayground(modelId: string): Promise<void> {
await this.playgroundManager.stopPlayground(modelId);
}

askPlayground(modelId: string, prompt: string): Promise<number> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
Expand All @@ -140,8 +112,12 @@ export class StudioApiImpl implements StudioAPI {
return this.playgroundManager.askPlayground(localModelInfo[0], prompt);
}

async getPlaygroundStates(): Promise<QueryState[]> {
return this.playgroundManager.getState();
async getPlaygroundQueriesState(): Promise<QueryState[]> {
return this.playgroundManager.getQueriesState();
}

async getPlaygroundsState(): Promise<PlaygroundState[]> {
return this.playgroundManager.getPlaygroundsState();
}

async getCatalog(): Promise<Catalog> {
Expand Down
6 changes: 4 additions & 2 deletions packages/frontend/src/lib/Card.svelte
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
<script lang="ts">
import Fa from 'svelte-fa';
import type { IconDefinition } from '@fortawesome/free-regular-svg-icons';
import { createEventDispatcher } from 'svelte'
const dispatch = createEventDispatcher()

export let title: string | undefined = undefined;
export let classes: string = "";
Expand All @@ -17,9 +19,9 @@ export let primaryBackground: string = "bg-charcoal-800"
<div class="flex flex-row">
<div class="flex flex-row items-center">
{#if icon}
<div class="{primaryBackground} rounded-full w-8 h-8 flex items-center justify-center mr-3">
<button on:click={() => dispatch('click')} class="{primaryBackground} rounded-full w-8 h-8 flex items-center justify-center mr-3">
<Fa size="20" class="text-purple-500 cursor-pointer" icon="{icon}" />
</div>
</button>
{/if}
{#if title}
<div class="flex flex-col text-gray-400 whitespace-nowrap" aria-label="context-name">
Expand Down
Loading
Loading