Skip to content

Commit

Permalink
feat: reuse existing model downloading tasks (containers#388)
Browse files Browse the repository at this point in the history
* feat: reuse existing model

Signed-off-by: axel7083 <[email protected]>

* feat: properly use the tasks registry

Signed-off-by: axel7083 <[email protected]>

* fix: prettier&linter

Signed-off-by: axel7083 <[email protected]>

* fix: unit tests

Signed-off-by: axel7083 <[email protected]>

* tests: ensuring multiple download call do not result in multiple downloader created

Signed-off-by: axel7083 <[email protected]>

* fix: prettier&linter

Signed-off-by: axel7083 <[email protected]>

* feat(models): improve model download management

Signed-off-by: axel7083 <[email protected]>

* fix: remove console.log

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Mar 6, 2024
1 parent bc03f30 commit 83c94d0
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 44 deletions.
2 changes: 2 additions & 0 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ const mocks = vi.hoisted(() => {
stopPodMock: vi.fn(),
removePodMock: vi.fn(),
performDownloadMock: vi.fn(),
getTargetMock: vi.fn(),
onEventDownloadMock: vi.fn(),
// TaskRegistry
getTaskMock: vi.fn(),
Expand All @@ -94,6 +95,7 @@ vi.mock('../utils/downloader', () => ({
Downloader: class {
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
getTarget = mocks.getTargetMock;
},
}));

Expand Down
2 changes: 1 addition & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class ApplicationManager {
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder);

// get model by downloading it or retrieving locally
const modelPath = await this.modelsManager.downloadModel(model, {
const modelPath = await this.modelsManager.requestDownloadModel(model, {
'recipe-id': recipe.id,
'model-id': model.id,
});
Expand Down
89 changes: 87 additions & 2 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const mocks = vi.hoisted(() => {
logErrorMock: vi.fn(),
performDownloadMock: vi.fn(),
onEventDownloadMock: vi.fn(),
getTargetMock: vi.fn(),
getDownloaderCompleter: vi.fn(),
isCompletionEventMock: vi.fn(),
};
});

Expand All @@ -53,9 +56,14 @@ vi.mock('@podman-desktop/api', () => {
});

vi.mock('../utils/downloader', () => ({
isCompletionEvent: mocks.isCompletionEventMock,
Downloader: class {
get completed() {
return mocks.getDownloaderCompleter();
}
onEvent = mocks.onEventDownloadMock;
perform = mocks.performDownloadMock;
getTarget = mocks.getTargetMock;
},
}));

Expand All @@ -69,6 +77,8 @@ const telemetryLogger = {
beforeEach(() => {
vi.resetAllMocks();
taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview);

mocks.isCompletionEventMock.mockReturnValue(true);
});

const dirent = [
Expand Down Expand Up @@ -411,7 +421,7 @@ describe('downloadModel', () => {
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);
const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask');
await manager.downloadModel({
await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
Expand Down Expand Up @@ -440,7 +450,7 @@ describe('downloadModel', () => {
const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue('');
await manager.downloadModel({
await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
Expand All @@ -455,4 +465,79 @@ describe('downloadModel', () => {
state: 'success',
});
});
test('multiple download request same model - second call after first completed', async () => {
mocks.getDownloaderCompleter.mockReturnValue(true);

const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);

await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);

await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);

// Only called once
expect(mocks.performDownloadMock).toHaveBeenCalledTimes(1);
expect(mocks.onEventDownloadMock).toHaveBeenCalledTimes(1);
});

test('multiple download request same model - second call before first completed', async () => {
mocks.getDownloaderCompleter.mockReturnValue(false);

const manager = new ModelsManager(
'appdir',
{} as Webview,
{
getModels(): ModelInfo[] {
return [];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99);

mocks.onEventDownloadMock.mockImplementation(listener => {
listener({
status: 'completed',
});
});

await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);

await manager.requestDownloadModel({
id: 'id',
url: 'url',
name: 'name',
} as ModelInfo);

// Only called once
expect(mocks.performDownloadMock).toHaveBeenCalledTimes(1);
expect(mocks.onEventDownloadMock).toHaveBeenCalledTimes(2);
});
});
121 changes: 93 additions & 28 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ export class ModelsManager implements Disposable {
#models: Map<string, ModelInfo>;
#watcher?: podmanDesktopApi.FileSystemWatcher;

#downloaders: Map<string, Downloader> = new Map<string, Downloader>();

constructor(
private appUserDirectory: string,
private webview: Webview,
Expand Down Expand Up @@ -171,37 +173,58 @@ export class ModelsManager implements Disposable {
}
}

async downloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise<string> {
const task: Task = this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', {
...labels,
'model-pulling': model.id,
});
/**
* This method will resolve when the provided model will be downloaded.
*
* This can method can be call multiple time for the same model, it will reuse existing downloader and wait on
* their completion.
* @param model
* @param labels
*/
async requestDownloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise<string> {
// Create a task to follow progress
const task: Task = this.createDownloadTask(model, labels);

// Check if the model is already on disk.
if (this.isModelOnDisk(model.id)) {
// Check there is no existing downloader running
if (!this.#downloaders.has(model.id)) {
return this.downloadModel(model, task);
}

const existingDownloader = this.#downloaders.get(model.id);
if (existingDownloader.completed) {
task.state = 'success';
task.name = `Model ${model.name} already present on disk`;
this.taskRegistry.updateTask(task); // update task
this.taskRegistry.updateTask(task);

// return model path
return this.getLocalModelPath(model.id);
return existingDownloader.getTarget();
}

// update task to loading state
this.taskRegistry.updateTask(task);
// If we have an existing downloader running we subscribe on its events
return new Promise((resolve, reject) => {
const disposable = existingDownloader.onEvent(event => {
if (!isCompletionEvent(event)) return;

// Ensure path to model directory exist
const destDir = path.join(this.appUserDirectory, 'models', model.id);
if (!fs.existsSync(destDir)) {
fs.mkdirSync(destDir, { recursive: true });
}
switch (event.status) {
case 'completed':
resolve(existingDownloader.getTarget());
break;
default:
reject(new Error(event.message));
}
disposable.dispose();
});
});
}

const target = path.resolve(destDir, path.basename(model.url));
// Create a downloader
const downloader = new Downloader(model.url, target);
private onDownloadEvent(event: DownloadEvent): void {
// Always use the task registry as source of truth for tasks
const tasks = this.taskRegistry.getTasksByLabels({ 'model-pulling': event.id });
if (tasks.length === 0) {
// tasks might have been cleared but still an error.
console.error('received download event but no task is associated.');
return;
}

// Capture downloader events
downloader.onEvent((event: DownloadEvent) => {
tasks.forEach(task => {
if (isProgressEvent(event)) {
task.state = 'loading';
task.progress = event.value;
Expand All @@ -214,7 +237,7 @@ export class ModelsManager implements Disposable {

// telemetry usage
this.telemetry.logError('model.download', {
'model.id': model.id,
'model.id': event.id,
message: 'error downloading model',
error: event.message,
durationSeconds: event.duration,
Expand All @@ -224,15 +247,57 @@ export class ModelsManager implements Disposable {
task.progress = 100;

// telemetry usage
this.telemetry.logUsage('model.download', { 'model.id': model.id, durationSeconds: event.duration });
this.telemetry.logUsage('model.download', { 'model.id': event.id, durationSeconds: event.duration });
}
}

this.taskRegistry.updateTask(task); // update task
});
}

private createDownloader(model: ModelInfo): Downloader {
// Ensure path to model directory exist
const destDir = path.join(this.appUserDirectory, 'models', model.id);
if (!fs.existsSync(destDir)) {
fs.mkdirSync(destDir, { recursive: true });
}

const target = path.resolve(destDir, path.basename(model.url));
// Create a downloader
const downloader = new Downloader(model.url, target);

this.#downloaders.set(model.id, downloader);

return downloader;
}

private createDownloadTask(model: ModelInfo, labels?: { [key: string]: string }): Task {
return this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', {
...labels,
'model-pulling': model.id,
});
}

private async downloadModel(model: ModelInfo, task: Task): Promise<string> {
// Check if the model is already on disk.
if (this.isModelOnDisk(model.id)) {
task.state = 'success';
task.name = `Model ${model.name} already present on disk`;
this.taskRegistry.updateTask(task); // update task

// return model path
return this.getLocalModelPath(model.id);
}

// update task to loading state
this.taskRegistry.updateTask(task);

const downloader = this.createDownloader(model);

// Capture downloader events
downloader.onEvent(this.onDownloadEvent.bind(this));

// perform download
await downloader.perform();
return target;
await downloader.perform(model.id);
return downloader.getTarget();
}
}
26 changes: 18 additions & 8 deletions packages/backend/src/registries/TaskRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,26 @@ export class TaskRegistry {
* @returns An array of tasks that match the specified labels.
*/
getTasksByLabels(requestedLabels: { [key: string]: string }): Task[] {
return this.getTasks().filter(task => {
const labels = task.labels;
if (labels === undefined) return false;
return this.getTasks().filter(task => this.filter(task, requestedLabels));
}

for (const [key, value] of Object.entries(requestedLabels)) {
if (!(key in labels) || labels[key] !== value) return false;
}
/**
* Return the first task matching all the labels provided
* @param requestedLabels
*/
findTaskByLabels(requestedLabels: { [key: string]: string }): Task | undefined {
return this.getTasks().find(task => this.filter(task, requestedLabels));
}

return true;
});
private filter(task: Task, requestedLabels: { [key: string]: string }): boolean {
const labels = task.labels;
if (labels === undefined) return false;

for (const [key, value] of Object.entries(requestedLabels)) {
if (!(key in labels) || labels[key] !== value) return false;
}

return true;
}

/**
Expand Down
7 changes: 4 additions & 3 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ export class StudioApiImpl implements StudioAPI {
.withProgress({ location: podmanDesktopApi.ProgressLocation.TASK_WIDGET, title: `Pulling ${recipe.name}.` }, () =>
this.applicationManager.pullApplication(recipe, model),
)
.catch(() => {
.catch((err: unknown) => {
console.error('Something went wrong while trying to start application', err);
podmanDesktopApi.window
.showErrorMessage(`Error starting the application "${recipe.name}"`)
.showErrorMessage(`Error starting the application "${recipe.name}": ${String(err)}`)
.catch((err: unknown) => {
console.error(`Something went wrong with confirmation modals`, err);
});
Expand Down Expand Up @@ -250,7 +251,7 @@ export class StudioApiImpl implements StudioAPI {
const modelInfo: ModelInfo = this.modelsManager.getModelInfo(modelId);

// Do not wait for the download task as it is too long.
this.modelsManager.downloadModel(modelInfo).catch((err: unknown) => {
this.modelsManager.requestDownloadModel(modelInfo).catch((err: unknown) => {
console.error(`Something went wrong while trying to download the model ${modelId}`, err);
});
}
Expand Down
Loading

0 comments on commit 83c94d0

Please sign in to comment.