Skip to content

Commit

Permalink
feat(models): improve model download management
Browse files Browse the repository at this point in the history
Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 committed Mar 4, 2024
1 parent 9ecf93d commit e63c4b8
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 45 deletions.
93 changes: 51 additions & 42 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,28 +173,32 @@ export class ModelsManager implements Disposable {
}
}

/**
* This method will resolve when the provided model will be downloaded.
*
* This can method can be call multiple time for the same model, it will reuse existing downloader and wait on
* their completion.
* @param model
* @param labels
*/
async requestDownloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise<string> {
// Create a task to follow progress
const task: Task = this.createDownloadTask(model, labels);

// Check there is no existing downloader running
if (!this.#downloaders.has(model.id)) {
console.debug('no downloader has been found.');
return this.downloadModel(model, labels);
}

const task = this.taskRegistry.findTaskByLabels({ 'model-pulling': model.id });
if (task !== undefined) {
task.labels = {
...labels,
...task.labels,
};
this.taskRegistry.updateTask(task);
return this.downloadModel(model, task);
}

const existingDownloader = this.#downloaders.get(model.id);
if (existingDownloader.completed) {
task.state = 'success';
this.taskRegistry.updateTask(task);

return existingDownloader.getTarget();
}

// If we have an existing downloader running we
// If we have an existing downloader running we subscribe on its events
return new Promise((resolve, reject) => {
const disposable = existingDownloader.onEvent(event => {
if (!isCompletionEvent(event)) return;
Expand All @@ -213,40 +217,43 @@ export class ModelsManager implements Disposable {

private onDownloadEvent(event: DownloadEvent): void {
// Always use the task registry as source of truth for tasks
const task = this.taskRegistry.findTaskByLabels({ 'model-pulling': event.id });
if (task === undefined) {
const tasks = this.taskRegistry.getTasksByLabels({ 'model-pulling': event.id });
if (tasks.length === 0) {
// tasks might have been cleared but still an error.
console.error('received download event but no task is associated.');
return;
}

if (isProgressEvent(event)) {
task.state = 'loading';
task.progress = event.value;
} else if (isCompletionEvent(event)) {
// status error or canceled
if (event.status === 'error' || event.status === 'canceled') {
task.state = 'error';
task.progress = undefined;
task.error = event.message;

// telemetry usage
this.telemetry.logError('model.download', {
'model.id': event.id,
message: 'error downloading model',
error: event.message,
durationSeconds: event.duration,
});
} else {
task.state = 'success';
task.progress = 100;

// telemetry usage
this.telemetry.logUsage('model.download', { 'model.id': event.id, durationSeconds: event.duration });
console.log(`onDownloadEvent updating ${tasks.length} tasks.`);

tasks.forEach(task => {
if (isProgressEvent(event)) {
task.state = 'loading';
task.progress = event.value;
} else if (isCompletionEvent(event)) {
// status error or canceled
if (event.status === 'error' || event.status === 'canceled') {
task.state = 'error';
task.progress = undefined;
task.error = event.message;

// telemetry usage
this.telemetry.logError('model.download', {
'model.id': event.id,
message: 'error downloading model',
error: event.message,
durationSeconds: event.duration,
});
} else {
task.state = 'success';
task.progress = 100;

// telemetry usage
this.telemetry.logUsage('model.download', { 'model.id': event.id, durationSeconds: event.duration });
}
}
}

this.taskRegistry.updateTask(task); // update task
this.taskRegistry.updateTask(task); // update task
});
}

private createDownloader(model: ModelInfo): Downloader {
Expand All @@ -267,12 +274,14 @@ export class ModelsManager implements Disposable {
return downloader;
}

private async downloadModel(model: ModelInfo, labels?: { [key: string]: string }): Promise<string> {
const task: Task = this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', {
private createDownloadTask(model: ModelInfo, labels?: { [key: string]: string }): Task {
return this.taskRegistry.createTask(`Downloading model ${model.name}`, 'loading', {
...labels,
'model-pulling': model.id,
});
}

private async downloadModel(model: ModelInfo, task: Task): Promise<string> {
// Check if the model is already on disk.
if (this.isModelOnDisk(model.id)) {
task.state = 'success';
Expand Down
4 changes: 4 additions & 0 deletions packages/backend/src/registries/TaskRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ export class TaskRegistry {
return this.getTasks().filter(task => this.filter(task, requestedLabels));
}

/**
* Return the first task matching all the labels provided
* @param requestedLabels
*/
findTaskByLabels(requestedLabels: { [key: string]: string }): Task | undefined {
return this.getTasks().find(task => this.filter(task, requestedLabels));
}
Expand Down
5 changes: 3 additions & 2 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ export class StudioApiImpl implements StudioAPI {
.withProgress({ location: podmanDesktopApi.ProgressLocation.TASK_WIDGET, title: `Pulling ${recipe.name}.` }, () =>
this.applicationManager.pullApplication(recipe, model),
)
.catch(() => {
.catch((err: unknown) => {
console.error('Something went wrong while trying to start application', err);
podmanDesktopApi.window
.showErrorMessage(`Error starting the application "${recipe.name}"`)
.showErrorMessage(`Error starting the application "${recipe.name}": ${String(err)}`)
.catch((err: unknown) => {
console.error(`Something went wrong with confirmation modals`, err);
});
Expand Down
16 changes: 15 additions & 1 deletion packages/frontend/src/pages/Models.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,21 @@ function filterModels(): void {
onMount(() => {
// Subscribe to the tasks store
const tasksUnsubscribe = tasks.subscribe(value => {
pullingTasks = value.filter(task => task.state === 'loading' && task.labels && 'model-pulling' in task.labels);
// Filter out duplicates
const modelIds = new Set<string>();
pullingTasks = value.reduce((filtered: Task[], task: Task) => {
if (
task.state === 'loading' &&
task.labels !== undefined &&
'model-pulling' in task.labels &&
!modelIds.has(task.labels['model-pulling'])
) {
modelIds.add(task.labels['model-pulling']);
filtered.push(task);
}
return filtered;
}, []);
loading = false;
filterModels();
});
Expand Down

0 comments on commit e63c4b8

Please sign in to comment.