Skip to content

Commit

Permalink
refactor: move downloadModels function into modelsManager (#133)
Browse files Browse the repository at this point in the history
Signed-off-by: lstocchi <[email protected]>
  • Loading branch information
lstocchi committed Jan 30, 2024
1 parent 29d869a commit 9050c97
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 259 deletions.
138 changes: 14 additions & 124 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/
import { type MockInstance, describe, expect, test, vi, beforeEach } from 'vitest';
import type { ContainerAttachedInfo, DownloadModelResult, ImageInfo, PodInfo } from './applicationManager';
import type { ContainerAttachedInfo, ImageInfo, PodInfo } from './applicationManager';
import { ApplicationManager } from './applicationManager';
import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
import type { GitManager } from './gitManager';
Expand All @@ -25,12 +25,14 @@ import fs from 'node:fs';
import type { Recipe } from '@shared/src/models/IRecipe';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { RecipeStatusUtils } from '../utils/recipeStatusUtils';
import type { ModelsManager } from './modelsManager';
import { ModelsManager } from './modelsManager';
import path from 'node:path';
import type { AIConfig, ContainerConfig } from '../models/AIConfig';
import * as portsUtils from '../utils/ports';
import { goarch } from '../utils/arch';
import * as utils from '../utils/utils';
import type { Webview } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';

const mocks = vi.hoisted(() => {
return {
Expand Down Expand Up @@ -81,9 +83,8 @@ describe('pullApplication', () => {
}
const setStatusMock = vi.fn();
const cloneRepositoryMock = vi.fn();
const isModelOnDiskMock = vi.fn();
const getLocalModelPathMock = vi.fn();
let manager: ApplicationManager;
let modelsManager: ModelsManager;
let doDownloadModelWrapperSpy: MockInstance<
[modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string],
Promise<string>
Expand Down Expand Up @@ -156,6 +157,7 @@ describe('pullApplication', () => {
mocks.createContainerMock.mockResolvedValue({
id: 'id',
});
modelsManager = new ModelsManager('appdir', {} as Webview, {} as CatalogManager);
manager = new ApplicationManager(
'/home/user/aistudio',
{
Expand All @@ -164,19 +166,16 @@ describe('pullApplication', () => {
{
setStatus: setStatusMock,
} as unknown as RecipeStatusRegistry,
{
isModelOnDisk: isModelOnDiskMock,
getLocalModelPath: getLocalModelPathMock,
} as unknown as ModelsManager,
modelsManager,
);
doDownloadModelWrapperSpy = vi.spyOn(manager, 'doDownloadModelWrapper');
doDownloadModelWrapperSpy.mockResolvedValue('path');
doDownloadModelWrapperSpy = vi.spyOn(modelsManager, 'doDownloadModelWrapper');
}
test('pullApplication should clone repository and call downloadModelMain and buildImage', async () => {
mockForPullApplication({
recipeFolderExists: false,
});
isModelOnDiskMock.mockReturnValue(false);
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
doDownloadModelWrapperSpy.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -220,7 +219,8 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
isModelOnDiskMock.mockReturnValue(false);
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(modelsManager, 'doDownloadModelWrapper').mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand All @@ -247,8 +247,8 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
isModelOnDiskMock.mockReturnValue(true);
getLocalModelPathMock.mockReturnValue('path');
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(true);
vi.spyOn(modelsManager, 'getLocalModelPath').mockReturnValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -423,70 +423,6 @@ describe('getConfiguration', () => {
});
});

describe('downloadModel', () => {
test('download model if not already on disk', async () => {
const isModelOnDiskMock = vi.fn().mockReturnValue(false);
const manager = new ApplicationManager(
'/home/user/aistudio',
{} as unknown as GitManager,
{} as unknown as RecipeStatusRegistry,
{ isModelOnDisk: isModelOnDiskMock } as unknown as ModelsManager,
);
const doDownloadModelWrapperMock = vi
.spyOn(manager, 'doDownloadModelWrapper')
.mockImplementation((_modelId: string, _url: string, _taskUtil: RecipeStatusUtils, _destFileName?: string) => {
return Promise.resolve('');
});
await manager.downloadModel(
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
);
expect(doDownloadModelWrapperMock).toBeCalledWith('id', 'url', taskUtils);
expect(setTaskMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Downloading model name',
labels: {
'model-pulling': 'id',
},
state: 'loading',
});
});
test('retrieve model path if already on disk', async () => {
const isModelOnDiskMock = vi.fn().mockReturnValue(true);
const getLocalModelPathMock = vi.fn();
const manager = new ApplicationManager(
'/home/user/aistudio',
{} as unknown as GitManager,
{} as unknown as RecipeStatusRegistry,
{
isModelOnDisk: isModelOnDiskMock,
getLocalModelPath: getLocalModelPathMock,
} as unknown as ModelsManager,
);
await manager.downloadModel(
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
);
expect(getLocalModelPathMock).toBeCalledWith('id');
expect(setTaskMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Model name already present on disk',
labels: {
'model-pulling': 'id',
},
state: 'success',
});
});
});

describe('filterContainers', () => {
test('return empty array when no container fit the system', () => {
const aiConfig: AIConfig = {
Expand Down Expand Up @@ -808,52 +744,6 @@ describe('createApplicationPod', () => {
});
});

describe('doDownloadModelWrapper', () => {
const manager = new ApplicationManager(
'/home/user/aistudio',
{} as unknown as GitManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as ModelsManager,
);
test('returning model path if model has been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: true,
path: 'path',
});
},
);
setTaskStateMock.mockReturnThis();
const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils);
expect(result).toBe('path');
});
test('rejecting with error message if model has NOT been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: false,
error: 'error',
});
},
);
setTaskStateMock.mockReturnThis();
await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error');
});
});

describe('restartContainerWhenModelServiceIsUp', () => {
const containerAttachedInfo: ContainerAttachedInfo = {
name: 'name',
Expand Down
135 changes: 1 addition & 134 deletions packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import type { Recipe } from '@shared/src/models/IRecipe';
import type { GitCloneInfo, GitManager } from './gitManager';
import fs from 'fs';
import * as https from 'node:https';
import * as path from 'node:path';
import { type PodCreatePortOptions, containerEngine } from '@podman-desktop/api';
import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
Expand All @@ -36,18 +35,6 @@ import { isEndpointAlive, timeout } from '../utils/utils';

export const CONFIG_FILENAME = 'ai-studio.yaml';

export type DownloadModelResult = DownloadModelSuccessfulResult | DownloadModelFailureResult;

interface DownloadModelSuccessfulResult {
successful: true;
path: string;
}

interface DownloadModelFailureResult {
successful: false;
error: string;
}

interface AIContainers {
aiConfigFile: AIConfigFile;
containers: ContainerConfig[];
Expand Down Expand Up @@ -100,7 +87,7 @@ export class ApplicationManager {
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder, taskUtil);

// get model by downloading it or retrieving locally
const modelPath = await this.downloadModel(model, taskUtil);
const modelPath = await this.modelsManager.downloadModel(model, taskUtil);

// build all images, one per container (for a basic sample we should have 2 containers = sample app + model service)
const images = await this.buildImages(
Expand Down Expand Up @@ -464,45 +451,6 @@ export class ApplicationManager {
);
}

async downloadModel(model: ModelInfo, taskUtil: RecipeStatusUtils) {
if (!this.modelsManager.isModelOnDisk(model.id)) {
// Download model
taskUtil.setTask({
id: model.id,
state: 'loading',
name: `Downloading model ${model.name}`,
labels: {
'model-pulling': model.id,
},
});

try {
return await this.doDownloadModelWrapper(model.id, model.url, taskUtil);
} catch (e) {
console.error(e);
taskUtil.setTask({
id: model.id,
state: 'error',
name: `Downloading model ${model.name}`,
labels: {
'model-pulling': model.id,
},
});
throw e;
}
} else {
taskUtil.setTask({
id: model.id,
state: 'success',
name: `Model ${model.name} already present on disk`,
labels: {
'model-pulling': model.id,
},
});
return this.modelsManager.getLocalModelPath(model.id);
}
}

getConfiguration(recipeConfig: string, localFolder: string): AIConfigFile {
let configFile: string;
if (recipeConfig !== undefined) {
Expand Down Expand Up @@ -572,85 +520,4 @@ export class ApplicationManager {
// Update task
taskUtil.setTask(checkoutTask);
}

doDownloadModelWrapper(
modelId: string,
url: string,
taskUtil: RecipeStatusUtils,
destFileName?: string,
): Promise<string> {
return new Promise((resolve, reject) => {
const downloadCallback = (result: DownloadModelResult) => {
if (result.successful === true) {
taskUtil.setTaskState(modelId, 'success');
resolve(result.path);
} else if (result.successful === false) {
taskUtil.setTaskState(modelId, 'error');
reject(result.error);
}
};

this.doDownloadModel(modelId, url, taskUtil, downloadCallback, destFileName);
});
}

doDownloadModel(
modelId: string,
url: string,
taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
destFileName?: string,
) {
const destDir = path.join(this.appUserDirectory, 'models', modelId);
if (!fs.existsSync(destDir)) {
fs.mkdirSync(destDir, { recursive: true });
}
if (!destFileName) {
destFileName = path.basename(url);
}
const destFile = path.resolve(destDir, destFileName);
const file = fs.createWriteStream(destFile);
let totalFileSize = 0;
let progress = 0;
https.get(url, resp => {
if (resp.headers.location) {
this.doDownloadModel(modelId, resp.headers.location, taskUtil, callback, destFileName);
return;
} else {
if (totalFileSize === 0 && resp.headers['content-length']) {
totalFileSize = parseFloat(resp.headers['content-length']);
}
}

let previousProgressValue = -1;
resp.on('data', chunk => {
progress += chunk.length;
const progressValue = (progress * 100) / totalFileSize;

if (progressValue === 100 || progressValue - previousProgressValue > 1) {
previousProgressValue = progressValue;
taskUtil.setTaskProgress(modelId, progressValue);
}

// send progress in percentage (ex. 1.2%, 2.6%, 80.1%) to frontend
//this.sendProgress(progressValue);
if (progressValue === 100) {
callback({
successful: true,
path: destFile,
});
}
});
file.on('finish', () => {
file.close();
});
file.on('error', e => {
callback({
successful: false,
error: e.message,
});
});
resp.pipe(file);
});
}
}
Loading

0 comments on commit 9050c97

Please sign in to comment.