diff --git a/cortex-js/src/infrastructure/controllers/models.controller.ts b/cortex-js/src/infrastructure/controllers/models.controller.ts index 851a9f91a..f4428cc74 100644 --- a/cortex-js/src/infrastructure/controllers/models.controller.ts +++ b/cortex-js/src/infrastructure/controllers/models.controller.ts @@ -3,7 +3,6 @@ import { Get, Post, Body, - Patch, Param, Delete, HttpCode, @@ -23,9 +22,7 @@ import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-s import { TransformInterceptor } from '../interceptors/transform.interceptor'; import { CortexUsecases } from '@/usecases/cortex/cortex.usecases'; import { ModelSettingsDto } from '../dtos/models/model-settings.dto'; -import { - EventName, -} from '@/domain/telemetry/telemetry.interface'; +import { EventName } from '@/domain/telemetry/telemetry.interface'; import { TelemetryUsecases } from '@/usecases/telemetry/telemetry.usecases'; import { CommonResponseDto } from '../dtos/common/common-response.dto'; import { HuggingFaceRepoSibling } from '@/domain/models/huggingface.interface'; @@ -118,23 +115,53 @@ export class ModelsController { }, ], }) - + @ApiParam({ + name: 'modelId', + required: true, + description: 'The unique identifier of the model.', + }) + @ApiParam({ + name: 'fileName', + required: false, + description: 'The file name of the model to download.', + }) + @ApiParam({ + name: 'persistedModelId', + required: false, + description: 'The unique identifier of the model in your local storage.', + }) @Get('download/:modelId(*)') - downloadModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string, @Query('persistedModelId') persistedModelId?: string) { - this.modelsUsecases.pullModel(modelId, false, (files) => { - return new Promise(async (resolve, reject) => { - const file = files - .find((e) => e.quantization && e.rfilename === fileName) - if(!file) { - return reject(new BadRequestException('File not found')); - } - return resolve(file); - }); - }, persistedModelId).then(() => this.telemetryUsecases.addEventToQueue({ - name: EventName.DOWNLOAD_MODEL, - modelId, - }) - ); + downloadModel( + @Param('modelId') modelId: string, + @Query('fileName') fileName: string, + @Query('persistedModelId') persistedModelId?: string, + ) { + this.modelsUsecases + .pullModel( + modelId, + false, + (files) => { + return new Promise( + async (resolve, reject) => { + const file = files.find( + (e) => + e.quantization && (!fileName || e.rfilename === fileName), + ); + if (!file) { + return reject(new BadRequestException('File not found')); + } + return resolve(file); + }, + ); + }, + persistedModelId, + ) + .then(() => + this.telemetryUsecases.addEventToQueue({ + name: EventName.DOWNLOAD_MODEL, + modelId, + }), + ); return { message: 'Download model started successfully.', }; @@ -173,22 +200,48 @@ export class ModelsController { required: true, description: 'The unique identifier of the model.', }) + @ApiParam({ + name: 'fileName', + required: false, + description: 'The file name of the model to download.', + }) + @ApiParam({ + name: 'persistedModelId', + required: false, + description: 'The unique identifier of the model in your local storage.', + }) @Get('pull/:modelId(*)') - pullModel(@Param('modelId') modelId: string, @Query('fileName') fileName: string, @Query('persistedModelId') persistedModelId?: string) { - this.modelsUsecases.pullModel(modelId, false, (files) => { - return new Promise(async (resolve, reject) => { - const file = files - .find((e) => e.quantization && e.rfilename === fileName) - if(!file) { - return reject(new BadRequestException('File not found')); - } - return resolve(file); - }); - }, persistedModelId).then(() => this.telemetryUsecases.addEventToQueue({ - name: EventName.DOWNLOAD_MODEL, - modelId, - }) - ); + pullModel( + @Param('modelId') modelId: string, + @Query('fileName') fileName?: string, + @Query('persistedModelId') persistedModelId?: string, + ) { + this.modelsUsecases + .pullModel( + modelId, + false, + (files) => { + return new Promise( + async (resolve, reject) => { + const file = files.find( + (e) => + e.quantization && (!fileName || e.rfilename === fileName), + ); + if (!file) { + return reject(new BadRequestException('File not found')); + } + return resolve(file); + }, + ); + }, + persistedModelId, + ) + .then(() => + this.telemetryUsecases.addEventToQueue({ + name: EventName.DOWNLOAD_MODEL, + modelId, + }), + ); return { message: 'Download model started successfully.', }; diff --git a/cortex-js/src/infrastructure/services/file-manager/file-manager.service.ts b/cortex-js/src/infrastructure/services/file-manager/file-manager.service.ts index c4c3fd499..d1b6d3a5d 100644 --- a/cortex-js/src/infrastructure/services/file-manager/file-manager.service.ts +++ b/cortex-js/src/infrastructure/services/file-manager/file-manager.service.ts @@ -285,7 +285,8 @@ export class FileManagerService { async writeFile(filePath: string, data: any): Promise { try { const dirPath = filePath.split('/').slice(0, -1).join('/'); - await this.createFolderIfNotExistInDataFolder(dirPath); + const folderName = dirPath.split('/').slice(-1)[0]; + await this.createFolderIfNotExistInDataFolder(folderName); return promises.writeFile(filePath, data, { encoding: 'utf8', flag: 'w+',