Skip to content

Commit

Permalink
feat: add model settings and prompt template from hf
Browse files Browse the repository at this point in the history
Signed-off-by: James <[email protected]>
  • Loading branch information
namchuai committed May 21, 2024
1 parent 0ae2c27 commit 050fe6e
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 7 deletions.
2 changes: 2 additions & 0 deletions cortex-js/src/command.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { ModelRemoveCommand } from './infrastructure/commanders/models/model-rem
import { RunCommand } from './infrastructure/commanders/shortcuts/run.command';
import { InitCudaQuestions } from './infrastructure/commanders/questions/cuda.questions';
import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usecases.module';
import { ModelUpdateCommand } from './infrastructure/commanders/models/model-update.command';

@Module({
imports: [
Expand Down Expand Up @@ -55,6 +56,7 @@ import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usec
ModelGetCommand,
ModelRemoveCommand,
ModelPullCommand,
ModelUpdateCommand,

// Shortcuts
RunCommand,
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/infrastructure/commanders/chat.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class ChatCommand extends CommandRunner {
}

@Option({
flags: '--model <model_id>',
flags: '-m, --model <model_id>',
description: 'Model Id to start chat with',
})
parseModelId(value: string) {
Expand Down
2 changes: 2 additions & 0 deletions cortex-js/src/infrastructure/commanders/models.command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { ModelListCommand } from './models/model-list.command';
import { ModelStopCommand } from './models/model-stop.command';
import { ModelPullCommand } from './models/model-pull.command';
import { ModelRemoveCommand } from './models/model-remove.command';
import { ModelUpdateCommand } from './models/model-update.command';

@SubCommand({
name: 'models',
Expand All @@ -15,6 +16,7 @@ import { ModelRemoveCommand } from './models/model-remove.command';
ModelListCommand,
ModelGetCommand,
ModelRemoveCommand,
ModelUpdateCommand,
],
description: 'Subcommands for managing models',
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
import { exit } from 'node:process';
import { ModelParameterParser } from '../utils/model-parameter.parser';
import {
ModelRuntimeParams,
ModelSettingParams,
} from '@/domain/models/model.interface';

type UpdateOptions = {
model?: string;
options?: string[];
};

@SubCommand({ name: 'update', description: 'Update configuration of a model.' })
export class ModelUpdateCommand extends CommandRunner {
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
super();
}

async run(_input: string[], option: UpdateOptions): Promise<void> {
const modelId = option.model;
if (!modelId) {
console.error('Model Id is required');
exit(1);
}

const options = option.options;
if (!options || options.length === 0) {
console.log('Nothing to update');
exit(0);
}

const parser = new ModelParameterParser();
const settingParams: ModelSettingParams = {};
const runtimeParams: ModelRuntimeParams = {};

options.forEach((option) => {
const [key, stringValue] = option.split('=');
if (parser.isModelSettingParam(key)) {
const value = parser.parse(key, stringValue);
// @ts-expect-error did the check so it's safe
settingParams[key] = value;
} else if (parser.isModelRuntimeParam(key)) {
const value = parser.parse(key, stringValue);
// @ts-expect-error did the check so it's safe
runtimeParams[key] = value;
}
});

if (Object.keys(settingParams).length > 0) {
const updatedSettingParams =
await this.modelsCliUsecases.updateModelSettingParams(
modelId,
settingParams,
);
console.log(
'Updated setting params! New setting params:',
updatedSettingParams,
);
}

if (Object.keys(runtimeParams).length > 0) {
await this.modelsCliUsecases.updateModelRuntimeParams(
modelId,
runtimeParams,
);
console.log('Updated runtime params! New runtime params:', runtimeParams);
}
}

@Option({
flags: '-m, --model <model_id>',
required: true,
description: 'Model Id to update',
})
parseModelId(value: string) {
return value;
}

@Option({
flags: '-c, --options <options...>',
description:
'Specify the options to update the model. Syntax: -c option1=value1 option2=value2. For example: cortex models update -c max_tokens=100 temperature=0.5',
})
parseOptions(option: string, optionsAccumulator: string[] = []): string[] {
optionsAccumulator.push(option);
return optionsAccumulator;
}
}
37 changes: 37 additions & 0 deletions cortex-js/src/infrastructure/commanders/prompt-constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//// HF Chat template
export const OPEN_CHAT_3_5_JINJA = ``;

export const ZEPHYR_JINJA = `{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
' + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}`;

//// Corresponding prompt template
export const OPEN_CHAT_3_5 = `GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:`;

export const ZEPHYR = `<|system|>
{system_message}</s>
<|user|>
{prompt}</s>
<|assistant|>
`;

export const COMMAND_R = `<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}
`;

// getting from https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGUF
export const LLAMA_2 = `[INST] <<SYS>>
{system_message}
<</SYS>>
{prompt}[/INST]`;
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { CommandRunner, SubCommand, Option } from 'nest-commander';
import { exit } from 'node:process';
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';

type RunOptions = {
model?: string;
Expand All @@ -29,7 +30,11 @@ export class RunCommand extends CommandRunner {
exit(1);
}

await this.cortexUsecases.startCortex();
await this.cortexUsecases.startCortex(
defaultCortexCppHost,
defaultCortexCppPort,
false,
);
await this.modelsUsecases.startModel(modelId);
const chatCliUsecases = new ChatCliUsecases(
this.chatUsecases,
Expand All @@ -39,7 +44,7 @@ export class RunCommand extends CommandRunner {
}

@Option({
flags: '--model <model_id>',
flags: '-m, --model <model_id>',
description: 'Model Id to start chat with',
})
parseModelId(value: string) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import { exit } from 'node:process';
import { ModelsUsecases } from '@/usecases/models/models.usecases';
import { Model, ModelFormat } from '@/domain/models/model.interface';
import {
Model,
ModelFormat,
ModelRuntimeParams,
ModelSettingParams,
} from '@/domain/models/model.interface';
import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto';
import { HuggingFaceRepoData } from '@/domain/models/huggingface.interface';
import { gguf } from '@huggingface/gguf';
import { InquirerService } from 'nest-commander';
import { Inject, Injectable } from '@nestjs/common';
import { Presets, SingleBar } from 'cli-progress';
import {
LLAMA_2,
OPEN_CHAT_3_5,
OPEN_CHAT_3_5_JINJA,
ZEPHYR,
ZEPHYR_JINJA,
} from '../prompt-constants';

const AllQuantizations = [
'Q3_K_S',
Expand Down Expand Up @@ -49,6 +61,20 @@ export class ModelsCliUsecases {
await this.modelsUsecases.stopModel(modelId);
}

async updateModelSettingParams(
modelId: string,
settingParams: ModelSettingParams,
): Promise<ModelSettingParams> {
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
}

async updateModelRuntimeParams(
modelId: string,
runtimeParams: ModelRuntimeParams,
): Promise<ModelRuntimeParams> {
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
}

private async getModelOrStop(modelId: string): Promise<Model> {
const model = await this.modelsUsecases.findOne(modelId);
if (!model) {
Expand Down Expand Up @@ -103,10 +129,16 @@ export class ModelsCliUsecases {
if (!sibling) throw 'No expected quantization found';

let stopWord = '';
let promptTemplate = LLAMA_2;

try {
const { metadata } = await gguf(sibling.downloadUrl!);
// @ts-expect-error "tokenizer.ggml.eos_token_id"
const index = metadata['tokenizer.ggml.eos_token_id'];
// @ts-expect-error "tokenizer.ggml.eos_token_id"
const hfChatTemplate = metadata['tokenizer.chat_template'];
promptTemplate = this.guessPromptTemplateFromHuggingFace(hfChatTemplate);

// @ts-expect-error "tokenizer.ggml.tokens"
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
} catch (err) {
Expand All @@ -129,7 +161,9 @@ export class ModelsCliUsecases {
version: '',
format: ModelFormat.GGUF,
description: '',
settings: {},
settings: {
prompt_template: promptTemplate,
},
parameters: {
stop: stopWords,
},
Expand All @@ -144,6 +178,37 @@ export class ModelsCliUsecases {
await this.modelsUsecases.create(model);
}

// TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider?
private guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
if (!jinjaCode) {
console.log('No jinja code provided. Returning default LLAMA_2');
return LLAMA_2;
}

if (typeof jinjaCode !== 'string') {
console.log(
`Invalid jinja code provided (type is ${typeof jinjaCode}). Returning default LLAMA_2`,
);
return LLAMA_2;
}

switch (jinjaCode) {
case ZEPHYR_JINJA:
return ZEPHYR;

case OPEN_CHAT_3_5_JINJA:
return OPEN_CHAT_3_5;

default:
console.log(
'Unknown jinja code:',
jinjaCode,
'Returning default LLAMA_2',
);
return LLAMA_2;
}
}

private async fetchHuggingFaceRepoData(repoId: string) {
const sanitizedUrl = this.toHuggingFaceUrl(repoId);

Expand Down
Loading

0 comments on commit 050fe6e

Please sign in to comment.