diff --git a/packages/backend/src/managers/playground.spec.ts b/packages/backend/src/managers/playground.spec.ts index 7864954cc..b2c65982f 100644 --- a/packages/backend/src/managers/playground.spec.ts +++ b/packages/backend/src/managers/playground.spec.ts @@ -20,22 +20,20 @@ import { beforeEach, expect, test, vi } from 'vitest'; import { PlayGroundManager } from './playground'; import type { ImageInfo, Webview } from '@podman-desktop/api'; import type { ContainerRegistry } from '../registries/ContainerRegistry'; +import type { PodmanConnection } from './podmanConnection'; const mocks = vi.hoisted(() => ({ postMessage: vi.fn(), - getContainerConnections: vi.fn(), pullImage: vi.fn(), createContainer: vi.fn(), stopContainer: vi.fn(), getFreePort: vi.fn(), containerRegistrySubscribeMock: vi.fn(), + getConnection: vi.fn(), })); vi.mock('@podman-desktop/api', async () => { return { - provider: { - getContainerConnections: mocks.getContainerConnections, - }, containerEngine: { pullImage: mocks.pullImage, createContainer: mocks.createContainer, @@ -64,12 +62,14 @@ beforeEach(() => { postMessage: mocks.postMessage, } as unknown as Webview, containerRegistryMock, + { + getConnection: mocks.getConnection, + } as unknown as PodmanConnection, ); }); test('startPlayground should fail if no provider', async () => { mocks.postMessage.mockResolvedValue(undefined); - mocks.getContainerConnections.mockReturnValue([]); await expect(manager.startPlayground('model1', '/path/to/model')).rejects.toThrowError( 'Unable to find an engine to start playground', ); @@ -77,14 +77,12 @@ test('startPlayground should fail if no provider', async () => { test('startPlayground should download image if not present then create container', async () => { mocks.postMessage.mockResolvedValue(undefined); - mocks.getContainerConnections.mockReturnValue([ - { - connection: { - type: 'podman', - status: () => 'started', - }, + mocks.getConnection.mockReturnValue({ + connection: { + type: 'podman', + status: () => 'started', }, - ]); + }); vi.spyOn(manager, 'selectImage') .mockResolvedValueOnce(undefined) .mockResolvedValueOnce({ @@ -135,14 +133,12 @@ test('stopPlayground should fail if no playground is running', async () => { test('stopPlayground should stop a started playground', async () => { mocks.postMessage.mockResolvedValue(undefined); - mocks.getContainerConnections.mockReturnValue([ - { - connection: { - type: 'podman', - status: () => 'started', - }, + mocks.getConnection.mockReturnValue({ + connection: { + type: 'podman', + status: () => 'started', }, - ]); + }); vi.spyOn(manager, 'selectImage').mockResolvedValue({ Id: 'image1', engineId: 'engine1', diff --git a/packages/backend/src/managers/playground.ts b/packages/backend/src/managers/playground.ts index 7ad72a20f..52b4de11a 100644 --- a/packages/backend/src/managers/playground.ts +++ b/packages/backend/src/managers/playground.ts @@ -16,14 +16,7 @@ * SPDX-License-Identifier: Apache-2.0 ***********************************************************************/ -import { - provider, - containerEngine, - type Webview, - type ProviderContainerConnection, - type ImageInfo, - type RegisterContainerConnectionEvent, -} from '@podman-desktop/api'; +import { containerEngine, type Webview, type ImageInfo } from '@podman-desktop/api'; import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; import type { ModelResponse } from '@shared/src/models/IModelResponse'; @@ -34,6 +27,7 @@ import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; import { MSG_NEW_PLAYGROUND_QUERIES_STATE, MSG_PLAYGROUNDS_STATE_UPDATE } from '@shared/Messages'; import type { PlaygroundState, PlaygroundStatus } from '@shared/src/models/IPlaygroundState'; import type { ContainerRegistry } from '../registries/ContainerRegistry'; +import type { PodmanConnection } from './podmanConnection'; const LABEL_MODEL_ID = 'ai-studio-model-id'; const LABEL_MODEL_PORT = 'ai-studio-model-port'; @@ -41,14 +35,6 @@ const LABEL_MODEL_PORT = 'ai-studio-model-port'; // TODO: this should not be hardcoded const PLAYGROUND_IMAGE = 'quay.io/bootsy/playground:v0'; -function findFirstProvider(): ProviderContainerConnection | undefined { - const engines = provider - .getContainerConnections() - .filter(connection => connection.connection.type === 'podman') - .filter(connection => connection.connection.status() === 'started'); - return engines.length > 0 ? engines[0] : undefined; -} - export class PlayGroundManager { private queryIdCounter = 0; @@ -59,49 +45,48 @@ export class PlayGroundManager { constructor( private webview: Webview, private containerRegistry: ContainerRegistry, + private podmanConnection: PodmanConnection, ) { this.playgrounds = new Map(); this.queries = new Map(); } - async adoptRunningPlaygrounds() { - provider.onDidRegisterContainerConnection(async (e: RegisterContainerConnectionEvent) => { - if (e.connection.type === 'podman' && e.connection.status() === 'started') { - await this.doAdoptRunningPlaygrounds(); - } + adoptRunningPlaygrounds() { + this.podmanConnection.startupSubscribe(() => { + containerEngine + .listContainers() + .then(containers => { + const playgroundContainers = containers.filter( + c => LABEL_MODEL_ID in c.Labels && LABEL_MODEL_PORT in c.Labels && c.State === 'running', + ); + for (const containerToAdopt of playgroundContainers) { + const modelId = containerToAdopt.Labels[LABEL_MODEL_ID]; + if (this.playgrounds.has(modelId)) { + continue; + } + const modelPort = parseInt(containerToAdopt.Labels[LABEL_MODEL_PORT], 10); + if (isNaN(modelPort)) { + continue; + } + const state: PlaygroundState = { + modelId, + status: 'running', + container: { + containerId: containerToAdopt.Id, + engineId: containerToAdopt.engineId, + port: modelPort, + }, + }; + this.updatePlaygroundState(modelId, state); + } + }) + .catch((err: unknown) => { + console.error('error during adoption of existing playground containers', err); + }); }); - // Do it now in case providers are already registered - await this.doAdoptRunningPlaygrounds(); - } - - private async doAdoptRunningPlaygrounds() { - const containers = await containerEngine.listContainers(); - const playgroundContainers = containers.filter( - c => LABEL_MODEL_ID in c.Labels && LABEL_MODEL_PORT in c.Labels && c.State === 'running', - ); - for (const containerToAdopt of playgroundContainers) { - const modelId = containerToAdopt.Labels[LABEL_MODEL_ID]; - if (this.playgrounds.has(modelId)) { - continue; - } - const modelPort = parseInt(containerToAdopt.Labels[LABEL_MODEL_PORT], 10); - if (isNaN(modelPort)) { - continue; - } - const state: PlaygroundState = { - modelId, - status: 'running', - container: { - containerId: containerToAdopt.Id, - engineId: containerToAdopt.engineId, - port: modelPort, - }, - }; - this.updatePlaygroundState(modelId, state); - } } - async selectImage(connection: ProviderContainerConnection, image: string): Promise { + async selectImage(image: string): Promise { const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image)); return images.length > 0 ? images[0] : undefined; } @@ -145,16 +130,16 @@ export class PlayGroundManager { this.setPlaygroundStatus(modelId, 'starting'); - const connection = findFirstProvider(); + const connection = this.podmanConnection.getConnection(); if (!connection) { this.setPlaygroundStatus(modelId, 'error'); throw new Error('Unable to find an engine to start playground'); } - let image = await this.selectImage(connection, PLAYGROUND_IMAGE); + let image = await this.selectImage(PLAYGROUND_IMAGE); if (!image) { - await containerEngine.pullImage(connection.connection, PLAYGROUND_IMAGE, () => {}); - image = await this.selectImage(connection, PLAYGROUND_IMAGE); + await containerEngine.pullImage(connection, PLAYGROUND_IMAGE, () => {}); + image = await this.selectImage(PLAYGROUND_IMAGE); if (!image) { this.setPlaygroundStatus(modelId, 'error'); throw new Error(`Unable to find ${PLAYGROUND_IMAGE} image`); diff --git a/packages/backend/src/managers/podmanConnection.ts b/packages/backend/src/managers/podmanConnection.ts new file mode 100644 index 000000000..1d275ffd4 --- /dev/null +++ b/packages/backend/src/managers/podmanConnection.ts @@ -0,0 +1,73 @@ +/********************************************************************** + * Copyright (C) 2024 Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ***********************************************************************/ + +import { type ContainerProviderConnection, type RegisterContainerConnectionEvent, provider } from '@podman-desktop/api'; + +type startupHandle = () => void; + +export class PodmanConnection { + #firstFound = false; + #connection: ContainerProviderConnection | undefined = undefined; + + #toExecuteAtStartup: startupHandle[] = []; + + init(): void { + // In case the extension has not yet registered, we listen for new registrations + // and retain the first started podman provider + const disposable = provider.onDidRegisterContainerConnection((e: RegisterContainerConnectionEvent) => { + if (e.connection.type !== 'podman' || e.connection.status() !== 'started') { + return; + } + if (this.#firstFound) { + return; + } + this.#firstFound = true; + this.#connection = e.connection; + for (const f of this.#toExecuteAtStartup) { + f(); + } + this.#toExecuteAtStartup = []; + disposable.dispose(); + }); + + // In case at least one extension has already registered, we get one started podman provider + const engines = provider + .getContainerConnections() + .filter(connection => connection.connection.type === 'podman') + .filter(connection => connection.connection.status() === 'started'); + if (engines.length > 0) { + disposable.dispose(); + this.#firstFound = true; + this.#connection = engines[0].connection; + } + } + + // startupSubscribe registers f to be executed when a podman container provider + // registers, or immediately if already registered + startupSubscribe(f: startupHandle): void { + if (this.#firstFound) { + f(); + } else { + this.#toExecuteAtStartup.push(f); + } + } + + getConnection(): ContainerProviderConnection | undefined { + return this.#connection; + } +} diff --git a/packages/backend/src/studio.spec.ts b/packages/backend/src/studio.spec.ts index 4b7505dd2..75c58ba1d 100644 --- a/packages/backend/src/studio.spec.ts +++ b/packages/backend/src/studio.spec.ts @@ -34,6 +34,7 @@ const studio = new Studio(mockedExtensionContext); const mocks = vi.hoisted(() => ({ listContainers: vi.fn(), + getContainerConnections: vi.fn(), })); vi.mock('@podman-desktop/api', async () => { @@ -56,6 +57,7 @@ vi.mock('@podman-desktop/api', async () => { }, provider: { onDidRegisterContainerConnection: vi.fn(), + getContainerConnections: mocks.getContainerConnections, }, }; }); @@ -75,6 +77,7 @@ afterEach(() => { test('check activate ', async () => { mocks.listContainers.mockReturnValue([]); + mocks.getContainerConnections.mockReturnValue([]); vi.spyOn(fs.promises, 'readFile').mockImplementation(() => { return Promise.resolve(''); }); diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index cfe6c7560..5ebe7194b 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -31,6 +31,7 @@ import path from 'node:path'; import os from 'os'; import fs from 'node:fs'; import { ContainerRegistry } from './registries/ContainerRegistry'; +import { PodmanConnection } from './managers/podmanConnection'; // TODO: Need to be configured export const AI_STUDIO_FOLDER = path.join('podman-desktop', 'ai-studio'); @@ -104,9 +105,11 @@ export class Studio { this.rpcExtension = new RpcExtension(this.#panel.webview); const gitManager = new GitManager(); + + const podmanConnection = new PodmanConnection(); const taskRegistry = new TaskRegistry(); const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview); - this.playgroundManager = new PlayGroundManager(this.#panel.webview, containerRegistry); + this.playgroundManager = new PlayGroundManager(this.#panel.webview, containerRegistry, podmanConnection); // Create catalog manager, responsible for loading the catalog files and watching for changes this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview); this.modelsManager = new ModelsManager(appUserDirectory, this.#panel.webview, this.catalogManager); @@ -128,7 +131,8 @@ export class Studio { await this.catalogManager.loadCatalog(); await this.modelsManager.loadLocalModels(); - await this.playgroundManager.adoptRunningPlaygrounds(); + podmanConnection.init(); + this.playgroundManager.adoptRunningPlaygrounds(); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi);