Skip to content

Commit

Permalink
feature: WSL upload support podman connections (#1535)
Browse files Browse the repository at this point in the history
* feature: WSL upload support podman connections

Signed-off-by: axel7083 <[email protected]>

* fix: unit tests

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
  • Loading branch information
axel7083 authored Aug 26, 2024
1 parent 4e20f77 commit 65e02a6
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 327 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ describe('pullApplication', () => {
'model-id': remoteModelMock.id,
});
// upload model to podman machine
expect(modelsManagerMock.uploadModelToPodmanMachine).toHaveBeenCalledWith(remoteModelMock, {
expect(modelsManagerMock.uploadModelToPodmanMachine).toHaveBeenCalledWith(connectionMock, remoteModelMock, {
'test-label': 'test-value',
'recipe-id': recipeMock.id,
'model-id': remoteModelMock.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ export class ApplicationManager extends Publisher<ApplicationState[]> implements
});

// upload model to podman machine if user system is supported
const modelPath = await this.modelsManager.uploadModelToPodmanMachine(model, {
const modelPath = await this.modelsManager.uploadModelToPodmanMachine(connection, model, {
...labels,
'recipe-id': recipe.id,
'model-id': model.id,
Expand Down
11 changes: 11 additions & 0 deletions packages/backend/src/managers/inference/inferenceManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import type { InferenceProvider } from '../../workers/provider/InferenceProvider
import type { CatalogManager } from '../catalogManager';
import type { InferenceServer } from '@shared/src/models/IInference';
import { InferenceType } from '@shared/src/models/IInference';
import { VMType } from '@shared/src/models/IPodman';

vi.mock('@podman-desktop/api', async () => {
return {
Expand Down Expand Up @@ -64,6 +65,7 @@ const containerRegistryMock = {

const podmanConnectionMock = {
onPodmanConnectionEvent: vi.fn(),
findRunningContainerProviderConnection: vi.fn(),
} as unknown as PodmanConnection;

const modelsManager = {
Expand Down Expand Up @@ -126,6 +128,15 @@ beforeEach(() => {
Health: undefined,
},
} as unknown as ContainerInspectInfo);
vi.mocked(podmanConnectionMock.findRunningContainerProviderConnection).mockReturnValue({
name: 'Podman Machine',
vmType: VMType.UNKNOWN,
type: 'podman',
status: () => 'started',
endpoint: {
socketPath: 'socket.sock',
},
});
vi.mocked(taskRegistryMock.getTasksByLabels).mockReturnValue([]);
vi.mocked(modelsManager.getLocalModelPath).mockReturnValue('/local/model.guff');
vi.mocked(modelsManager.uploadModelToPodmanMachine).mockResolvedValue('/mnt/path/model.guff');
Expand Down
13 changes: 11 additions & 2 deletions packages/backend/src/managers/inference/inferenceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import type { InferenceServer, InferenceServerStatus, InferenceType } from '@shared/src/models/IInference';
import type { PodmanConnection, PodmanConnectionEvent } from '../podmanConnection';
import { containerEngine, Disposable } from '@podman-desktop/api';
import { type ContainerInfo, type TelemetryLogger, type Webview } from '@podman-desktop/api';
import type { ContainerInfo, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { ContainerRegistry, ContainerStart } from '../../registries/ContainerRegistry';
import { getInferenceType, isTransitioning, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
import { Publisher } from '../../utils/Publisher';
Expand Down Expand Up @@ -198,10 +198,19 @@ export class InferenceManager extends Publisher<InferenceServer[]> implements Di
provider = providers[0];
}

let connection: ContainerProviderConnection | undefined = undefined;
if (config.connection) {
connection = this.podmanConnection.getContainerProviderConnection(config.connection);
} else {
connection = this.podmanConnection.findRunningContainerProviderConnection();
}

if (!connection) throw new Error('cannot find running container provider connection');

// upload models to podman machine if user system is supported
config.modelsInfo = await Promise.all(
config.modelsInfo.map(modelInfo =>
this.modelsManager.uploadModelToPodmanMachine(modelInfo, config.labels).then(path => ({
this.modelsManager.uploadModelToPodmanMachine(connection, modelInfo, config.labels).then(path => ({
...modelInfo,
file: {
path: dirname(path),
Expand Down
110 changes: 58 additions & 52 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import fs, { type Stats, type PathLike } from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import { env, process as coreProcess } from '@podman-desktop/api';
import type { RunResult, TelemetryLogger, Webview } from '@podman-desktop/api';
import type { RunResult, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import * as utils from '../utils/utils';
Expand All @@ -31,6 +31,9 @@ import type { CancellationTokenRegistry } from '../registries/CancellationTokenR
import * as sha from '../utils/sha';
import type { GGUFParseOutput } from '@huggingface/gguf';
import { gguf } from '@huggingface/gguf';
import type { PodmanConnection } from './podmanConnection';
import { VMType } from '@shared/src/models/IPodman';
import { getPodmanMachineName } from '../utils/podman';

const mocks = vi.hoisted(() => {
return {
Expand All @@ -42,7 +45,6 @@ const mocks = vi.hoisted(() => {
getTargetMock: vi.fn(),
getDownloaderCompleter: vi.fn(),
isCompletionEventMock: vi.fn(),
getFirstRunningMachineNameMock: vi.fn(),
getPodmanCliMock: vi.fn(),
};
});
Expand All @@ -52,8 +54,8 @@ vi.mock('@huggingface/gguf', () => ({
}));

vi.mock('../utils/podman', () => ({
getFirstRunningMachineName: mocks.getFirstRunningMachineNameMock,
getPodmanCli: mocks.getPodmanCliMock,
getPodmanMachineName: vi.fn(),
}));

vi.mock('@podman-desktop/api', () => {
Expand Down Expand Up @@ -92,6 +94,10 @@ vi.mock('../utils/downloader', () => ({
},
}));

const podmanConnectionMock = {
getContainerProviderConnections: vi.fn(),
} as unknown as PodmanConnection;

const cancellationTokenRegistryMock = {
createCancellationTokenSource: vi.fn(),
} as unknown as CancellationTokenRegistry;
Expand Down Expand Up @@ -183,6 +189,7 @@ test('getModelsInfo should get models in local directory', async () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -232,6 +239,7 @@ test('getModelsInfo should return an empty array if the models folder does not e
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
manager.getLocalModelsFromDisk();
Expand Down Expand Up @@ -272,6 +280,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -330,6 +339,7 @@ test('getLocalModelsFromDisk should skip folders containing tmp files', async ()
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -370,6 +380,7 @@ test('loadLocalModels should post a message with the message on disk and on cata
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -420,6 +431,7 @@ test('deleteModel deletes the model folder', async () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -484,6 +496,7 @@ describe('deleting models', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
manager.init();
await manager.loadLocalModels();
Expand Down Expand Up @@ -550,21 +563,44 @@ describe('deleting models', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
await manager.loadLocalModels();
await manager.deleteModel('model-id-1');

expect(removeUserModelMock).toBeCalledWith('model-id-1');
});

test('deleting on windows should check if models is uploaded', async () => {
vi.mocked(env).isWindows = true;
vi.mocked(coreProcess.exec).mockResolvedValue({} as unknown as RunResult);
mocks.getFirstRunningMachineNameMock.mockReturnValue('dummyMachine');
test('deleting on windows should check for all connections', async () => {
vi.mocked(coreProcess.exec).mockResolvedValue({} as RunResult);
mocks.getPodmanCliMock.mockReturnValue('dummyCli');
vi.mocked(env).isWindows = true;
const connections: ContainerProviderConnection[] = [
{
name: 'Machine 1',
type: 'podman',
vmType: VMType.HYPERV,
endpoint: {
socketPath: '',
},
status: () => 'started',
},
{
name: 'Machine 2',
type: 'podman',
vmType: VMType.WSL,
endpoint: {
socketPath: '',
},
status: () => 'started',
},
];
vi.mocked(podmanConnectionMock.getContainerProviderConnections).mockReturnValue(connections);
vi.mocked(getPodmanMachineName).mockReturnValue('machine-2');

const rmSpy = vi.spyOn(fs.promises, 'rm');
rmSpy.mockResolvedValue(undefined);

const manager = new ModelsManager(
'/home/user/aistudio',
{
Expand All @@ -587,62 +623,24 @@ describe('deleting models', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

await manager.loadLocalModels();
// delete the model
await manager.deleteModel('model-id-1');
expect(coreProcess.exec).toHaveBeenNthCalledWith(1, 'dummyCli', [
'machine',
'ssh',
'dummyMachine',
'stat',
'/home/user/ai-lab/models/dummyFile',
]);
expect(coreProcess.exec).toHaveBeenNthCalledWith(2, 'dummyCli', [

expect(podmanConnectionMock.getContainerProviderConnections).toHaveBeenCalledOnce();

expect(coreProcess.exec).toHaveBeenCalledWith('dummyCli', [
'machine',
'ssh',
'dummyMachine',
'machine-2',
'rm',
'-f',
'/home/user/ai-lab/models/dummyFile',
]);
});

test('deleting on windows should check if models is uploaded', async () => {
vi.mocked(env).isWindows = false;

const rmSpy = vi.spyOn(fs.promises, 'rm');
rmSpy.mockResolvedValue(undefined);
const manager = new ModelsManager(
'/home/user/aistudio',
{
postMessage: vi.fn().mockResolvedValue(undefined),
} as unknown as Webview,
{
getModels: () => {
return [
{
id: 'model-id-1',
url: 'model-url',
file: {
file: 'dummyFile',
path: 'dummyPath',
},
},
] as ModelInfo[];
},
} as CatalogManager,
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
);

await manager.loadLocalModels();
await manager.deleteModel('model-id-1');
expect(coreProcess.exec).not.toHaveBeenCalled();
expect(mocks.getFirstRunningMachineNameMock).not.toHaveBeenCalled();
expect(mocks.getPodmanCliMock).not.toHaveBeenCalled();
});
});

describe('downloadModel', () => {
Expand All @@ -659,6 +657,7 @@ describe('downloadModel', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -693,6 +692,7 @@ describe('downloadModel', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
Expand Down Expand Up @@ -724,6 +724,7 @@ describe('downloadModel', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);
vi.spyOn(taskRegistry, 'updateTask');
vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true);
Expand Down Expand Up @@ -754,6 +755,7 @@ describe('downloadModel', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -790,6 +792,7 @@ describe('downloadModel', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false);
Expand Down Expand Up @@ -837,6 +840,7 @@ describe('getModelMetadata', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

await expect(() => manager.getModelMetadata('unknown-model-id')).rejects.toThrowError(
Expand All @@ -861,6 +865,7 @@ describe('getModelMetadata', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

manager.init();
Expand Down Expand Up @@ -901,6 +906,7 @@ describe('getModelMetadata', () => {
telemetryLogger,
taskRegistry,
cancellationTokenRegistryMock,
podmanConnectionMock,
);

manager.init();
Expand Down
Loading

0 comments on commit 65e02a6

Please sign in to comment.