Skip to content

Commit

Permalink
add support of changing default serving model
Browse files Browse the repository at this point in the history
  • Loading branch information
t83714 committed Jul 23, 2024
1 parent fbe2674 commit 554ceb6
Show file tree
Hide file tree
Showing 14 changed files with 284 additions and 131 deletions.
151 changes: 76 additions & 75 deletions README.md

Large diffs are not rendered by default.

14 changes: 2 additions & 12 deletions deploy/magda-embedding-api/templates/configmap.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
{{- $root := . }}
{{- if .Values.config }}
apiVersion: v1
kind: ConfigMap
metadata:
name: "{{ template "magda.fullname" . }}-config"
name: "{{ template "magda.fullname" . }}-appConfig"
labels:
{{- include "magda.common.labels.standard" . | nindent 4 }}
data:
{{- range $configName, $configYaml := .Values.config }}
{{ $configName }}: |
{{- if (eq (kindOf $configYaml) "map")}}
{{- tpl (toYaml $configYaml) $root | nindent 4 }}
{{- else -}}
{{- tpl $configYaml $root | nindent 4 }}
{{- end -}}
{{- end -}}
{{- end -}}
appConfig.json: {{ .Values.appConfig | mustToRawJson | quote }}
24 changes: 9 additions & 15 deletions deploy/magda-embedding-api/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ spec:
{{ $key }}: {{ $value | quote }}
{{- end }}
{{- /* This forces a restart if the configmap has changed */}}
{{- if .Values.config }}
configchecksum: {{ include (print .Template.BasePath "/configmap.yaml") . | sha256sum | trunc 63 }}
{{- end }}
appconfigchecksum: {{ include (print .Template.BasePath "/configmap.yaml") . | sha256sum | trunc 63 }}
{{- end }}
spec:
{{- if and .Values.priorityClassName .Values.global.enablePriorityClass }}
Expand All @@ -51,14 +49,12 @@ spec:
hostAliases: {{ toYaml .Values.hostAliases | nindent 6 }}
{{- end }}
volumes:
{{- if .Values.config }}
- name: config
- name: appConfig
configMap:
name: "{{ template "magda.fullname" . }}-config"
{{- if .Values.opensearchDashboardsYml.defaultMode }}
defaultMode: {{ .Values.opensearchDashboardsYml.defaultMode }}
{{- end }}
{{- end }}
name: "{{ template "magda.fullname" . }}-appConfig"
items:
- key: appConfig.json
path: appConfig.json
{{- if .Values.extraVolumes }}
# Currently some extra blocks accept strings
# to continue with backwards compatibility this is being kept
Expand Down Expand Up @@ -157,11 +153,9 @@ spec:
resources:
{{ toYaml .Values.resources | indent 10 }}
volumeMounts:
{{- range $path, $config := .Values.config }}
- name: config
mountPath: /usr/share/opensearch-dashboards/config/{{ $path }}
subPath: {{ $path }}
{{- end }}
- name: appConfig
mountPath: /etc/config/appConfig.json
subPath: appConfig.json
{{- if .Values.extraVolumeMounts }}
# Currently some extra blocks accept strings
# to continue with backwards compatibility this is being kept
Expand Down
20 changes: 20 additions & 0 deletions deploy/magda-embedding-api/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ bodyLimit: 10485760
# @default -- Default to 25000 (25s).
closeGraceDelay: 25000

# -- (object) Application configuration of the service.
# You can supply a list of key-value pairs to be used as the application configuration.
# Currently, the only supported config field is `modelList`.
# Via the `modelList` field, you can specify a list of LLM models that the service supports.
# Although you can specify multiple models, only one model will be used at this moment.
# Each model item have the following fields:
# - `name` (string): The huggingface registered model name. We only support ONNX model at this moment. This field is required.
# - `default` (bool): Optional; Whether this model is the default model. If not specified, the first model in the list will be the default model. Only default model will be loaded.
# - `quantized` (bool): Optional; Whether the quantized version of model will be used. If not specified, the quantized version model will be loaded.
# - `config` (object): Optional; The configuration object that will be passed to the model.
# - `cache_dir` (string): Optional; The cache directory of the downloaded models. If not specified, the default cache directory will be used.
# - `local_files_only` (bool): Optional; Whether to only load the model from local files. If not specified, the model will be downloaded from the huggingface model hub.
# - `revision` (string) Optional;
# - `model_file_name` (string) Optional;
# Please note: The released docker image only contains "Alibaba-NLP/gte-base-en-v1.5" model.
# If you specify other models, the server will download the model from the huggingface model hub at the startup.
# You might want to adjust the `startupProbe` settings to accommodate the model downloading time.
# Depends on the model size, you might also want to adjust the `resources.limits.memory` & `resources.requests.memory`value.
appConfig: {}

# image setting loadding order: (from higher priority to lower priority)
# - Values.image.x
# - Values.defaultImage.x
Expand Down
2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@
"fastify": "^4.28.1",
"fastify-cli": "^6.2.1",
"fastify-plugin": "^4.5.1",
"fs-extra": "^11.2.0",
"onnxruntime-node": "^1.14.0"
},
"devDependencies": {
"@langchain/openai": "^0.2.1",
"@magda/ci-utils": "^1.0.5",
"@magda/docker-utils": "^4.2.1",
"@types/fs-extra": "^11.0.4",
"@types/node": "^18.19.3",
"concurrently": "^8.2.2",
"eslint": "^9.6.0",
Expand Down
5 changes: 4 additions & 1 deletion src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@ const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);

export type AppOptions = {
appConfigFile: string;
// Place your custom options for app below here.
} & Partial<AutoloadPluginOptions>;

// Pass --options via CLI arguments in command to enable these options.
const options: AppOptions = {};
const options: AppOptions = {
appConfigFile: ""
};

const app: FastifyPluginAsync<AppOptions> = async (
fastify,
Expand Down
120 changes: 105 additions & 15 deletions src/libs/EmbeddingGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,25 @@ import {
PretrainedOptions
} from "@xenova/transformers";

export const supportModels = ["Alibaba-NLP/gte-base-en-v1.5"];
export const defaultModel = "Alibaba-NLP/gte-base-en-v1.5";
export const defaultModel = {
name: "Alibaba-NLP/gte-base-en-v1.5",
quantized: false
};

export interface ModelItem {
name: string;
// whether or not this model is the default model
// if all models are not default, the first one will be used as default
default?: boolean;
quantized?: boolean | null;
config?: any;
cache_dir?: string;
local_files_only?: boolean;
revision?: string;
model_file_name?: string;
}

export type ModelListItem = string | ModelItem;

class EmbeddingGenerator {
protected ready: boolean = false;
Expand All @@ -21,20 +38,78 @@ class EmbeddingGenerator {
private tokenizer: PreTrainedTokenizer | null = null;
private model: PreTrainedModel | null = null;

private supportModelNames: string[] = [];
// although we allow user to pass in a list of models, we only use the first one (default model) is used for now
private defaultModel: string = "";
private modelList: ModelListItem[] = [defaultModel];

private pipelineLayout = {
tokenizer: AutoTokenizer,
model: AutoModel,
default: {
model: defaultModel
}
model: AutoModel
};

constructor() {
constructor(modelList: ModelListItem[] = []) {
if (modelList?.length) {
this.modelList = [...modelList];
}
this.processModelList();
this.readPromise = this.init();
}

/**
* Process this.modelList and set this.defaultModel and this.supportModelNames
*
* @private
* @memberof EmbeddingGenerator
*/
private processModelList() {
const modelNames: string[] = [];
let defaultModel: string = "";

for (const model of this.modelList) {
if (typeof model === "string") {
modelNames.push(model);
} else {
if (typeof model.name !== "string") {
throw new Error(
"Invalid model list supplied, when list item is not a string, it must contain a string `name` field"
);
}
modelNames.push(model.name);
if (model?.default === true) {
defaultModel = model.name;
}
}
}
if (!defaultModel) {
defaultModel = modelNames[0];
}

this.defaultModel = defaultModel;
this.supportModelNames = modelNames;
}

private getModelByName(modelName: string): ModelItem {
for (const model of this.modelList) {
if (typeof model === "string") {
if (model === modelName) {
return {
name: modelName
};
}
} else if (model?.name === modelName) {
return model;
}
}
throw new Error(`Model \`${modelName}\` not found`);
}

get supportModels() {
return supportModels;
return this.supportModelNames;
}

get defaultModelName() {
return this.defaultModel;
}

isReady() {
Expand All @@ -49,8 +124,8 @@ class EmbeddingGenerator {
model: string | null = null,
pretrainedOptions: PretrainedOptions = {}
) {
if (model === null) {
model = this.pipelineLayout.default.model;
if (!model) {
model = this.defaultModel;
}

const defaultPretrainedOptions = {
Expand Down Expand Up @@ -129,19 +204,34 @@ class EmbeddingGenerator {

protected async init() {
// Create feature extraction pipeline
await this.createPipeline("Alibaba-NLP/gte-base-en-v1.5", {
quantized: false // Comment out this line to use the quantized version
});
const { name: modelName, ...modelOpts } = this.getModelByName(
this.defaultModel
);
await this.createPipeline(modelName, modelOpts);
this.ready = true;
}

async switchModel(model: string = this.defaultModel) {
if (model && this.supportModels.indexOf(model) === -1) {
throw new Error(
`Model \`${model}\` is not supported. Supported models: ${this.supportModels.join(", ")}`
);
}
const { name: modelName, ...modelOpts } = this.getModelByName(model);
await this.dispose();
await this.createPipeline(modelName, modelOpts);
}

async dispose() {
if (this.model) {
await this.model.dispose();
}
}

async generate(sentences: string | string[], model: string = defaultModel) {
async generate(
sentences: string | string[],
model: string = this.defaultModel
) {
if (model && this.supportModels.indexOf(model) === -1) {
throw new Error(
`Model \`${model}\` is not supported. Supported models: ${this.supportModels.join(", ")}`
Expand All @@ -160,7 +250,7 @@ class EmbeddingGenerator {

async tokenize(
texts: string | string[],
model: string = defaultModel,
model: string = this.defaultModel,
opts: {
text_pair?: string | string[];
padding?: boolean | "max_length";
Expand Down
29 changes: 29 additions & 0 deletions src/plugins/loadAppConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import fp from "fastify-plugin";
import fse from "fs-extra/esm";

declare module "fastify" {
export interface FastifyInstance {
appConfig: {
[key: string]: any;
};
}
}

export interface SupportPluginOptions {
appConfigFile?: string;
}

export default fp<SupportPluginOptions>(
async (fastify, opts) => {
fastify.decorate("appConfig", {} as any);
if (opts?.appConfigFile) {
fastify.log.info(`Loading app config from ${opts.appConfigFile}`);
const appConfig = await fse.readJson(opts.appConfigFile);
fastify.appConfig = appConfig;
}
},
{
fastify: "4.x",
name: "loadAppConfig"
}
);
6 changes: 4 additions & 2 deletions src/plugins/setupEmbeddingGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ const WAIT_TIME_MS = 500;
// to export the decorators to the outer scope
export default fp<SupportPluginOptions>(
async (fastify, opts) => {
const embeddingGenerator = new EmbeddingGenerator();
const embeddingGenerator = new EmbeddingGenerator(
fastify?.appConfig?.modelList
);
fastify.decorate("embeddingGenerator", embeddingGenerator);

fastify.addHook("onRequest", function (request, reply, next) {
Expand All @@ -43,6 +45,6 @@ export default fp<SupportPluginOptions>(
{
fastify: "4.x",
name: "setupEmbeddingGenerator",
dependencies: ["@fastify/sensible"]
dependencies: ["@fastify/sensible", "loadAppConfig"]
}
);
11 changes: 4 additions & 7 deletions src/routes/v1/embeddings/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import { FastifyPluginAsync } from "fastify";
import { Type } from "@sinclair/typebox";
import { StringEnum } from "../../../libs/types.js";
import { TypeBoxTypeProvider } from "@fastify/type-provider-typebox";
import {
supportModels,
defaultModel
} from "../../../libs/EmbeddingGenerator.js";

const schemaEmebeddingObject = Type.Object({
index: Type.Integer(),
Expand All @@ -22,7 +17,7 @@ const schema = {
200: Type.Object({
object: Type.Const("list"),
data: Type.Array(schemaEmebeddingObject),
model: StringEnum(supportModels),
model: Type.String(),
usage: Type.Object({
prompt_tokens: Type.Integer(),
total_tokens: Type.Integer()
Expand All @@ -38,6 +33,7 @@ const embeddings: FastifyPluginAsync = async (
const fastify = fastifyInstance.withTypeProvider<TypeBoxTypeProvider>();

fastify.post("/", { schema }, async function (request, reply) {
const supportModels = this.embeddingGenerator.supportModels;
if (
request.body.model &&
supportModels.indexOf(request.body.model) === -1
Expand All @@ -46,7 +42,8 @@ const embeddings: FastifyPluginAsync = async (
`Model \`${request.body.model}\` is not supported. Supported models: ${supportModels.join(", ")}`
);
}
const model = request.body.model || defaultModel;
const model =
request.body.model || fastify.embeddingGenerator.defaultModelName;
const inputItems = Array.isArray(request.body.input)
? request.body.input
: [request.body.input];
Expand Down
Loading

0 comments on commit 554ceb6

Please sign in to comment.