From b84eb2f7d8d0e8b5386edbe536d3816a0f6b169e Mon Sep 17 00:00:00 2001 From: lstocchi Date: Fri, 26 Jan 2024 10:45:07 +0100 Subject: [PATCH] fix: return correct path when downloading model Signed-off-by: lstocchi --- .../src/managers/applicationManager.spec.ts | 50 ++++++++++++++++++- .../src/managers/applicationManager.ts | 47 +++++++++++------ 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/packages/backend/src/managers/applicationManager.spec.ts b/packages/backend/src/managers/applicationManager.spec.ts index 82af4538a..0491ffd41 100644 --- a/packages/backend/src/managers/applicationManager.spec.ts +++ b/packages/backend/src/managers/applicationManager.spec.ts @@ -1,5 +1,5 @@ import { type MockInstance, describe, expect, test, vi, beforeEach } from 'vitest'; -import type { ContainerAttachedInfo, ImageInfo, PodInfo } from './applicationManager'; +import type { ContainerAttachedInfo, DownloadModelResult, ImageInfo, PodInfo } from './applicationManager'; import { ApplicationManager } from './applicationManager'; import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; import type { GitManager } from './gitManager'; @@ -733,17 +733,63 @@ describe('createApplicationPod', () => { }); }); -describe('restartContainerWhenEndpointIsUp', () => { +describe('doDownloadModelWrapper', () => { const manager = new ApplicationManager( '/home/user/aistudio', {} as unknown as GitManager, {} as unknown as RecipeStatusRegistry, {} as unknown as ModelsManager, ); + test('returning model path if model has been downloaded', async () => { + vi.spyOn(manager, 'doDownloadModel').mockImplementation( + ( + _modelId: string, + _url: string, + _taskUtil: RecipeStatusUtils, + callback: (message: DownloadModelResult) => void, + _destFileName?: string, + ) => { + callback({ + successful: true, + path: 'path', + }); + }, + ); + setTaskStateMock.mockReturnThis(); + const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils); + expect(result).toBe('path'); + }); + test('rejecting with error message if model has NOT been downloaded', async () => { + vi.spyOn(manager, 'doDownloadModel').mockImplementation( + ( + _modelId: string, + _url: string, + _taskUtil: RecipeStatusUtils, + callback: (message: DownloadModelResult) => void, + _destFileName?: string, + ) => { + callback({ + successful: false, + error: 'error', + }); + }, + ); + setTaskStateMock.mockReturnThis(); + await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error'); + }); +}); + +describe('restartContainerWhenEndpointIsUp', () => { const containerAttachedInfo: ContainerAttachedInfo = { name: 'name', endPoint: 'endpoint', }; + const manager = new ApplicationManager( + '/home/user/aistudio', + {} as unknown as GitManager, + {} as unknown as RecipeStatusRegistry, + {} as unknown as ModelsManager, + ); test('restart container if endpoint is alive', async () => { vi.spyOn(utils, 'isEndpointAlive').mockResolvedValue(true); await manager.restartContainerWhenEndpointIsUp('engine', containerAttachedInfo); diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index df5a252a9..d21ff2d79 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -36,9 +36,16 @@ import { isEndpointAlive, timeout } from '../utils/utils'; export const CONFIG_FILENAME = 'ai-studio.yaml'; -interface DownloadModelResult { - result: 'ok' | 'failed'; - error?: string; +export type DownloadModelResult = DownloadModelSuccessfulResult | DownloadModelFailureResult; + +interface DownloadModelSuccessfulResult { + successful: true; + path: string; +} + +interface DownloadModelFailureResult { + successful: false; + error: string; } interface AIContainers { @@ -428,7 +435,20 @@ export class ApplicationManager { }, }); - return await this.doDownloadModelWrapper(model.id, model.url, taskUtil); + try { + return await this.doDownloadModelWrapper(model.id, model.url, taskUtil); + } catch (e) { + console.error(e); + taskUtil.setTask({ + id: model.id, + state: 'error', + name: `Downloading model ${model.name}`, + labels: { + 'model-pulling': model.id, + }, + }); + throw e; + } } else { taskUtil.setTask({ id: model.id, @@ -520,26 +540,20 @@ export class ApplicationManager { ): Promise { return new Promise((resolve, reject) => { const downloadCallback = (result: DownloadModelResult) => { - if (result.result) { + if (result.successful === true) { taskUtil.setTaskState(modelId, 'success'); - resolve(destFileName); - } else { + resolve(result.path); + } else if (result.successful === false) { taskUtil.setTaskState(modelId, 'error'); reject(result.error); } }; - if (fs.existsSync(destFileName)) { - taskUtil.setTaskState(modelId, 'success'); - taskUtil.setTaskProgress(modelId, 100); - return; - } - this.doDownloadModel(modelId, url, taskUtil, downloadCallback, destFileName); }); } - private doDownloadModel( + doDownloadModel( modelId: string, url: string, taskUtil: RecipeStatusUtils, @@ -581,7 +595,8 @@ export class ApplicationManager { //this.sendProgress(progressValue); if (progressValue === 100) { callback({ - result: 'ok', + successful: true, + path: destFile, }); } }); @@ -590,7 +605,7 @@ export class ApplicationManager { }); file.on('error', e => { callback({ - result: 'failed', + successful: false, error: e.message, }); });