Skip to content

Commit

Permalink
feat(inference): introducing InferenceProviders (#1161)
Browse files Browse the repository at this point in the history
* feat(inference): introducing InferenceProviders

Signed-off-by: axel7083 <[email protected]>

* feat: improve inference provider integration

Signed-off-by: axel7083 <[email protected]>

* fix: compilation

Signed-off-by: axel7083 <[email protected]>

* fix: labels propagation

Signed-off-by: axel7083 <[email protected]>

* fix: error message

Signed-off-by: axel7083 <[email protected]>

* fix: revert to podman desktop api 1.10.3

Signed-off-by: axel7083 <[email protected]>

* fix: typecheck

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Jun 7, 2024
1 parent 987938c commit d2ea36b
Show file tree
Hide file tree
Showing 18 changed files with 844 additions and 453 deletions.
2 changes: 1 addition & 1 deletion packages/backend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"xml-js": "^1.6.11"
},
"devDependencies": {
"@podman-desktop/api": "0.0.202404101645-5d46ba5",
"@podman-desktop/api": "1.10.3",
"@types/js-yaml": "^4.0.9",
"@types/node": "^20",
"@types/postman-collection": "^3.5.10",
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ import { ApplicationRegistry } from '../registries/ApplicationRegistry';
import type { TaskRegistry } from '../registries/TaskRegistry';
import { Publisher } from '../utils/Publisher';
import { isQEMUMachine } from '../utils/podman';
import { SECOND } from '../utils/inferenceUtils';
import { getModelPropertiesForEnvironment } from '../utils/modelsUtils';
import { getRandomName } from '../utils/randomUtils';
import type { BuilderManager } from './recipes/BuilderManager';
import type { PodManager } from './recipes/PodManager';
import { SECOND } from '../workers/provider/LlamaCppPython';

export const LABEL_MODEL_ID = 'ai-lab-model-id';
export const LABEL_MODEL_PORTS = 'ai-lab-model-ports';
Expand Down
212 changes: 52 additions & 160 deletions packages/backend/src/managers/inference/inferenceManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,36 @@
***********************************************************************/
import {
containerEngine,
provider,
type Webview,
type TelemetryLogger,
type ImageInfo,
type ContainerInfo,
type ContainerInspectInfo,
type ProviderContainerConnection,
} from '@podman-desktop/api';
import type { ContainerRegistry } from '../../registries/ContainerRegistry';
import type { PodmanConnection } from '../podmanConnection';
import { beforeEach, expect, describe, test, vi } from 'vitest';
import { InferenceManager } from './inferenceManager';
import type { ModelsManager } from '../modelsManager';
import { LABEL_INFERENCE_SERVER, INFERENCE_SERVER_IMAGE } from '../../utils/inferenceUtils';
import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
import type { TaskRegistry } from '../../registries/TaskRegistry';
import { Messages } from '@shared/Messages';
import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry';
import type { InferenceProvider } from '../../workers/provider/InferenceProvider';

vi.mock('@podman-desktop/api', async () => {
return {
containerEngine: {
startContainer: vi.fn(),
stopContainer: vi.fn(),
listContainers: vi.fn(),
inspectContainer: vi.fn(),
pullImage: vi.fn(),
listImages: vi.fn(),
createContainer: vi.fn(),
deleteContainer: vi.fn(),
listContainers: vi.fn(),
},
Disposable: {
from: vi.fn(),
create: vi.fn(),
},
provider: {
getContainerConnections: vi.fn(),
},
};
});

Expand Down Expand Up @@ -87,6 +80,11 @@ const taskRegistryMock = {
getTasksByLabels: vi.fn(),
} as unknown as TaskRegistry;

const inferenceProviderRegistryMock = {
getAll: vi.fn(),
get: vi.fn(),
} as unknown as InferenceProviderRegistry;

const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
const manager = new InferenceManager(
webviewMock,
Expand All @@ -95,6 +93,7 @@ const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
modelsManager,
telemetryMock,
taskRegistryMock,
inferenceProviderRegistryMock,
);
manager.init();
await vi.waitUntil(manager.isInitialize.bind(manager), {
Expand All @@ -119,26 +118,6 @@ beforeEach(() => {
Health: undefined,
},
} as unknown as ContainerInspectInfo);
vi.mocked(provider.getContainerConnections).mockReturnValue([
{
providerId: 'test@providerId',
connection: {
type: 'podman',
name: 'test@connection',
status: () => 'started',
},
} as unknown as ProviderContainerConnection,
]);
vi.mocked(containerEngine.listImages).mockResolvedValue([
{
Id: 'dummyImageId',
engineId: 'dummyEngineId',
RepoTags: [INFERENCE_SERVER_IMAGE],
},
] as unknown as ImageInfo[]);
vi.mocked(containerEngine.createContainer).mockResolvedValue({
id: 'dummyCreatedContainerId',
});
vi.mocked(taskRegistryMock.getTasksByLabels).mockReturnValue([]);
vi.mocked(modelsManager.getLocalModelPath).mockReturnValue('/local/model.guff');
vi.mocked(modelsManager.uploadModelToPodmanMachine).mockResolvedValue('/mnt/path/model.guff');
Expand Down Expand Up @@ -233,119 +212,59 @@ describe('init Inference Manager', () => {
* Testing the creation logic
*/
describe('Create Inference Server', () => {
test('unknown providerId', async () => {
const inferenceManager = await getInitializedInferenceManager();
await expect(
inferenceManager.createInferenceServer(
{
providerId: 'unknown',
} as unknown as InferenceServerConfig,
'dummyTrackingId',
),
).rejects.toThrowError('cannot find any started container provider.');
test('no provider available should throw an error', async () => {
vi.mocked(inferenceProviderRegistryMock.getAll).mockReturnValue([]);

expect(provider.getContainerConnections).toHaveBeenCalled();
});

test('unknown imageId', async () => {
const inferenceManager = await getInitializedInferenceManager();
await expect(
inferenceManager.createInferenceServer(
{
providerId: 'test@providerId',
image: 'unknown',
} as unknown as InferenceServerConfig,
'dummyTrackingId',
),
).rejects.toThrowError('image unknown not found.');

expect(containerEngine.listImages).toHaveBeenCalled();
inferenceManager.createInferenceServer({
inferenceProvider: undefined,
labels: {},
modelsInfo: [],
port: 8888,
}),
).rejects.toThrowError('no enabled provider could be found.');
});

test('empty modelsInfo', async () => {
test('inference provider provided should use get from InferenceProviderRegistry', async () => {
vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue({
enabled: () => false,
} as unknown as InferenceProvider);

const inferenceManager = await getInitializedInferenceManager();
await expect(
inferenceManager.createInferenceServer(
{
providerId: 'test@providerId',
image: INFERENCE_SERVER_IMAGE,
modelsInfo: [],
} as unknown as InferenceServerConfig,
'dummyTrackingId',
),
).rejects.toThrowError('Need at least one model info to start an inference server.');
inferenceManager.createInferenceServer({
inferenceProvider: 'dummy-inference-provider',
labels: {},
modelsInfo: [],
port: 8888,
}),
).rejects.toThrowError('provider requested is not enabled.');
expect(inferenceProviderRegistryMock.get).toHaveBeenCalledWith('dummy-inference-provider');
});

test('valid InferenceServerConfig', async () => {
test('selected inference provider should receive config', async () => {
const provider: InferenceProvider = {
enabled: () => true,
name: 'dummy-inference-provider',
dispose: () => {},
perform: vi.fn().mockResolvedValue({ id: 'dummy-container-id', engineId: 'dummy-engine-id' }),
} as unknown as InferenceProvider;
vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue(provider);

const inferenceManager = await getInitializedInferenceManager();
await inferenceManager.createInferenceServer(
{
port: 8888,
providerId: 'test@providerId',
image: INFERENCE_SERVER_IMAGE,
modelsInfo: [
{
id: 'dummyModelId',
file: {
file: 'model.guff',
path: '/mnt/path',
},
},
],
} as unknown as InferenceServerConfig,
'dummyTrackingId',
);

expect(modelsManager.uploadModelToPodmanMachine).toHaveBeenCalledWith(
{
id: 'dummyModelId',
file: {
file: 'model.guff',
path: '/mnt/path',
},
},
{
trackingId: 'dummyTrackingId',
},
);
expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith(
1,
expect.stringContaining(
'Pulling ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:',
),
'loading',
{
trackingId: 'dummyTrackingId',
},
);
expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith(2, 'Creating container.', 'loading', {
trackingId: 'dummyTrackingId',
});
expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({
state: 'success',
});
expect(containerEngine.createContainer).toHaveBeenCalled();
expect(inferenceManager.getServers()).toStrictEqual([
{
connection: {
port: 8888,
},
container: {
containerId: 'dummyCreatedContainerId',
engineId: 'dummyEngineId',
},
models: [
{
file: {
file: 'model.guff',
path: '/mnt/path',
},
id: 'dummyModelId',
},
],
status: 'running',
},
]);
const config: InferenceServerConfig = {
inferenceProvider: 'dummy-inference-provider',
labels: {},
modelsInfo: [],
port: 8888,
};
const result = await inferenceManager.createInferenceServer(config);

expect(provider.perform).toHaveBeenCalledWith(config);

expect(result).toBe('dummy-container-id');
});
});

Expand Down Expand Up @@ -511,33 +430,6 @@ describe('Request Create Inference Server', () => {
trackingId: identifier,
});
});

test('Pull image error should be reflected in task registry', async () => {
vi.mocked(containerEngine.pullImage).mockRejectedValue(new Error('dummy pull image error'));

const inferenceManager = await getInitializedInferenceManager();
inferenceManager.requestCreateInferenceServer({
port: 8888,
providerId: 'test@providerId',
image: 'quay.io/bootsy/playground:v0',
modelsInfo: [
{
id: 'dummyModelId',
file: {
file: 'dummyFile',
path: 'dummyPath',
},
},
],
} as unknown as InferenceServerConfig);

await vi.waitFor(() => {
expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({
state: 'error',
error: 'Something went wrong while trying to create an inference server Error: dummy pull image error.',
});
});
});
});

describe('containerRegistry events', () => {
Expand Down
Loading

0 comments on commit d2ea36b

Please sign in to comment.