Skip to content

Commit

Permalink
feat: pull model from jan hub (#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan authored May 23, 2024
1 parent d31a788 commit 1084844
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 42 deletions.
185 changes: 147 additions & 38 deletions cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,55 @@ export class ModelsCliUsecases {
private readonly inquirerService: InquirerService,
) {}

/**
* Start a model by ID
* @param modelId
*/
async startModel(modelId: string): Promise<void> {
await this.getModelOrStop(modelId);
await this.modelsUsecases.startModel(modelId);
}

/**
* Stop a model by ID
* @param modelId
*/
async stopModel(modelId: string): Promise<void> {
await this.getModelOrStop(modelId);
await this.modelsUsecases.stopModel(modelId);
}

/**
* Update model's settings. E.g. ngl, prompt_template, etc.
* @param modelId
* @param settingParams
* @returns
*/
async updateModelSettingParams(
modelId: string,
settingParams: ModelSettingParams,
): Promise<ModelSettingParams> {
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
}

/**
* Update model's runtime parameters. E.g. max_tokens, temperature, etc.
* @param modelId
* @param runtimeParams
* @returns
*/
async updateModelRuntimeParams(
modelId: string,
runtimeParams: ModelRuntimeParams,
): Promise<ModelRuntimeParams> {
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
}

/**
* Find a model or abort if not exist
* @param modelId
* @returns
*/
private async getModelOrStop(modelId: string): Promise<Model> {
const model = await this.modelsUsecases.findOne(modelId);
if (!model) {
Expand All @@ -84,25 +109,42 @@ export class ModelsCliUsecases {
return model;
}

/**
* List all of the models
* @returns
*/
async listAllModels(): Promise<Model[]> {
return this.modelsUsecases.findAll();
}

/**
* Get a model by ID
* @param modelId
* @returns
*/
async getModel(modelId: string): Promise<Model> {
const model = await this.getModelOrStop(modelId);
return model;
}

/**
* Remove a model, this would also delete model files
* @param modelId
* @returns
*/
async removeModel(modelId: string) {
await this.getModelOrStop(modelId);
return this.modelsUsecases.remove(modelId);
}

/**
* Pull model from Model repository (HF, Jan...)
* @param modelId
*/
async pullModel(modelId: string) {
if (modelId.includes('/')) {
if (modelId.includes('/') || modelId.includes(':')) {
await this.pullHuggingFaceModel(modelId);
}

const bar = new SingleBar({}, Presets.shades_classic);
bar.start(100, 0);
const callback = (progress: number) => {
Expand All @@ -111,21 +153,43 @@ export class ModelsCliUsecases {
await this.modelsUsecases.downloadModel(modelId, callback);
}

private async pullHuggingFaceModel(modelId: string) {
const data = await this.fetchHuggingFaceRepoData(modelId);
const { quantization } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'quantization',
message: 'Select quantization',
choices: data.siblings
.map((e) => e.quantization)
.filter((e) => e != null),
});
//// PRIVATE METHODS ////

const sibling = data.siblings
.filter((e) => !!e.quantization)
.find((e: any) => e.quantization === quantization);
/**
* It's to pull model from HuggingFace repository
* It could be a model from Jan's repo or other authors
* @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
*/
private async pullHuggingFaceModel(modelId: string) {
let data: HuggingFaceRepoData;
if (modelId.includes('/'))
data = await this.fetchHuggingFaceRepoData(modelId);
else data = await this.fetchJanRepoData(modelId);

let sibling;

const listChoices = data.siblings
.filter((e) => e.quantization != null)
.map((e) => {
return {
name: e.quantization,
value: e.quantization,
};
});

if (listChoices.length > 1) {
const { quantization } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'quantization',
message: 'Select quantization',
choices: listChoices,
});
sibling = data.siblings
.filter((e) => !!e.quantization)
.find((e: any) => e.quantization === quantization);
} else {
sibling = data.siblings.find((e) => e.rfilename.includes('.gguf'));
}
if (!sibling) throw 'No expected quantization found';

let stopWord = '';
Expand All @@ -141,9 +205,7 @@ export class ModelsCliUsecases {

// @ts-expect-error "tokenizer.ggml.tokens"
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
} catch (err) {
console.log('Failed to get stop word: ', err);
}
} catch (err) {}

const stopWords: string[] = [];
if (stopWord.length > 0) {
Expand All @@ -163,6 +225,7 @@ export class ModelsCliUsecases {
description: '',
settings: {
prompt_template: promptTemplate,
llama_model_path: sibling.rfilename,
},
parameters: {
stop: stopWords,
Expand Down Expand Up @@ -209,8 +272,71 @@ export class ModelsCliUsecases {
}
}

/**
* Fetch the model data from Jan's repo
* @param modelId HuggingFace model id. e.g. "llama-3:7b"
* @returns
*/
private async fetchJanRepoData(modelId: string) {
const repo = modelId.split(':')[0];
const tree = modelId.split(':')[1];
const url = this.getRepoModelsUrl(`janhq/${repo}`, tree);
const res = await fetch(url);
const response:
| {
path: string;
size: number;
}[]
| { error: string } = await res.json();

if ('error' in response && response.error != null) {
throw new Error(response.error);
}

const data: HuggingFaceRepoData = {
siblings: Array.isArray(response)
? response.map((e) => {
return {
rfilename: e.path,
downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`,
fileSize: e.size ?? 0,
};
})
: [],
tags: ['gguf'],
id: modelId,
modelId: modelId,
author: 'janhq',
sha: '',
downloads: 0,
lastModified: '',
private: false,
disabled: false,
gated: false,
pipeline_tag: 'text-generation',
cardData: {},
createdAt: '',
};

AllQuantizations.forEach((quantization) => {
data.siblings.forEach((sibling: any) => {
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
sibling.quantization = quantization;
}
});
});

data.modelUrl = url;
return data;
}

/**
* Fetches the model data from HuggingFace API
* @param repoId HuggingFace model id. e.g. "janhq/llama-3"
* @returns
*/
private async fetchHuggingFaceRepoData(repoId: string) {
const sanitizedUrl = this.toHuggingFaceUrl(repoId);
const sanitizedUrl = this.getRepoModelsUrl(repoId);

const res = await fetch(sanitizedUrl);
const response = await res.json();
Expand Down Expand Up @@ -245,24 +371,7 @@ export class ModelsCliUsecases {
return data;
}

private toHuggingFaceUrl(repoId: string): string {
try {
const url = new URL(`https://huggingface.co/${repoId}`);
if (url.host !== 'huggingface.co') {
throw `Invalid Hugging Face repo URL: ${repoId}`;
}

const paths = url.pathname.split('/').filter((e) => e.trim().length > 0);
if (paths.length < 2) {
throw `Invalid Hugging Face repo URL: ${repoId}`;
}

return `${url.origin}/api/models/${paths[0]}/${paths[1]}`;
} catch (err) {
if (repoId.startsWith('https')) {
throw new Error(`Cannot parse url: ${repoId}`);
}
throw err;
}
private getRepoModelsUrl(repoId: string, tree?: string): string {
return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export const normalizeModelId = (modelId: string): string => {
return modelId.replace(':', '%3A');
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { Model, ModelSettingParams } from '@/domain/models/model.interface';
import { HttpService } from '@nestjs/axios';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';
import { readdirSync } from 'node:fs';
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';

/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
Expand All @@ -32,7 +33,10 @@ export default class CortexProvider extends OAIEngineExtension {
): Promise<void> {
const modelsContainerDir = this.modelDir();

const modelFolderFullPath = join(modelsContainerDir, model.id);
const modelFolderFullPath = join(
modelsContainerDir,
normalizeModelId(model.id),
);
const ggufFiles = readdirSync(modelFolderFullPath).filter((file) => {
return file.endsWith('.gguf');
});
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/usecases/cortex/cortex.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class CortexUsecases {
);

if (!existsSync(cortexCppPath)) {
throw new Error('Cortex binary not found');
throw new Error('The engine is not available, please run "cortex init".');
}

// go up one level to get the binary folder, have to also work on windows
Expand Down
5 changes: 3 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { HttpService } from '@nestjs/axios';
import { ModelSettingParamsDto } from '@/infrastructure/dtos/models/model-setting-params.dto';
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';

@Injectable()
export class ModelsUsecases {
Expand Down Expand Up @@ -106,7 +107,7 @@ export class ModelsUsecases {
return;
}

const modelFolder = join(modelsContainerDir, id);
const modelFolder = join(modelsContainerDir, normalizeModelId(id));

return this.modelRepository
.delete(id)
Expand Down Expand Up @@ -205,7 +206,7 @@ export class ModelsUsecases {
mkdirSync(modelsContainerDir, { recursive: true });
}

const modelFolder = join(modelsContainerDir, model.id);
const modelFolder = join(modelsContainerDir, normalizeModelId(model.id));
await promises.mkdir(modelFolder, { recursive: true });
const destination = join(modelFolder, fileName);

Expand Down

0 comments on commit 1084844

Please sign in to comment.