diff --git a/cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts b/cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts index e3e242997..1d118f647 100644 --- a/cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts +++ b/cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts @@ -51,16 +51,30 @@ export class ModelsCliUsecases { private readonly inquirerService: InquirerService, ) {} + /** + * Start a model by ID + * @param modelId + */ async startModel(modelId: string): Promise { await this.getModelOrStop(modelId); await this.modelsUsecases.startModel(modelId); } + /** + * Stop a model by ID + * @param modelId + */ async stopModel(modelId: string): Promise { await this.getModelOrStop(modelId); await this.modelsUsecases.stopModel(modelId); } + /** + * Update model's settings. E.g. ngl, prompt_template, etc. + * @param modelId + * @param settingParams + * @returns + */ async updateModelSettingParams( modelId: string, settingParams: ModelSettingParams, @@ -68,6 +82,12 @@ export class ModelsCliUsecases { return this.modelsUsecases.updateModelSettingParams(modelId, settingParams); } + /** + * Update model's runtime parameters. E.g. max_tokens, temperature, etc. + * @param modelId + * @param runtimeParams + * @returns + */ async updateModelRuntimeParams( modelId: string, runtimeParams: ModelRuntimeParams, @@ -75,6 +95,11 @@ export class ModelsCliUsecases { return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams); } + /** + * Find a model or abort if not exist + * @param modelId + * @returns + */ private async getModelOrStop(modelId: string): Promise { const model = await this.modelsUsecases.findOne(modelId); if (!model) { @@ -84,25 +109,42 @@ export class ModelsCliUsecases { return model; } + /** + * List all of the models + * @returns + */ async listAllModels(): Promise { return this.modelsUsecases.findAll(); } + /** + * Get a model by ID + * @param modelId + * @returns + */ async getModel(modelId: string): Promise { const model = await this.getModelOrStop(modelId); return model; } + /** + * Remove a model, this would also delete model files + * @param modelId + * @returns + */ async removeModel(modelId: string) { await this.getModelOrStop(modelId); return this.modelsUsecases.remove(modelId); } + /** + * Pull model from Model repository (HF, Jan...) + * @param modelId + */ async pullModel(modelId: string) { - if (modelId.includes('/')) { + if (modelId.includes('/') || modelId.includes(':')) { await this.pullHuggingFaceModel(modelId); } - const bar = new SingleBar({}, Presets.shades_classic); bar.start(100, 0); const callback = (progress: number) => { @@ -111,21 +153,43 @@ export class ModelsCliUsecases { await this.modelsUsecases.downloadModel(modelId, callback); } - private async pullHuggingFaceModel(modelId: string) { - const data = await this.fetchHuggingFaceRepoData(modelId); - const { quantization } = await this.inquirerService.inquirer.prompt({ - type: 'list', - name: 'quantization', - message: 'Select quantization', - choices: data.siblings - .map((e) => e.quantization) - .filter((e) => e != null), - }); + //// PRIVATE METHODS //// - const sibling = data.siblings - .filter((e) => !!e.quantization) - .find((e: any) => e.quantization === quantization); + /** + * It's to pull model from HuggingFace repository + * It could be a model from Jan's repo or other authors + * @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b" + */ + private async pullHuggingFaceModel(modelId: string) { + let data: HuggingFaceRepoData; + if (modelId.includes('/')) + data = await this.fetchHuggingFaceRepoData(modelId); + else data = await this.fetchJanRepoData(modelId); + + let sibling; + + const listChoices = data.siblings + .filter((e) => e.quantization != null) + .map((e) => { + return { + name: e.quantization, + value: e.quantization, + }; + }); + if (listChoices.length > 1) { + const { quantization } = await this.inquirerService.inquirer.prompt({ + type: 'list', + name: 'quantization', + message: 'Select quantization', + choices: listChoices, + }); + sibling = data.siblings + .filter((e) => !!e.quantization) + .find((e: any) => e.quantization === quantization); + } else { + sibling = data.siblings.find((e) => e.rfilename.includes('.gguf')); + } if (!sibling) throw 'No expected quantization found'; let stopWord = ''; @@ -141,9 +205,7 @@ export class ModelsCliUsecases { // @ts-expect-error "tokenizer.ggml.tokens" stopWord = metadata['tokenizer.ggml.tokens'][index] ?? ''; - } catch (err) { - console.log('Failed to get stop word: ', err); - } + } catch (err) {} const stopWords: string[] = []; if (stopWord.length > 0) { @@ -163,6 +225,7 @@ export class ModelsCliUsecases { description: '', settings: { prompt_template: promptTemplate, + llama_model_path: sibling.rfilename, }, parameters: { stop: stopWords, @@ -209,8 +272,71 @@ export class ModelsCliUsecases { } } + /** + * Fetch the model data from Jan's repo + * @param modelId HuggingFace model id. e.g. "llama-3:7b" + * @returns + */ + private async fetchJanRepoData(modelId: string) { + const repo = modelId.split(':')[0]; + const tree = modelId.split(':')[1]; + const url = this.getRepoModelsUrl(`janhq/${repo}`, tree); + const res = await fetch(url); + const response: + | { + path: string; + size: number; + }[] + | { error: string } = await res.json(); + + if ('error' in response && response.error != null) { + throw new Error(response.error); + } + + const data: HuggingFaceRepoData = { + siblings: Array.isArray(response) + ? response.map((e) => { + return { + rfilename: e.path, + downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`, + fileSize: e.size ?? 0, + }; + }) + : [], + tags: ['gguf'], + id: modelId, + modelId: modelId, + author: 'janhq', + sha: '', + downloads: 0, + lastModified: '', + private: false, + disabled: false, + gated: false, + pipeline_tag: 'text-generation', + cardData: {}, + createdAt: '', + }; + + AllQuantizations.forEach((quantization) => { + data.siblings.forEach((sibling: any) => { + if (!sibling.quantization && sibling.rfilename.includes(quantization)) { + sibling.quantization = quantization; + } + }); + }); + + data.modelUrl = url; + return data; + } + + /** + * Fetches the model data from HuggingFace API + * @param repoId HuggingFace model id. e.g. "janhq/llama-3" + * @returns + */ private async fetchHuggingFaceRepoData(repoId: string) { - const sanitizedUrl = this.toHuggingFaceUrl(repoId); + const sanitizedUrl = this.getRepoModelsUrl(repoId); const res = await fetch(sanitizedUrl); const response = await res.json(); @@ -245,24 +371,7 @@ export class ModelsCliUsecases { return data; } - private toHuggingFaceUrl(repoId: string): string { - try { - const url = new URL(`https://huggingface.co/${repoId}`); - if (url.host !== 'huggingface.co') { - throw `Invalid Hugging Face repo URL: ${repoId}`; - } - - const paths = url.pathname.split('/').filter((e) => e.trim().length > 0); - if (paths.length < 2) { - throw `Invalid Hugging Face repo URL: ${repoId}`; - } - - return `${url.origin}/api/models/${paths[0]}/${paths[1]}`; - } catch (err) { - if (repoId.startsWith('https')) { - throw new Error(`Cannot parse url: ${repoId}`); - } - throw err; - } + private getRepoModelsUrl(repoId: string, tree?: string): string { + return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`; } } diff --git a/cortex-js/src/infrastructure/commanders/utils/normalize-model-id.ts b/cortex-js/src/infrastructure/commanders/utils/normalize-model-id.ts new file mode 100644 index 000000000..c36cb339e --- /dev/null +++ b/cortex-js/src/infrastructure/commanders/utils/normalize-model-id.ts @@ -0,0 +1,3 @@ +export const normalizeModelId = (modelId: string): string => { + return modelId.replace(':', '%3A'); +}; diff --git a/cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts b/cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts index 941f1b860..e9174787a 100644 --- a/cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts +++ b/cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts @@ -6,6 +6,7 @@ import { Model, ModelSettingParams } from '@/domain/models/model.interface'; import { HttpService } from '@nestjs/axios'; import { defaultCortexCppHost, defaultCortexCppPort } from 'constant'; import { readdirSync } from 'node:fs'; +import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id'; /** * A class that implements the InferenceExtension interface from the @janhq/core package. @@ -32,7 +33,10 @@ export default class CortexProvider extends OAIEngineExtension { ): Promise { const modelsContainerDir = this.modelDir(); - const modelFolderFullPath = join(modelsContainerDir, model.id); + const modelFolderFullPath = join( + modelsContainerDir, + normalizeModelId(model.id), + ); const ggufFiles = readdirSync(modelFolderFullPath).filter((file) => { return file.endsWith('.gguf'); }); diff --git a/cortex-js/src/usecases/cortex/cortex.usecases.ts b/cortex-js/src/usecases/cortex/cortex.usecases.ts index 035aa1486..f5ef2a87f 100644 --- a/cortex-js/src/usecases/cortex/cortex.usecases.ts +++ b/cortex-js/src/usecases/cortex/cortex.usecases.ts @@ -32,7 +32,7 @@ export class CortexUsecases { ); if (!existsSync(cortexCppPath)) { - throw new Error('Cortex binary not found'); + throw new Error('The engine is not available, please run "cortex init".'); } // go up one level to get the binary folder, have to also work on windows diff --git a/cortex-js/src/usecases/models/models.usecases.ts b/cortex-js/src/usecases/models/models.usecases.ts index c5257648b..2ec0ffbba 100644 --- a/cortex-js/src/usecases/models/models.usecases.ts +++ b/cortex-js/src/usecases/models/models.usecases.ts @@ -23,6 +23,7 @@ import { ExtensionRepository } from '@/domain/repositories/extension.interface'; import { EngineExtension } from '@/domain/abstracts/engine.abstract'; import { HttpService } from '@nestjs/axios'; import { ModelSettingParamsDto } from '@/infrastructure/dtos/models/model-setting-params.dto'; +import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id'; @Injectable() export class ModelsUsecases { @@ -106,7 +107,7 @@ export class ModelsUsecases { return; } - const modelFolder = join(modelsContainerDir, id); + const modelFolder = join(modelsContainerDir, normalizeModelId(id)); return this.modelRepository .delete(id) @@ -205,7 +206,7 @@ export class ModelsUsecases { mkdirSync(modelsContainerDir, { recursive: true }); } - const modelFolder = join(modelsContainerDir, model.id); + const modelFolder = join(modelsContainerDir, normalizeModelId(model.id)); await promises.mkdir(modelFolder, { recursive: true }); const destination = join(modelFolder, fileName);