From 0af8659996273af33f265d8f36f322bccff28964 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Fri, 13 Dec 2024 18:36:29 +0100 Subject: [PATCH] feat: basic vllm support for hf cached models Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --- packages/backend/package.json | 1 + .../backend/src/assets/inference-images.json | 3 + .../backend/src/managers/modelsManager.ts | 73 +++++++-- packages/backend/src/studio.ts | 5 + packages/backend/src/workers/provider/VLLM.ts | 148 ++++++++++++++++++ .../lib/table/model/ModelColumnName.svelte | 15 +- packages/shared/src/models/IInference.ts | 1 + pnpm-lock.yaml | 16 ++ 8 files changed, 247 insertions(+), 15 deletions(-) create mode 100644 packages/backend/src/workers/provider/VLLM.ts diff --git a/packages/backend/package.json b/packages/backend/package.json index a0f5a0148..7a76b1c7b 100644 --- a/packages/backend/package.json +++ b/packages/backend/package.json @@ -101,6 +101,7 @@ }, "dependencies": { "@huggingface/gguf": "^0.1.12", + "@huggingface/hub": "^0.21.0", "express": "^4.21.2", "express-openapi-validator": "^5.3.9", "isomorphic-git": "^1.27.2", diff --git a/packages/backend/src/assets/inference-images.json b/packages/backend/src/assets/inference-images.json index b6758bc32..6bdebeee6 100644 --- a/packages/backend/src/assets/inference-images.json +++ b/packages/backend/src/assets/inference-images.json @@ -6,5 +6,8 @@ "default": "ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat@sha256:20734e9d60f047d27e4c9cf6a3b663e0627d48bd06d0a73b968f9d81c82de2f1", "cuda": "ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat-cuda@sha256:798acced911527254601d0e39a90c5a29ecad82755f28594bea9a587ea9e6043", "vulkan": "ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat-vulkan@sha256:22e11661fe66ace7c30b419703305b803eb937da10e19c23cb6767f03578256c" + }, + "vllm": { + "default": "quay.io/rh-ee-astefani/vllm:cpu-1734105797" } } diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 5e42c0129..9336c798e 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -39,9 +39,14 @@ import { gguf } from '@huggingface/gguf'; import type { PodmanConnection } from './podmanConnection'; import { VMType } from '@shared/src/models/IPodman'; import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry'; +import { InferenceType } from '@shared/src/models/IInference'; +import { scanCacheDir } from '@huggingface/hub'; +import { basename, join } from 'node:path'; export class ModelsManager implements Disposable { #models: Map; + #hfCache: Map; + #watcher?: podmanDesktopApi.FileSystemWatcher; #disposables: Disposable[]; @@ -58,6 +63,7 @@ export class ModelsManager implements Disposable { private configurationRegistry: ConfigurationRegistry, ) { this.#models = new Map(); + this.#hfCache = new Map(); this.#disposables = []; } @@ -72,6 +78,44 @@ export class ModelsManager implements Disposable { this.loadLocalModels().catch((err: unknown) => { console.error('Something went wrong while trying to load local models', err); }); + + scanCacheDir() + .then(results => { + this.#hfCache.clear(); + results.repos.forEach(repo => { + if (repo.revisions.length === 0) { + console.warn(`found hugging face cache repository ${repo.id} without any revision`); + return; + } + + // ensure at least one safetensor is available + if (!repo.revisions[0].files.some(file => file.path.endsWith('.safetensors'))) { + console.warn( + `hugging face cache repository ${repo.id.name} do not contain any .safetensors file: ignoring`, + ); + return; + } + + const id = basename(repo.path); + this.#hfCache.set(id, { + id: id, + backend: InferenceType.VLLM, + file: { + file: repo.revisions[0].commitOid, + path: join(repo.path, 'snapshots'), + creation: repo.lastModifiedAt, + size: repo.size, + }, + name: repo.id.name, + description: repo.id.name, + properties: { + origin: 'HF_CACHE', + }, + }); + }); + this.notify(); + }) + .catch(console.error); } dispose(): void { @@ -85,7 +129,7 @@ export class ModelsManager implements Disposable { this.catalogManager.getModels().forEach(m => this.#models.set(m.id, m)); const reloadLocalModels = async (): Promise => { this.getLocalModelsFromDisk(); - await this.sendModelsInfo(); + this.notify(); }; if (this.#watcher === undefined) { this.#watcher = apiFs.createFileSystemWatcher(this.modelsDir); @@ -99,15 +143,17 @@ export class ModelsManager implements Disposable { } getModelsInfo(): ModelInfo[] { - return [...this.#models.values()]; + return [...this.#models.values(), ...this.#hfCache.values()]; } - async sendModelsInfo(): Promise { + notify(): void { const models = this.getModelsInfo(); - await this.webview.postMessage({ - id: Messages.MSG_NEW_MODELS_STATE, - body: models, - }); + this.webview + .postMessage({ + id: Messages.MSG_NEW_MODELS_STATE, + body: models, + }) + .catch(console.error); } getModelsDirectory(): string { @@ -186,7 +232,7 @@ export class ModelsManager implements Disposable { } model.state = 'deleting'; - await this.sendModelsInfo(); + this.notify(); try { await this.deleteRemoteModel(model); let modelPath; @@ -214,7 +260,7 @@ export class ModelsManager implements Disposable { model.state = undefined; this.getLocalModelsFromDisk(); } finally { - await this.sendModelsInfo(); + this.notify(); } } @@ -331,9 +377,7 @@ export class ModelsManager implements Disposable { // refresh model lists on event completion this.getLocalModelsFromDisk(); - this.sendModelsInfo().catch((err: unknown) => { - console.error('Something went wrong while sending models info.', err); - }); + this.notify(); // cleanup downloader this.#downloaders.delete(event.id); @@ -433,6 +477,11 @@ export class ModelsManager implements Disposable { return getLocalModelFile(model); } + if (model.backend === InferenceType.VLLM) { + console.warn('Model upload for vllm is disabled'); + return getLocalModelFile(model); + } + this.taskRegistry.createTask(`Copying model ${model.name} to ${connection.name}`, 'loading', { ...labels, 'model-uploading': model.id, diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index eca82ebed..cf9bd1b7d 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -54,6 +54,7 @@ import { InstructlabApiImpl } from './instructlab-api-impl'; import { NavigationRegistry } from './registries/NavigationRegistry'; import { StudioAPI } from '@shared/src/StudioAPI'; import { InstructlabAPI } from '@shared/src/InstructlabAPI'; +import { VLLM } from './workers/provider/VLLM'; export class Studio { readonly #extensionContext: ExtensionContext; @@ -260,6 +261,10 @@ export class Studio { this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry, this.#podmanConnection)), ); + this.#extensionContext.subscriptions.push( + this.#inferenceProviderRegistry.register(new VLLM(this.#taskRegistry, this.#podmanConnection)), + ); + /** * The inference manager create, stop, manage Inference servers */ diff --git a/packages/backend/src/workers/provider/VLLM.ts b/packages/backend/src/workers/provider/VLLM.ts new file mode 100644 index 000000000..701afdbdb --- /dev/null +++ b/packages/backend/src/workers/provider/VLLM.ts @@ -0,0 +1,148 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { InferenceProvider } from './InferenceProvider'; +import type { TaskRegistry } from '../../registries/TaskRegistry'; +import type { PodmanConnection } from '../../managers/podmanConnection'; +import { type InferenceServer, InferenceType } from '@shared/src/models/IInference'; +import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; +import type { ContainerProviderConnection, MountConfig } from '@podman-desktop/api'; +import * as images from '../../assets/inference-images.json'; +import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils'; +import { basename, dirname } from 'node:path'; +import { join as joinposix } from 'node:path/posix'; +import { getLocalModelFile } from '../../utils/modelsUtils'; + +export class VLLM extends InferenceProvider { + constructor( + taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { + super(taskRegistry, InferenceType.VLLM, 'vllm'); + } + + dispose(): void {} + + public enabled = (): boolean => true; + + /** + * Here is an example + * + * podman run -it --rm + * -v C:\Users\axels\.cache\huggingface\hub\models--mistralai--Mistral-7B-v0.1:/cache/models--mistralai--Mistral-7B-v0.1 + * -e HF_HUB_CACHE=/cache + * localhost/vllm-cpu-env:latest + * --model=/cache/models--mistralai--Mistral-7B-v0.1/snapshots/7231864981174d9bee8c7687c24c8344414eae6b + * + * @param config + */ + override async perform(config: InferenceServerConfig): Promise { + if (config.modelsInfo.length !== 1) + throw new Error(`only one model is supported, received ${config.modelsInfo.length}`); + + const modelInfo = config.modelsInfo[0]; + if (modelInfo.backend !== InferenceType.VLLM) { + throw new Error(`VLLM requires models with backend type ${InferenceType.VLLM} got ${modelInfo.backend}.`); + } + + if (modelInfo.file === undefined) { + throw new Error('The model info file provided is undefined'); + } + + console.log('[VLLM]', config); + console.log('[VLLM] modelInfo.file', modelInfo.file); + + const fullPath = getLocalModelFile(modelInfo); + + // modelInfo.file.path must be under the form $(HF_HUB_CACHE)/--/snapshots/ + const parent = dirname(fullPath); + const commitHash = basename(fullPath); + const name = basename(parent); + if (name !== 'snapshots') throw new Error('you must provide snapshot path for vllm'); + const modelCache = dirname(parent); + + let connection: ContainerProviderConnection | undefined; + if (config.connection) { + connection = this.podmanConnection.getContainerProviderConnection(config.connection); + } else { + connection = this.podmanConnection.findRunningContainerProviderConnection(); + } + + if (!connection) throw new Error('no running connection could be found'); + + const labels: Record = { + ...config.labels, + [LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)), + }; + + const imageInfo = await this.pullImage(connection, config.image ?? images.vllm.default, labels); + // https://huggingface.co/docs/transformers/main/en/installation#offline-mode + // HF_HUB_OFFLINE in main + // TRANSFORMERS_OFFLINE for legacy + const envs: string[] = [`HF_HUB_CACHE=/cache`, 'TRANSFORMERS_OFFLINE=1', 'HF_HUB_OFFLINE=1']; + + labels['api'] = `http://localhost:${config.port}/inference`; + + const mounts: MountConfig = [ + { + Target: `/cache/${modelInfo.id}`, + Source: modelCache, + Type: 'bind', + }, + ]; + + const containerInfo = await this.createContainer( + imageInfo.engineId, + { + Image: imageInfo.Id, + Detach: true, + Labels: labels, + HostConfig: { + AutoRemove: false, + Mounts: mounts, + PortBindings: { + '8000/tcp': [ + { + HostPort: `${config.port}`, + }, + ], + }, + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + }, + Env: envs, + Cmd: [`--model=${joinposix('/cache', modelInfo.id, 'snapshots', commitHash)}`], + }, + labels, + ); + + return { + models: [modelInfo], + status: 'running', + connection: { + port: config.port, + }, + container: { + containerId: containerInfo.id, + engineId: containerInfo.engineId, + }, + type: InferenceType.VLLM, + labels: labels, + }; + } +} diff --git a/packages/frontend/src/lib/table/model/ModelColumnName.svelte b/packages/frontend/src/lib/table/model/ModelColumnName.svelte index de275aa0a..937686a98 100644 --- a/packages/frontend/src/lib/table/model/ModelColumnName.svelte +++ b/packages/frontend/src/lib/table/model/ModelColumnName.svelte @@ -2,14 +2,20 @@ import type { ModelInfo } from '@shared/src/models/IModelInfo'; import { router } from 'tinro'; -export let object: ModelInfo; +interface Props { + object: ModelInfo; +} + +let { object }: Props = $props(); + +let hf: boolean = $state(object.properties?.['origin'] === 'HF_CACHE'); function openDetails(): void { router.goto(`/model/${object.id}`); } - diff --git a/packages/shared/src/models/IInference.ts b/packages/shared/src/models/IInference.ts index dac43bb23..2987eec5d 100644 --- a/packages/shared/src/models/IInference.ts +++ b/packages/shared/src/models/IInference.ts @@ -21,6 +21,7 @@ export enum InferenceType { LLAMA_CPP = 'llama-cpp', WHISPER_CPP = 'whisper-cpp', NONE = 'none', + VLLM = 'vllm', } export type InferenceServerStatus = 'stopped' | 'running' | 'deleting' | 'stopping' | 'error' | 'starting'; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d6c62527a..1c619c022 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -117,6 +117,9 @@ importers: '@huggingface/gguf': specifier: ^0.1.12 version: 0.1.12 + '@huggingface/hub': + specifier: ^0.21.0 + version: 0.21.0 express: specifier: ^4.21.2 version: 4.21.2 @@ -1191,6 +1194,13 @@ packages: resolution: {integrity: sha512-m+u/ms28wE74v2VVCTncfI/KB2v897MRMOoRuYSU62P85fJ6/B2exMlHCNyAXkgDLeXBWDivXl4gPq+XbHmkaA==} engines: {node: '>=20'} + '@huggingface/hub@0.21.0': + resolution: {integrity: sha512-DpitNhqobMJLTv8dUq/EMtrz1dpfs3UrSVCxe1aKpjLAdOs6Gm6rqrinUFNvC9G88RIRzIYzojUtYUqlkKwKnA==} + engines: {node: '>=18'} + + '@huggingface/tasks@0.13.13': + resolution: {integrity: sha512-jaU91/x9mn3q1pwHMzpUiXICqME56LgDgza/nyt4h3Jp6k84YW931YFK5ri32qBDHmtjn/1dR4OMw85+dx87dA==} + '@humanfs/core@0.19.1': resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} engines: {node: '>=18.18.0'} @@ -5886,6 +5896,12 @@ snapshots: '@huggingface/gguf@0.1.12': {} + '@huggingface/hub@0.21.0': + dependencies: + '@huggingface/tasks': 0.13.13 + + '@huggingface/tasks@0.13.13': {} + '@humanfs/core@0.19.1': {} '@humanfs/node@0.16.6':