diff --git a/packages/backend/src/managers/playground.ts b/packages/backend/src/managers/playground.ts index 7bf9e8c38..22645f962 100644 --- a/packages/backend/src/managers/playground.ts +++ b/packages/backend/src/managers/playground.ts @@ -30,7 +30,8 @@ import path from 'node:path'; import * as http from 'node:http'; import { getFreePort } from '../utils/ports'; import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; -import { MSG_NEW_PLAYGROUND_QUERIES_STATE } from '@shared/Messages'; +import { MSG_NEW_PLAYGROUND_QUERIES_STATE, MSG_PLAYGROUNDS_STATE_UPDATE } from '@shared/Messages'; +import type { PlaygroundState, PlaygroundStatus } from '@shared/src/models/IPlaygroundState'; // TODO: this should not be hardcoded const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1'; @@ -43,14 +44,10 @@ function findFirstProvider(): ProviderContainerConnection | undefined { return engines.length > 0 ? engines[0] : undefined; } -export interface PlaygroundState { - containerId: string; - port: number; -} - export class PlayGroundManager { private queryIdCounter = 0; + // Dict modelId => state private playgrounds: Map; private queries: Map; @@ -64,14 +61,44 @@ export class PlayGroundManager { return images.length > 0 ? images[0] : undefined; } + setPlaygroundStatus(modelId: string, status: PlaygroundStatus) { + return this.updatePlaygroundState(modelId, { + modelId: modelId, + ...(this.playgrounds.get(modelId) || {}), + status: status, + }); + } + + updatePlaygroundState(modelId: string, state: PlaygroundState) { + this.playgrounds.set(modelId, state); + return this.webview.postMessage({ + id: MSG_PLAYGROUNDS_STATE_UPDATE, + body: this.getPlaygroundsState(), + }); + } + async startPlayground(modelId: string, modelPath: string): Promise { // TODO(feloy) remove previous query from state? - if (this.playgrounds.has(modelId)) { - throw new Error('model is already running'); + // TODO: check manually if the contains has a matching state + switch (this.playgrounds.get(modelId).status) { + case 'running': + throw new Error('playground is already running'); + case 'starting': + case 'stopping': + throw new Error('playground is transitioning'); + case 'error': + case 'none': + case 'stopped': + break; + } } + + await this.setPlaygroundStatus(modelId, 'starting'); + const connection = findFirstProvider(); if (!connection) { + await this.setPlaygroundStatus(modelId, 'error'); throw new Error('Unable to find an engine to start playground'); } @@ -80,9 +107,11 @@ export class PlayGroundManager { await containerEngine.pullImage(connection.connection, LOCALAI_IMAGE, () => {}); image = await this.selectImage(connection, LOCALAI_IMAGE); if (!image) { + await this.setPlaygroundStatus(modelId, 'error'); throw new Error(`Unable to find ${LOCALAI_IMAGE} image`); } } + const freePort = await getFreePort(); const result = await containerEngine.createContainer(image.engineId, { Image: image.Id, @@ -107,24 +136,41 @@ export class PlayGroundManager { }, Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'], }); - this.playgrounds.set(modelId, { - containerId: result.id, - port: freePort, + + await this.updatePlaygroundState(modelId, { + container: { + containerId: result.id, + port: freePort, + engineId: image.engineId, + }, + status: 'running', + modelId, }); + return result.id; } - async stopPlayground(playgroundId: string): Promise { - const connection = findFirstProvider(); - if (!connection) { - throw new Error('Unable to find an engine to start playground'); + async stopPlayground(modelId: string): Promise { + const state = this.playgrounds.get(modelId); + if (state?.container === undefined) { + throw new Error('model is not running'); } - return containerEngine.stopContainer(connection.providerId, playgroundId); + await this.setPlaygroundStatus(modelId, 'stopping'); + // We do not await since it can take a lot of time + containerEngine + .stopContainer(state.container.engineId, state.container.containerId) + .then(async () => { + await this.setPlaygroundStatus(modelId, 'stopped'); + }) + .catch(async (error: unknown) => { + console.error(error); + await this.setPlaygroundStatus(modelId, 'error'); + }); } async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { const state = this.playgrounds.get(modelInfo.id); - if (!state) { + if (state?.container === undefined) { throw new Error('model is not running'); } @@ -142,7 +188,7 @@ export class PlayGroundManager { const post_options: http.RequestOptions = { host: 'localhost', - port: '' + state.port, + port: '' + state.container.port, path: '/v1/completions', method: 'POST', headers: { @@ -164,7 +210,7 @@ export class PlayGroundManager { } q.response = result as ModelResponse; this.queries.set(query.id, q); - this.sendState().catch((err: unknown) => { + this.sendQueriesState().catch((err: unknown) => { console.error('playground: unable to send the response to the frontend', err); }); } @@ -175,20 +221,25 @@ export class PlayGroundManager { post_req.end(); this.queries.set(query.id, query); - await this.sendState(); + await this.sendQueriesState(); return query.id; } getNextQueryId() { return ++this.queryIdCounter; } - getState(): QueryState[] { + getQueriesState(): QueryState[] { return Array.from(this.queries.values()); } - async sendState() { + + getPlaygroundsState(): PlaygroundState[] { + return Array.from(this.playgrounds.values()); + } + + async sendQueriesState() { await this.webview.postMessage({ id: MSG_NEW_PLAYGROUND_QUERIES_STATE, - body: this.getState(), + body: this.getQueriesState(), }); } } diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 73a73b16b..a36f268e5 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -17,8 +17,6 @@ ***********************************************************************/ import type { StudioAPI } from '@shared/src/StudioAPI'; -import type { Category } from '@shared/src/models/ICategory'; -import type { Recipe } from '@shared/src/models/IRecipe'; import type { ApplicationManager } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import type { RecipeStatus } from '@shared/src/models/IRecipeStatus'; @@ -32,8 +30,7 @@ import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; import * as path from 'node:path'; import type { CatalogManager } from './managers/catalogManager'; import type { Catalog } from '@shared/src/models/ICatalog'; - -export const RECENT_CATEGORY_ID = 'recent-category'; +import type { PlaygroundState } from '@shared/src/models/IPlaygroundState'; export class StudioApiImpl implements StudioAPI { constructor( @@ -56,28 +53,6 @@ export class StudioApiImpl implements StudioAPI { return this.recipeStatusRegistry.getStatus(recipeId); } - async getRecentRecipes(): Promise { - return []; // no recent implementation for now - } - - async getCategories(): Promise { - return this.catalogManager.getCategories(); - } - - async getRecipesByCategory(categoryId: string): Promise { - if (categoryId === RECENT_CATEGORY_ID) return this.getRecentRecipes(); - - // TODO: move logic to catalog manager - return this.catalogManager.getRecipes().filter(recipe => recipe.categories.includes(categoryId)); - } - - async getRecipeById(recipeId: string): Promise { - // TODO: move logic to catalog manager - const recipe = this.catalogManager.getRecipes().find(recipe => recipe.id === recipeId); - if (recipe) return recipe; - throw new Error('Not found'); - } - async getModelById(modelId: string): Promise { // TODO: move logic to catalog manager const model = this.catalogManager.getModels().find(m => modelId === m.id); @@ -87,18 +62,9 @@ export class StudioApiImpl implements StudioAPI { return model; } - async getModelsByIds(ids: string[]): Promise { - // TODO: move logic to catalog manager - return this.catalogManager.getModels().filter(m => ids.includes(m.id)) ?? []; - } - - // eslint-disable-next-line @typescript-eslint/no-unused-vars - async searchRecipes(_query: string): Promise { - return []; // todo: not implemented - } - async pullApplication(recipeId: string): Promise { - const recipe: Recipe = await this.getRecipeById(recipeId); + const recipe = this.catalogManager.getRecipes().find(recipe => recipe.id === recipeId); + if (!recipe) throw new Error('Not found'); // the user should have selected one model, we use the first one for the moment const modelId = recipe.models[0]; @@ -122,16 +88,22 @@ export class StudioApiImpl implements StudioAPI { } async startPlayground(modelId: string): Promise { + // TODO: improve the following const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); } + // TODO: we need to stop doing that. const modelPath = path.resolve(this.applicationManager.appUserDirectory, 'models', modelId, localModelInfo[0].file); await this.playgroundManager.startPlayground(modelId, modelPath); } + async stopPlayground(modelId: string): Promise { + await this.playgroundManager.stopPlayground(modelId); + } + askPlayground(modelId: string, prompt: string): Promise { const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { @@ -140,8 +112,12 @@ export class StudioApiImpl implements StudioAPI { return this.playgroundManager.askPlayground(localModelInfo[0], prompt); } - async getPlaygroundStates(): Promise { - return this.playgroundManager.getState(); + async getPlaygroundQueriesState(): Promise { + return this.playgroundManager.getQueriesState(); + } + + async getPlaygroundsState(): Promise { + return this.playgroundManager.getPlaygroundsState(); } async getCatalog(): Promise { diff --git a/packages/frontend/src/lib/Card.svelte b/packages/frontend/src/lib/Card.svelte index d0a0b7c7b..11ae9e1aa 100644 --- a/packages/frontend/src/lib/Card.svelte +++ b/packages/frontend/src/lib/Card.svelte @@ -1,6 +1,8 @@
+ +
+ {#key playgroundState?.status} + Playground {playgroundState?.status} +
+
Prompt
- + {#key playgroundState?.status} + + {/key}
{#if result} @@ -89,3 +183,4 @@ class="w-full p-2 outline-none text-sm bg-charcoal-800 rounded-sm text-gray-700 placeholder-gray-700"> {/if}
+ diff --git a/packages/frontend/src/stores/playground-queries.ts b/packages/frontend/src/stores/playground-queries.ts index 9196ffd9d..dece000d9 100644 --- a/packages/frontend/src/stores/playground-queries.ts +++ b/packages/frontend/src/stores/playground-queries.ts @@ -9,7 +9,7 @@ export const playgroundQueries: Readable = readable( set(msg); }); // Initialize the store manually - studioClient.getPlaygroundStates().then(state => { + studioClient.getPlaygroundQueriesState().then(state => { set(state); }); return () => { diff --git a/packages/frontend/src/stores/playground-states.ts b/packages/frontend/src/stores/playground-states.ts new file mode 100644 index 000000000..934ef6262 --- /dev/null +++ b/packages/frontend/src/stores/playground-states.ts @@ -0,0 +1,18 @@ +import type { Readable } from 'svelte/store'; +import { readable } from 'svelte/store'; +import { MSG_PLAYGROUNDS_STATE_UPDATE } from '@shared/Messages'; +import { rpcBrowser, studioClient } from '/@/utils/client'; +import type { PlaygroundState } from '@shared/src/models/IPlaygroundState'; + +export const playgroundStates: Readable = readable([], set => { + const sub = rpcBrowser.subscribe(MSG_PLAYGROUNDS_STATE_UPDATE, msg => { + set(msg); + }); + // Initialize the store manually + studioClient.getPlaygroundsState().then(state => { + set(state); + }); + return () => { + sub.unsubscribe(); + }; +}); diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts index ede982299..a0a31767d 100644 --- a/packages/shared/Messages.ts +++ b/packages/shared/Messages.ts @@ -1,2 +1,3 @@ +export const MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update'; export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state'; export const MSG_NEW_CATALOG_STATE = 'new-catalog-state'; diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index 03edf8e3e..300ea8f23 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -3,6 +3,7 @@ import type { ModelInfo } from './models/IModelInfo'; import type { Task } from './models/ITask'; import type { QueryState } from './models/IPlaygroundQueryState'; import type { Catalog } from './models/ICatalog'; +import type { PlaygroundState } from './models/IPlaygroundState'; export abstract class StudioAPI { abstract ping(): Promise; @@ -16,6 +17,7 @@ export abstract class StudioAPI { abstract getLocalModels(): Promise; abstract startPlayground(modelId: string): Promise; + abstract stopPlayground(modelId: string): Promise; abstract askPlayground(modelId: string, prompt: string): Promise; /** @@ -24,8 +26,7 @@ export abstract class StudioAPI { */ abstract getTasksByLabel(label: string): Promise; - /** - * Ask to send a message MSG_NEW_PLAYGROUND_QUERIES_STATE with the current Playground queries - */ - abstract getPlaygroundStates(): Promise; + abstract getPlaygroundQueriesState(): Promise; + + abstract getPlaygroundsState(): Promise; } diff --git a/packages/shared/src/models/IPlaygroundState.ts b/packages/shared/src/models/IPlaygroundState.ts new file mode 100644 index 000000000..9ff588228 --- /dev/null +++ b/packages/shared/src/models/IPlaygroundState.ts @@ -0,0 +1,11 @@ +export type PlaygroundStatus = 'none' | 'stopped' | 'running' | 'starting' | 'stopping' | 'error'; + +export interface PlaygroundState { + container?: { + containerId: string; + port: number; + engineId: string; + }; + modelId: string; + status: PlaygroundStatus; +}