Skip to content

Commit

Permalink
fix: return correct path when downloading model
Browse files Browse the repository at this point in the history
Signed-off-by: lstocchi <[email protected]>
  • Loading branch information
lstocchi committed Jan 26, 2024
1 parent 6c2133e commit b84eb2f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 18 deletions.
50 changes: 48 additions & 2 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { type MockInstance, describe, expect, test, vi, beforeEach } from 'vitest';
import type { ContainerAttachedInfo, ImageInfo, PodInfo } from './applicationManager';
import type { ContainerAttachedInfo, DownloadModelResult, ImageInfo, PodInfo } from './applicationManager';
import { ApplicationManager } from './applicationManager';
import type { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
import type { GitManager } from './gitManager';
Expand Down Expand Up @@ -733,17 +733,63 @@ describe('createApplicationPod', () => {
});
});

describe('restartContainerWhenEndpointIsUp', () => {
describe('doDownloadModelWrapper', () => {
const manager = new ApplicationManager(
'/home/user/aistudio',
{} as unknown as GitManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as ModelsManager,
);
test('returning model path if model has been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: true,
path: 'path',
});
},
);
setTaskStateMock.mockReturnThis();
const result = await manager.doDownloadModelWrapper('id', 'url', taskUtils);
expect(result).toBe('path');
});
test('rejecting with error message if model has NOT been downloaded', async () => {
vi.spyOn(manager, 'doDownloadModel').mockImplementation(
(
_modelId: string,
_url: string,
_taskUtil: RecipeStatusUtils,
callback: (message: DownloadModelResult) => void,
_destFileName?: string,
) => {
callback({
successful: false,
error: 'error',
});
},
);
setTaskStateMock.mockReturnThis();
await expect(manager.doDownloadModelWrapper('id', 'url', taskUtils)).rejects.toThrowError('error');
});
});

describe('restartContainerWhenEndpointIsUp', () => {
const containerAttachedInfo: ContainerAttachedInfo = {
name: 'name',
endPoint: 'endpoint',
};
const manager = new ApplicationManager(
'/home/user/aistudio',
{} as unknown as GitManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as ModelsManager,
);
test('restart container if endpoint is alive', async () => {
vi.spyOn(utils, 'isEndpointAlive').mockResolvedValue(true);
await manager.restartContainerWhenEndpointIsUp('engine', containerAttachedInfo);
Expand Down
47 changes: 31 additions & 16 deletions packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,16 @@ import { isEndpointAlive, timeout } from '../utils/utils';

export const CONFIG_FILENAME = 'ai-studio.yaml';

interface DownloadModelResult {
result: 'ok' | 'failed';
error?: string;
export type DownloadModelResult = DownloadModelSuccessfulResult | DownloadModelFailureResult;

interface DownloadModelSuccessfulResult {
successful: true;
path: string;
}

interface DownloadModelFailureResult {
successful: false;
error: string;
}

interface AIContainers {
Expand Down Expand Up @@ -428,7 +435,20 @@ export class ApplicationManager {
},
});

return await this.doDownloadModelWrapper(model.id, model.url, taskUtil);
try {
return await this.doDownloadModelWrapper(model.id, model.url, taskUtil);
} catch (e) {
console.error(e);
taskUtil.setTask({
id: model.id,
state: 'error',
name: `Downloading model ${model.name}`,
labels: {
'model-pulling': model.id,
},
});
throw e;
}
} else {
taskUtil.setTask({
id: model.id,
Expand Down Expand Up @@ -520,26 +540,20 @@ export class ApplicationManager {
): Promise<string> {
return new Promise((resolve, reject) => {
const downloadCallback = (result: DownloadModelResult) => {
if (result.result) {
if (result.successful === true) {
taskUtil.setTaskState(modelId, 'success');
resolve(destFileName);
} else {
resolve(result.path);
} else if (result.successful === false) {
taskUtil.setTaskState(modelId, 'error');
reject(result.error);
}
};

if (fs.existsSync(destFileName)) {
taskUtil.setTaskState(modelId, 'success');
taskUtil.setTaskProgress(modelId, 100);
return;
}

this.doDownloadModel(modelId, url, taskUtil, downloadCallback, destFileName);
});
}

private doDownloadModel(
doDownloadModel(
modelId: string,
url: string,
taskUtil: RecipeStatusUtils,
Expand Down Expand Up @@ -581,7 +595,8 @@ export class ApplicationManager {
//this.sendProgress(progressValue);
if (progressValue === 100) {
callback({
result: 'ok',
successful: true,
path: destFile,
});
}
});
Expand All @@ -590,7 +605,7 @@ export class ApplicationManager {
});
file.on('error', e => {
callback({
result: 'failed',
successful: false,
error: e.message,
});
});
Expand Down

0 comments on commit b84eb2f

Please sign in to comment.