From 8b63ac925b8db3e966f89adb1c0cbb0ffc052612 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Tue, 23 Jan 2024 11:49:16 +0100 Subject: [PATCH] refactor: move getLocalModels to ModelsManager (#112) * refactor: move getLocalModels to ModelsManager * fix fs import --- .../src/managers/applicationManager.spec.ts | 87 +++---------------- .../src/managers/applicationManager.ts | 38 ++------ .../src/managers/modelsManager.spec.ts | 53 +++++++++++ .../backend/src/managers/modelsManager.ts | 26 ++++++ packages/backend/src/studio-api-impl.spec.ts | 9 +- packages/backend/src/studio-api-impl.ts | 13 +-- packages/backend/src/studio.ts | 24 ++++- 7 files changed, 131 insertions(+), 119 deletions(-) create mode 100644 packages/backend/src/managers/modelsManager.spec.ts create mode 100644 packages/backend/src/managers/modelsManager.ts diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index f6dd4927d..bd227baef 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -1,15 +1,13 @@ import { type MockInstance, describe, expect, test, vi, beforeEach } from 'vitest'; import { ApplicationManager } from './applicationManager'; import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; -import type { ExtensionContext } from '@podman-desktop/api'; import type { GitManager } from './gitManager'; import os from 'os'; -import fs, { Stats, type Dirent } from 'fs'; -import path from 'path'; +import fs from 'node:fs'; import type { Recipe } from '@shared/src/models/IRecipe'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import type { RecipeStatusUtils } from '../utils/recipeStatusUtils'; +import type { ModelsManager } from './modelsManager'; const mocks = vi.hoisted(() => { return { @@ -32,61 +30,6 @@ beforeEach(() => { vi.resetAllMocks(); }); -test('appUserDirectory should be under home directory', () => { - vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); - const manager = new ApplicationManager({} as GitManager, {} as RecipeStatusRegistry, {} as ExtensionContext); - if (process.platform === 'win32') { - expect(manager.appUserDirectory).toMatch(/^\\home\\user/); - } else { - expect(manager.appUserDirectory).toMatch(/^\/home\/user/); - } -}); - -test('getLocalModels should return models in local directory', () => { - vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const readdirSyncMock = vi.spyOn(fs, 'readdirSync') as unknown as MockInstance< - [path: string], - string[] | fs.Dirent[] - >; - readdirSyncMock.mockImplementation((dir: string) => { - if (dir.endsWith('model-id-1') || dir.endsWith('model-id-2')) { - const base = path.basename(dir); - return [base + '-model']; - } else { - return [ - { - isDirectory: () => true, - path: '/home/user/appstudio-dir', - name: 'model-id-1', - }, - { - isDirectory: () => true, - path: '/home/user/appstudio-dir', - name: 'model-id-2', - }, - { - isDirectory: () => false, - path: '/home/user/appstudio-dir', - name: 'other-file-should-be-ignored.txt', - }, - ] as Dirent[]; - } - }); - const manager = new ApplicationManager({} as GitManager, {} as RecipeStatusRegistry, {} as ExtensionContext); - const models = manager.getLocalModels(); - expect(models).toEqual([ - { - id: 'model-id-1', - file: 'model-id-1-model', - }, - { - id: 'model-id-2', - file: 'model-id-2-model', - }, - ]); -}); - describe('pullApplication', () => { interface mockForPullApplicationOptions { recipeFolderExists: boolean; @@ -94,8 +37,8 @@ describe('pullApplication', () => { const setStatusMock = vi.fn(); const cloneRepositoryMock = vi.fn(); + const getLocalModelsMock = vi.fn(); let manager: ApplicationManager; - let getLocalModelsSpy: MockInstance<[], LocalModelInfo[]>; let downloadModelMainSpy: MockInstance< [modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string], Promise @@ -116,11 +59,11 @@ describe('pullApplication', () => { }); vi.spyOn(fs, 'statSync').mockImplementation((path: string) => { if (path.endsWith('recipe1')) { - const stat = new Stats(); + const stat = new fs.Stats(); stat.isDirectory = () => true; return stat; } else if (path.endsWith('ai-studio.yaml')) { - const stat = new Stats(); + const stat = new fs.Stats(); stat.isDirectory = () => false; return stat; } @@ -142,16 +85,18 @@ describe('pullApplication', () => { mocks.builImageMock.mockResolvedValue(undefined); manager = new ApplicationManager( + '/home/user/aistudio', { cloneRepository: cloneRepositoryMock, } as unknown as GitManager, { setStatus: setStatusMock, } as unknown as RecipeStatusRegistry, - {} as ExtensionContext, + { + getLocalModels: getLocalModelsMock, + } as unknown as ModelsManager, ); - getLocalModelsSpy = vi.spyOn(manager, 'getLocalModels'); downloadModelMainSpy = vi.spyOn(manager, 'downloadModelMain'); downloadModelMainSpy.mockResolvedValue(''); } @@ -160,7 +105,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: false, }); - getLocalModelsSpy.mockReturnValue([]); + getLocalModelsMock.mockReturnValue([]); const recipe: Recipe = { id: 'recipe1', @@ -183,13 +128,9 @@ describe('pullApplication', () => { await manager.pullApplication(recipe, model); if (process.platform === 'win32') { - expect(cloneRepositoryMock).toHaveBeenNthCalledWith( - 1, - 'repo', - '\\home\\user\\podman-desktop\\ai-studio\\recipe1', - ); + expect(cloneRepositoryMock).toHaveBeenNthCalledWith(1, 'repo', '\\home\\user\\aistudio\\recipe1'); } else { - expect(cloneRepositoryMock).toHaveBeenNthCalledWith(1, 'repo', '/home/user/podman-desktop/ai-studio/recipe1'); + expect(cloneRepositoryMock).toHaveBeenNthCalledWith(1, 'repo', '/home/user/aistudio/recipe1'); } expect(downloadModelMainSpy).toHaveBeenCalledOnce(); expect(mocks.builImageMock).toHaveBeenCalledOnce(); @@ -199,7 +140,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - getLocalModelsSpy.mockReturnValue([]); + getLocalModelsMock.mockReturnValue([]); const recipe: Recipe = { id: 'recipe1', @@ -228,7 +169,7 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - getLocalModelsSpy.mockReturnValue([ + getLocalModelsMock.mockReturnValue([ { id: 'model1', file: 'model1.file', diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 8f707a1ce..4b9ee0e63 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -19,22 +19,19 @@ import type { Recipe } from '@shared/src/models/IRecipe'; import { arch } from 'node:os'; import type { GitManager } from './gitManager'; -import os from 'os'; import fs from 'fs'; import * as https from 'node:https'; import * as path from 'node:path'; -import { containerEngine, type ExtensionContext } from '@podman-desktop/api'; +import { containerEngine } from '@podman-desktop/api'; import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; import type { AIConfig } from '../models/AIConfig'; import { parseYaml } from '../models/AIConfig'; import type { Task } from '@shared/src/models/ITask'; import { RecipeStatusUtils } from '../utils/recipeStatusUtils'; import { getParentDirectory } from '../utils/pathUtils'; -import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; +import type { ModelsManager } from './modelsManager'; -// TODO: Need to be configured -export const AI_STUDIO_FOLDER = path.join('podman-desktop', 'ai-studio'); export const CONFIG_FILENAME = 'ai-studio.yaml'; interface DownloadModelResult { @@ -43,15 +40,12 @@ interface DownloadModelResult { } export class ApplicationManager { - readonly appUserDirectory: string; // todo: make configurable - constructor( + private appUserDirectory: string, private git: GitManager, private recipeStatusRegistry: RecipeStatusRegistry, - private extensionContext: ExtensionContext, - ) { - this.appUserDirectory = path.join(os.homedir(), AI_STUDIO_FOLDER); - } + private modelsManager: ModelsManager, + ) {} async pullApplication(recipe: Recipe, model: ModelInfo) { // Create a TaskUtils object to help us @@ -141,7 +135,7 @@ export class ApplicationManager { container => container.arch === undefined || container.arch === arch(), ); - const localModels = this.getLocalModels(); + const localModels = this.modelsManager.getLocalModels(); if (!localModels.map(m => m.id).includes(model.id)) { // Download model taskUtil.setTask({ @@ -298,24 +292,4 @@ export class ApplicationManager { resp.pipe(file); }); } - - // todo: move somewhere else (dedicated to models) - getLocalModels(): LocalModelInfo[] { - const result: LocalModelInfo[] = []; - const modelsDir = path.join(this.appUserDirectory, 'models'); - const entries = fs.readdirSync(modelsDir, { withFileTypes: true }); - const dirs = entries.filter(dir => dir.isDirectory()); - for (const d of dirs) { - const modelEntries = fs.readdirSync(path.resolve(d.path, d.name)); - if (modelEntries.length !== 1) { - // we support models with one file only for now - continue; - } - result.push({ - id: d.name, - file: modelEntries[0], - }); - } - return result; - } } diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts new file mode 100644 index 000000000..7b5778ddc --- /dev/null +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -0,0 +1,53 @@ +import { type MockInstance, beforeEach, expect, test, vi } from 'vitest'; +import os from 'os'; +import fs from 'node:fs'; +import path from 'node:path'; +import { ModelsManager } from './modelsManager'; + +beforeEach(() => { + vi.resetAllMocks(); +}); + +test('getLocalModels should return models in local directory', () => { + vi.spyOn(os, 'homedir').mockReturnValue('/home/user'); + const readdirSyncMock = vi.spyOn(fs, 'readdirSync') as unknown as MockInstance< + [path: string], + string[] | fs.Dirent[] + >; + readdirSyncMock.mockImplementation((dir: string) => { + if (dir.endsWith('model-id-1') || dir.endsWith('model-id-2')) { + const base = path.basename(dir); + return [base + '-model']; + } else { + return [ + { + isDirectory: () => true, + path: '/home/user/appstudio-dir', + name: 'model-id-1', + }, + { + isDirectory: () => true, + path: '/home/user/appstudio-dir', + name: 'model-id-2', + }, + { + isDirectory: () => false, + path: '/home/user/appstudio-dir', + name: 'other-file-should-be-ignored.txt', + }, + ] as fs.Dirent[]; + } + }); + const manager = new ModelsManager('/home/user/aistudio'); + const models = manager.getLocalModels(); + expect(models).toEqual([ + { + id: 'model-id-1', + file: 'model-id-1-model', + }, + { + id: 'model-id-2', + file: 'model-id-2-model', + }, + ]); +}); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts new file mode 100644 index 000000000..c69e3d97e --- /dev/null +++ b/packages/backend/src/managers/modelsManager.ts @@ -0,0 +1,26 @@ +import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; +import fs from 'fs'; +import * as path from 'node:path'; + +export class ModelsManager { + constructor(private appUserDirectory: string) {} + + getLocalModels(): LocalModelInfo[] { + const result: LocalModelInfo[] = []; + const modelsDir = path.join(this.appUserDirectory, 'models'); + const entries = fs.readdirSync(modelsDir, { withFileTypes: true }); + const dirs = entries.filter(dir => dir.isDirectory()); + for (const d of dirs) { + const modelEntries = fs.readdirSync(path.resolve(d.path, d.name)); + if (modelEntries.length !== 1) { + // we support models with one file only for now + continue; + } + result.push({ + id: d.name, + file: modelEntries[0], + }); + } + return result; + } +} diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 6d8d9cd84..105fed58f 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -26,9 +26,10 @@ import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import { StudioApiImpl } from './studio-api-impl'; import type { PlayGroundManager } from './managers/playground'; import type { Webview } from '@podman-desktop/api'; +import { CatalogManager } from './managers/catalogManager'; +import type { ModelsManager } from './managers/modelsManager'; import * as fs from 'node:fs'; -import { CatalogManager } from './managers/catalogManager'; vi.mock('./ai.json', () => { return { @@ -70,12 +71,12 @@ beforeEach(async () => { // Creating StudioApiImpl studioApiImpl = new StudioApiImpl( - { - appUserDirectory, - } as unknown as ApplicationManager, + appUserDirectory, + {} as unknown as ApplicationManager, {} as unknown as RecipeStatusRegistry, {} as unknown as PlayGroundManager, catalogManager, + {} as unknown as ModelsManager, ); vi.resetAllMocks(); vi.mock('node:fs'); diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 63b481aee..9347ab04f 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -29,13 +29,16 @@ import * as path from 'node:path'; import type { CatalogManager } from './managers/catalogManager'; import type { Catalog } from '@shared/src/models/ICatalog'; import type { PlaygroundState } from '@shared/src/models/IPlaygroundState'; +import type { ModelsManager } from './managers/modelsManager'; export class StudioApiImpl implements StudioAPI { constructor( + private appUserDirectory: string, private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, private playgroundManager: PlayGroundManager, private catalogManager: CatalogManager, + private modelsManager: ModelsManager, ) {} async ping(): Promise { @@ -55,7 +58,6 @@ export class StudioApiImpl implements StudioAPI { } async getModelById(modelId: string): Promise { - // TODO: move logic to catalog manager const model = this.catalogManager.getModels().find(m => modelId === m.id); if (!model) { throw new Error(`No model found having id ${modelId}`); @@ -78,21 +80,20 @@ export class StudioApiImpl implements StudioAPI { } async getLocalModels(): Promise { - // TODO: move logic to catalog manager - const local = this.applicationManager.getLocalModels(); + const local = this.modelsManager.getLocalModels(); const localIds = local.map(l => l.id); return this.catalogManager.getModels().filter(m => localIds.includes(m.id)); } async startPlayground(modelId: string): Promise { // TODO: improve the following - const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); + const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); } // TODO: we need to stop doing that. - const modelPath = path.resolve(this.applicationManager.appUserDirectory, 'models', modelId, localModelInfo[0].file); + const modelPath = path.resolve(this.appUserDirectory, 'models', modelId, localModelInfo[0].file); await this.playgroundManager.startPlayground(modelId, modelPath); } @@ -102,7 +103,7 @@ export class StudioApiImpl implements StudioAPI { } askPlayground(modelId: string, prompt: string): Promise { - const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); + const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 87bb917bd..ff4e39026 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -23,11 +23,16 @@ import { StudioApiImpl } from './studio-api-impl'; import { ApplicationManager } from './managers/applicationManager'; import { GitManager } from './managers/gitManager'; import { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; - -import * as fs from 'node:fs'; import { TaskRegistry } from './registries/TaskRegistry'; import { PlayGroundManager } from './managers/playground'; import { CatalogManager } from './managers/catalogManager'; +import { ModelsManager } from './managers/modelsManager'; +import path from 'node:path'; +import os from 'os'; +import fs from 'node:fs'; + +// TODO: Need to be configured +export const AI_STUDIO_FOLDER = path.join('podman-desktop', 'ai-studio'); export class Studio { readonly #extensionContext: ExtensionContext; @@ -38,6 +43,7 @@ export class Studio { studioApi: StudioApiImpl; playgroundManager: PlayGroundManager; catalogManager: CatalogManager; + modelsManager: ModelsManager; constructor(readonly extensionContext: ExtensionContext) { this.#extensionContext = extensionContext; @@ -89,21 +95,31 @@ export class Studio { this.#panel.webview.html = indexHtml; // Let's create the api that the front will be able to call + const appUserDirectory = path.join(os.homedir(), AI_STUDIO_FOLDER); + this.rpcExtension = new RpcExtension(this.#panel.webview); const gitManager = new GitManager(); const taskRegistry = new TaskRegistry(); const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview); - const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext); + this.modelsManager = new ModelsManager(appUserDirectory); + const applicationManager = new ApplicationManager( + appUserDirectory, + gitManager, + recipeStatusRegistry, + this.modelsManager, + ); this.playgroundManager = new PlayGroundManager(this.#panel.webview); // Create catalog manager, responsible for loading the catalog files and watching for changes - this.catalogManager = new CatalogManager(applicationManager.appUserDirectory, this.#panel.webview); + this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview); // Creating StudioApiImpl this.studioApi = new StudioApiImpl( + appUserDirectory, applicationManager, recipeStatusRegistry, this.playgroundManager, this.catalogManager, + this.modelsManager, ); await this.catalogManager.loadCatalog();