From 83c94d0b2416623fdff0da10a787b29d2a997655 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Wed, 6 Mar 2024 16:23:37 +0100 Subject: [PATCH] feat: reuse existing model downloading tasks (#388) * feat: reuse existing model Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * feat: properly use the tasks registry Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: prettier&linter Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: unit tests Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * tests: ensuring multiple download call do not result in multiple downloader created Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: prettier&linter Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * feat(models): improve model download management Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: remove console.log Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --------- Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --- .../src/managers/applicationManager.spec.ts | 2 + .../src/managers/applicationManager.ts | 2 +- .../src/managers/modelsManager.spec.ts | 89 ++++++++++++- .../backend/src/managers/modelsManager.ts | 121 ++++++++++++++---- .../backend/src/registries/TaskRegistry.ts | 26 ++-- packages/backend/src/studio-api-impl.ts | 7 +- packages/backend/src/utils/downloader.ts | 17 ++- packages/frontend/src/pages/Models.svelte | 16 ++- 8 files changed, 236 insertions(+), 44 deletions(-) diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index 14310ea63..8cb61bb7c 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -74,6 +74,7 @@ const mocks = vi.hoisted(() => { stopPodMock: vi.fn(), removePodMock: vi.fn(), performDownloadMock: vi.fn(), + getTargetMock: vi.fn(), onEventDownloadMock: vi.fn(), // TaskRegistry getTaskMock: vi.fn(), @@ -94,6 +95,7 @@ vi.mock('../utils/downloader', () => ({ Downloader: class { onEvent = mocks.onEventDownloadMock; perform = mocks.performDownloadMock; + getTarget = mocks.getTargetMock; }, })); diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index d4379b675..c63b0e468 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -131,7 +131,7 @@ export class ApplicationManager { const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder); // get model by downloading it or retrieving locally - const modelPath = await this.modelsManager.downloadModel(model, { + const modelPath = await this.modelsManager.requestDownloadModel(model, { 'recipe-id': recipe.id, 'model-id': model.id, }); diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index 4c2dd1dd3..b5c2679b7 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -34,6 +34,9 @@ const mocks = vi.hoisted(() => { logErrorMock: vi.fn(), performDownloadMock: vi.fn(), onEventDownloadMock: vi.fn(), + getTargetMock: vi.fn(), + getDownloaderCompleter: vi.fn(), + isCompletionEventMock: vi.fn(), }; }); @@ -53,9 +56,14 @@ vi.mock('@podman-desktop/api', () => { }); vi.mock('../utils/downloader', () => ({ + isCompletionEvent: mocks.isCompletionEventMock, Downloader: class { + get completed() { + return mocks.getDownloaderCompleter(); + } onEvent = mocks.onEventDownloadMock; perform = mocks.performDownloadMock; + getTarget = mocks.getTargetMock; }, })); @@ -69,6 +77,8 @@ const telemetryLogger = { beforeEach(() => { vi.resetAllMocks(); taskRegistry = new TaskRegistry({ postMessage: vi.fn().mockResolvedValue(undefined) } as unknown as Webview); + + mocks.isCompletionEventMock.mockReturnValue(true); }); const dirent = [ @@ -411,7 +421,7 @@ describe('downloadModel', () => { vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99); const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask'); - await manager.downloadModel({ + await manager.requestDownloadModel({ id: 'id', url: 'url', name: 'name', @@ -440,7 +450,7 @@ describe('downloadModel', () => { const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask'); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); const getLocalModelPathMock = vi.spyOn(manager, 'getLocalModelPath').mockReturnValue(''); - await manager.downloadModel({ + await manager.requestDownloadModel({ id: 'id', url: 'url', name: 'name', @@ -455,4 +465,79 @@ describe('downloadModel', () => { state: 'success', }); }); + test('multiple download request same model - second call after first completed', async () => { + mocks.getDownloaderCompleter.mockReturnValue(true); + + const manager = new ModelsManager( + 'appdir', + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + taskRegistry, + ); + + vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); + vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99); + + await manager.requestDownloadModel({ + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo); + + await manager.requestDownloadModel({ + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo); + + // Only called once + expect(mocks.performDownloadMock).toHaveBeenCalledTimes(1); + expect(mocks.onEventDownloadMock).toHaveBeenCalledTimes(1); + }); + + test('multiple download request same model - second call before first completed', async () => { + mocks.getDownloaderCompleter.mockReturnValue(false); + + const manager = new ModelsManager( + 'appdir', + {} as Webview, + { + getModels(): ModelInfo[] { + return []; + }, + } as CatalogManager, + telemetryLogger, + taskRegistry, + ); + + vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); + vi.spyOn(utils, 'getDurationSecondsSince').mockReturnValue(99); + + mocks.onEventDownloadMock.mockImplementation(listener => { + listener({ + status: 'completed', + }); + }); + + await manager.requestDownloadModel({ + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo); + + await manager.requestDownloadModel({ + id: 'id', + url: 'url', + name: 'name', + } as ModelInfo); + + // Only called once + expect(mocks.performDownloadMock).toHaveBeenCalledTimes(1); + expect(mocks.onEventDownloadMock).toHaveBeenCalledTimes(2); + }); }); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 1b9fe99e5..d67f6bbf9 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -33,6 +33,8 @@ export class ModelsManager implements Disposable { #models: Map; #watcher?: podmanDesktopApi.FileSystemWatcher; + #downloaders: Map = new Map(); + constructor( private appUserDirectory: string, private webview: Webview, @@ -171,37 +173,58 @@ export class ModelsManager implements Disposable { } } - async downloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise { - const task: Task = this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', { - ...labels, - 'model-pulling': model.id, - }); + /** + * This method will resolve when the provided model will be downloaded. + * + * This can method can be call multiple time for the same model, it will reuse existing downloader and wait on + * their completion. + * @param model + * @param labels + */ + async requestDownloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise { + // Create a task to follow progress + const task: Task = this.createDownloadTask(model, labels); - // Check if the model is already on disk. - if (this.isModelOnDisk(model.id)) { + // Check there is no existing downloader running + if (!this.#downloaders.has(model.id)) { + return this.downloadModel(model, task); + } + + const existingDownloader = this.#downloaders.get(model.id); + if (existingDownloader.completed) { task.state = 'success'; - task.name = `Model ${model.name} already present on disk`; - this.taskRegistry.updateTask(task); // update task + this.taskRegistry.updateTask(task); - // return model path - return this.getLocalModelPath(model.id); + return existingDownloader.getTarget(); } - // update task to loading state - this.taskRegistry.updateTask(task); + // If we have an existing downloader running we subscribe on its events + return new Promise((resolve, reject) => { + const disposable = existingDownloader.onEvent(event => { + if (!isCompletionEvent(event)) return; - // Ensure path to model directory exist - const destDir = path.join(this.appUserDirectory, 'models', model.id); - if (!fs.existsSync(destDir)) { - fs.mkdirSync(destDir, { recursive: true }); - } + switch (event.status) { + case 'completed': + resolve(existingDownloader.getTarget()); + break; + default: + reject(new Error(event.message)); + } + disposable.dispose(); + }); + }); + } - const target = path.resolve(destDir, path.basename(model.url)); - // Create a downloader - const downloader = new Downloader(model.url, target); + private onDownloadEvent(event: DownloadEvent): void { + // Always use the task registry as source of truth for tasks + const tasks = this.taskRegistry.getTasksByLabels({ 'model-pulling': event.id }); + if (tasks.length === 0) { + // tasks might have been cleared but still an error. + console.error('received download event but no task is associated.'); + return; + } - // Capture downloader events - downloader.onEvent((event: DownloadEvent) => { + tasks.forEach(task => { if (isProgressEvent(event)) { task.state = 'loading'; task.progress = event.value; @@ -214,7 +237,7 @@ export class ModelsManager implements Disposable { // telemetry usage this.telemetry.logError('model.download', { - 'model.id': model.id, + 'model.id': event.id, message: 'error downloading model', error: event.message, durationSeconds: event.duration, @@ -224,15 +247,57 @@ export class ModelsManager implements Disposable { task.progress = 100; // telemetry usage - this.telemetry.logUsage('model.download', { 'model.id': model.id, durationSeconds: event.duration }); + this.telemetry.logUsage('model.download', { 'model.id': event.id, durationSeconds: event.duration }); } } - this.taskRegistry.updateTask(task); // update task }); + } + + private createDownloader(model: ModelInfo): Downloader { + // Ensure path to model directory exist + const destDir = path.join(this.appUserDirectory, 'models', model.id); + if (!fs.existsSync(destDir)) { + fs.mkdirSync(destDir, { recursive: true }); + } + + const target = path.resolve(destDir, path.basename(model.url)); + // Create a downloader + const downloader = new Downloader(model.url, target); + + this.#downloaders.set(model.id, downloader); + + return downloader; + } + + private createDownloadTask(model: ModelInfo, labels?: { [key: string]: string }): Task { + return this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', { + ...labels, + 'model-pulling': model.id, + }); + } + + private async downloadModel(model: ModelInfo, task: Task): Promise { + // Check if the model is already on disk. + if (this.isModelOnDisk(model.id)) { + task.state = 'success'; + task.name = `Model ${model.name} already present on disk`; + this.taskRegistry.updateTask(task); // update task + + // return model path + return this.getLocalModelPath(model.id); + } + + // update task to loading state + this.taskRegistry.updateTask(task); + + const downloader = this.createDownloader(model); + + // Capture downloader events + downloader.onEvent(this.onDownloadEvent.bind(this)); // perform download - await downloader.perform(); - return target; + await downloader.perform(model.id); + return downloader.getTarget(); } } diff --git a/packages/backend/src/registries/TaskRegistry.ts b/packages/backend/src/registries/TaskRegistry.ts index 0d974ee9c..3bc88727a 100644 --- a/packages/backend/src/registries/TaskRegistry.ts +++ b/packages/backend/src/registries/TaskRegistry.ts @@ -107,16 +107,26 @@ export class TaskRegistry { * @returns An array of tasks that match the specified labels. */ getTasksByLabels(requestedLabels: { [key: string]: string }): Task[] { - return this.getTasks().filter(task => { - const labels = task.labels; - if (labels === undefined) return false; + return this.getTasks().filter(task => this.filter(task, requestedLabels)); + } - for (const [key, value] of Object.entries(requestedLabels)) { - if (!(key in labels) || labels[key] !== value) return false; - } + /** + * Return the first task matching all the labels provided + * @param requestedLabels + */ + findTaskByLabels(requestedLabels: { [key: string]: string }): Task | undefined { + return this.getTasks().find(task => this.filter(task, requestedLabels)); + } - return true; - }); + private filter(task: Task, requestedLabels: { [key: string]: string }): boolean { + const labels = task.labels; + if (labels === undefined) return false; + + for (const [key, value] of Object.entries(requestedLabels)) { + if (!(key in labels) || labels[key] !== value) return false; + } + + return true; } /** diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index c2da531ea..71164dfa1 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -68,9 +68,10 @@ export class StudioApiImpl implements StudioAPI { .withProgress({ location: podmanDesktopApi.ProgressLocation.TASK_WIDGET, title: `Pulling ${recipe.name}.` }, () => this.applicationManager.pullApplication(recipe, model), ) - .catch(() => { + .catch((err: unknown) => { + console.error('Something went wrong while trying to start application', err); podmanDesktopApi.window - .showErrorMessage(`Error starting the application "${recipe.name}"`) + .showErrorMessage(`Error starting the application "${recipe.name}": ${String(err)}`) .catch((err: unknown) => { console.error(`Something went wrong with confirmation modals`, err); }); @@ -250,7 +251,7 @@ export class StudioApiImpl implements StudioAPI { const modelInfo: ModelInfo = this.modelsManager.getModelInfo(modelId); // Do not wait for the download task as it is too long. - this.modelsManager.downloadModel(modelInfo).catch((err: unknown) => { + this.modelsManager.requestDownloadModel(modelInfo).catch((err: unknown) => { console.error(`Something went wrong while trying to download the model ${modelId}`, err); }); } diff --git a/packages/backend/src/utils/downloader.ts b/packages/backend/src/utils/downloader.ts index d57495e13..7489df4c3 100644 --- a/packages/backend/src/utils/downloader.ts +++ b/packages/backend/src/utils/downloader.ts @@ -22,6 +22,7 @@ import https from 'node:https'; import { EventEmitter, type Event } from '@podman-desktop/api'; export interface DownloadEvent { + id: string; status: 'error' | 'completed' | 'progress' | 'canceled'; message?: string; } @@ -56,6 +57,9 @@ export const isProgressEvent = (value: unknown): value is ProgressEvent => { export class Downloader { private readonly _onEvent = new EventEmitter(); readonly onEvent: Event = this._onEvent.event; + private requestedIdentifier: string; + + completed: boolean; constructor( private url: string, @@ -63,13 +67,19 @@ export class Downloader { private abortSignal?: AbortSignal, ) {} - async perform() { + getTarget(): string { + return this.target; + } + + async perform(id: string) { + this.requestedIdentifier = id; const startTime = performance.now(); try { await this.download(this.url); const durationSeconds = getDurationSecondsSince(startTime); this._onEvent.fire({ + id: this.requestedIdentifier, status: 'completed', message: `Duration ${durationSeconds}s.`, duration: durationSeconds, @@ -77,15 +87,19 @@ export class Downloader { } catch (err: unknown) { if (!this.abortSignal?.aborted) { this._onEvent.fire({ + id: this.requestedIdentifier, status: 'error', message: `Something went wrong: ${String(err)}.`, }); } else { this._onEvent.fire({ + id: this.requestedIdentifier, status: 'canceled', message: `Request cancelled: ${String(err)}.`, }); } + } finally { + this.completed = true; } } @@ -124,6 +138,7 @@ export class Downloader { if (progressValue === 100 || progressValue - previousProgressValue > 1) { previousProgressValue = progressValue; this._onEvent.fire({ + id: this.requestedIdentifier, status: 'progress', value: progressValue, } as ProgressEvent); diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index 6e0b18a16..2e54b2f3b 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -55,7 +55,21 @@ function filterModels(): void { onMount(() => { // Subscribe to the tasks store const tasksUnsubscribe = tasks.subscribe(value => { - pullingTasks = value.filter(task => task.state === 'loading' && task.labels && 'model-pulling' in task.labels); + // Filter out duplicates + const modelIds = new Set(); + pullingTasks = value.reduce((filtered: Task[], task: Task) => { + if ( + task.state === 'loading' && + task.labels !== undefined && + 'model-pulling' in task.labels && + !modelIds.has(task.labels['model-pulling']) + ) { + modelIds.add(task.labels['model-pulling']); + filtered.push(task); + } + return filtered; + }, []); + loading = false; filterModels(); });