Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: remove the dependency of download task to the RecipeStatusUtils #340

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import type {
podStopHandle,
startupHandle,
} from './podmanConnection';
import { TaskRegistry } from '../registries/TaskRegistry';

const mocks = vi.hoisted(() => {
return {
Expand Down Expand Up @@ -73,11 +74,21 @@ const mocks = vi.hoisted(() => {
listPodsMock: vi.fn(),
stopPodMock: vi.fn(),
removePodMock: vi.fn(),
performDownloadMock: vi.fn(),
onEventDownloadMock: vi.fn(),
};
});
vi.mock('../models/AIConfig', () => ({
parseYamlFile: mocks.parseYamlFileMock,
}));

vi.mock('../utils/downloader', () => ({
Downloader: class {
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
},
}));

vi.mock('@podman-desktop/api', () => ({
provider: {
getContainerConnections: mocks.getContainerConnectionsMock,
Expand Down Expand Up @@ -129,10 +140,6 @@ describe('pullApplication', () => {
const cloneRepositoryMock = vi.fn();
let manager: ApplicationManager;
let modelsManager: ModelsManager;
let doDownloadModelWrapperSpy: MockInstance<
[modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string],
Promise<string>
>;
vi.spyOn(utils, 'timeout').mockResolvedValue();
function mockForPullApplication(options: mockForPullApplicationOptions) {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
Expand Down Expand Up @@ -210,6 +217,7 @@ describe('pullApplication', () => {
},
} as CatalogManager,
telemetryLogger,
new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview),
);
manager = new ApplicationManager(
'/home/user/aistudio',
Expand All @@ -225,14 +233,13 @@ describe('pullApplication', () => {
modelsManager,
telemetryLogger,
);
doDownloadModelWrapperSpy = vi.spyOn(modelsManager, 'doDownloadModelWrapper');
}
test('pullApplication should clone repository and call downloadModelMain and buildImage', async () => {
mockForPullApplication({
recipeFolderExists: false,
});
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
doDownloadModelWrapperSpy.mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -270,7 +277,7 @@ describe('pullApplication', () => {
gitCloneOptions.targetDirectory = '/home/user/aistudio/recipe1';
expect(cloneRepositoryMock).toHaveBeenNthCalledWith(1, gitCloneOptions);
}
expect(doDownloadModelWrapperSpy).toHaveBeenCalledOnce();
expect(mocks.performDownloadMock).toHaveBeenCalledOnce();
expect(mocks.buildImageMock).toHaveBeenCalledOnce();
expect(mocks.buildImageMock).toHaveBeenCalledWith(
`${gitCloneOptions.targetDirectory}${path.sep}contextdir1`,
Expand All @@ -283,11 +290,7 @@ describe('pullApplication', () => {
},
},
);
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'model.download', {
'model.id': 'model1',
durationSeconds: 99,
});
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(2, 'recipe.pull', {
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'recipe.pull', {
'recipe.id': 'recipe1',
'recipe.name': 'Recipe 1',
durationSeconds: 99,
Expand All @@ -298,7 +301,7 @@ describe('pullApplication', () => {
recipeFolderExists: true,
});
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(modelsManager, 'doDownloadModelWrapper').mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down Expand Up @@ -348,7 +351,7 @@ describe('pullApplication', () => {
};
await manager.pullApplication(recipe, model);
expect(cloneRepositoryMock).not.toHaveBeenCalled();
expect(doDownloadModelWrapperSpy).not.toHaveBeenCalled();
expect(mocks.performDownloadMock).not.toHaveBeenCalled();
});

test('pullApplication should mark the loading config as error if not container are found', async () => {
Expand Down Expand Up @@ -385,7 +388,7 @@ describe('pullApplication', () => {
await expect(manager.pullApplication(recipe, model)).rejects.toThrowError('No containers available.');

expect(cloneRepositoryMock).not.toHaveBeenCalled();
expect(doDownloadModelWrapperSpy).not.toHaveBeenCalled();
expect(mocks.performDownloadMock).not.toHaveBeenCalled();
});
});
describe('doCheckout', () => {
Expand Down
11 changes: 10 additions & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,17 @@ export class ApplicationManager {
// and backend (that define which model supports)
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder, taskUtil);

// Create the task on the recipe (which will be propagated to the TaskRegistry
taskUtil.setTask({
id: model.id,
state: 'loading',
name: `Downloading model ${model.name}`,
labels: {
'model-pulling': model.id,
},
});
// get model by downloading it or retrieving locally
const modelPath = await this.modelsManager.downloadModel(model, taskUtil);
const modelPath = await this.modelsManager.downloadModel(model);

// build all images, one per container (for a basic sample we should have 2 containers = sample app + model service)
const images = await this.buildImages(
Expand Down
154 changes: 55 additions & 99 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ import { type MockInstance, beforeEach, describe, expect, test, vi } from 'vites
import os from 'os';
import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import type { DownloadModelResult } from './modelsManager';
import { ModelsManager } from './modelsManager';
import type { TelemetryLogger, Webview } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { RecipeStatusUtils } from '../utils/recipeStatusUtils';
import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
import * as utils from '../utils/utils';
import { TaskRegistry } from '../registries/TaskRegistry';

const mocks = vi.hoisted(() => {
return {
showErrorMessageMock: vi.fn(),
logUsageMock: vi.fn(),
logErrorMock: vi.fn(),
performDownloadMock: vi.fn(),
onEventDownloadMock: vi.fn(),
};
});

Expand All @@ -52,10 +52,14 @@ vi.mock('@podman-desktop/api', () => {
};
});

let setTaskMock: MockInstance;
let taskUtils: RecipeStatusUtils;
let setTaskStateMock: MockInstance;
let setTaskErrorMock: MockInstance;
vi.mock('../utils/downloader', () => ({
Downloader: class {
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
},
}));

let taskRegistry: TaskRegistry;

const telemetryLogger = {
logUsage: mocks.logUsageMock,
Expand All @@ -64,12 +68,7 @@ const telemetryLogger = {

beforeEach(() => {
vi.resetAllMocks();
taskUtils = new RecipeStatusUtils('recipe', {
setStatus: vi.fn(),
} as unknown as RecipeStatusRegistry);
setTaskMock = vi.spyOn(taskUtils, 'setTask');
setTaskStateMock = vi.spyOn(taskUtils, 'setTaskState');
setTaskErrorMock = vi.spyOn(taskUtils, 'setTaskError');
taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview);
});

const dirent = [
Expand Down Expand Up @@ -143,6 +142,7 @@ test('getModelsInfo should get models in local directory', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(manager.getModelsInfo()).toEqual([
Expand Down Expand Up @@ -188,6 +188,7 @@ test('getModelsInfo should return an empty array if the models folder does not e
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
manager.getLocalModelsFromDisk();
expect(manager.getModelsInfo()).toEqual([]);
Expand Down Expand Up @@ -224,6 +225,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(manager.getModelsInfo()).toEqual([
Expand Down Expand Up @@ -266,6 +268,7 @@ test('loadLocalModels should post a message with the message on disk and on cata
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
expect(postMessageMock).toHaveBeenNthCalledWith(1, {
Expand Down Expand Up @@ -311,6 +314,7 @@ test('deleteLocalModel deletes the model folder', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
await manager.deleteLocalModel('model-id-1');
Expand Down Expand Up @@ -360,6 +364,7 @@ test('deleteLocalModel fails to delete the model folder', async () => {
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
await manager.loadLocalModels();
await manager.deleteLocalModel('model-id-1');
Expand Down Expand Up @@ -390,56 +395,58 @@ test('deleteLocalModel fails to delete the model folder', async () => {
});

describe('downloadModel', () => {
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
);
test('download model if not already on disk', async () => {
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
const doDownloadModelWrapperMock = vi
.spyOn(manager, 'doDownloadModelWrapper')
.mockImplementation((_modelId: string, _url: string, _taskUtil: RecipeStatusUtils, _destFileName?: string) => {
return Promise.resolve('');
});
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);
await manager.downloadModel(
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
expect(doDownloadModelWrapperMock).toBeCalledWith('id', 'url', taskUtils);
expect(setTaskMock).toHaveBeenLastCalledWith({

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);
const setMock = vi.spyOn(taskRegistry, 'set');
await manager.downloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);
expect(setMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Downloading model name',
labels: {
'model-pulling': 'id',
},
state: 'loading',
});
expect(mocks.logUsageMock).toHaveBeenNthCalledWith(1, 'model.download', { 'model.id': 'id', durationSeconds: 99 });
});
test('retrieve model path if already on disk', async () => {
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue('');
await manager.downloadModel(
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo,
taskUtils,
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);
const setMock = vi.spyOn(taskRegistry, 'set');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue('');
await manager.downloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);
expect(getLocalModelPathMock).toBeCalledWith('id');
expect(setTaskMock).toHaveBeenLastCalledWith({
expect(setMock).toHaveBeenLastCalledWith({
id: 'id',
name: 'Model name already present on disk',
labels: {
Expand All @@ -449,54 +456,3 @@ describe('downloadModel', () => {
});
});
});

describe('doDownloadModelWrapper', () => {
const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
);
test('returning model path if model has been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: true,
path: 'path',
});
},
);
setTaskStateMock.mockReturnThis();
const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils);
expect(result).toBe('path');
});
test('rejecting with error message if model has NOT been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: false,
error: 'error',
});
},
);
setTaskStateMock.mockReturnThis();
setTaskErrorMock.mockReturnThis();
await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error');
});
});
Loading