Skip to content

Commit

Permalink
chore: remove the dependency of download task to the RecipeStatusUtils (
Browse files Browse the repository at this point in the history
#340)

* feat: splitting download system

Signed-off-by: axel7083 <[email protected]>

* fix: prettier&linter

Signed-off-by: axel7083 <[email protected]>

* fix: recipe update registry

Signed-off-by: axel7083 <[email protected]>

* fix: prettier

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Feb 15, 2024
1 parent 4b65712 commit 60aaa3e
Show file tree
Hide file tree
Showing 11 changed files with 334 additions and 249 deletions.
33 changes: 18 additions & 15 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import type {
podStopHandle,
startupHandle,
} from './podmanConnection';
import { TaskRegistry } from '../registries/TaskRegistry';

const mocks = vi.hoisted(() => {
return {
Expand Down Expand Up @@ -73,11 +74,21 @@ const mocks = vi.hoisted(() => {
listPodsMock: vi.fn(),
stopPodMock: vi.fn(),
removePodMock: vi.fn(),
performDownloadMock: vi.fn(),
onEventDownloadMock: vi.fn(),
};
});
vi.mock('../models/AIConfig', () => ({
parseYamlFile: mocks.parseYamlFileMock,
}));

vi.mock('../utils/downloader', () => ({
Downloader: class {
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
},
}));

vi.mock('@podman-desktop/api', () => ({
provider: {
getContainerConnections: mocks.getContainerConnectionsMock,
Expand Down Expand Up @@ -129,10 +140,6 @@ describe('pullApplication', () => {
const cloneRepositoryMock = vi.fn();
let manager: ApplicationManager;
let modelsManager: ModelsManager;
let doDownloadModelWrapperSpy: MockInstance<
[modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string],
Promise<string>
>;
vi.spyOn(utils, 'timeout').mockResolvedValue();
function mockForPullApplication(options: mockForPullApplicationOptions) {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
Expand Down Expand Up @@ -210,6 +217,7 @@ describe('pullApplication', () => {
},
} as CatalogManager,
telemetryLogger,
new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview),
);
manager = new ApplicationManager(
'/home/user/aistudio',
Expand All @@ -225,14 +233,13 @@ describe('pullApplication', () => {
modelsManager,
telemetryLogger,
);
doDownloadModelWrapperSpy = vi.spyOn(modelsManager, 'doDownloadModelWrapper');
}
test('pullApplication should clone repository and call downloadModelMain and buildImage', async () => {
mockForPullApplication({
recipeFolderExists: false,
});
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
doDownloadModelWrapperSpy.mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -270,7 +277,7 @@ describe('pullApplication', () => {
gitCloneOptions.targetDirectory = '/home/user/aistudio/recipe1';
expect(cloneRepositoryMock).toHaveBeenNthCalledWith(1, gitCloneOptions);
}
expect(doDownloadModelWrapperSpy).toHaveBeenCalledOnce();
expect(mocks.performDownloadMock).toHaveBeenCalledOnce();
expect(mocks.buildImageMock).toHaveBeenCalledOnce();
expect(mocks.buildImageMock).toHaveBeenCalledWith(
`${gitCloneOptions.targetDirectory}${path.sep}contextdir1`,
Expand All @@ -283,11 +290,7 @@ describe('pullApplication', () => {
},
},
);
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'model.download', {
'model.id': 'model1',
durationSeconds: 99,
});
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(2, 'recipe.pull', {
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'recipe.pull', {
'recipe.id': 'recipe1',
'recipe.name': 'Recipe 1',
durationSeconds: 99,
Expand All @@ -298,7 +301,7 @@ describe('pullApplication', () => {
recipeFolderExists: true,
});
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(modelsManager, 'doDownloadModelWrapper').mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -348,7 +351,7 @@ describe('pullApplication', () => {
};
await manager.pullApplication(recipe, model);
expect(cloneRepositoryMock).not.toHaveBeenCalled();
expect(doDownloadModelWrapperSpy).not.toHaveBeenCalled();
expect(mocks.performDownloadMock).not.toHaveBeenCalled();
});

test('pullApplication should mark the loading config as error if not container are found', async () => {
Expand Down Expand Up @@ -385,7 +388,7 @@ describe('pullApplication', () => {
await expect(manager.pullApplication(recipe, model)).rejects.toThrowError('No containers available.');

expect(cloneRepositoryMock).not.toHaveBeenCalled();
expect(doDownloadModelWrapperSpy).not.toHaveBeenCalled();
expect(mocks.performDownloadMock).not.toHaveBeenCalled();
});
});
describe('doCheckout', () => {
Expand Down
11 changes: 10 additions & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,17 @@ export class ApplicationManager {
// and backend (that define which model supports)
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder, taskUtil);

// Create the task on the recipe (which will be propagated to the TaskRegistry
taskUtil.setTask({
id: model.id,
state: 'loading',
name: `Downloading model ${model.name}`,
labels: {
'model-pulling': model.id,
},
});
// get model by downloading it or retrieving locally
const modelPath = await this.modelsManager.downloadModel(model, taskUtil);
const modelPath = await this.modelsManager.downloadModel(model);

// build all images, one per container (for a basic sample we should have 2 containers = sample app + model service)
const images = await this.buildImages(
Expand Down
154 changes: 55 additions & 99 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ import { type MockInstance, beforeEach, describe, expect, test, vi } from 'vites
import os from 'os';
import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import type { DownloadModelResult } from './modelsManager';
import { ModelsManager } from './modelsManager';
import type { TelemetryLogger, Webview } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { RecipeStatusUtils } from '../utils/recipeStatusUtils';
import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
import * as utils from '../utils/utils';
import { TaskRegistry } from '../registries/TaskRegistry';

const mocks = vi.hoisted(() => {
return {
showErrorMessageMock: vi.fn(),
logUsageMock: vi.fn(),
logErrorMock: vi.fn(),
performDownloadMock: vi.fn(),
onEventDownloadMock: vi.fn(),
};
});

Expand All @@ -52,10 +52,14 @@ vi.mock('@podman-desktop/api', () => {
};
});

let setTaskMock: MockInstance;
let taskUtils: RecipeStatusUtils;
let setTaskStateMock: MockInstance;
let setTaskErrorMock: MockInstance;
vi.mock('../utils/downloader', () => ({
Downloader: class {
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
},
}));

let taskRegistry: TaskRegistry;

const telemetryLogger = {
logUsage: mocks.logUsageMock,
Expand All @@ -64,12 +68,7 @@ const telemetryLogger = {

beforeEach(() => {
vi.resetAllMocks();
taskUtils = new RecipeStatusUtils('recipe', {
setStatus: vi.fn(),
} as unknown as RecipeStatusRegistry);
setTaskMock = vi.spyOn(taskUtils, 'setTask');
setTaskStateMock = vi.spyOn(taskUtils, 'setTaskState');
setTaskErrorMock = vi.spyOn(taskUtils, 'setTaskError');
taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview);
});

const dirent = [
Expand Down Expand Up @@ -143,6 +142,7 @@ test('getModelsInfo should get models in local directory', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(manager.getModelsInfo()).toEqual([
Expand Down Expand Up @@ -188,6 +188,7 @@ test('getModelsInfo should return an empty array if the models folder does not e
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
manager.getLocalModelsFromDisk();
expect(manager.getModelsInfo()).toEqual([]);
Expand Down Expand Up @@ -224,6 +225,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(manager.getModelsInfo()).toEqual([
Expand Down Expand Up @@ -266,6 +268,7 @@ test('loadLocalModels should post a message with the message on disk and on cata
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(postMessageMock).toHaveBeenNthCalledWith(1, {
Expand Down Expand Up @@ -311,6 +314,7 @@ test('deleteLocalModel deletes the model folder', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
await manager.deleteLocalModel('model-id-1');
Expand Down Expand Up @@ -360,6 +364,7 @@ test('deleteLocalModel fails to delete the model folder', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
await manager.deleteLocalModel('model-id-1');
Expand Down Expand Up @@ -390,56 +395,58 @@ test('deleteLocalModel fails to delete the model folder', async () => {
});

describe('downloadModel', () => {
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
);
test('download model if not already on disk', async () => {
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
const doDownloadModelWrapperMock = vi
.spyOn(manager, 'doDownloadModelWrapper')
.mockImplementation((_modelId: string, _url: string, _taskUtil: RecipeStatusUtils, _destFileName?: string) => {
return Promise.resolve('');
});
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);
await manager.downloadModel(
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
expect(doDownloadModelWrapperMock).toBeCalledWith('id', 'url', taskUtils);
expect(setTaskMock).toHaveBeenLastCalledWith({

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);
const setMock = vi.spyOn(taskRegistry, 'set');
await manager.downloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);
expect(setMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Downloading model name',
labels: {
'model-pulling': 'id',
},
state: 'loading',
});
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'model.download', { 'model.id': 'id', durationSeconds: 99 });
});
test('retrieve model path if already on disk', async () => {
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue('');
await manager.downloadModel(
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
const setMock = vi.spyOn(taskRegistry, 'set');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue('');
await manager.downloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);
expect(getLocalModelPathMock).toBeCalledWith('id');
expect(setTaskMock).toHaveBeenLastCalledWith({
expect(setMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Model name already present on disk',
labels: {
Expand All @@ -449,54 +456,3 @@ describe('downloadModel', () => {
});
});
});

describe('doDownloadModelWrapper', () => {
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
);
test('returning model path if model has been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: true,
path: 'path',
});
},
);
setTaskStateMock.mockReturnThis();
const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils);
expect(result).toBe('path');
});
test('rejecting with error message if model has NOT been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: false,
error: 'error',
});
},
);
setTaskStateMock.mockReturnThis();
setTaskErrorMock.mockReturnThis();
await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error');
});
});
Loading

0 comments on commit 60aaa3e

Please sign in to comment.