diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index bc1dffdd5..ddabbd5c1 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -167,7 +167,16 @@ describe('pullApplication', () => { mocks.createContainerMock.mockResolvedValue({ id: 'id', }); - modelsManager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager, telemetryLogger); + modelsManager = new ModelsManager( + 'appdir', + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + ); manager = new ApplicationManager( '/home/user/aistudio', { diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index b02c20b40..8c7d72632 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -120,7 +120,7 @@ function mockFiles(now: Date) { }); } -test('getLocalModelsFromDisk should get models in local directory', () => { +test('getModelsInfo should get models in local directory', async () => { const now = new Date(); mockFiles(now); let appdir: string; @@ -129,27 +129,47 @@ test('getLocalModelsFromDisk should get models in local directory', () => { } else { appdir = '/home/user/aistudio'; } - const manager = new ModelsManager(appdir, {} as Webview, {} as CatalogManager, telemetryLogger); - manager.getLocalModelsFromDisk(); - expect(manager.getLocalModels()).toEqual([ + const manager = new ModelsManager( + appdir, + { + postMessage: vi.fn(), + } as unknown as Webview, + { + getModels(): ModelInfo[] { + return [ + { id: 'model-id-1', name: 'model-id-1-model' } as ModelInfo, + { id: 'model-id-2', name: 'model-id-2-model' } as ModelInfo, + ]; + }, + } as CatalogManager, + telemetryLogger, + ); + await manager.loadLocalModels(); + expect(manager.getModelsInfo()).toEqual([ { id: 'model-id-1', - file: 'model-id-1-model', - size: 32000, - creation: now, - path: path.resolve(dirent[0].path, dirent[0].name), + name: 'model-id-1-model', + file: { + size: 32000, + creation: now, + path: path.resolve(dirent[0].path, dirent[0].name), + file: 'model-id-1-model', + }, }, { id: 'model-id-2', - file: 'model-id-2-model', - size: 32000, - creation: now, - path: path.resolve(dirent[1].path, dirent[1].name), + name: 'model-id-2-model', + file: { + size: 32000, + creation: now, + path: path.resolve(dirent[1].path, dirent[1].name), + file: 'model-id-2-model', + }, }, ]); }); -test('getLocalModelsFromDisk should return an empty array if the models folder does not exist', () => { +test('getModelsInfo should return an empty array if the models folder does not exist', () => { vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); const existsSyncSpy = vi.spyOn(fs, 'existsSync'); existsSyncSpy.mockReturnValue(false); @@ -159,9 +179,18 @@ test('getLocalModelsFromDisk should return an empty array if the models folder d } else { appdir = '/home/user/aistudio'; } - const manager = new ModelsManager(appdir, {} as Webview, {} as CatalogManager, telemetryLogger); + const manager = new ModelsManager( + appdir, + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + ); manager.getLocalModelsFromDisk(); - expect(manager.getLocalModels()).toEqual([]); + expect(manager.getModelsInfo()).toEqual([]); if (process.platform === 'win32') { expect(existsSyncSpy).toHaveBeenCalledWith('C:\\home\\user\\aistudio\\models'); } else { @@ -198,13 +227,12 @@ test('loadLocalModels should post a message with the message on disk and on cata ); await manager.loadLocalModels(); expect(postMessageMock).toHaveBeenNthCalledWith(1, { - id: 'new-local-models-state', + id: 'new-models-state', body: [ { file: { creation: now, file: 'model-id-1-model', - id: 'model-id-1', size: 32000, path: path.resolve(dirent[0].path, dirent[0].name), }, @@ -242,7 +270,7 @@ test('deleteLocalModel deletes the model folder', async () => { } as CatalogManager, telemetryLogger, ); - manager.getLocalModelsFromDisk(); + await manager.loadLocalModels(); await manager.deleteLocalModel('model-id-1'); // check that the model's folder is removed from disk if (process.platform === 'win32') { @@ -250,29 +278,16 @@ test('deleteLocalModel deletes the model folder', async () => { } else { expect(rmSpy).toBeCalledWith('/home/user/aistudio/models/model-id-1', { recursive: true }); } - expect(postMessageMock).toHaveBeenCalledTimes(2); - // check that a state is sent with the model being deleted - expect(postMessageMock).toHaveBeenCalledWith({ - id: 'new-local-models-state', + expect(postMessageMock).toHaveBeenCalledTimes(3); + // check that a new state is sent with the model removed + expect(postMessageMock).toHaveBeenNthCalledWith(3, { + id: 'new-models-state', body: [ { - file: { - creation: now, - file: 'model-id-1-model', - id: 'model-id-1', - size: 32000, - path: path.resolve(dirent[0].path, dirent[0].name), - }, id: 'model-id-1', - state: 'deleting', }, ], }); - // check that a new state is sent with the model removed - expect(postMessageMock).toHaveBeenCalledWith({ - id: 'new-local-models-state', - body: [], - }); expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'model.delete', { 'model.id': 'model-id-1' }); }); @@ -304,7 +319,7 @@ test('deleteLocalModel fails to delete the model folder', async () => { } as CatalogManager, telemetryLogger, ); - manager.getLocalModelsFromDisk(); + await manager.loadLocalModels(); await manager.deleteLocalModel('model-id-1'); // check that the model's folder is removed from disk if (process.platform === 'win32') { @@ -312,36 +327,12 @@ test('deleteLocalModel fails to delete the model folder', async () => { } else { expect(rmSpy).toBeCalledWith('/home/user/aistudio/models/model-id-1', { recursive: true }); } - expect(postMessageMock).toHaveBeenCalledTimes(2); - // check that a state is sent with the model being deleted - expect(postMessageMock).toHaveBeenCalledWith({ - id: 'new-local-models-state', - body: [ - { - file: { - creation: now, - file: 'model-id-1-model', - id: 'model-id-1', - size: 32000, - path: path.resolve(dirent[0].path, dirent[0].name), - }, - id: 'model-id-1', - state: 'deleting', - }, - ], - }); + expect(postMessageMock).toHaveBeenCalledTimes(3); // check that a new state is sent with the model non removed - expect(postMessageMock).toHaveBeenCalledWith({ - id: 'new-local-models-state', + expect(postMessageMock).toHaveBeenNthCalledWith(3, { + id: 'new-models-state', body: [ { - file: { - creation: now, - file: 'model-id-1-model', - id: 'model-id-1', - size: 32000, - path: path.resolve(dirent[0].path, dirent[0].name), - }, id: 'model-id-1', }, ], @@ -351,7 +342,16 @@ test('deleteLocalModel fails to delete the model folder', async () => { }); describe('downloadModel', () => { - const manager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager, telemetryLogger); + const manager = new ModelsManager( + 'appdir', + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + ); test('download model if not already on disk', async () => { vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); const doDownloadModelWrapperMock = vi @@ -403,7 +403,16 @@ describe('downloadModel', () => { }); describe('doDownloadModelWrapper', () => { - const manager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager, telemetryLogger); + const manager = new ModelsManager( + 'appdir', + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + ); test('returning model path if model has been downloaded', async () => { vi.spyOn(manager, 'doDownloadModel').mockImplementation( ( diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 10bc11517..0ebd5bba3 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -21,7 +21,7 @@ import fs from 'fs'; import * as https from 'node:https'; import * as path from 'node:path'; import { type Webview, fs as apiFs } from '@podman-desktop/api'; -import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages'; +import { MSG_NEW_MODELS_STATE } from '@shared/Messages'; import type { CatalogManager } from './catalogManager'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import * as podmanDesktopApi from '@podman-desktop/api'; @@ -42,9 +42,7 @@ interface DownloadModelFailureResult { export class ModelsManager { #modelsDir: string; - #localModels: Map; - // models being deleted - #deleted: Set; + #models: Map; constructor( private appUserDirectory: string, @@ -53,11 +51,11 @@ export class ModelsManager { private telemetry: podmanDesktopApi.TelemetryLogger, ) { this.#modelsDir = path.join(this.appUserDirectory, 'models'); - this.#localModels = new Map(); - this.#deleted = new Set(); + this.#models = new Map(); } async loadLocalModels() { + this.catalogManager.getModels().forEach(m => this.#models.set(m.id, m)); const reloadLocalModels = async () => { this.getLocalModelsFromDisk(); await this.sendModelsInfo(); @@ -71,23 +69,13 @@ export class ModelsManager { } getModelsInfo() { - return this.catalogManager - .getModels() - .filter(m => this.#localModels.has(m.id)) - .map( - m => - ({ - ...m, - file: this.#localModels.get(m.id), - state: this.#deleted.has(m.id) ? 'deleting' : undefined, - }) as ModelInfo, - ); + return [...this.#models.values()]; } async sendModelsInfo() { const models = this.getModelsInfo(); await this.webview.postMessage({ - id: MSG_NEW_LOCAL_MODELS_STATE, + id: MSG_NEW_MODELS_STATE, body: models, }); } @@ -100,7 +88,6 @@ export class ModelsManager { if (!fs.existsSync(this.#modelsDir)) { return; } - const result = new Map(); const entries = fs.readdirSync(this.#modelsDir, { withFileTypes: true }); const dirs = entries.filter(dir => dir.isDirectory()); for (const d of dirs) { @@ -112,26 +99,35 @@ export class ModelsManager { const modelFile = modelEntries[0]; const fullPath = path.resolve(d.path, d.name, modelFile); const info = fs.statSync(fullPath); - result.set(d.name, { - id: d.name, - file: modelFile, - path: path.resolve(d.path, d.name), - size: info.size, - creation: info.mtime, - }); + const model = this.#models.get(d.name); + if (model) { + model.file = { + file: modelFile, + path: path.resolve(d.path, d.name), + size: info.size, + creation: info.mtime, + }; + } } - this.#localModels = result; } isModelOnDisk(modelId: string) { - return this.#localModels.has(modelId); + return this.#models.get(modelId)?.file !== undefined; } getLocalModelInfo(modelId: string): LocalModelInfo { if (!this.isModelOnDisk(modelId)) { throw new Error('model is not on disk'); } - return this.#localModels.get(modelId); + return this.#models.get(modelId).file; + } + + getModelInfo(modelId: string): ModelInfo { + const model = this.#models.get(modelId); + if (!model) { + throw new Error('model is not loaded'); + } + return model; } getLocalModelPath(modelId: string): string { @@ -143,28 +139,26 @@ export class ModelsManager { return path.resolve(this.#modelsDir, modelId); } - getLocalModels(): LocalModelInfo[] { - return Array.from(this.#localModels.values()); - } - async deleteLocalModel(modelId: string): Promise { - const modelDir = this.getLocalModelFolder(modelId); - this.#deleted.add(modelId); - await this.sendModelsInfo(); - try { - await fs.promises.rm(modelDir, { recursive: true }); - this.#localModels.delete(modelId); - this.telemetry.logUsage('model.delete', { 'model.id': modelId }); - } catch (err: unknown) { - this.telemetry.logError('model.delete', { - 'model.id': modelId, - message: 'error deleting model from disk', - error: err, - }); - await podmanDesktopApi.window.showErrorMessage(`Error deleting model ${modelId}. ${String(err)}`); - } finally { - this.#deleted.delete(modelId); + const model = this.#models.get(modelId); + if (model) { + const modelDir = this.getLocalModelFolder(modelId); + model.state = 'deleting'; await this.sendModelsInfo(); + try { + await fs.promises.rm(modelDir, { recursive: true }); + this.telemetry.logUsage('model.delete', { 'model.id': modelId }); + } catch (err: unknown) { + this.telemetry.logError('model.delete', { + 'model.id': modelId, + message: 'error deleting model from disk', + error: err, + }); + await podmanDesktopApi.window.showErrorMessage(`Error deleting model ${modelId}. ${String(err)}`); + } finally { + model.file = model.state = undefined; + await this.sendModelsInfo(); + } } } diff --git a/packages/backend/src/managers/playground.ts b/packages/backend/src/managers/playground.ts index 41f70d3d9..f96e27404 100644 --- a/packages/backend/src/managers/playground.ts +++ b/packages/backend/src/managers/playground.ts @@ -24,7 +24,6 @@ import { provider, type TelemetryLogger, } from '@podman-desktop/api'; -import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import path from 'node:path'; import { getFreePort } from '../utils/ports'; @@ -35,6 +34,7 @@ import type { ContainerRegistry } from '../registries/ContainerRegistry'; import type { PodmanConnection } from './podmanConnection'; import OpenAI from 'openai'; import { getDurationSecondsSince, timeout } from '../utils/utils'; +import type { ModelInfo } from '@shared/src/models/IModelInfo'; export const LABEL_MODEL_ID = 'ai-studio-model-id'; export const LABEL_MODEL_PORT = 'ai-studio-model-port'; @@ -303,7 +303,7 @@ export class PlayGroundManager { this.telemetry.logUsage('playground.stop', { 'model.id': modelId, durationSeconds }); } - async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise { + async askPlayground(modelInfo: ModelInfo, prompt: string): Promise { const startTime = performance.now(); const state = this.playgrounds.get(modelInfo.id); if (state?.container === undefined) { @@ -320,7 +320,7 @@ export class PlayGroundManager { const client = new OpenAI({ baseURL: `http://localhost:${state.container.port}/v1`, apiKey: 'dummy' }); const response = await client.completions.create({ - model: modelInfo.file, + model: modelInfo.file.file, prompt, temperature: 0.7, max_tokens: 1024, diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index bb113763b..5aaedb3d4 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -81,7 +81,7 @@ export class StudioApiImpl implements StudioAPI { }); } - async getLocalModels(): Promise { + async getModelsInfo(): Promise { return this.modelsManager.getModelsInfo(); } @@ -99,8 +99,8 @@ export class StudioApiImpl implements StudioAPI { } async askPlayground(modelId: string, prompt: string): Promise { - const localModelInfo = this.modelsManager.getLocalModelInfo(modelId); - return this.playgroundManager.askPlayground(localModelInfo, prompt); + const modelInfo = this.modelsManager.getModelInfo(modelId); + return this.playgroundManager.askPlayground(modelInfo, prompt); } async getPlaygroundQueriesState(): Promise { diff --git a/packages/frontend/src/lib/table/model/ModelColumnActions.svelte b/packages/frontend/src/lib/table/model/ModelColumnActions.svelte index 5ff31bdf8..9601978f1 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnActions.svelte +++ b/packages/frontend/src/lib/table/model/ModelColumnActions.svelte @@ -23,10 +23,11 @@ function openModelFolder() { icon={faFolderOpen} onClick={() => openModelFolder()} title="Open Model Folder" + enabled="{object.file !== undefined && !object.state}" /> deleteModel()} title="Delete Model" - enabled={!object.state} + enabled={object.file !== undefined && !object.state} /> diff --git a/packages/frontend/src/lib/table/model/ModelColumnCreation.spec.ts b/packages/frontend/src/lib/table/model/ModelColumnCreation.spec.ts index 5e754ddd8..4710f68d3 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnCreation.spec.ts +++ b/packages/frontend/src/lib/table/model/ModelColumnCreation.spec.ts @@ -36,7 +36,6 @@ test('Expect simple column styling', async () => { registry: '', url: '', file: { - id: 'my-model', file: 'file', creation: d, size: 1000, diff --git a/packages/frontend/src/lib/table/model/ModelColumnSize.spec.ts b/packages/frontend/src/lib/table/model/ModelColumnSize.spec.ts index c5825769b..6ce9c0dd9 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnSize.spec.ts +++ b/packages/frontend/src/lib/table/model/ModelColumnSize.spec.ts @@ -33,7 +33,6 @@ test('Expect simple column styling', async () => { registry: '', url: '', file: { - id: 'my-model', file: 'file', creation: new Date(), size: 1000, diff --git a/packages/frontend/src/pages/Models.spec.ts b/packages/frontend/src/pages/Models.spec.ts index da8637a52..fcffaafe9 100644 --- a/packages/frontend/src/pages/Models.spec.ts +++ b/packages/frontend/src/pages/Models.spec.ts @@ -8,21 +8,21 @@ const mocks = vi.hoisted(() => { getCatalogMock: vi.fn(), getPullingStatusesMock: vi.fn().mockResolvedValue(new Map()), getLocalModelsMock: vi.fn().mockResolvedValue([]), - localModelsSubscribeMock: vi.fn(), + modelsInfoSubscribeMock: vi.fn(), localModelsQueriesMock: { subscribe: (f: (msg: any) => void) => { - f(mocks.localModelsSubscribeMock()); + f(mocks.modelsInfoSubscribeMock()); return () => {}; }, }, - getLocalModels: vi.fn().mockResolvedValue([]), + getModelsInfoMock: vi.fn().mockResolvedValue([]), }; }); vi.mock('/@/utils/client', async () => { return { studioClient: { - getLocalModels: mocks.getLocalModels, + getModelsInfo: mocks.getModelsInfoMock, getPullingStatuses: mocks.getPullingStatusesMock, }, rpcBrowser: { @@ -42,7 +42,7 @@ vi.mock('../stores/local-models', async () => { }); test('should display There is no model yet', async () => { - mocks.localModelsSubscribeMock.mockReturnValue([]); + mocks.modelsInfoSubscribeMock.mockReturnValue([]); render(Models); @@ -51,7 +51,7 @@ test('should display There is no model yet', async () => { }); test('should display There is no model yet and have a task running', async () => { - mocks.localModelsSubscribeMock.mockReturnValue([]); + mocks.modelsInfoSubscribeMock.mockReturnValue([]); const map = new Map(); map.set('random', { recipeId: 'random-recipe-id', diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index abc91747d..16a4f3c83 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -3,7 +3,7 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; import NavPage from '../lib/NavPage.svelte'; import Table from '../lib/table/Table.svelte'; import { Column, Row } from '../lib/table/table'; -import { localModels } from '../stores/local-models'; +import { modelsInfo } from '../stores/modelsInfo'; import ModelColumnName from '../lib/table/model/ModelColumnName.svelte'; import ModelColumnRegistry from '../lib/table/model/ModelColumnRegistry.svelte'; import ModelColumnPopularity from '../lib/table/model/ModelColumnPopularity.svelte'; @@ -57,7 +57,7 @@ onMount(() => { }); // Subscribe to the models store - const localModelsUnsubscribe = localModels.subscribe((value) => { + const localModelsUnsubscribe = modelsInfo.subscribe((value) => { models = value; filterModels(); }) @@ -69,7 +69,7 @@ onMount(() => { }); - +
diff --git a/packages/frontend/src/stores/local-models.ts b/packages/frontend/src/stores/modelsInfo.ts similarity index 55% rename from packages/frontend/src/stores/local-models.ts rename to packages/frontend/src/stores/modelsInfo.ts index 311b308b8..e77aaf989 100644 --- a/packages/frontend/src/stores/local-models.ts +++ b/packages/frontend/src/stores/modelsInfo.ts @@ -2,14 +2,14 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { Readable } from 'svelte/store'; import { readable } from 'svelte/store'; import { rpcBrowser, studioClient } from '/@/utils/client'; -import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages'; +import { MSG_NEW_MODELS_STATE } from '@shared/Messages'; -export const localModels: Readable = readable([], set => { - const sub = rpcBrowser.subscribe(MSG_NEW_LOCAL_MODELS_STATE, msg => { +export const modelsInfo: Readable = readable([], set => { + const sub = rpcBrowser.subscribe(MSG_NEW_MODELS_STATE, msg => { set(msg); }); // Initialize the store manually - studioClient.getLocalModels().then(v => { + studioClient.getModelsInfo().then(v => { set(v); }); return () => { diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts index eec53c489..ef748c6cc 100644 --- a/packages/shared/Messages.ts +++ b/packages/shared/Messages.ts @@ -2,5 +2,6 @@ export const MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update'; export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state'; export const MSG_NEW_CATALOG_STATE = 'new-catalog-state'; export const MSG_NEW_RECIPE_STATE = 'new-recipe-state'; -export const MSG_NEW_LOCAL_MODELS_STATE = 'new-local-models-state'; +export const MSG_NEW_MODELS_STATE = 'new-models-state'; export const MSG_ENVIRONMENTS_STATE_UPDATE = 'environments-state-update'; + diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index 5f22e110d..508373a52 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -17,7 +17,7 @@ export abstract class StudioAPI { /** * Get the information of models saved locally into the user's directory */ - abstract getLocalModels(): Promise; + abstract getModelsInfo(): Promise; /** * Delete the folder containing the model from local storage */ diff --git a/packages/shared/src/models/ILocalModelInfo.ts b/packages/shared/src/models/ILocalModelInfo.ts index ecdeb88b0..d631a0cdf 100644 --- a/packages/shared/src/models/ILocalModelInfo.ts +++ b/packages/shared/src/models/ILocalModelInfo.ts @@ -1,5 +1,4 @@ export interface LocalModelInfo { - id: string; file: string; path: string; size: number;