From 65e02a61195d77b10724a15dbc6092a680020f82 Mon Sep 17 00:00:00 2001 From: axel7083 <42176370+axel7083@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:08:48 +0200 Subject: [PATCH] feature: WSL upload support podman connections (#1535) * feature: WSL upload support podman connections Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> * fix: unit tests Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --------- Signed-off-by: axel7083 <42176370+axel7083@users.noreply.github.com> --- .../application/applicationManager.spec.ts | 2 +- .../application/applicationManager.ts | 2 +- .../inference/inferenceManager.spec.ts | 11 ++ .../managers/inference/inferenceManager.ts | 13 +- .../src/managers/modelsManager.spec.ts | 110 +++++++------- .../backend/src/managers/modelsManager.ts | 41 ++++-- .../src/managers/podmanConnection.spec.ts | 24 ++-- .../backend/src/managers/podmanConnection.ts | 6 +- packages/backend/src/studio.ts | 1 + packages/backend/src/utils/podman.spec.ts | 135 ------------------ packages/backend/src/utils/podman.ts | 53 ++----- packages/backend/src/utils/uploader.spec.ts | 23 +-- packages/backend/src/utils/uploader.ts | 13 +- .../src/workers/uploader/UploaderOptions.ts | 25 ++++ .../src/workers/uploader/WSLUploader.spec.ts | 111 ++++++++------ .../src/workers/uploader/WSLUploader.ts | 29 ++-- 16 files changed, 272 insertions(+), 327 deletions(-) create mode 100644 packages/backend/src/workers/uploader/UploaderOptions.ts diff --git a/packages/backend/src/managers/application/applicationManager.spec.ts b/packages/backend/src/managers/application/applicationManager.spec.ts index f190fa8b2..182e47615 100644 --- a/packages/backend/src/managers/application/applicationManager.spec.ts +++ b/packages/backend/src/managers/application/applicationManager.spec.ts @@ -306,7 +306,7 @@ describe('pullApplication', () => { 'model-id': remoteModelMock.id, }); // upload model to podman machine - expect(modelsManagerMock.uploadModelToPodmanMachine).toHaveBeenCalledWith(remoteModelMock, { + expect(modelsManagerMock.uploadModelToPodmanMachine).toHaveBeenCalledWith(connectionMock, remoteModelMock, { 'test-label': 'test-value', 'recipe-id': recipeMock.id, 'model-id': remoteModelMock.id, diff --git a/packages/backend/src/managers/application/applicationManager.ts b/packages/backend/src/managers/application/applicationManager.ts index 6b61987d5..1a029ae5a 100644 --- a/packages/backend/src/managers/application/applicationManager.ts +++ b/packages/backend/src/managers/application/applicationManager.ts @@ -180,7 +180,7 @@ export class ApplicationManager extends Publisher implements }); // upload model to podman machine if user system is supported - const modelPath = await this.modelsManager.uploadModelToPodmanMachine(model, { + const modelPath = await this.modelsManager.uploadModelToPodmanMachine(connection, model, { ...labels, 'recipe-id': recipe.id, 'model-id': model.id, diff --git a/packages/backend/src/managers/inference/inferenceManager.spec.ts b/packages/backend/src/managers/inference/inferenceManager.spec.ts index 7ffd64592..6d4af9ca7 100644 --- a/packages/backend/src/managers/inference/inferenceManager.spec.ts +++ b/packages/backend/src/managers/inference/inferenceManager.spec.ts @@ -36,6 +36,7 @@ import type { InferenceProvider } from '../../workers/provider/InferenceProvider import type { CatalogManager } from '../catalogManager'; import type { InferenceServer } from '@shared/src/models/IInference'; import { InferenceType } from '@shared/src/models/IInference'; +import { VMType } from '@shared/src/models/IPodman'; vi.mock('@podman-desktop/api', async () => { return { @@ -64,6 +65,7 @@ const containerRegistryMock = { const podmanConnectionMock = { onPodmanConnectionEvent: vi.fn(), + findRunningContainerProviderConnection: vi.fn(), } as unknown as PodmanConnection; const modelsManager = { @@ -126,6 +128,15 @@ beforeEach(() => { Health: undefined, }, } as unknown as ContainerInspectInfo); + vi.mocked(podmanConnectionMock.findRunningContainerProviderConnection).mockReturnValue({ + name: 'Podman Machine', + vmType: VMType.UNKNOWN, + type: 'podman', + status: () => 'started', + endpoint: { + socketPath: 'socket.sock', + }, + }); vi.mocked(taskRegistryMock.getTasksByLabels).mockReturnValue([]); vi.mocked(modelsManager.getLocalModelPath).mockReturnValue('/local/model.guff'); vi.mocked(modelsManager.uploadModelToPodmanMachine).mockResolvedValue('/mnt/path/model.guff'); diff --git a/packages/backend/src/managers/inference/inferenceManager.ts b/packages/backend/src/managers/inference/inferenceManager.ts index ac9b8dd5f..fb953332c 100644 --- a/packages/backend/src/managers/inference/inferenceManager.ts +++ b/packages/backend/src/managers/inference/inferenceManager.ts @@ -18,7 +18,7 @@ import type { InferenceServer, InferenceServerStatus, InferenceType } from '@shared/src/models/IInference'; import type { PodmanConnection, PodmanConnectionEvent } from '../podmanConnection'; import { containerEngine, Disposable } from '@podman-desktop/api'; -import { type ContainerInfo, type TelemetryLogger, type Webview } from '@podman-desktop/api'; +import type { ContainerInfo, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api'; import type { ContainerRegistry, ContainerStart } from '../../registries/ContainerRegistry'; import { getInferenceType, isTransitioning, LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils'; import { Publisher } from '../../utils/Publisher'; @@ -198,10 +198,19 @@ export class InferenceManager extends Publisher implements Di provider = providers[0]; } + let connection: ContainerProviderConnection | undefined = undefined; + if (config.connection) { + connection = this.podmanConnection.getContainerProviderConnection(config.connection); + } else { + connection = this.podmanConnection.findRunningContainerProviderConnection(); + } + + if (!connection) throw new Error('cannot find running container provider connection'); + // upload models to podman machine if user system is supported config.modelsInfo = await Promise.all( config.modelsInfo.map(modelInfo => - this.modelsManager.uploadModelToPodmanMachine(modelInfo, config.labels).then(path => ({ + this.modelsManager.uploadModelToPodmanMachine(connection, modelInfo, config.labels).then(path => ({ ...modelInfo, file: { path: dirname(path), diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index ade59f35f..0c303d9af 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -22,7 +22,7 @@ import fs, { type Stats, type PathLike } from 'node:fs'; import path from 'node:path'; import { ModelsManager } from './modelsManager'; import { env, process as coreProcess } from '@podman-desktop/api'; -import type { RunResult, TelemetryLogger, Webview } from '@podman-desktop/api'; +import type { RunResult, TelemetryLogger, Webview, ContainerProviderConnection } from '@podman-desktop/api'; import type { CatalogManager } from './catalogManager'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import * as utils from '../utils/utils'; @@ -31,6 +31,9 @@ import type { CancellationTokenRegistry } from '../registries/CancellationTokenR import * as sha from '../utils/sha'; import type { GGUFParseOutput } from '@huggingface/gguf'; import { gguf } from '@huggingface/gguf'; +import type { PodmanConnection } from './podmanConnection'; +import { VMType } from '@shared/src/models/IPodman'; +import { getPodmanMachineName } from '../utils/podman'; const mocks = vi.hoisted(() => { return { @@ -42,7 +45,6 @@ const mocks = vi.hoisted(() => { getTargetMock: vi.fn(), getDownloaderCompleter: vi.fn(), isCompletionEventMock: vi.fn(), - getFirstRunningMachineNameMock: vi.fn(), getPodmanCliMock: vi.fn(), }; }); @@ -52,8 +54,8 @@ vi.mock('@huggingface/gguf', () => ({ })); vi.mock('../utils/podman', () => ({ - getFirstRunningMachineName: mocks.getFirstRunningMachineNameMock, getPodmanCli: mocks.getPodmanCliMock, + getPodmanMachineName: vi.fn(), })); vi.mock('@podman-desktop/api', () => { @@ -92,6 +94,10 @@ vi.mock('../utils/downloader', () => ({ }, })); +const podmanConnectionMock = { + getContainerProviderConnections: vi.fn(), +} as unknown as PodmanConnection; + const cancellationTokenRegistryMock = { createCancellationTokenSource: vi.fn(), } as unknown as CancellationTokenRegistry; @@ -183,6 +189,7 @@ test('getModelsInfo should get models in local directory', async () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -232,6 +239,7 @@ test('getModelsInfo should return an empty array if the models folder does not e telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); manager.getLocalModelsFromDisk(); @@ -272,6 +280,7 @@ test('getLocalModelsFromDisk should return undefined Date and size when stat fai telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -330,6 +339,7 @@ test('getLocalModelsFromDisk should skip folders containing tmp files', async () telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -370,6 +380,7 @@ test('loadLocalModels should post a message with the message on disk and on cata telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -420,6 +431,7 @@ test('deleteModel deletes the model folder', async () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -484,6 +496,7 @@ describe('deleting models', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); await manager.loadLocalModels(); @@ -550,6 +563,7 @@ describe('deleting models', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); await manager.loadLocalModels(); await manager.deleteModel('model-id-1'); @@ -557,14 +571,36 @@ describe('deleting models', () => { expect(removeUserModelMock).toBeCalledWith('model-id-1'); }); - test('deleting on windows should check if models is uploaded', async () => { - vi.mocked(env).isWindows = true; - vi.mocked(coreProcess.exec).mockResolvedValue({} as unknown as RunResult); - mocks.getFirstRunningMachineNameMock.mockReturnValue('dummyMachine'); + test('deleting on windows should check for all connections', async () => { + vi.mocked(coreProcess.exec).mockResolvedValue({} as RunResult); mocks.getPodmanCliMock.mockReturnValue('dummyCli'); + vi.mocked(env).isWindows = true; + const connections: ContainerProviderConnection[] = [ + { + name: 'Machine 1', + type: 'podman', + vmType: VMType.HYPERV, + endpoint: { + socketPath: '', + }, + status: () => 'started', + }, + { + name: 'Machine 2', + type: 'podman', + vmType: VMType.WSL, + endpoint: { + socketPath: '', + }, + status: () => 'started', + }, + ]; + vi.mocked(podmanConnectionMock.getContainerProviderConnections).mockReturnValue(connections); + vi.mocked(getPodmanMachineName).mockReturnValue('machine-2'); const rmSpy = vi.spyOn(fs.promises, 'rm'); rmSpy.mockResolvedValue(undefined); + const manager = new ModelsManager( '/home/user/aistudio', { @@ -587,62 +623,24 @@ describe('deleting models', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); await manager.loadLocalModels(); + // delete the model await manager.deleteModel('model-id-1'); - expect(coreProcess.exec).toHaveBeenNthCalledWith(1, 'dummyCli', [ - 'machine', - 'ssh', - 'dummyMachine', - 'stat', - '/home/user/ai-lab/models/dummyFile', - ]); - expect(coreProcess.exec).toHaveBeenNthCalledWith(2, 'dummyCli', [ + + expect(podmanConnectionMock.getContainerProviderConnections).toHaveBeenCalledOnce(); + + expect(coreProcess.exec).toHaveBeenCalledWith('dummyCli', [ 'machine', 'ssh', - 'dummyMachine', + 'machine-2', 'rm', '-f', '/home/user/ai-lab/models/dummyFile', ]); }); - - test('deleting on windows should check if models is uploaded', async () => { - vi.mocked(env).isWindows = false; - - const rmSpy = vi.spyOn(fs.promises, 'rm'); - rmSpy.mockResolvedValue(undefined); - const manager = new ModelsManager( - '/home/user/aistudio', - { - postMessage: vi.fn().mockResolvedValue(undefined), - } as unknown as Webview, - { - getModels: () => { - return [ - { - id: 'model-id-1', - url: 'model-url', - file: { - file: 'dummyFile', - path: 'dummyPath', - }, - }, - ] as ModelInfo[]; - }, - } as CatalogManager, - telemetryLogger, - taskRegistry, - cancellationTokenRegistryMock, - ); - - await manager.loadLocalModels(); - await manager.deleteModel('model-id-1'); - expect(coreProcess.exec).not.toHaveBeenCalled(); - expect(mocks.getFirstRunningMachineNameMock).not.toHaveBeenCalled(); - expect(mocks.getPodmanCliMock).not.toHaveBeenCalled(); - }); }); describe('downloadModel', () => { @@ -659,6 +657,7 @@ describe('downloadModel', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -693,6 +692,7 @@ describe('downloadModel', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); const updateTaskMock = vi.spyOn(taskRegistry, 'updateTask'); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); @@ -724,6 +724,7 @@ describe('downloadModel', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); vi.spyOn(taskRegistry, 'updateTask'); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(true); @@ -754,6 +755,7 @@ describe('downloadModel', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -790,6 +792,7 @@ describe('downloadModel', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); vi.spyOn(manager, 'isModelOnDisk').mockReturnValue(false); @@ -837,6 +840,7 @@ describe('getModelMetadata', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); await expect(() => manager.getModelMetadata('unknown-model-id')).rejects.toThrowError( @@ -861,6 +865,7 @@ describe('getModelMetadata', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); @@ -901,6 +906,7 @@ describe('getModelMetadata', () => { telemetryLogger, taskRegistry, cancellationTokenRegistryMock, + podmanConnectionMock, ); manager.init(); diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 9c5b16f0a..1efa1d4cc 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -19,7 +19,7 @@ import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import fs from 'node:fs'; import * as path from 'node:path'; -import { type Webview, fs as apiFs, type Disposable, env } from '@podman-desktop/api'; +import { type Webview, fs as apiFs, type Disposable, env, type ContainerProviderConnection } from '@podman-desktop/api'; import { Messages } from '@shared/Messages'; import type { CatalogManager } from './catalogManager'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; @@ -31,11 +31,13 @@ import type { BaseEvent } from '../models/baseEvent'; import { isCompletionEvent, isProgressEvent } from '../models/baseEvent'; import { Uploader } from '../utils/uploader'; import { deleteRemoteModel, getLocalModelFile, isModelUploaded } from '../utils/modelsUtils'; -import { getFirstRunningMachineName } from '../utils/podman'; +import { getPodmanMachineName } from '../utils/podman'; import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry'; import { hasValidSha } from '../utils/sha'; import type { GGUFParseOutput } from '@huggingface/gguf'; import { gguf } from '@huggingface/gguf'; +import type { PodmanConnection } from './podmanConnection'; +import { VMType } from '@shared/src/models/IPodman'; export class ModelsManager implements Disposable { #models: Map; @@ -51,6 +53,7 @@ export class ModelsManager implements Disposable { private telemetry: podmanDesktopApi.TelemetryLogger, private taskRegistry: TaskRegistry, private cancellationTokenRegistry: CancellationTokenRegistry, + private podmanConnection: PodmanConnection, ) { this.#models = new Map(); this.#disposables = []; @@ -219,17 +222,22 @@ export class ModelsManager implements Disposable { return; } - const machineName = getFirstRunningMachineName(); - if (!machineName) { - console.warn('No podman machine is running'); - return; - } + // get all container provider connections + const connections = this.podmanConnection.getContainerProviderConnections(); - // check if model already loaded on the podman machine - const existsRemote = await isModelUploaded(machineName, modelInfo); - if (!existsRemote) return; + // iterate over all connections + for (const connection of connections) { + // ignore non-wsl machines + if (connection.vmType !== VMType.WSL) continue; + // Get the corresponding machine name + const machineName = getPodmanMachineName(connection); - return deleteRemoteModel(machineName, modelInfo); + // check if model already loaded on the podman machine + const existsRemote = await isModelUploaded(machineName, modelInfo); + if (!existsRemote) return; + + await deleteRemoteModel(machineName, modelInfo); + } } /** @@ -412,13 +420,18 @@ export class ModelsManager implements Disposable { return downloader.getTarget(); } - async uploadModelToPodmanMachine(model: ModelInfo, labels?: { [key: string]: string }): Promise { - this.taskRegistry.createTask(`Copying model ${model.name} to Podman Machine`, 'loading', { + async uploadModelToPodmanMachine( + connection: ContainerProviderConnection, + model: ModelInfo, + labels?: { [key: string]: string }, + ): Promise { + this.taskRegistry.createTask(`Copying model ${model.name} to ${connection.name}`, 'loading', { ...labels, 'model-uploading': model.id, + connection: connection.name, }); - const uploader = new Uploader(model); + const uploader = new Uploader(connection, model); uploader.onEvent(event => this.onDownloadUploadEvent(event, 'upload'), this); // perform download diff --git a/packages/backend/src/managers/podmanConnection.spec.ts b/packages/backend/src/managers/podmanConnection.spec.ts index 9c567479d..0314c1e0a 100644 --- a/packages/backend/src/managers/podmanConnection.spec.ts +++ b/packages/backend/src/managers/podmanConnection.spec.ts @@ -20,6 +20,7 @@ import { beforeEach, describe, expect, test, vi } from 'vitest'; import { PodmanConnection } from './podmanConnection'; import type { ContainerProviderConnection, + ProviderContainerConnection, ProviderEvent, RegisterContainerConnectionEvent, RunResult, @@ -106,20 +107,18 @@ describe('podman connection initialization', () => { test('init should fetch all container connections', () => { const statusMock = vi.fn().mockReturnValue('started'); - - vi.mocked(provider.getContainerConnections).mockReturnValue([ - { - connection: { - type: 'podman', - status: statusMock, - name: 'Podman Machine', - endpoint: { - socketPath: './socket-path', - }, + const providerContainerConnection: ProviderContainerConnection = { + connection: { + type: 'podman', + status: statusMock, + name: 'Podman Machine', + endpoint: { + socketPath: './socket-path', }, - providerId: 'podman', }, - ]); + providerId: 'podman', + }; + vi.mocked(provider.getContainerConnections).mockReturnValue([providerContainerConnection]); const manager = new PodmanConnection(webviewMock); manager.init(); @@ -134,6 +133,7 @@ describe('podman connection initialization', () => { }, ]); + expect(manager.getContainerProviderConnections()).toStrictEqual([providerContainerConnection.connection]); expect(statusMock).toHaveBeenCalled(); }); }); diff --git a/packages/backend/src/managers/podmanConnection.ts b/packages/backend/src/managers/podmanConnection.ts index 552c3d782..f2d8adf02 100644 --- a/packages/backend/src/managers/podmanConnection.ts +++ b/packages/backend/src/managers/podmanConnection.ts @@ -54,6 +54,10 @@ export class PodmanConnection extends Publisher connection.name === mConnection.name, ); - if (!output) throw new Error(`no container provider connection found for connection name ${name}`); + if (!output) throw new Error(`no container provider connection found for connection name ${connection.name}`); return output; } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 63488630f..924eec75e 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -207,6 +207,7 @@ export class Studio { this.#telemetry, this.#taskRegistry, this.#cancellationTokenRegistry, + this.#podmanConnection, ); this.#modelsManager.init(); this.#extensionContext.subscriptions.push(this.#modelsManager); diff --git a/packages/backend/src/utils/podman.spec.ts b/packages/backend/src/utils/podman.spec.ts index bac4bbe88..41e8ef862 100644 --- a/packages/backend/src/utils/podman.spec.ts +++ b/packages/backend/src/utils/podman.spec.ts @@ -81,141 +81,6 @@ describe('getPodmanCli', () => { }); }); -describe('getFirstRunningMachineName', () => { - test('return machine name if connection name does contain default Podman Machine name', () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'Podman Machine', - status: () => 'started', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - ]); - const machineName = utils.getFirstRunningMachineName(); - expect(machineName).equals('podman-machine-default'); - }); - test('return machine name if connection name does contain custom Podman Machine name', () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'Podman Machine test', - status: () => 'started', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - ]); - const machineName = utils.getFirstRunningMachineName(); - expect(machineName).equals('podman-machine-test'); - }); - test('return machine name if connection name does not contain Podman Machine', () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'test', - status: () => 'started', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - ]); - const machineName = utils.getFirstRunningMachineName(); - expect(machineName).equals('test'); - }); - test('return undefined if there is no running connection', () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'machine', - status: () => 'stopped', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - ]); - const machineName = utils.getFirstRunningMachineName(); - expect(machineName).toBeUndefined(); - }); -}); - -describe('getFirstRunningPodmanConnection', () => { - test('should return undefined if failing at retrieving connection', async () => { - mocks.getConfigurationMock.mockRejectedValue('error'); - const result = utils.getFirstRunningPodmanConnection(); - expect(result).toBeUndefined(); - }); - test('should return undefined if default podman machine is not running', async () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'machine', - status: () => 'stopped', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - { - connection: { - name: 'machine2', - status: () => 'stopped', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman2', - }, - ]); - const result = utils.getFirstRunningPodmanConnection(); - expect(result).toBeUndefined(); - }); - test('should return default running podman connection', async () => { - mocks.getContainerConnectionsMock.mockReturnValue([ - { - connection: { - name: 'machine', - status: () => 'stopped', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }, - { - connection: { - name: 'machine2', - status: () => 'started', - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman2', - }, - ]); - const result = utils.getFirstRunningPodmanConnection(); - expect(result?.connection.name).equal('machine2'); - }); -}); - describe('getPodmanConnection', () => { test('throw error if there is no podman connection with name', () => { mocks.getContainerConnectionsMock.mockReturnValue([ diff --git a/packages/backend/src/utils/podman.ts b/packages/backend/src/utils/podman.ts index 7cdd18d07..d66718ca8 100644 --- a/packages/backend/src/utils/podman.ts +++ b/packages/backend/src/utils/podman.ts @@ -15,7 +15,7 @@ * * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import type { ProviderContainerConnection } from '@podman-desktop/api'; +import type { ContainerProviderConnection, ProviderContainerConnection } from '@podman-desktop/api'; import { configuration, env, provider } from '@podman-desktop/api'; export const MIN_CPUS_VALUE = 4; @@ -52,49 +52,20 @@ export function getCustomBinaryPath(): string | undefined { } /** - * @deprecated uses {@link PodmanConnection.findRunningContainerProviderConnection} + * In the ${link ContainerProviderConnection.name} property the name is not usage, and we need to transform it + * @param connection */ -export function getFirstRunningMachineName(): string | undefined { - // the name of the podman connection is the name of the podman machine updated to make it more user friendly, - // so to retrieve the real machine name we need to revert the process - - // podman-machine-default -> Podman Machine - // podman-machine-{name} -> Podman Machine {name} - // {name} -> {name} - try { - const runningConnection = getFirstRunningPodmanConnection(); - if (!runningConnection) return undefined; - const runningConnectionName = runningConnection.connection.name; - if (runningConnectionName.startsWith('Podman Machine')) { - const machineName = runningConnectionName.replace(/Podman Machine\s*/, 'podman-machine-'); - if (machineName.endsWith('-')) { - return `${machineName}default`; - } - return machineName; - } else { - return runningConnectionName; +export function getPodmanMachineName(connection: ContainerProviderConnection): string { + const runningConnectionName = connection.name; + if (runningConnectionName.startsWith('Podman Machine')) { + const machineName = runningConnectionName.replace(/Podman Machine\s*/, 'podman-machine-'); + if (machineName.endsWith('-')) { + return `${machineName}default`; } - } catch (e) { - console.log(e); - } - - return undefined; -} - -/** - * @deprecated uses {@link PodmanConnection.findRunningContainerProviderConnection} - */ -export function getFirstRunningPodmanConnection(): ProviderContainerConnection | undefined { - let engine: ProviderContainerConnection | undefined = undefined; - try { - engine = provider - .getContainerConnections() - .filter(connection => connection.connection.type === 'podman') - .find(connection => connection.connection.status() === 'started'); - } catch (e) { - console.log(e); + return machineName; + } else { + return runningConnectionName; } - return engine; } /** diff --git a/packages/backend/src/utils/uploader.spec.ts b/packages/backend/src/utils/uploader.spec.ts index f6e47fa5a..a4d87ce83 100644 --- a/packages/backend/src/utils/uploader.spec.ts +++ b/packages/backend/src/utils/uploader.spec.ts @@ -22,12 +22,8 @@ import * as podmanDesktopApi from '@podman-desktop/api'; import { beforeEach } from 'node:test'; import { Uploader } from './uploader'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; - -const mocks = vi.hoisted(() => { - return { - execMock: vi.fn(), - }; -}); +import type { ContainerProviderConnection } from '@podman-desktop/api'; +import { VMType } from '@shared/src/models/IPodman'; vi.mock('@podman-desktop/api', async () => { return { @@ -35,7 +31,7 @@ vi.mock('@podman-desktop/api', async () => { isWindows: false, }, process: { - exec: mocks.execMock, + exec: vi.fn(), }, EventEmitter: vi.fn().mockImplementation(() => { return { @@ -44,7 +40,18 @@ vi.mock('@podman-desktop/api', async () => { }), }; }); -const uploader = new Uploader({ + +const connectionMock: ContainerProviderConnection = { + name: 'machine2', + type: 'podman', + status: () => 'started', + vmType: VMType.WSL, + endpoint: { + socketPath: 'socket.sock', + }, +}; + +const uploader = new Uploader(connectionMock, { id: 'dummyModelId', file: { file: 'dummyFile.guff', diff --git a/packages/backend/src/utils/uploader.ts b/packages/backend/src/utils/uploader.ts index ceaadc2cf..fbb13a872 100644 --- a/packages/backend/src/utils/uploader.ts +++ b/packages/backend/src/utils/uploader.ts @@ -16,20 +16,22 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import { EventEmitter, type Event } from '@podman-desktop/api'; +import { EventEmitter, type Event, type ContainerProviderConnection } from '@podman-desktop/api'; import { WSLUploader } from '../workers/uploader/WSLUploader'; import { getDurationSecondsSince } from './utils'; import type { CompletionEvent, BaseEvent } from '../models/baseEvent'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import { getLocalModelFile } from './modelsUtils'; import type { IWorker } from '../workers/IWorker'; +import type { UploaderOptions } from '../workers/uploader/UploaderOptions'; export class Uploader { readonly #_onEvent = new EventEmitter(); readonly onEvent: Event = this.#_onEvent.event; - readonly #workers: IWorker[] = []; + readonly #workers: IWorker[] = []; constructor( + private connection: ContainerProviderConnection, private modelInfo: ModelInfo, private abortSignal?: AbortSignal, ) { @@ -44,7 +46,7 @@ export class Uploader { */ async perform(id: string): Promise { // Find the uploader for the current operating system - const worker: IWorker | undefined = this.#workers.find(w => w.enabled()); + const worker: IWorker | undefined = this.#workers.find(w => w.enabled()); // If none are found, we return the current path if (worker === undefined) { @@ -62,7 +64,10 @@ export class Uploader { // measure performance const startTime = performance.now(); // get new path - const remotePath = await worker.perform(this.modelInfo); + const remotePath = await worker.perform({ + connection: this.connection, + model: this.modelInfo, + }); // compute full time const durationSeconds = getDurationSecondsSince(startTime); // fire events diff --git a/packages/backend/src/workers/uploader/UploaderOptions.ts b/packages/backend/src/workers/uploader/UploaderOptions.ts new file mode 100644 index 000000000..c4eaea963 --- /dev/null +++ b/packages/backend/src/workers/uploader/UploaderOptions.ts @@ -0,0 +1,25 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import type { ModelInfo } from '@shared/src/models/IModelInfo'; +import type { ContainerProviderConnection } from '@podman-desktop/api'; + +export interface UploaderOptions { + model: ModelInfo; + connection: ContainerProviderConnection; +} diff --git a/packages/backend/src/workers/uploader/WSLUploader.spec.ts b/packages/backend/src/workers/uploader/WSLUploader.spec.ts index d9a3b0a11..3f07ac147 100644 --- a/packages/backend/src/workers/uploader/WSLUploader.spec.ts +++ b/packages/backend/src/workers/uploader/WSLUploader.spec.ts @@ -16,82 +16,104 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import { expect, test, describe, vi } from 'vitest'; +import { expect, test, describe, vi, beforeEach } from 'vitest'; import { WSLUploader } from './WSLUploader'; -import * as podmanDesktopApi from '@podman-desktop/api'; -import * as utils from '../../utils/podman'; -import { beforeEach } from 'node:test'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; - -const mocks = vi.hoisted(() => { - return { - execMock: vi.fn(), - }; -}); +import { configuration, env, process, type ContainerProviderConnection, type RunResult } from '@podman-desktop/api'; +import { VMType } from '@shared/src/models/IPodman'; vi.mock('@podman-desktop/api', () => ({ env: { isWindows: false, }, process: { - exec: mocks.execMock, + exec: vi.fn(), + }, + configuration: { + getConfiguration: vi.fn(), }, })); +const connectionMock: ContainerProviderConnection = { + name: 'machine2', + type: 'podman', + status: () => 'started', + vmType: VMType.WSL, + endpoint: { + socketPath: 'socket.sock', + }, +}; + const wslUploader = new WSLUploader(); beforeEach(() => { vi.resetAllMocks(); + + vi.mocked(configuration.getConfiguration).mockReturnValue({ + get: () => 'podman.exe', + has: vi.fn(), + update: vi.fn(), + }); }); describe('canUpload', () => { test('should return false if system is not windows', () => { - vi.mocked(podmanDesktopApi.env).isWindows = false; + vi.mocked(env).isWindows = false; const result = wslUploader.enabled(); expect(result).toBeFalsy(); }); test('should return true if system is windows', () => { - vi.mocked(podmanDesktopApi.env).isWindows = true; + vi.mocked(env).isWindows = true; const result = wslUploader.enabled(); expect(result).toBeTruthy(); }); }); describe('upload', () => { - vi.spyOn(utils, 'getPodmanCli').mockReturnValue('podman'); - vi.spyOn(utils, 'getFirstRunningPodmanConnection').mockResolvedValue({ - connection: { - name: 'test', - status: vi.fn(), - endpoint: { - socketPath: '/endpoint.sock', - }, - type: 'podman', - }, - providerId: 'podman', - }); test('throw if localpath is not defined', async () => { await expect( wslUploader.perform({ - file: undefined, - } as unknown as ModelInfo), + connection: connectionMock, + model: { + file: undefined, + } as unknown as ModelInfo, + }), ).rejects.toThrowError('model is not available locally.'); }); + + test('non-WSL VMType should return the original path', async () => { + vi.mocked(process.exec).mockRejectedValueOnce('error'); + const result = await wslUploader.perform({ + connection: { + ...connectionMock, + vmType: VMType.UNKNOWN, + }, + model: { + id: 'dummyId', + file: { path: 'C:\\Users\\podman\\folder', file: 'dummy.guff' }, + } as unknown as ModelInfo, + }); + expect(process.exec).not.toHaveBeenCalled(); + expect(result.startsWith('C:\\Users\\podman\\folder')).toBeTruthy(); + }); + test('copy model if not exists on podman machine', async () => { - mocks.execMock.mockRejectedValueOnce('error'); - vi.spyOn(utils, 'getFirstRunningMachineName').mockReturnValue('machine2'); + vi.mocked(process.exec).mockRejectedValueOnce('error'); await wslUploader.perform({ - id: 'dummyId', - file: { path: 'C:\\Users\\podman\\folder', file: 'dummy.guff' }, - } as unknown as ModelInfo); - expect(mocks.execMock).toBeCalledWith('podman', [ + connection: connectionMock, + model: { + id: 'dummyId', + file: { path: 'C:\\Users\\podman\\folder', file: 'dummy.guff' }, + } as unknown as ModelInfo, + }); + expect(process.exec).toBeCalledWith('podman.exe', [ 'machine', 'ssh', 'machine2', 'stat', '/home/user/ai-lab/models/dummy.guff', ]); - expect(mocks.execMock).toBeCalledWith('podman', [ + expect(process.exec).toBeCalledWith('podman.exe', [ 'machine', 'ssh', 'machine2', @@ -99,7 +121,7 @@ describe('upload', () => { '-p', '/home/user/ai-lab/models/', ]); - expect(mocks.execMock).toBeCalledWith('podman', [ + expect(process.exec).toBeCalledWith('podman.exe', [ 'machine', 'ssh', 'machine2', @@ -107,23 +129,24 @@ describe('upload', () => { '/mnt/c/Users/podman/folder/dummy.guff', '/home/user/ai-lab/models/dummy.guff', ]); - mocks.execMock.mockClear(); }); + test('do not copy model if it exists on podman machine', async () => { - mocks.execMock.mockResolvedValue(''); - vi.spyOn(utils, 'getFirstRunningMachineName').mockReturnValue('machine2'); + vi.mocked(process.exec).mockResolvedValue({} as RunResult); await wslUploader.perform({ - id: 'dummyId', - file: { path: 'C:\\Users\\podman\\folder', file: 'dummy.guff' }, - } as unknown as ModelInfo); - expect(mocks.execMock).toBeCalledWith('podman', [ + connection: connectionMock, + model: { + id: 'dummyId', + file: { path: 'C:\\Users\\podman\\folder', file: 'dummy.guff' }, + } as unknown as ModelInfo, + }); + expect(process.exec).toBeCalledWith('podman.exe', [ 'machine', 'ssh', 'machine2', 'stat', '/home/user/ai-lab/models/dummy.guff', ]); - expect(mocks.execMock).toBeCalledTimes(1); - mocks.execMock.mockClear(); + expect(process.exec).toBeCalledTimes(1); }); }); diff --git a/packages/backend/src/workers/uploader/WSLUploader.ts b/packages/backend/src/workers/uploader/WSLUploader.ts index 1794c367b..d8c6ec72e 100644 --- a/packages/backend/src/workers/uploader/WSLUploader.ts +++ b/packages/backend/src/workers/uploader/WSLUploader.ts @@ -17,28 +17,33 @@ ***********************************************************************/ import * as podmanDesktopApi from '@podman-desktop/api'; -import { getFirstRunningMachineName, getPodmanCli } from '../../utils/podman'; +import { getPodmanCli, getPodmanMachineName } from '../../utils/podman'; import { getLocalModelFile, getRemoteModelFile, isModelUploaded, MACHINE_BASE_FOLDER } from '../../utils/modelsUtils'; -import type { ModelInfo } from '@shared/src/models/IModelInfo'; import { WindowsWorker } from '../WindowsWorker'; +import { VMType } from '@shared/src/models/IPodman'; +import type { UploaderOptions } from './UploaderOptions'; -export class WSLUploader extends WindowsWorker { - async perform(modelInfo: ModelInfo): Promise { - const localPath = getLocalModelFile(modelInfo); +export class WSLUploader extends WindowsWorker { + async perform(options: UploaderOptions): Promise { + const localPath = getLocalModelFile(options.model); + + // ensure the connection type is WSL + if (options.connection.vmType !== VMType.WSL) { + console.warn('cannot upload on non-WSL machine'); + return localPath; + } + + // the connection name cannot be used as it is + const machineName = getPodmanMachineName(options.connection); const driveLetter = localPath.charAt(0); const convertToMntPath = localPath .replace(`${driveLetter}:\\`, `/mnt/${driveLetter.toLowerCase()}/`) .replace(/\\/g, '/'); - const machineName = getFirstRunningMachineName(); - - if (!machineName) { - throw new Error('No podman machine is running'); - } // check if model already loaded on the podman machine - const existsRemote = await isModelUploaded(machineName, modelInfo); - const remoteFile = getRemoteModelFile(modelInfo); + const existsRemote = await isModelUploaded(machineName, options.model); + const remoteFile = getRemoteModelFile(options.model); // if not exists remotely it copies it from the local path if (!existsRemote) {