From 554ceb64556f435de429de5c4b1e3b3a1983cc8d Mon Sep 17 00:00:00 2001 From: Jacky Jiang Date: Wed, 24 Jul 2024 01:23:12 +1000 Subject: [PATCH] add support of changing default serving model --- README.md | 151 +++++++++--------- .../templates/configmap.yaml | 14 +- .../templates/deployment.yaml | 24 ++- deploy/magda-embedding-api/values.yaml | 20 +++ package.json | 2 + src/app.ts | 5 +- src/libs/EmbeddingGenerator.ts | 120 ++++++++++++-- src/plugins/loadAppConfig.ts | 29 ++++ src/plugins/setupEmbeddingGenerator.ts | 6 +- src/routes/v1/embeddings/index.ts | 11 +- test/integration.test.ts | 5 +- test/plugins/setupEmbeddingGenerator.test.ts | 4 + test/routes/v1/embeddings/index.test.ts | 7 +- yarn.lock | 17 ++ 14 files changed, 284 insertions(+), 131 deletions(-) create mode 100644 src/plugins/loadAppConfig.ts diff --git a/README.md b/README.md index 4e2c9c8..07bb7af 100644 --- a/README.md +++ b/README.md @@ -36,81 +36,82 @@ Kubernetes: `>= 1.21.0` ## Values -| Key | Type | Default | Description | -| ---------------------------------- | ------ | ------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| affinity | object | `{}` | | -| autoscaling.hpa.enabled | bool | `false` | | -| autoscaling.hpa.maxReplicas | int | `3` | | -| autoscaling.hpa.minReplicas | int | `1` | | -| autoscaling.hpa.targetCPU | int | `90` | | -| autoscaling.hpa.targetMemory | string | `""` | | -| bodyLimit | int | Default to 10485760 (10MB). | Defines the maximum payload, in bytes, that the server is allowed to accept | -| closeGraceDelay | int | Default to 25000 (25s). | The maximum amount of time before forcefully closing pending requests. This should set to a value lower than the Pod's termination grace period (which is default to 30s) | -| debug | bool | `false` | Start Fastify app in debug mode with nodejs inspector inspector port is 9320 | -| defaultImage.imagePullSecret | bool | `false` | | -| defaultImage.pullPolicy | string | `"IfNotPresent"` | | -| defaultImage.repository | string | `"ghcr.io/magda-io"` | | -| deploymentAnnotations | object | `{}` | | -| envFrom | list | `[]` | | -| extraContainers | string | `""` | | -| extraEnvs | list | `[]` | | -| extraInitContainers | string | `""` | | -| extraVolumeMounts | list | `[]` | | -| extraVolumes | list | `[]` | | -| fullnameOverride | string | `""` | | -| global.image | object | `{}` | | -| global.rollingUpdate | object | `{}` | | -| hostAliases | list | `[]` | | -| image.name | string | `"magda-embedding-api"` | | -| lifecycle | object | `{}` | pod lifecycle policies as outlined here: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks | -| livenessProbe.failureThreshold | int | `10` | | -| livenessProbe.httpGet.path | string | `"/status/liveness"` | | -| livenessProbe.httpGet.port | int | `3000` | | -| livenessProbe.initialDelaySeconds | int | `10` | | -| livenessProbe.periodSeconds | int | `20` | | -| livenessProbe.successThreshold | int | `1` | | -| livenessProbe.timeoutSeconds | int | `5` | | -| logLevel | string | `info`. | The log level of the application. one of 'fatal', 'error', 'warn', 'info', 'debug', 'trace'; also 'silent' is supported to disable logging. Any other value defines a custom level and requires supplying a level value via levelVal. | -| nameOverride | string | `""` | | -| nodeSelector | object | `{}` | | -| pluginTimeout | int | Default to 10000 (10 seconds). | The maximum amount of time in milliseconds in which a fastify plugin can load. If not, ready will complete with an Error with code 'ERR_AVVIO_PLUGIN_TIMEOUT'. | -| podAnnotations | object | `{}` | | -| podSecurityContext.runAsUser | int | `1000` | | -| priorityClassName | string | `"magda-9"` | | -| rbac.automountServiceAccountToken | bool | `false` | Controls whether or not the Service Account token is automatically mounted to /var/run/secrets/kubernetes.io/serviceaccount | -| rbac.create | bool | `false` | | -| rbac.serviceAccountAnnotations | object | `{}` | | -| rbac.serviceAccountName | string | `""` | | -| readinessProbe.failureThreshold | int | `10` | | -| readinessProbe.httpGet.path | string | `"/status/readiness"` | | -| readinessProbe.httpGet.port | int | `3000` | | -| readinessProbe.initialDelaySeconds | int | `10` | | -| readinessProbe.periodSeconds | int | `20` | | -| readinessProbe.successThreshold | int | `1` | | -| readinessProbe.timeoutSeconds | int | `5` | | -| replicas | int | `1` | | -| resources.limits.memory | string | `"2000M"` | the memory limit of the container Due to [this issue of ONNX runtime](https://github.com/microsoft/onnxruntime/issues/15080), the peak memory usage of the service is much higher than the model file size. When change the default model, be sure to test the peak memory usage of the service before setting the memory limit. | -| resources.requests.cpu | string | `"100m"` | | -| resources.requests.memory | string | `"850M"` | | -| service.annotations | object | `{}` | | -| service.httpPortName | string | `"http"` | | -| service.labels | object | `{}` | | -| service.loadBalancerIP | string | `""` | | -| service.loadBalancerSourceRanges | list | `[]` | | -| service.name | string | `"magda-embedding-api"` | | -| service.nodePort | string | `""` | | -| service.port | int | `80` | | -| service.targetPort | int | `3000` | | -| service.type | string | `"ClusterIP"` | | -| startupProbe.failureThreshold | int | `30` | | -| startupProbe.httpGet.path | string | `"/status/startup"` | | -| startupProbe.httpGet.port | int | `3000` | | -| startupProbe.initialDelaySeconds | int | `10` | | -| startupProbe.periodSeconds | int | `10` | | -| startupProbe.successThreshold | int | `1` | | -| startupProbe.timeoutSeconds | int | `5` | | -| tolerations | list | `[]` | | -| topologySpreadConstraints | list | `[]` | This is the pod topology spread constraints https://kubernetes.io/docs/concepts/workloads/pods/pod-topology-spread-constraints/ | +| Key | Type | Default | Description | +| ---------------------------------- | ------ | ------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| affinity | object | `{}` | | +| appConfig | object | `{}` | Application configuration of the service. You can supply a list of key-value pairs to be used as the application configuration. Currently, the only supported config field is `modelList`. Via the `modelList` field, you can specify a list of LLM models that the service supports. Although you can specify multiple models, only one model will be used at this moment. Each model item have the following fields: - `name` (string): The huggingface registered model name. We only support ONNX model at this moment. This field is required. - `default` (bool): Optional; Whether this model is the default model. If not specified, the first model in the list will be the default model. Only default model will be loaded. - `quantized` (bool): Optional; Whether the quantized version of model will be used. If not specified, the quantized version model will be loaded. - `config` (object): Optional; The configuration object that will be passed to the model. - `cache_dir` (string): Optional; The cache directory of the downloaded models. If not specified, the default cache directory will be used. - `local_files_only` (bool): Optional; Whether to only load the model from local files. If not specified, the model will be downloaded from the huggingface model hub. - `revision` (string) Optional; - `model_file_name` (string) Optional; Please note: The released docker image only contains "Alibaba-NLP/gte-base-en-v1.5" model. If you specify other models, the server will download the model from the huggingface model hub at the startup. You might want to adjust the `startupProbe` settings to accommodate the model downloading time. Depends on the model size, you might also want to adjust the `resources.limits.memory` & `resources.requests.memory`value. | +| autoscaling.hpa.enabled | bool | `false` | | +| autoscaling.hpa.maxReplicas | int | `3` | | +| autoscaling.hpa.minReplicas | int | `1` | | +| autoscaling.hpa.targetCPU | int | `90` | | +| autoscaling.hpa.targetMemory | string | `""` | | +| bodyLimit | int | Default to 10485760 (10MB). | Defines the maximum payload, in bytes, that the server is allowed to accept | +| closeGraceDelay | int | Default to 25000 (25s). | The maximum amount of time before forcefully closing pending requests. This should set to a value lower than the Pod's termination grace period (which is default to 30s) | +| debug | bool | `false` | Start Fastify app in debug mode with nodejs inspector inspector port is 9320 | +| defaultImage.imagePullSecret | bool | `false` | | +| defaultImage.pullPolicy | string | `"IfNotPresent"` | | +| defaultImage.repository | string | `"ghcr.io/magda-io"` | | +| deploymentAnnotations | object | `{}` | | +| envFrom | list | `[]` | | +| extraContainers | string | `""` | | +| extraEnvs | list | `[]` | | +| extraInitContainers | string | `""` | | +| extraVolumeMounts | list | `[]` | | +| extraVolumes | list | `[]` | | +| fullnameOverride | string | `""` | | +| global.image | object | `{}` | | +| global.rollingUpdate | object | `{}` | | +| hostAliases | list | `[]` | | +| image.name | string | `"magda-embedding-api"` | | +| lifecycle | object | `{}` | pod lifecycle policies as outlined here: https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/#container-hooks | +| livenessProbe.failureThreshold | int | `10` | | +| livenessProbe.httpGet.path | string | `"/status/liveness"` | | +| livenessProbe.httpGet.port | int | `3000` | | +| livenessProbe.initialDelaySeconds | int | `10` | | +| livenessProbe.periodSeconds | int | `20` | | +| livenessProbe.successThreshold | int | `1` | | +| livenessProbe.timeoutSeconds | int | `5` | | +| logLevel | string | `info`. | The log level of the application. one of 'fatal', 'error', 'warn', 'info', 'debug', 'trace'; also 'silent' is supported to disable logging. Any other value defines a custom level and requires supplying a level value via levelVal. | +| nameOverride | string | `""` | | +| nodeSelector | object | `{}` | | +| pluginTimeout | int | Default to 10000 (10 seconds). | The maximum amount of time in milliseconds in which a fastify plugin can load. If not, ready will complete with an Error with code 'ERR_AVVIO_PLUGIN_TIMEOUT'. | +| podAnnotations | object | `{}` | | +| podSecurityContext.runAsUser | int | `1000` | | +| priorityClassName | string | `"magda-9"` | | +| rbac.automountServiceAccountToken | bool | `false` | Controls whether or not the Service Account token is automatically mounted to /var/run/secrets/kubernetes.io/serviceaccount | +| rbac.create | bool | `false` | | +| rbac.serviceAccountAnnotations | object | `{}` | | +| rbac.serviceAccountName | string | `""` | | +| readinessProbe.failureThreshold | int | `10` | | +| readinessProbe.httpGet.path | string | `"/status/readiness"` | | +| readinessProbe.httpGet.port | int | `3000` | | +| readinessProbe.initialDelaySeconds | int | `10` | | +| readinessProbe.periodSeconds | int | `20` | | +| readinessProbe.successThreshold | int | `1` | | +| readinessProbe.timeoutSeconds | int | `5` | | +| replicas | int | `1` | | +| resources.limits.memory | string | `"2000M"` | the memory limit of the container Due to [this issue of ONNX runtime](https://github.com/microsoft/onnxruntime/issues/15080), the peak memory usage of the service is much higher than the model file size. When change the default model, be sure to test the peak memory usage of the service before setting the memory limit. | +| resources.requests.cpu | string | `"100m"` | | +| resources.requests.memory | string | `"850M"` | | +| service.annotations | object | `{}` | | +| service.httpPortName | string | `"http"` | | +| service.labels | object | `{}` | | +| service.loadBalancerIP | string | `""` | | +| service.loadBalancerSourceRanges | list | `[]` | | +| service.name | string | `"magda-embedding-api"` | | +| service.nodePort | string | `""` | | +| service.port | int | `80` | | +| service.targetPort | int | `3000` | | +| service.type | string | `"ClusterIP"` | | +| startupProbe.failureThreshold | int | `30` | | +| startupProbe.httpGet.path | string | `"/status/startup"` | | +| startupProbe.httpGet.port | int | `3000` | | +| startupProbe.initialDelaySeconds | int | `10` | | +| startupProbe.periodSeconds | int | `10` | | +| startupProbe.successThreshold | int | `1` | | +| startupProbe.timeoutSeconds | int | `5` | | +| tolerations | list | `[]` | | +| topologySpreadConstraints | list | `[]` | This is the pod topology spread constraints https://kubernetes.io/docs/concepts/workloads/pods/pod-topology-spread-constraints/ | ### Build & Run for Local Development diff --git a/deploy/magda-embedding-api/templates/configmap.yaml b/deploy/magda-embedding-api/templates/configmap.yaml index 23213d3..a7b8577 100644 --- a/deploy/magda-embedding-api/templates/configmap.yaml +++ b/deploy/magda-embedding-api/templates/configmap.yaml @@ -1,18 +1,8 @@ -{{- $root := . }} -{{- if .Values.config }} apiVersion: v1 kind: ConfigMap metadata: - name: "{{ template "magda.fullname" . }}-config" + name: "{{ template "magda.fullname" . }}-appConfig" labels: {{- include "magda.common.labels.standard" . | nindent 4 }} data: -{{- range $configName, $configYaml := .Values.config }} - {{ $configName }}: | - {{- if (eq (kindOf $configYaml) "map")}} - {{- tpl (toYaml $configYaml) $root | nindent 4 }} - {{- else -}} - {{- tpl $configYaml $root | nindent 4 }} - {{- end -}} -{{- end -}} -{{- end -}} \ No newline at end of file + appConfig.json: {{ .Values.appConfig | mustToRawJson | quote }} \ No newline at end of file diff --git a/deploy/magda-embedding-api/templates/deployment.yaml b/deploy/magda-embedding-api/templates/deployment.yaml index 6536c6f..322e3c7 100644 --- a/deploy/magda-embedding-api/templates/deployment.yaml +++ b/deploy/magda-embedding-api/templates/deployment.yaml @@ -28,9 +28,7 @@ spec: {{ $key }}: {{ $value | quote }} {{- end }} {{- /* This forces a restart if the configmap has changed */}} - {{- if .Values.config }} - configchecksum: {{ include (print .Template.BasePath "/configmap.yaml") . | sha256sum | trunc 63 }} - {{- end }} + appconfigchecksum: {{ include (print .Template.BasePath "/configmap.yaml") . | sha256sum | trunc 63 }} {{- end }} spec: {{- if and .Values.priorityClassName .Values.global.enablePriorityClass }} @@ -51,14 +49,12 @@ spec: hostAliases: {{ toYaml .Values.hostAliases | nindent 6 }} {{- end }} volumes: - {{- if .Values.config }} - - name: config + - name: appConfig configMap: - name: "{{ template "magda.fullname" . }}-config" - {{- if .Values.opensearchDashboardsYml.defaultMode }} - defaultMode: {{ .Values.opensearchDashboardsYml.defaultMode }} - {{- end }} - {{- end }} + name: "{{ template "magda.fullname" . }}-appConfig" + items: + - key: appConfig.json + path: appConfig.json {{- if .Values.extraVolumes }} # Currently some extra blocks accept strings # to continue with backwards compatibility this is being kept @@ -157,11 +153,9 @@ spec: resources: {{ toYaml .Values.resources | indent 10 }} volumeMounts: - {{- range $path, $config := .Values.config }} - - name: config - mountPath: /usr/share/opensearch-dashboards/config/{{ $path }} - subPath: {{ $path }} - {{- end }} + - name: appConfig + mountPath: /etc/config/appConfig.json + subPath: appConfig.json {{- if .Values.extraVolumeMounts }} # Currently some extra blocks accept strings # to continue with backwards compatibility this is being kept diff --git a/deploy/magda-embedding-api/values.yaml b/deploy/magda-embedding-api/values.yaml index a6d8c54..e7cd567 100644 --- a/deploy/magda-embedding-api/values.yaml +++ b/deploy/magda-embedding-api/values.yaml @@ -27,6 +27,26 @@ bodyLimit: 10485760 # @default -- Default to 25000 (25s). closeGraceDelay: 25000 +# -- (object) Application configuration of the service. +# You can supply a list of key-value pairs to be used as the application configuration. +# Currently, the only supported config field is `modelList`. +# Via the `modelList` field, you can specify a list of LLM models that the service supports. +# Although you can specify multiple models, only one model will be used at this moment. +# Each model item have the following fields: +# - `name` (string): The huggingface registered model name. We only support ONNX model at this moment. This field is required. +# - `default` (bool): Optional; Whether this model is the default model. If not specified, the first model in the list will be the default model. Only default model will be loaded. +# - `quantized` (bool): Optional; Whether the quantized version of model will be used. If not specified, the quantized version model will be loaded. +# - `config` (object): Optional; The configuration object that will be passed to the model. +# - `cache_dir` (string): Optional; The cache directory of the downloaded models. If not specified, the default cache directory will be used. +# - `local_files_only` (bool): Optional; Whether to only load the model from local files. If not specified, the model will be downloaded from the huggingface model hub. +# - `revision` (string) Optional; +# - `model_file_name` (string) Optional; +# Please note: The released docker image only contains "Alibaba-NLP/gte-base-en-v1.5" model. +# If you specify other models, the server will download the model from the huggingface model hub at the startup. +# You might want to adjust the `startupProbe` settings to accommodate the model downloading time. +# Depends on the model size, you might also want to adjust the `resources.limits.memory` & `resources.requests.memory`value. +appConfig: {} + # image setting loadding order: (from higher priority to lower priority) # - Values.image.x # - Values.defaultImage.x diff --git a/package.json b/package.json index 67cb076..84560ed 100644 --- a/package.json +++ b/package.json @@ -46,12 +46,14 @@ "fastify": "^4.28.1", "fastify-cli": "^6.2.1", "fastify-plugin": "^4.5.1", + "fs-extra": "^11.2.0", "onnxruntime-node": "^1.14.0" }, "devDependencies": { "@langchain/openai": "^0.2.1", "@magda/ci-utils": "^1.0.5", "@magda/docker-utils": "^4.2.1", + "@types/fs-extra": "^11.0.4", "@types/node": "^18.19.3", "concurrently": "^8.2.2", "eslint": "^9.6.0", diff --git a/src/app.ts b/src/app.ts index f773b4f..1648d6a 100644 --- a/src/app.ts +++ b/src/app.ts @@ -8,11 +8,14 @@ const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); export type AppOptions = { + appConfigFile: string; // Place your custom options for app below here. } & Partial; // Pass --options via CLI arguments in command to enable these options. -const options: AppOptions = {}; +const options: AppOptions = { + appConfigFile: "" +}; const app: FastifyPluginAsync = async ( fastify, diff --git a/src/libs/EmbeddingGenerator.ts b/src/libs/EmbeddingGenerator.ts index 7386f99..d7ee856 100644 --- a/src/libs/EmbeddingGenerator.ts +++ b/src/libs/EmbeddingGenerator.ts @@ -11,8 +11,25 @@ import { PretrainedOptions } from "@xenova/transformers"; -export const supportModels = ["Alibaba-NLP/gte-base-en-v1.5"]; -export const defaultModel = "Alibaba-NLP/gte-base-en-v1.5"; +export const defaultModel = { + name: "Alibaba-NLP/gte-base-en-v1.5", + quantized: false +}; + +export interface ModelItem { + name: string; + // whether or not this model is the default model + // if all models are not default, the first one will be used as default + default?: boolean; + quantized?: boolean | null; + config?: any; + cache_dir?: string; + local_files_only?: boolean; + revision?: string; + model_file_name?: string; +} + +export type ModelListItem = string | ModelItem; class EmbeddingGenerator { protected ready: boolean = false; @@ -21,20 +38,78 @@ class EmbeddingGenerator { private tokenizer: PreTrainedTokenizer | null = null; private model: PreTrainedModel | null = null; + private supportModelNames: string[] = []; + // although we allow user to pass in a list of models, we only use the first one (default model) is used for now + private defaultModel: string = ""; + private modelList: ModelListItem[] = [defaultModel]; + private pipelineLayout = { tokenizer: AutoTokenizer, - model: AutoModel, - default: { - model: defaultModel - } + model: AutoModel }; - constructor() { + constructor(modelList: ModelListItem[] = []) { + if (modelList?.length) { + this.modelList = [...modelList]; + } + this.processModelList(); this.readPromise = this.init(); } + /** + * Process this.modelList and set this.defaultModel and this.supportModelNames + * + * @private + * @memberof EmbeddingGenerator + */ + private processModelList() { + const modelNames: string[] = []; + let defaultModel: string = ""; + + for (const model of this.modelList) { + if (typeof model === "string") { + modelNames.push(model); + } else { + if (typeof model.name !== "string") { + throw new Error( + "Invalid model list supplied, when list item is not a string, it must contain a string `name` field" + ); + } + modelNames.push(model.name); + if (model?.default === true) { + defaultModel = model.name; + } + } + } + if (!defaultModel) { + defaultModel = modelNames[0]; + } + + this.defaultModel = defaultModel; + this.supportModelNames = modelNames; + } + + private getModelByName(modelName: string): ModelItem { + for (const model of this.modelList) { + if (typeof model === "string") { + if (model === modelName) { + return { + name: modelName + }; + } + } else if (model?.name === modelName) { + return model; + } + } + throw new Error(`Model \`${modelName}\` not found`); + } + get supportModels() { - return supportModels; + return this.supportModelNames; + } + + get defaultModelName() { + return this.defaultModel; } isReady() { @@ -49,8 +124,8 @@ class EmbeddingGenerator { model: string | null = null, pretrainedOptions: PretrainedOptions = {} ) { - if (model === null) { - model = this.pipelineLayout.default.model; + if (!model) { + model = this.defaultModel; } const defaultPretrainedOptions = { @@ -129,19 +204,34 @@ class EmbeddingGenerator { protected async init() { // Create feature extraction pipeline - await this.createPipeline("Alibaba-NLP/gte-base-en-v1.5", { - quantized: false // Comment out this line to use the quantized version - }); + const { name: modelName, ...modelOpts } = this.getModelByName( + this.defaultModel + ); + await this.createPipeline(modelName, modelOpts); this.ready = true; } + async switchModel(model: string = this.defaultModel) { + if (model && this.supportModels.indexOf(model) === -1) { + throw new Error( + `Model \`${model}\` is not supported. Supported models: ${this.supportModels.join(", ")}` + ); + } + const { name: modelName, ...modelOpts } = this.getModelByName(model); + await this.dispose(); + await this.createPipeline(modelName, modelOpts); + } + async dispose() { if (this.model) { await this.model.dispose(); } } - async generate(sentences: string | string[], model: string = defaultModel) { + async generate( + sentences: string | string[], + model: string = this.defaultModel + ) { if (model && this.supportModels.indexOf(model) === -1) { throw new Error( `Model \`${model}\` is not supported. Supported models: ${this.supportModels.join(", ")}` @@ -160,7 +250,7 @@ class EmbeddingGenerator { async tokenize( texts: string | string[], - model: string = defaultModel, + model: string = this.defaultModel, opts: { text_pair?: string | string[]; padding?: boolean | "max_length"; diff --git a/src/plugins/loadAppConfig.ts b/src/plugins/loadAppConfig.ts new file mode 100644 index 0000000..15bf736 --- /dev/null +++ b/src/plugins/loadAppConfig.ts @@ -0,0 +1,29 @@ +import fp from "fastify-plugin"; +import fse from "fs-extra/esm"; + +declare module "fastify" { + export interface FastifyInstance { + appConfig: { + [key: string]: any; + }; + } +} + +export interface SupportPluginOptions { + appConfigFile?: string; +} + +export default fp( + async (fastify, opts) => { + fastify.decorate("appConfig", {} as any); + if (opts?.appConfigFile) { + fastify.log.info(`Loading app config from ${opts.appConfigFile}`); + const appConfig = await fse.readJson(opts.appConfigFile); + fastify.appConfig = appConfig; + } + }, + { + fastify: "4.x", + name: "loadAppConfig" + } +); diff --git a/src/plugins/setupEmbeddingGenerator.ts b/src/plugins/setupEmbeddingGenerator.ts index b3023e5..0a23eef 100644 --- a/src/plugins/setupEmbeddingGenerator.ts +++ b/src/plugins/setupEmbeddingGenerator.ts @@ -18,7 +18,9 @@ const WAIT_TIME_MS = 500; // to export the decorators to the outer scope export default fp( async (fastify, opts) => { - const embeddingGenerator = new EmbeddingGenerator(); + const embeddingGenerator = new EmbeddingGenerator( + fastify?.appConfig?.modelList + ); fastify.decorate("embeddingGenerator", embeddingGenerator); fastify.addHook("onRequest", function (request, reply, next) { @@ -43,6 +45,6 @@ export default fp( { fastify: "4.x", name: "setupEmbeddingGenerator", - dependencies: ["@fastify/sensible"] + dependencies: ["@fastify/sensible", "loadAppConfig"] } ); diff --git a/src/routes/v1/embeddings/index.ts b/src/routes/v1/embeddings/index.ts index 28cb026..e0352eb 100644 --- a/src/routes/v1/embeddings/index.ts +++ b/src/routes/v1/embeddings/index.ts @@ -1,11 +1,6 @@ import { FastifyPluginAsync } from "fastify"; import { Type } from "@sinclair/typebox"; -import { StringEnum } from "../../../libs/types.js"; import { TypeBoxTypeProvider } from "@fastify/type-provider-typebox"; -import { - supportModels, - defaultModel -} from "../../../libs/EmbeddingGenerator.js"; const schemaEmebeddingObject = Type.Object({ index: Type.Integer(), @@ -22,7 +17,7 @@ const schema = { 200: Type.Object({ object: Type.Const("list"), data: Type.Array(schemaEmebeddingObject), - model: StringEnum(supportModels), + model: Type.String(), usage: Type.Object({ prompt_tokens: Type.Integer(), total_tokens: Type.Integer() @@ -38,6 +33,7 @@ const embeddings: FastifyPluginAsync = async ( const fastify = fastifyInstance.withTypeProvider(); fastify.post("/", { schema }, async function (request, reply) { + const supportModels = this.embeddingGenerator.supportModels; if ( request.body.model && supportModels.indexOf(request.body.model) === -1 @@ -46,7 +42,8 @@ const embeddings: FastifyPluginAsync = async ( `Model \`${request.body.model}\` is not supported. Supported models: ${supportModels.join(", ")}` ); } - const model = request.body.model || defaultModel; + const model = + request.body.model || fastify.embeddingGenerator.defaultModelName; const inputItems = Array.isArray(request.body.input) ? request.body.input : [request.body.input]; diff --git a/test/integration.test.ts b/test/integration.test.ts index af7c0e2..f0d58ce 100644 --- a/test/integration.test.ts +++ b/test/integration.test.ts @@ -1,7 +1,7 @@ import { test } from "node:test"; import * as assert from "node:assert"; -import { defaultModel } from "../src/libs/EmbeddingGenerator.js"; import { build } from "./helper.js"; +import { defaultModel } from "../src/libs/EmbeddingGenerator.js"; import { OpenAIEmbeddings } from "@langchain/openai"; test("Should work with @langchain/openai", async (t) => { @@ -10,7 +10,8 @@ test("Should work with @langchain/openai", async (t) => { const embeddings = new OpenAIEmbeddings({ verbose: true, - model: defaultModel, + model: + typeof defaultModel === "string" ? defaultModel : defaultModel.name, configuration: { baseURL: `http://localhost:${fastify.server.address().port}/v1` }, diff --git a/test/plugins/setupEmbeddingGenerator.test.ts b/test/plugins/setupEmbeddingGenerator.test.ts index 0381668..0433dfb 100644 --- a/test/plugins/setupEmbeddingGenerator.test.ts +++ b/test/plugins/setupEmbeddingGenerator.test.ts @@ -1,6 +1,7 @@ import t from "tap"; import Fastify from "fastify"; import sensible from "@fastify/sensible"; +import loadAppConfig from "../../src/plugins/loadAppConfig.js"; import type SetupEmbeddingGeneratorType from "../../src/plugins/setupEmbeddingGenerator.js"; import { MockEmbeddingGenerator } from "../helper.js"; @@ -13,6 +14,7 @@ const SetupEmbeddingGenerator = await t.mockImport< t.test("should works for child plugin routes", async (t) => { const fastify = Fastify(); fastify.register(sensible); + fastify.register(loadAppConfig); fastify.register(SetupEmbeddingGenerator); fastify.register(async (fastify, opts) => { fastify.get("/test", async function (request, reply) { @@ -35,6 +37,7 @@ t.test( async (t) => { const fastify = Fastify(); fastify.register(sensible); + fastify.register(loadAppConfig); fastify.register(SetupEmbeddingGenerator); fastify.register(async (fastify, opts) => { fastify.get("/test", async function (request, reply) { @@ -63,6 +66,7 @@ t.test( async (t) => { const fastify = Fastify(); fastify.register(sensible); + fastify.register(loadAppConfig); fastify.register(SetupEmbeddingGenerator); fastify.register(async (fastify, opts) => { fastify.get("/status/ready", async function (request, reply) { diff --git a/test/routes/v1/embeddings/index.test.ts b/test/routes/v1/embeddings/index.test.ts index 7bdb706..bd79a38 100644 --- a/test/routes/v1/embeddings/index.test.ts +++ b/test/routes/v1/embeddings/index.test.ts @@ -3,6 +3,9 @@ import * as assert from "node:assert"; import { defaultModel } from "../../../../src/libs/EmbeddingGenerator.js"; import { build } from "../../../helper.js"; +const defaultModelName = + typeof defaultModel === "string" ? defaultModel : defaultModel.name; + test("/v1/embeddings", async (t) => { const app = await build(t); @@ -11,7 +14,7 @@ test("/v1/embeddings", async (t) => { method: "POST", payload: { input: "This is a cake", - model: defaultModel + model: defaultModelName } }); assert.strictEqual(res.statusCode, 200); @@ -24,5 +27,5 @@ test("/v1/embeddings", async (t) => { for (let i = 0; i < resData.data[0].embedding.length; i++) { assert.equal(typeof resData.data[0].embedding[i], "number"); } - assert.equal(resData["model"], defaultModel); + assert.equal(resData["model"], defaultModelName); }); diff --git a/yarn.lock b/yarn.lock index ca9589f..4f1b7f0 100644 --- a/yarn.lock +++ b/yarn.lock @@ -835,6 +835,14 @@ resolved "https://registry.yarnpkg.com/@types/estree/-/estree-1.0.5.tgz#a6ce3e556e00fd9895dd872dd172ad0d4bd687f4" integrity sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw== +"@types/fs-extra@^11.0.4": + version "11.0.4" + resolved "https://registry.yarnpkg.com/@types/fs-extra/-/fs-extra-11.0.4.tgz#e16a863bb8843fba8c5004362b5a73e17becca45" + integrity sha512-yTbItCNreRooED33qjunPthRcSjERP1r4MqCZc7wv0u2sUkzTFp45tgUfS5+r7FrZPdmCCNflLhVSP/o+SemsQ== + dependencies: + "@types/jsonfile" "*" + "@types/node" "*" + "@types/istanbul-lib-coverage@^2.0.1": version "2.0.6" resolved "https://registry.yarnpkg.com/@types/istanbul-lib-coverage/-/istanbul-lib-coverage-2.0.6.tgz#7739c232a1fee9b4d3ce8985f314c0c6d33549d7" @@ -845,6 +853,13 @@ resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.15.tgz#596a1747233694d50f6ad8a7869fcb6f56cf5841" integrity sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA== +"@types/jsonfile@*": + version "6.1.4" + resolved "https://registry.yarnpkg.com/@types/jsonfile/-/jsonfile-6.1.4.tgz#614afec1a1164e7d670b4a7ad64df3e7beb7b702" + integrity sha512-D5qGUYwjvnNNextdU59/+fI+spnwtTFmyQP0h+PfIOSkNfpU6AOICUOkm4i0OnSk+NyjdPJrxCDro0sJsWlRpQ== + dependencies: + "@types/node" "*" + "@types/long@^4.0.1": version "4.0.2" resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.2.tgz#b74129719fc8d11c01868010082d483b7545591a" @@ -4188,6 +4203,7 @@ string-length@^6.0.0: strip-ansi "^7.1.0" "string-width-cjs@npm:string-width@^4.2.0", string-width@^4.1.0, string-width@^4.2.0, string-width@^4.2.3: + name string-width-cjs version "4.2.3" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.3.tgz#269c7117d27b05ad2e536830a8ec895ef9c6d010" integrity sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g== @@ -4708,6 +4724,7 @@ word-wrap@^1.2.5: integrity sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA== "wrap-ansi-cjs@npm:wrap-ansi@^7.0.0", wrap-ansi@^7.0.0: + name wrap-ansi-cjs version "7.0.0" resolved "https://registry.yarnpkg.com/wrap-ansi/-/wrap-ansi-7.0.0.tgz#67e145cff510a6a6984bdf1152911d69d2eb9e43" integrity sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==