diff --git a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts index 1be3d28a0..78cb4ab2e 100644 --- a/packages/backend/src/workers/provider/LlamaCppPython.spec.ts +++ b/packages/backend/src/workers/provider/LlamaCppPython.spec.ts @@ -208,8 +208,8 @@ describe('perform', () => { AutoRemove: false, Mounts: [ { - Source: 'dummy-path', - Target: '/models', + Source: 'dummy-path/dummy-file.guff', + Target: '/models/dummy-file.guff', Type: 'bind', }, ], diff --git a/packages/backend/src/workers/provider/LlamaCppPython.ts b/packages/backend/src/workers/provider/LlamaCppPython.ts index 6b1d31af0..d0a4c38d9 100644 --- a/packages/backend/src/workers/provider/LlamaCppPython.ts +++ b/packages/backend/src/workers/provider/LlamaCppPython.ts @@ -24,7 +24,7 @@ import type { } from '@podman-desktop/api'; import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig'; import { InferenceProvider } from './InferenceProvider'; -import { getModelPropertiesForEnvironment } from '../../utils/modelsUtils'; +import { getLocalModelFile, getModelPropertiesForEnvironment } from '../../utils/modelsUtils'; import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils'; import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; import type { TaskRegistry } from '../../registries/TaskRegistry'; @@ -81,17 +81,23 @@ export class LlamaCppPython extends InferenceProvider { [LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)), }; - const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; - envs.push(...getModelPropertiesForEnvironment(modelInfo)); + // get model mount settings + const filename = getLocalModelFile(modelInfo); + const target = `/models/${modelInfo.file.file}`; + // mount the file directory to avoid adding other files to the containers const mounts: MountConfig = [ { - Target: '/models', - Source: modelInfo.file.path, + Target: target, + Source: filename, Type: 'bind', }, ]; + // provide envs + const envs: string[] = [`MODEL_PATH=${target}`, 'HOST=0.0.0.0', 'PORT=8000']; + envs.push(...getModelPropertiesForEnvironment(modelInfo)); + const deviceRequests: DeviceRequest[] = []; const devices: Device[] = []; let entrypoint: string | undefined = undefined; diff --git a/packages/backend/src/workers/provider/WhisperCpp.spec.ts b/packages/backend/src/workers/provider/WhisperCpp.spec.ts index e4c2c666f..4a32312ca 100644 --- a/packages/backend/src/workers/provider/WhisperCpp.spec.ts +++ b/packages/backend/src/workers/provider/WhisperCpp.spec.ts @@ -235,4 +235,33 @@ test('provided connection should be used for pulling the image', async () => { expect(getImageInfo).toHaveBeenCalledWith(connectionMock, 'localhost/whisper-cpp:custom', expect.any(Function)); expect(podmanConnection.getContainerProviderConnection).toHaveBeenCalledWith(connection); expect(podmanConnection.findRunningContainerProviderConnection).not.toHaveBeenCalled(); + // ensure the create container is called with appropriate arguments + expect(containerEngine.createContainer).toHaveBeenCalledWith('dummy-engine-id', { + Detach: true, + Env: ['MODEL_PATH=/models/random-file', 'HOST=0.0.0.0', 'PORT=8000'], + HostConfig: { + AutoRemove: false, + Mounts: [ + { + Source: 'path-to-file/random-file', + Target: '/models/random-file', + Type: 'bind', + }, + ], + PortBindings: { + '8000/tcp': [ + { + HostPort: '8888', + }, + ], + }, + SecurityOpt: ['label=disable'], + }, + Image: 'dummy-image-id', + Labels: { + 'ai-lab-inference-server': '["whisper-cpp"]', + api: 'http://localhost:8888/inference', + hello: 'world', + }, + }); }); diff --git a/packages/backend/src/workers/provider/WhisperCpp.ts b/packages/backend/src/workers/provider/WhisperCpp.ts index 26f5c7936..cd1a5b5bd 100644 --- a/packages/backend/src/workers/provider/WhisperCpp.ts +++ b/packages/backend/src/workers/provider/WhisperCpp.ts @@ -25,6 +25,7 @@ import type { ContainerProviderConnection, MountConfig } from '@podman-desktop/a import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils'; import { whispercpp } from '../../assets/inference-images.json'; import type { PodmanConnection } from '../../managers/podmanConnection'; +import { getLocalModelFile } from '../../utils/modelsUtils'; export class WhisperCpp extends InferenceProvider { constructor( @@ -67,17 +68,22 @@ export class WhisperCpp extends InferenceProvider { if (!connection) throw new Error('no running connection could be found'); - const imageInfo = await this.pullImage(connection, config.image ?? whispercpp.default, labels); - const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000']; + // get model mount settings + const filename = getLocalModelFile(modelInfo); + const target = `/models/${modelInfo.file.file}`; + // mount the file directory to avoid adding other files to the containers const mounts: MountConfig = [ { - Target: '/models', - Source: modelInfo.file.path, + Target: target, + Source: filename, Type: 'bind', }, ]; + const imageInfo = await this.pullImage(connection, config.image ?? whispercpp.default, labels); + const envs: string[] = [`MODEL_PATH=${target}`, 'HOST=0.0.0.0', 'PORT=8000']; + labels['api'] = `http://localhost:${config.port}/inference`; const containerInfo = await this.createContainer(