From ac9a732d27c9af947f200affaed1f2f6c2c28f88 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Wed, 24 Jan 2024 08:56:52 +0100 Subject: [PATCH] watch local models --- .../src/managers/applicationManager.spec.ts | 16 ++-- .../src/managers/applicationManager.ts | 3 +- .../src/managers/modelsManager.spec.ts | 73 +++++++++++++++--- .../backend/src/managers/modelsManager.ts | 74 ++++++++++++++++--- packages/backend/src/studio-api-impl.ts | 30 +------- packages/backend/src/studio.spec.ts | 2 + packages/backend/src/studio.ts | 9 ++- packages/frontend/src/stores/local-models.ts | 10 ++- packages/shared/Messages.ts | 1 + 9 files changed, 156 insertions(+), 62 deletions(-) diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index bd227baef..cdbe1a8b3 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -37,7 +37,7 @@ describe('pullApplication', () => { const setStatusMock = vi.fn(); const cloneRepositoryMock = vi.fn(); - const getLocalModelsMock = vi.fn(); + const isModelOnDiskMock = vi.fn(); let manager: ApplicationManager; let downloadModelMainSpy: MockInstance< [modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string], @@ -93,7 +93,7 @@ describe('pullApplication', () => { setStatus: setStatusMock, } as unknown as RecipeStatusRegistry, { - getLocalModels: getLocalModelsMock, + isModelOnDisk: isModelOnDiskMock, } as unknown as ModelsManager, ); @@ -105,7 +105,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: false, }); - getLocalModelsMock.mockReturnValue([]); + isModelOnDiskMock.mockReturnValue(false); const recipe: Recipe = { id: 'recipe1', @@ -140,7 +140,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - getLocalModelsMock.mockReturnValue([]); + isModelOnDiskMock.mockReturnValue(false); const recipe: Recipe = { id: 'recipe1', @@ -169,13 +169,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - getLocalModelsMock.mockReturnValue([ - { - id: 'model1', - file: 'model1.file', - }, - ]); - + isModelOnDiskMock.mockReturnValue(true); const recipe: Recipe = { id: 'recipe1', name: 'Recipe 1', diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 4b9ee0e63..940b562c2 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -135,8 +135,7 @@ export class ApplicationManager { container => container.arch === undefined || container.arch === arch(), ); - const localModels = this.modelsManager.getLocalModels(); - if (!localModels.map(m => m.id).includes(model.id)) { + if (!this.modelsManager.isModelOnDisk(model.id)) { // Download model taskUtil.setTask({ id: model.id, diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index 94f6d5033..f57b7ef4f 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -3,12 +3,15 @@ import os from 'os'; import fs from 'node:fs'; import path from 'node:path'; import { ModelsManager } from './modelsManager'; +import type { Webview } from '@podman-desktop/api'; +import type { CatalogManager } from './catalogManager'; +import type { ModelInfo } from '@shared/src/models/IModelInfo'; beforeEach(() => { vi.resetAllMocks(); }); -test('getLocalModels should return models in local directory', () => { +function mockFiles(now: Date) { vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); const existsSyncSpy = vi.spyOn(fs, 'existsSync'); existsSyncSpy.mockImplementation((path: string) => { @@ -21,7 +24,6 @@ test('getLocalModels should return models in local directory', () => { }); const statSyncSpy = vi.spyOn(fs, 'statSync'); const info = new fs.Stats(); - const now = new Date(); info.size = 32000; info.mtime = now; statSyncSpy.mockReturnValue(info); @@ -53,9 +55,14 @@ test('getLocalModels should return models in local directory', () => { ] as fs.Dirent[]; } }); - const manager = new ModelsManager('/home/user/aistudio'); - const models = manager.getLocalModels(); - expect(models).toEqual([ +} + +test('getLocalModelsFromDisk should get models in local directory', () => { + const now = new Date(); + mockFiles(now); + const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager); + manager.getLocalModelsFromDisk(); + expect(manager.getLocalModels()).toEqual([ { id: 'model-id-1', file: 'model-id-1-model', @@ -71,16 +78,64 @@ test('getLocalModels should return models in local directory', () => { ]); }); -test('getLocalModels should return an empty array if the models folder does not exist', () => { +test('getLocalModelsFromDisk should return an empty array if the models folder does not exist', () => { vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); const existsSyncSpy = vi.spyOn(fs, 'existsSync'); existsSyncSpy.mockReturnValue(false); - const manager = new ModelsManager('/home/user/aistudio'); - const models = manager.getLocalModels(); - expect(models).toEqual([]); + const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager); + manager.getLocalModelsFromDisk(); + expect(manager.getLocalModels()).toEqual([]); if (process.platform === 'win32') { expect(existsSyncSpy).toHaveBeenCalledWith('\\home\\user\\aistudio\\models'); } else { expect(existsSyncSpy).toHaveBeenCalledWith('/home/user/aistudio/models'); } }); + +test('loadLocalModels should post a message with the message on disk and on catalog', async () => { + const now = new Date(); + mockFiles(now); + + vi.mock('@podman-desktop/api', () => { + return { + fs: { + createFileSystemWatcher: () => ({ + onDidCreate: vi.fn(), + onDidDelete: vi.fn(), + onDidChange: vi.fn(), + }), + }, + }; + }); + const postMessageMock = vi.fn(); + const manager = new ModelsManager( + '/home/user/aistudio', + { + postMessage: postMessageMock, + } as unknown as Webview, + { + getModels: () => { + return [ + { + id: 'model-id-1', + }, + ] as ModelInfo[]; + }, + } as CatalogManager, + ); + await manager.loadLocalModels(); + expect(postMessageMock).toHaveBeenNthCalledWith(1, { + id: 'new-local-models-state', + body: [ + { + file: { + creation: now, + file: 'model-id-1-model', + id: 'model-id-1', + size: 32000, + }, + id: 'model-id-1', + }, + ], + }); +}); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index b46327bb9..61744deb6 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -1,17 +1,53 @@ import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import fs from 'fs'; import * as path from 'node:path'; +import { type Webview, fs as apiFs } from '@podman-desktop/api'; +import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages'; +import type { CatalogManager } from './catalogManager'; export class ModelsManager { - constructor(private appUserDirectory: string) {} + #modelsDir: string; + #localModels: Map; - getLocalModels(): LocalModelInfo[] { - const result: LocalModelInfo[] = []; - const modelsDir = path.join(this.appUserDirectory, 'models'); - if (!fs.existsSync(modelsDir)) { - return []; + constructor( + private appUserDirectory: string, + private webview: Webview, + private catalogManager: CatalogManager, + ) { + this.#modelsDir = path.join(this.appUserDirectory, 'models'); + this.#localModels = new Map(); + } + + async loadLocalModels() { + const reloadLocalModels = async () => { + this.getLocalModelsFromDisk(); + const models = this.getModelsInfo(); + await this.webview.postMessage({ + id: MSG_NEW_LOCAL_MODELS_STATE, + body: models, + }); + }; + const watcher = apiFs.createFileSystemWatcher(this.#modelsDir); + watcher.onDidCreate(reloadLocalModels); + watcher.onDidDelete(reloadLocalModels); + watcher.onDidChange(reloadLocalModels); + // Initialize the local models manually + await reloadLocalModels(); + } + + getModelsInfo() { + return this.catalogManager + .getModels() + .filter(m => this.#localModels.has(m.id)) + .map(m => ({ ...m, file: this.#localModels.get(m.id) })); + } + + getLocalModelsFromDisk(): void { + if (!fs.existsSync(this.#modelsDir)) { + return; } - const entries = fs.readdirSync(modelsDir, { withFileTypes: true }); + const result = new Map(); + const entries = fs.readdirSync(this.#modelsDir, { withFileTypes: true }); const dirs = entries.filter(dir => dir.isDirectory()); for (const d of dirs) { const modelEntries = fs.readdirSync(path.resolve(d.path, d.name)); @@ -21,13 +57,33 @@ export class ModelsManager { } const modelFile = modelEntries[0]; const info = fs.statSync(path.resolve(d.path, d.name, modelFile)); - result.push({ + result.set(d.name, { id: d.name, file: modelFile, size: info.size, creation: info.mtime, }); } - return result; + this.#localModels = result; + } + + isModelOnDisk(modelId: string) { + return this.#localModels.has(modelId); + } + + getLocalModelInfo(modelId: string): LocalModelInfo { + if (!this.isModelOnDisk(modelId)) { + throw new Error('model is not on disk'); + } + return this.#localModels.get(modelId); + } + + getLocalModelPath(modelId: string): string { + const info = this.getLocalModelInfo(modelId); + return path.resolve(this.#modelsDir, modelId, info.file); + } + + getLocalModels(): LocalModelInfo[] { + return Array.from(this.#localModels.values()); } } diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 122a6a361..cc340f9f4 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -25,12 +25,10 @@ import type { PlayGroundManager } from './managers/playground'; import * as podmanDesktopApi from '@podman-desktop/api'; import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; -import * as path from 'node:path'; import type { CatalogManager } from './managers/catalogManager'; import type { Catalog } from '@shared/src/models/ICatalog'; import type { PlaygroundState } from '@shared/src/models/IPlaygroundState'; import type { ModelsManager } from './managers/modelsManager'; -import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; export class StudioApiImpl implements StudioAPI { constructor( @@ -81,28 +79,11 @@ export class StudioApiImpl implements StudioAPI { } async getLocalModels(): Promise { - const local = this.modelsManager.getLocalModels(); - const localMap = new Map(); - for (const l of local) { - localMap.set(l.id, l); - } - const localIds = local.map(l => l.id); - return this.catalogManager - .getModels() - .filter(m => localIds.includes(m.id)) - .map(m => ({ ...m, file: localMap.get(m.id) })); + return this.modelsManager.getModelsInfo(); } async startPlayground(modelId: string): Promise { - // TODO: improve the following - const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId); - if (localModelInfo.length !== 1) { - throw new Error('model not found'); - } - - // TODO: we need to stop doing that. - const modelPath = path.resolve(this.appUserDirectory, 'models', modelId, localModelInfo[0].file); - + const modelPath = this.modelsManager.getLocalModelPath(modelId); await this.playgroundManager.startPlayground(modelId, modelPath); } @@ -111,11 +92,8 @@ export class StudioApiImpl implements StudioAPI { } askPlayground(modelId: string, prompt: string): Promise { - const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId); - if (localModelInfo.length !== 1) { - throw new Error('model not found'); - } - return this.playgroundManager.askPlayground(localModelInfo[0], prompt); + const localModelInfo = this.modelsManager.getLocalModelInfo(modelId); + return this.playgroundManager.askPlayground(localModelInfo, prompt); } async getPlaygroundQueriesState(): Promise { diff --git a/packages/backend/src/studio.spec.ts b/packages/backend/src/studio.spec.ts index f456d46fd..f15bec5d8 100644 --- a/packages/backend/src/studio.spec.ts +++ b/packages/backend/src/studio.spec.ts @@ -24,6 +24,8 @@ import type { ExtensionContext } from '@podman-desktop/api'; import * as fs from 'node:fs'; +vi.mock('./managers/modelsManager'); + const mockedExtensionContext = { subscriptions: [], } as unknown as ExtensionContext; diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index ff4e39026..550d9deff 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -101,16 +101,16 @@ export class Studio { const gitManager = new GitManager(); const taskRegistry = new TaskRegistry(); const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview); - this.modelsManager = new ModelsManager(appUserDirectory); + this.playgroundManager = new PlayGroundManager(this.#panel.webview); + // Create catalog manager, responsible for loading the catalog files and watching for changes + this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview); + this.modelsManager = new ModelsManager(appUserDirectory, this.#panel.webview, this.catalogManager); const applicationManager = new ApplicationManager( appUserDirectory, gitManager, recipeStatusRegistry, this.modelsManager, ); - this.playgroundManager = new PlayGroundManager(this.#panel.webview); - // Create catalog manager, responsible for loading the catalog files and watching for changes - this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview); // Creating StudioApiImpl this.studioApi = new StudioApiImpl( @@ -123,6 +123,7 @@ export class Studio { ); await this.catalogManager.loadCatalog(); + await this.modelsManager.loadLocalModels(); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi); diff --git a/packages/frontend/src/stores/local-models.ts b/packages/frontend/src/stores/local-models.ts index c40e32b87..311b308b8 100644 --- a/packages/frontend/src/stores/local-models.ts +++ b/packages/frontend/src/stores/local-models.ts @@ -1,10 +1,18 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { Readable } from 'svelte/store'; import { readable } from 'svelte/store'; -import { studioClient } from '/@/utils/client'; +import { rpcBrowser, studioClient } from '/@/utils/client'; +import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages'; export const localModels: Readable = readable([], set => { + const sub = rpcBrowser.subscribe(MSG_NEW_LOCAL_MODELS_STATE, msg => { + set(msg); + }); + // Initialize the store manually studioClient.getLocalModels().then(v => { set(v); }); + return () => { + sub.unsubscribe(); + }; }); diff --git a/packages/shared/Messages.ts b/packages/shared/Messages.ts index 39ab64907..637b7ef76 100644 --- a/packages/shared/Messages.ts +++ b/packages/shared/Messages.ts @@ -2,3 +2,4 @@ export const MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update'; export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state'; export const MSG_NEW_CATALOG_STATE = 'new-catalog-state'; export const MSG_NEW_RECIPE_STATE = 'new-recipe-state'; +export const MSG_NEW_LOCAL_MODELS_STATE = 'new-local-models-state';