Skip to content

Commit

Permalink
fix: only mount model file to inference container
Browse files Browse the repository at this point in the history
Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 committed Dec 18, 2024
1 parent f8856df commit f639d67
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
4 changes: 2 additions & 2 deletions packages/backend/src/workers/provider/LlamaCppPython.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ describe('perform', () => {
AutoRemove: false,
Mounts: [
{
Source: 'dummy-path',
Target: '/models',
Source: 'dummy-path/dummy-file.guff',
Target: '/models/dummy-file.guff',
Type: 'bind',
},
],
Expand Down
16 changes: 11 additions & 5 deletions packages/backend/src/workers/provider/LlamaCppPython.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import type {
} from '@podman-desktop/api';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import { InferenceProvider } from './InferenceProvider';
import { getModelPropertiesForEnvironment } from '../../utils/modelsUtils';
import { getLocalModelFile, getModelPropertiesForEnvironment } from '../../utils/modelsUtils';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils';
import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
import type { TaskRegistry } from '../../registries/TaskRegistry';
Expand Down Expand Up @@ -81,17 +81,23 @@ export class LlamaCppPython extends InferenceProvider {
[LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)),
};

const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000'];
envs.push(...getModelPropertiesForEnvironment(modelInfo));
// get model mount settings
const filename = getLocalModelFile(modelInfo);
const target = `/models/${modelInfo.file.file}`;

// mount the file directory to avoid adding other files to the containers
const mounts: MountConfig = [
{
Target: '/models',
Source: modelInfo.file.path,
Target: target,
Source: filename,
Type: 'bind',
},
];

// provide envs
const envs: string[] = [`MODEL_PATH=${target}`, 'HOST=0.0.0.0', 'PORT=8000'];
envs.push(...getModelPropertiesForEnvironment(modelInfo));

const deviceRequests: DeviceRequest[] = [];
const devices: Device[] = [];
let entrypoint: string | undefined = undefined;
Expand Down
29 changes: 29 additions & 0 deletions packages/backend/src/workers/provider/WhisperCpp.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,33 @@ test('provided connection should be used for pulling the image', async () => {
expect(getImageInfo).toHaveBeenCalledWith(connectionMock, 'localhost/whisper-cpp:custom', expect.any(Function));
expect(podmanConnection.getContainerProviderConnection).toHaveBeenCalledWith(connection);
expect(podmanConnection.findRunningContainerProviderConnection).not.toHaveBeenCalled();
// ensure the create container is called with appropriate arguments
expect(containerEngine.createContainer).toHaveBeenCalledWith('dummy-engine-id', {

Check failure on line 239 in packages/backend/src/workers/provider/WhisperCpp.spec.ts

View workflow job for this annotation

GitHub Actions / linter, formatters and unit tests / windows-2022

src/workers/provider/WhisperCpp.spec.ts > provided connection should be used for pulling the image

AssertionError: expected "spy" to be called with arguments: [ 'dummy-engine-id', …(1) ] Received: 1st spy call: Array [ "dummy-engine-id", Object { "Detach": true, "Env": Array [ "MODEL_PATH=/models/random-file", "HOST=0.0.0.0", "PORT=8000", ], "HostConfig": Object { "AutoRemove": false, "Mounts": Array [ Object { - "Source": "path-to-file/random-file", + "Source": "path-to-file\\random-file", "Target": "/models/random-file", "Type": "bind", }, ], "PortBindings": Object { "8000/tcp": Array [ Object { "HostPort": "8888", }, ], }, "SecurityOpt": Array [ "label=disable", ], }, "Image": "dummy-image-id", "Labels": Object { "ai-lab-inference-server": "[\"whisper-cpp\"]", "api": "http://localhost:8888/inference", "hello": "world", }, }, ] Number of calls: 1 ❯ src/workers/provider/WhisperCpp.spec.ts:239:43
Detach: true,
Env: ['MODEL_PATH=/models/random-file', 'HOST=0.0.0.0', 'PORT=8000'],
HostConfig: {
AutoRemove: false,
Mounts: [
{
Source: 'path-to-file/random-file',
Target: '/models/random-file',
Type: 'bind',
},
],
PortBindings: {
'8000/tcp': [
{
HostPort: '8888',
},
],
},
SecurityOpt: ['label=disable'],
},
Image: 'dummy-image-id',
Labels: {
'ai-lab-inference-server': '["whisper-cpp"]',
api: 'http://localhost:8888/inference',
hello: 'world',
},
});
});
14 changes: 10 additions & 4 deletions packages/backend/src/workers/provider/WhisperCpp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import type { ContainerProviderConnection, MountConfig } from '@podman-desktop/a
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils';
import { whispercpp } from '../../assets/inference-images.json';
import type { PodmanConnection } from '../../managers/podmanConnection';
import { getLocalModelFile } from '../../utils/modelsUtils';

export class WhisperCpp extends InferenceProvider {
constructor(
Expand Down Expand Up @@ -67,17 +68,22 @@ export class WhisperCpp extends InferenceProvider {

if (!connection) throw new Error('no running connection could be found');

const imageInfo = await this.pullImage(connection, config.image ?? whispercpp.default, labels);
const envs: string[] = [`MODEL_PATH=/models/${modelInfo.file.file}`, 'HOST=0.0.0.0', 'PORT=8000'];
// get model mount settings
const filename = getLocalModelFile(modelInfo);
const target = `/models/${modelInfo.file.file}`;

// mount the file directory to avoid adding other files to the containers
const mounts: MountConfig = [
{
Target: '/models',
Source: modelInfo.file.path,
Target: target,
Source: filename,
Type: 'bind',
},
];

const imageInfo = await this.pullImage(connection, config.image ?? whispercpp.default, labels);
const envs: string[] = [`MODEL_PATH=${target}`, 'HOST=0.0.0.0', 'PORT=8000'];

labels['api'] = `http://localhost:${config.port}/inference`;

const containerInfo = await this.createContainer(
Expand Down

0 comments on commit f639d67

Please sign in to comment.