From b4851ca8522b7b11efa25380ddefa5b3ea435b94 Mon Sep 17 00:00:00 2001 From: Luca Stocchi <49404737+lstocchi@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:14:03 +0100 Subject: [PATCH] refactor: move downloadModels function into modelsManager (#133) (#183) Signed-off-by: lstocchi --- .../src/managers/applicationManager.spec.ts | 138 ++---------------- .../src/managers/applicationManager.ts | 135 +---------------- .../src/managers/modelsManager.spec.ts | 105 ++++++++++++- .../backend/src/managers/modelsManager.ts | 134 +++++++++++++++++ 4 files changed, 253 insertions(+), 259 deletions(-) diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index 78090ff08..2d26a4a08 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -16,7 +16,7 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ import { type MockInstance, describe, expect, test, vi, beforeEach } from 'vitest'; -import type { ContainerAttachedInfo, DownloadModelResult, ImageInfo, PodInfo } from './applicationManager'; +import type { ContainerAttachedInfo, ImageInfo, PodInfo } from './applicationManager'; import { ApplicationManager } from './applicationManager'; import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; import type { GitManager } from './gitManager'; @@ -25,12 +25,14 @@ import fs from 'node:fs'; import type { Recipe } from '@shared/src/models/IRecipe'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import { RecipeStatusUtils } from '../utils/recipeStatusUtils'; -import type { ModelsManager } from './modelsManager'; +import { ModelsManager } from './modelsManager'; import path from 'node:path'; import type { AIConfig, ContainerConfig } from '../models/AIConfig'; import * as portsUtils from '../utils/ports'; import { goarch } from '../utils/arch'; import * as utils from '../utils/utils'; +import type { Webview } from '@podman-desktop/api'; +import type { CatalogManager } from './catalogManager'; const mocks = vi.hoisted(() => { return { @@ -81,9 +83,8 @@ describe('pullApplication', () => { } const setStatusMock = vi.fn(); const cloneRepositoryMock = vi.fn(); - const isModelOnDiskMock = vi.fn(); - const getLocalModelPathMock = vi.fn(); let manager: ApplicationManager; + let modelsManager: ModelsManager; let doDownloadModelWrapperSpy: MockInstance< [modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string], Promise @@ -156,6 +157,7 @@ describe('pullApplication', () => { mocks.createContainerMock.mockResolvedValue({ id: 'id', }); + modelsManager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager); manager = new ApplicationManager( '/home/user/aistudio', { @@ -164,19 +166,16 @@ describe('pullApplication', () => { { setStatus: setStatusMock, } as unknown as RecipeStatusRegistry, - { - isModelOnDisk: isModelOnDiskMock, - getLocalModelPath: getLocalModelPathMock, - } as unknown as ModelsManager, + modelsManager, ); - doDownloadModelWrapperSpy = vi.spyOn(manager, 'doDownloadModelWrapper'); - doDownloadModelWrapperSpy.mockResolvedValue('path'); + doDownloadModelWrapperSpy = vi.spyOn(modelsManager, 'doDownloadModelWrapper'); } test('pullApplication should clone repository and call downloadModelMain and buildImage', async () => { mockForPullApplication({ recipeFolderExists: false, }); - isModelOnDiskMock.mockReturnValue(false); + vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false); + doDownloadModelWrapperSpy.mockResolvedValue('path'); const recipe: Recipe = { id: 'recipe1', name: 'Recipe 1', @@ -220,7 +219,8 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - isModelOnDiskMock.mockReturnValue(false); + vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false); + vi.spyOn(modelsManager, 'doDownloadModelWrapper').mockResolvedValue('path'); const recipe: Recipe = { id: 'recipe1', name: 'Recipe 1', @@ -247,8 +247,8 @@ describe('pullApplication', () => { mockForPullApplication({ recipeFolderExists: true, }); - isModelOnDiskMock.mockReturnValue(true); - getLocalModelPathMock.mockReturnValue('path'); + vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(true); + vi.spyOn(modelsManager, 'getLocalModelPath').mockReturnValue('path'); const recipe: Recipe = { id: 'recipe1', name: 'Recipe 1', @@ -423,70 +423,6 @@ describe('getConfiguration', () => { }); }); -describe('downloadModel', () => { - test('download model if not already on disk', async () => { - const isModelOnDiskMock = vi.fn().mockReturnValue(false); - const manager = new ApplicationManager( - '/home/user/aistudio', - {} as unknown as GitManager, - {} as unknown as RecipeStatusRegistry, - { isModelOnDisk: isModelOnDiskMock } as unknown as ModelsManager, - ); - const doDownloadModelWrapperMock = vi - .spyOn(manager, 'doDownloadModelWrapper') - .mockImplementation((_modelId: string, _url: string, _taskUtil: RecipeStatusUtils, _destFileName?: string) => { - return Promise.resolve(''); - }); - await manager.downloadModel( - { - id: 'id', - url: 'url', - name: 'name', - } as ModelInfo, - taskUtils, - ); - expect(doDownloadModelWrapperMock).toBeCalledWith('id', 'url', taskUtils); - expect(setTaskMock).toHaveBeenLastCalledWith({ - id: 'id', - name: 'Downloading model name', - labels: { - 'model-pulling': 'id', - }, - state: 'loading', - }); - }); - test('retrieve model path if already on disk', async () => { - const isModelOnDiskMock = vi.fn().mockReturnValue(true); - const getLocalModelPathMock = vi.fn(); - const manager = new ApplicationManager( - '/home/user/aistudio', - {} as unknown as GitManager, - {} as unknown as RecipeStatusRegistry, - { - isModelOnDisk: isModelOnDiskMock, - getLocalModelPath: getLocalModelPathMock, - } as unknown as ModelsManager, - ); - await manager.downloadModel( - { - id: 'id', - url: 'url', - name: 'name', - } as ModelInfo, - taskUtils, - ); - expect(getLocalModelPathMock).toBeCalledWith('id'); - expect(setTaskMock).toHaveBeenLastCalledWith({ - id: 'id', - name: 'Model name already present on disk', - labels: { - 'model-pulling': 'id', - }, - state: 'success', - }); - }); -}); - describe('filterContainers', () => { test('return empty array when no container fit the system', () => { const aiConfig: AIConfig = { @@ -808,52 +744,6 @@ describe('createApplicationPod', () => { }); }); -describe('doDownloadModelWrapper', () => { - const manager = new ApplicationManager( - '/home/user/aistudio', - {} as unknown as GitManager, - {} as unknown as RecipeStatusRegistry, - {} as unknown as ModelsManager, - ); - test('returning model path if model has been downloaded', async () => { - vi.spyOn(manager, 'doDownloadModel').mockImplementation( - ( - _modelId: string, - _url: string, - _taskUtil: RecipeStatusUtils, - callback: (message: DownloadModelResult) => void, - _destFileName?: string, - ) => { - callback({ - successful: true, - path: 'path', - }); - }, - ); - setTaskStateMock.mockReturnThis(); - const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils); - expect(result).toBe('path'); - }); - test('rejecting with error message if model has NOT been downloaded', async () => { - vi.spyOn(manager, 'doDownloadModel').mockImplementation( - ( - _modelId: string, - _url: string, - _taskUtil: RecipeStatusUtils, - callback: (message: DownloadModelResult) => void, - _destFileName?: string, - ) => { - callback({ - successful: false, - error: 'error', - }); - }, - ); - setTaskStateMock.mockReturnThis(); - await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error'); - }); -}); - describe('restartContainerWhenModelServiceIsUp', () => { const containerAttachedInfo: ContainerAttachedInfo = { name: 'name', diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index ecf4a0832..7d52624ad 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -19,7 +19,6 @@ import type { Recipe } from '@shared/src/models/IRecipe'; import type { GitCloneInfo, GitManager } from './gitManager'; import fs from 'fs'; -import * as https from 'node:https'; import * as path from 'node:path'; import { type PodCreatePortOptions, containerEngine } from '@podman-desktop/api'; import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; @@ -36,18 +35,6 @@ import { isEndpointAlive, timeout } from '../utils/utils'; export const CONFIG_FILENAME = 'ai-studio.yaml'; -export type DownloadModelResult = DownloadModelSuccessfulResult | DownloadModelFailureResult; - -interface DownloadModelSuccessfulResult { - successful: true; - path: string; -} - -interface DownloadModelFailureResult { - successful: false; - error: string; -} - interface AIContainers { aiConfigFile: AIConfigFile; containers: ContainerConfig[]; @@ -100,7 +87,7 @@ export class ApplicationManager { const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder, taskUtil); // get model by downloading it or retrieving locally - const modelPath = await this.downloadModel(model, taskUtil); + const modelPath = await this.modelsManager.downloadModel(model, taskUtil); // build all images, one per container (for a basic sample we should have 2 containers = sample app + model service) const images = await this.buildImages( @@ -464,45 +451,6 @@ export class ApplicationManager { ); } - async downloadModel(model: ModelInfo, taskUtil: RecipeStatusUtils) { - if (!this.modelsManager.isModelOnDisk(model.id)) { - // Download model - taskUtil.setTask({ - id: model.id, - state: 'loading', - name: `Downloading model ${model.name}`, - labels: { - 'model-pulling': model.id, - }, - }); - - try { - return await this.doDownloadModelWrapper(model.id, model.url, taskUtil); - } catch (e) { - console.error(e); - taskUtil.setTask({ - id: model.id, - state: 'error', - name: `Downloading model ${model.name}`, - labels: { - 'model-pulling': model.id, - }, - }); - throw e; - } - } else { - taskUtil.setTask({ - id: model.id, - state: 'success', - name: `Model ${model.name} already present on disk`, - labels: { - 'model-pulling': model.id, - }, - }); - return this.modelsManager.getLocalModelPath(model.id); - } - } - getConfiguration(recipeConfig: string, localFolder: string): AIConfigFile { let configFile: string; if (recipeConfig !== undefined) { @@ -572,85 +520,4 @@ export class ApplicationManager { // Update task taskUtil.setTask(checkoutTask); } - - doDownloadModelWrapper( - modelId: string, - url: string, - taskUtil: RecipeStatusUtils, - destFileName?: string, - ): Promise { - return new Promise((resolve, reject) => { - const downloadCallback = (result: DownloadModelResult) => { - if (result.successful === true) { - taskUtil.setTaskState(modelId, 'success'); - resolve(result.path); - } else if (result.successful === false) { - taskUtil.setTaskState(modelId, 'error'); - reject(result.error); - } - }; - - this.doDownloadModel(modelId, url, taskUtil, downloadCallback, destFileName); - }); - } - - doDownloadModel( - modelId: string, - url: string, - taskUtil: RecipeStatusUtils, - callback: (message: DownloadModelResult) => void, - destFileName?: string, - ) { - const destDir = path.join(this.appUserDirectory, 'models', modelId); - if (!fs.existsSync(destDir)) { - fs.mkdirSync(destDir, { recursive: true }); - } - if (!destFileName) { - destFileName = path.basename(url); - } - const destFile = path.resolve(destDir, destFileName); - const file = fs.createWriteStream(destFile); - let totalFileSize = 0; - let progress = 0; - https.get(url, resp => { - if (resp.headers.location) { - this.doDownloadModel(modelId, resp.headers.location, taskUtil, callback, destFileName); - return; - } else { - if (totalFileSize === 0 && resp.headers['content-length']) { - totalFileSize = parseFloat(resp.headers['content-length']); - } - } - - let previousProgressValue = -1; - resp.on('data', chunk => { - progress += chunk.length; - const progressValue = (progress * 100) / totalFileSize; - - if (progressValue === 100 || progressValue - previousProgressValue > 1) { - previousProgressValue = progressValue; - taskUtil.setTaskProgress(modelId, progressValue); - } - - // send progress in percentage (ex. 1.2%, 2.6%, 80.1%) to frontend - //this.sendProgress(progressValue); - if (progressValue === 100) { - callback({ - successful: true, - path: destFile, - }); - } - }); - file.on('finish', () => { - file.close(); - }); - file.on('error', e => { - callback({ - successful: false, - error: e.message, - }); - }); - resp.pipe(file); - }); - } } diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index facb3c3e5..51b2e31e2 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -1,11 +1,14 @@ -import { type MockInstance, beforeEach, expect, test, vi } from 'vitest'; +import { type MockInstance, beforeEach, describe, expect, test, vi } from 'vitest'; import os from 'os'; import fs from 'node:fs'; import path from 'node:path'; +import type { DownloadModelResult } from './modelsManager'; import { ModelsManager } from './modelsManager'; import type { Webview } from '@podman-desktop/api'; import type { CatalogManager } from './catalogManager'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; +import { RecipeStatusUtils } from '../utils/recipeStatusUtils'; +import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; const mocks = vi.hoisted(() => { return { @@ -28,8 +31,17 @@ vi.mock('@podman-desktop/api', () => { }; }); +let setTaskMock: MockInstance; +let taskUtils: RecipeStatusUtils; +let setTaskStateMock: MockInstance; + beforeEach(() => { vi.resetAllMocks(); + taskUtils = new RecipeStatusUtils('recipe', { + setStatus: vi.fn(), + } as unknown as RecipeStatusRegistry); + setTaskMock = vi.spyOn(taskUtils, 'setTask'); + setTaskStateMock = vi.spyOn(taskUtils, 'setTaskState'); }); const dirent = [ @@ -304,3 +316,94 @@ test('deleteLocalModel fails to delete the model folder', async () => { }); expect(mocks.showErrorMessageMock).toHaveBeenCalledOnce(); }); + +describe('downloadModel', () => { + const manager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager); + test('download model if not already on disk', async () => { + vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); + const doDownloadModelWrapperMock = vi + .spyOn(manager, 'doDownloadModelWrapper') + .mockImplementation((_modelId: string, _url: string, _taskUtil: RecipeStatusUtils, _destFileName?: string) => { + return Promise.resolve(''); + }); + await manager.downloadModel( + { + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo, + taskUtils, + ); + expect(doDownloadModelWrapperMock).toBeCalledWith('id', 'url', taskUtils); + expect(setTaskMock).toHaveBeenLastCalledWith({ + id: 'id', + name: 'Downloading model name', + labels: { + 'model-pulling': 'id', + }, + state: 'loading', + }); + }); + test('retrieve model path if already on disk', async () => { + vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); + const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue(''); + await manager.downloadModel( + { + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo, + taskUtils, + ); + expect(getLocalModelPathMock).toBeCalledWith('id'); + expect(setTaskMock).toHaveBeenLastCalledWith({ + id: 'id', + name: 'Model name already present on disk', + labels: { + 'model-pulling': 'id', + }, + state: 'success', + }); + }); +}); + +describe('doDownloadModelWrapper', () => { + const manager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager); + test('returning model path if model has been downloaded', async () => { + vi.spyOn(manager, 'doDownloadModel').mockImplementation( + ( + _modelId: string, + _url: string, + _taskUtil: RecipeStatusUtils, + callback: (message: DownloadModelResult) => void, + _destFileName?: string, + ) => { + callback({ + successful: true, + path: 'path', + }); + }, + ); + setTaskStateMock.mockReturnThis(); + const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils); + expect(result).toBe('path'); + }); + test('rejecting with error message if model has NOT been downloaded', async () => { + vi.spyOn(manager, 'doDownloadModel').mockImplementation( + ( + _modelId: string, + _url: string, + _taskUtil: RecipeStatusUtils, + callback: (message: DownloadModelResult) => void, + _destFileName?: string, + ) => { + callback({ + successful: false, + error: 'error', + }); + }, + ); + setTaskStateMock.mockReturnThis(); + await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error'); + }); +}); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index b080dde8d..17190a7c8 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -1,11 +1,25 @@ import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import fs from 'fs'; +import * as https from 'node:https'; import * as path from 'node:path'; import { type Webview, fs as apiFs } from '@podman-desktop/api'; import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages'; import type { CatalogManager } from './catalogManager'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import * as podmanDesktopApi from '@podman-desktop/api'; +import type { RecipeStatusUtils } from '../utils/recipeStatusUtils'; + +export type DownloadModelResult = DownloadModelSuccessfulResult | DownloadModelFailureResult; + +interface DownloadModelSuccessfulResult { + successful: true; + path: string; +} + +interface DownloadModelFailureResult { + successful: false; + error: string; +} export class ModelsManager { #modelsDir: string; @@ -127,4 +141,124 @@ export class ModelsManager { await this.sendModelsInfo(); } } + + async downloadModel(model: ModelInfo, taskUtil: RecipeStatusUtils) { + if (!this.isModelOnDisk(model.id)) { + // Download model + taskUtil.setTask({ + id: model.id, + state: 'loading', + name: `Downloading model ${model.name}`, + labels: { + 'model-pulling': model.id, + }, + }); + + try { + return await this.doDownloadModelWrapper(model.id, model.url, taskUtil); + } catch (e) { + console.error(e); + taskUtil.setTask({ + id: model.id, + state: 'error', + name: `Downloading model ${model.name}`, + labels: { + 'model-pulling': model.id, + }, + }); + throw e; + } + } else { + taskUtil.setTask({ + id: model.id, + state: 'success', + name: `Model ${model.name} already present on disk`, + labels: { + 'model-pulling': model.id, + }, + }); + return this.getLocalModelPath(model.id); + } + } + + doDownloadModelWrapper( + modelId: string, + url: string, + taskUtil: RecipeStatusUtils, + destFileName?: string, + ): Promise { + return new Promise((resolve, reject) => { + const downloadCallback = (result: DownloadModelResult) => { + if (result.successful === true) { + taskUtil.setTaskState(modelId, 'success'); + resolve(result.path); + } else if (result.successful === false) { + taskUtil.setTaskState(modelId, 'error'); + reject(result.error); + } + }; + + this.doDownloadModel(modelId, url, taskUtil, downloadCallback, destFileName); + }); + } + + doDownloadModel( + modelId: string, + url: string, + taskUtil: RecipeStatusUtils, + callback: (message: DownloadModelResult) => void, + destFileName?: string, + ) { + const destDir = path.join(this.appUserDirectory, 'models', modelId); + if (!fs.existsSync(destDir)) { + fs.mkdirSync(destDir, { recursive: true }); + } + if (!destFileName) { + destFileName = path.basename(url); + } + const destFile = path.resolve(destDir, destFileName); + const file = fs.createWriteStream(destFile); + let totalFileSize = 0; + let progress = 0; + https.get(url, resp => { + if (resp.headers.location) { + this.doDownloadModel(modelId, resp.headers.location, taskUtil, callback, destFileName); + return; + } else { + if (totalFileSize === 0 && resp.headers['content-length']) { + totalFileSize = parseFloat(resp.headers['content-length']); + } + } + + let previousProgressValue = -1; + resp.on('data', chunk => { + progress += chunk.length; + const progressValue = (progress * 100) / totalFileSize; + + if (progressValue === 100 || progressValue - previousProgressValue > 1) { + previousProgressValue = progressValue; + taskUtil.setTaskProgress(modelId, progressValue); + } + + // send progress in percentage (ex. 1.2%, 2.6%, 80.1%) to frontend + //this.sendProgress(progressValue); + if (progressValue === 100) { + callback({ + successful: true, + path: destFile, + }); + } + }); + file.on('finish', () => { + file.close(); + }); + file.on('error', e => { + callback({ + successful: false, + error: e.message, + }); + }); + resp.pipe(file); + }); + } }