Skip to content

Commit

Permalink
fix: upload model on podman machine on WSL to speed loading (#204)
Browse files Browse the repository at this point in the history
* fix: upload model on podman machine on WSL to speed loading

Signed-off-by: lstocchi <[email protected]>

* fix: add tests

Signed-off-by: lstocchi <[email protected]>

* fix: get running machine name and use it on podman cp/stat

Signed-off-by: lstocchi <[email protected]>

* fix: fix lint, format and add tests

Signed-off-by: lstocchi <[email protected]>

* fix: use machine name to execute podman cli

Signed-off-by: lstocchi <[email protected]>

* fix: fix tests and reorganize based on review

Signed-off-by: lstocchi <[email protected]>

* fix: rename progressiveEvent

Signed-off-by: lstocchi <[email protected]>

* fix: show error if failing

Signed-off-by: lstocchi <[email protected]>

* fix: calculate machine name from running connection

Signed-off-by: lstocchi <[email protected]>

* fix: update remote path where to store uploaded model

Signed-off-by: lstocchi <[email protected]>

---------

Signed-off-by: lstocchi <[email protected]>
  • Loading branch information
lstocchi authored Mar 13, 2024
1 parent 9803d1e commit d9dc209
Show file tree
Hide file tree
Showing 16 changed files with 765 additions and 91 deletions.
3 changes: 3 additions & 0 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ describe('pullApplication', () => {
});
mocks.listPodsMock.mockResolvedValue([]);
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(modelsManager, 'uploadModelToPodmanMachine').mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -316,6 +317,7 @@ describe('pullApplication', () => {
});
mocks.listPodsMock.mockResolvedValue([]);
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(false);
vi.spyOn(modelsManager, 'uploadModelToPodmanMachine').mockResolvedValue('path');
mocks.performDownloadMock.mockResolvedValue('path');
const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -344,6 +346,7 @@ describe('pullApplication', () => {
});
mocks.listPodsMock.mockResolvedValue([]);
vi.spyOn(modelsManager, 'isModelOnDisk').mockReturnValue(true);
vi.spyOn(modelsManager, 'uploadModelToPodmanMachine').mockResolvedValue('path');
vi.spyOn(modelsManager, 'getLocalModelPath').mockReturnValue('path');
const recipe: Recipe = {
id: 'recipe1',
Expand Down
8 changes: 7 additions & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ export class ApplicationManager extends Publisher<ApplicationState[]> {
const configAndFilteredContainers = this.getConfigAndFilterContainers(recipe.config, localFolder);

// get model by downloading it or retrieving locally
const modelPath = await this.modelsManager.requestDownloadModel(model, {
let modelPath = await this.modelsManager.requestDownloadModel(model, {
'recipe-id': recipe.id,
'model-id': model.id,
});

// upload model to podman machine if user system is supported
modelPath = await this.modelsManager.uploadModelToPodmanMachine(model, modelPath, {
'recipe-id': recipe.id,
'model-id': model.id,
});
Expand Down
2 changes: 2 additions & 0 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,9 @@ describe('downloadModel', () => {

mocks.onEventDownloadMock.mockImplementation(listener => {
listener({
id: 'id',
status: 'completed',
duration: 1000,
});
});

Expand Down
42 changes: 34 additions & 8 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ import { Messages } from '@shared/Messages';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import * as podmanDesktopApi from '@podman-desktop/api';
import { Downloader, type DownloadEvent, isCompletionEvent, isProgressEvent } from '../utils/downloader';
import { Downloader } from '../utils/downloader';
import type { TaskRegistry } from '../registries/TaskRegistry';
import type { Task } from '@shared/src/models/ITask';
import type { BaseEvent } from '../models/baseEvent';
import { isCompletionEvent, isProgressEvent } from '../models/baseEvent';
import { Uploader } from '../utils/uploader';

export class ModelsManager implements Disposable {
#modelsDir: string;
Expand Down Expand Up @@ -215,12 +218,18 @@ export class ModelsManager implements Disposable {
});
}

private onDownloadEvent(event: DownloadEvent): void {
private onDownloadUploadEvent(event: BaseEvent, action: 'download' | 'upload'): void {
let taskLabel = 'model-pulling';
let eventName = 'model.download';
if (action === 'upload') {
taskLabel = 'model-uploading';
eventName = 'model.upload';
}
// Always use the task registry as source of truth for tasks
const tasks = this.taskRegistry.getTasksByLabels({ 'model-pulling': event.id });
const tasks = this.taskRegistry.getTasksByLabels({ [taskLabel]: event.id });
if (tasks.length === 0) {
// tasks might have been cleared but still an error.
console.error('received download event but no task is associated.');
console.error(`received ${action} event but no task is associated.`);
return;
}

Expand All @@ -236,9 +245,9 @@ export class ModelsManager implements Disposable {
task.error = event.message;

// telemetry usage
this.telemetry.logError('model.download', {
this.telemetry.logError(eventName, {
'model.id': event.id,
message: 'error downloading model',
message: `error ${action}ing model`,
error: event.message,
durationSeconds: event.duration,
});
Expand All @@ -247,7 +256,7 @@ export class ModelsManager implements Disposable {
task.progress = 100;

// telemetry usage
this.telemetry.logUsage('model.download', { 'model.id': event.id, durationSeconds: event.duration });
this.telemetry.logUsage(eventName, { 'model.id': event.id, durationSeconds: event.duration });
}
}
this.taskRegistry.updateTask(task); // update task
Expand Down Expand Up @@ -294,10 +303,27 @@ export class ModelsManager implements Disposable {
const downloader = this.createDownloader(model);

// Capture downloader events
downloader.onEvent(this.onDownloadEvent.bind(this));
downloader.onEvent(event => this.onDownloadUploadEvent(event, 'download'), this);

// perform download
await downloader.perform(model.id);
return downloader.getTarget();
}

async uploadModelToPodmanMachine(
model: ModelInfo,
localModelPath: string,
labels?: { [key: string]: string },
): Promise<string> {
this.taskRegistry.createTask(`Uploading model ${model.name}`, 'loading', {
...labels,
'model-uploading': model.id,
});

const uploader = new Uploader(localModelPath);
uploader.onEvent(event => this.onDownloadUploadEvent(event, 'upload'), this);

// perform download
return uploader.perform(model.id);
}
}
31 changes: 17 additions & 14 deletions packages/backend/src/managers/playground.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const mocks = vi.hoisted(() => ({
listContainers: vi.fn(),
logUsage: vi.fn(),
logError: vi.fn(),
getFirstRunningPodmanConnectionMock: vi.fn(),
}));

vi.mock('@podman-desktop/api', async () => {
Expand All @@ -55,6 +56,12 @@ vi.mock('@podman-desktop/api', async () => {
};
});

vi.mock('../utils/podman', () => {
return {
getFirstRunningPodmanConnection: mocks.getFirstRunningPodmanConnectionMock,
};
});

const containerRegistryMock = {
subscribe: mocks.containerRegistrySubscribeMock,
} as unknown as ContainerRegistry;
Expand Down Expand Up @@ -103,14 +110,12 @@ test('startPlayground should fail if no provider', async () => {

test('startPlayground should download image if not present then create container', async () => {
mocks.postMessage.mockResolvedValue(undefined);
mocks.getContainerConnections.mockReturnValue([
{
connection: {
type: 'podman',
status: () => 'started',
},
mocks.getFirstRunningPodmanConnectionMock.mockReturnValue({
connection: {
type: 'podman',
status: () => 'started',
},
]);
});
vi.spyOn(manager, 'selectImage')
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce({
Expand Down Expand Up @@ -162,14 +167,12 @@ test('stopPlayground should fail if no playground is running', async () => {

test('stopPlayground should stop a started playground', async () => {
mocks.postMessage.mockResolvedValue(undefined);
mocks.getContainerConnections.mockReturnValue([
{
connection: {
type: 'podman',
status: () => 'started',
},
mocks.getFirstRunningPodmanConnectionMock.mockReturnValue({
connection: {
type: 'podman',
status: () => 'started',
},
]);
});
vi.spyOn(manager, 'selectImage').mockResolvedValue({
Id: 'image1',
engineId: 'engine1',
Expand Down
20 changes: 3 additions & 17 deletions packages/backend/src/managers/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import {
containerEngine,
type Webview,
type ImageInfo,
type ProviderContainerConnection,
provider,
type TelemetryLogger,
} from '@podman-desktop/api';
import { containerEngine, type Webview, type ImageInfo, type TelemetryLogger } from '@podman-desktop/api';

import path from 'node:path';
import { getFreePort } from '../utils/ports';
Expand All @@ -35,6 +28,7 @@ import type { PodmanConnection } from './podmanConnection';
import OpenAI from 'openai';
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION, getDurationSecondsSince, timeout } from '../utils/utils';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import { getFirstRunningPodmanConnection } from '../utils/podman';

export const LABEL_MODEL_ID = 'ai-studio-model-id';
export const LABEL_MODEL_PORT = 'ai-studio-model-port';
Expand All @@ -45,14 +39,6 @@ const PLAYGROUND_IMAGE = 'quay.io/bootsy/playground:v0';

const STARTING_TIME_MAX = 3600 * 1000;

function findFirstProvider(): ProviderContainerConnection | undefined {
const engines = provider
.getContainerConnections()
.filter(connection => connection.connection.type === 'podman')
.filter(connection => connection.connection.status() === 'started');
return engines.length > 0 ? engines[0] : undefined;
}

export class PlayGroundManager {
private queryIdCounter = 0;

Expand Down Expand Up @@ -172,7 +158,7 @@ export class PlayGroundManager {

this.setPlaygroundStatus(modelId, 'starting');

const connection = findFirstProvider();
const connection = getFirstRunningPodmanConnection();
if (!connection) {
const error = 'Unable to find an engine to start playground';
this.setPlaygroundError(modelId, error);
Expand Down
25 changes: 14 additions & 11 deletions packages/backend/src/managers/podmanConnection.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { PodmanConnection } from './podmanConnection';
import type { RegisterContainerConnectionEvent, UpdateContainerConnectionEvent } from '@podman-desktop/api';

const mocks = vi.hoisted(() => ({
getContainerConnections: vi.fn(),
getFirstRunningPodmanConnectionMock: vi.fn(),
onDidRegisterContainerConnection: vi.fn(),
onDidUpdateContainerConnection: vi.fn(),
}));
Expand All @@ -30,23 +30,26 @@ vi.mock('@podman-desktop/api', async () => {
return {
provider: {
onDidRegisterContainerConnection: mocks.onDidRegisterContainerConnection,
getContainerConnections: mocks.getContainerConnections,
onDidUpdateContainerConnection: mocks.onDidUpdateContainerConnection,
},
};
});

test('startupSubscribe should execute immediately if provider already registered', () => {
vi.mock('../utils/podman', () => {
return {
getFirstRunningPodmanConnection: mocks.getFirstRunningPodmanConnectionMock,
};
});

test('startupSubscribe should execute immediately if provider already registered', async () => {
const manager = new PodmanConnection();
// one provider is already registered
mocks.getContainerConnections.mockReturnValue([
{
connection: {
type: 'podman',
status: () => 'started',
},
mocks.getFirstRunningPodmanConnectionMock.mockReturnValue({
connection: {
type: 'podman',
status: () => 'started',
},
]);
});
mocks.onDidRegisterContainerConnection.mockReturnValue({
dispose: vi.fn,
});
Expand All @@ -61,7 +64,7 @@ test('startupSubscribe should execute when provider is registered', async () =>
const manager = new PodmanConnection();

// no provider is already registered
mocks.getContainerConnections.mockReturnValue([]);
mocks.getFirstRunningPodmanConnectionMock.mockReturnValue(undefined);
mocks.onDidRegisterContainerConnection.mockImplementation((f: (e: RegisterContainerConnectionEvent) => void) => {
setTimeout(() => {
f({
Expand Down
8 changes: 3 additions & 5 deletions packages/backend/src/managers/podmanConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
type PodInfo,
type Disposable,
} from '@podman-desktop/api';
import { getFirstRunningPodmanConnection } from '../utils/podman';

export type startupHandle = () => void;
export type machineStartHandle = () => void;
Expand Down Expand Up @@ -72,11 +73,8 @@ export class PodmanConnection implements Disposable {
});

// In case at least one extension has already registered, we get one started podman provider
const engines = provider
.getContainerConnections()
.filter(connection => connection.connection.type === 'podman')
.filter(connection => connection.connection.status() === 'started');
if (engines.length > 0) {
const engine = getFirstRunningPodmanConnection();
if (engine) {
disposable.dispose();
this.#firstFound = true;
}
Expand Down
49 changes: 49 additions & 0 deletions packages/backend/src/models/baseEvent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**********************************************************************
* Copyright (C) 2024 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

export interface BaseEvent {
id: string;
status: 'error' | 'completed' | 'progress' | 'canceled';
message?: string;
}

export interface CompletionEvent extends BaseEvent {
status: 'completed' | 'error' | 'canceled';
duration: number;
}

export interface ProgressEvent extends BaseEvent {
status: 'progress';
value: number;
}

export const isCompletionEvent = (value: unknown): value is CompletionEvent => {
return (
!!value &&
typeof value === 'object' &&
'status' in value &&
typeof value['status'] === 'string' &&
['canceled', 'completed', 'error'].includes(value['status'])
);
};

export const isProgressEvent = (value: unknown): value is ProgressEvent => {
return (
!!value && typeof value === 'object' && 'status' in value && value['status'] === 'progress' && 'value' in value
);
};
Loading

0 comments on commit d9dc209

Please sign in to comment.