Skip to content

Commit

Permalink
create a dedicated class for listening the registration of podman pro…
Browse files Browse the repository at this point in the history
…vider
  • Loading branch information
feloy committed Jan 31, 2024
1 parent 76f6ace commit 2ebd657
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 76 deletions.
34 changes: 15 additions & 19 deletions packages/backend/src/managers/playground.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,20 @@ import { beforeEach, expect, test, vi } from 'vitest';
import { PlayGroundManager } from './playground';
import type { ImageInfo, Webview } from '@podman-desktop/api';
import type { ContainerRegistry } from '../registries/ContainerRegistry';
import type { PodmanConnection } from './podmanConnection';

const mocks = vi.hoisted(() => ({
postMessage: vi.fn(),
getContainerConnections: vi.fn(),
pullImage: vi.fn(),
createContainer: vi.fn(),
stopContainer: vi.fn(),
getFreePort: vi.fn(),
containerRegistrySubscribeMock: vi.fn(),
getConnection: vi.fn(),
}));

vi.mock('@podman-desktop/api', async () => {
return {
provider: {
getContainerConnections: mocks.getContainerConnections,
},
containerEngine: {
pullImage: mocks.pullImage,
createContainer: mocks.createContainer,
Expand Down Expand Up @@ -64,27 +62,27 @@ beforeEach(() => {
postMessage: mocks.postMessage,
} as unknown as Webview,
containerRegistryMock,
{
getConnection: mocks.getConnection,
} as unknown as PodmanConnection,
);
});

test('startPlayground should fail if no provider', async () => {
mocks.postMessage.mockResolvedValue(undefined);
mocks.getContainerConnections.mockReturnValue([]);
await expect(manager.startPlayground('model1', '/path/to/model')).rejects.toThrowError(
'Unable to find an engine to start playground',
);
});

test('startPlayground should download image if not present then create container', async () => {
mocks.postMessage.mockResolvedValue(undefined);
mocks.getContainerConnections.mockReturnValue([
{
connection: {
type: 'podman',
status: () => 'started',
},
mocks.getConnection.mockReturnValue({
connection: {
type: 'podman',
status: () => 'started',
},
]);
});
vi.spyOn(manager, 'selectImage')
.mockResolvedValueOnce(undefined)
.mockResolvedValueOnce({
Expand Down Expand Up @@ -135,14 +133,12 @@ test('stopPlayground should fail if no playground is running', async () => {

test('stopPlayground should stop a started playground', async () => {
mocks.postMessage.mockResolvedValue(undefined);
mocks.getContainerConnections.mockReturnValue([
{
connection: {
type: 'podman',
status: () => 'started',
},
mocks.getConnection.mockReturnValue({
connection: {
type: 'podman',
status: () => 'started',
},
]);
});
vi.spyOn(manager, 'selectImage').mockResolvedValue({
Id: 'image1',
engineId: 'engine1',
Expand Down
95 changes: 40 additions & 55 deletions packages/backend/src/managers/playground.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import {
provider,
containerEngine,
type Webview,
type ProviderContainerConnection,
type ImageInfo,
type RegisterContainerConnectionEvent,
} from '@podman-desktop/api';
import { containerEngine, type Webview, type ImageInfo } from '@podman-desktop/api';
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';
import type { ModelResponse } from '@shared/src/models/IModelResponse';

Expand All @@ -34,21 +27,14 @@ import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';
import { MSG_NEW_PLAYGROUND_QUERIES_STATE, MSG_PLAYGROUNDS_STATE_UPDATE } from '@shared/Messages';
import type { PlaygroundState, PlaygroundStatus } from '@shared/src/models/IPlaygroundState';
import type { ContainerRegistry } from '../registries/ContainerRegistry';
import type { PodmanConnection } from './podmanConnection';

const LABEL_MODEL_ID = 'ai-studio-model-id';
const LABEL_MODEL_PORT = 'ai-studio-model-port';

// TODO: this should not be hardcoded
const PLAYGROUND_IMAGE = 'quay.io/bootsy/playground:v0';

function findFirstProvider(): ProviderContainerConnection | undefined {
const engines = provider
.getContainerConnections()
.filter(connection => connection.connection.type === 'podman')
.filter(connection => connection.connection.status() === 'started');
return engines.length > 0 ? engines[0] : undefined;
}

export class PlayGroundManager {
private queryIdCounter = 0;

Expand All @@ -59,49 +45,48 @@ export class PlayGroundManager {
constructor(
private webview: Webview,
private containerRegistry: ContainerRegistry,
private podmanConnection: PodmanConnection,
) {
this.playgrounds = new Map<string, PlaygroundState>();
this.queries = new Map<number, QueryState>();
}

async adoptRunningPlaygrounds() {
provider.onDidRegisterContainerConnection(async (e: RegisterContainerConnectionEvent) => {
if (e.connection.type === 'podman' && e.connection.status() === 'started') {
await this.doAdoptRunningPlaygrounds();
}
adoptRunningPlaygrounds() {
this.podmanConnection.startupSubscribe(() => {
containerEngine
.listContainers()
.then(containers => {
const playgroundContainers = containers.filter(
c => LABEL_MODEL_ID in c.Labels && LABEL_MODEL_PORT in c.Labels && c.State === 'running',
);
for (const containerToAdopt of playgroundContainers) {
const modelId = containerToAdopt.Labels[LABEL_MODEL_ID];
if (this.playgrounds.has(modelId)) {
continue;
}
const modelPort = parseInt(containerToAdopt.Labels[LABEL_MODEL_PORT], 10);
if (isNaN(modelPort)) {
continue;
}
const state: PlaygroundState = {
modelId,
status: 'running',
container: {
containerId: containerToAdopt.Id,
engineId: containerToAdopt.engineId,
port: modelPort,
},
};
this.updatePlaygroundState(modelId, state);
}
})
.catch((err: unknown) => {
console.error('error during adoption of existing playground containers', err);
});
});
// Do it now in case providers are already registered
await this.doAdoptRunningPlaygrounds();
}

private async doAdoptRunningPlaygrounds() {
const containers = await containerEngine.listContainers();
const playgroundContainers = containers.filter(
c => LABEL_MODEL_ID in c.Labels && LABEL_MODEL_PORT in c.Labels && c.State === 'running',
);
for (const containerToAdopt of playgroundContainers) {
const modelId = containerToAdopt.Labels[LABEL_MODEL_ID];
if (this.playgrounds.has(modelId)) {
continue;
}
const modelPort = parseInt(containerToAdopt.Labels[LABEL_MODEL_PORT], 10);
if (isNaN(modelPort)) {
continue;
}
const state: PlaygroundState = {
modelId,
status: 'running',
container: {
containerId: containerToAdopt.Id,
engineId: containerToAdopt.engineId,
port: modelPort,
},
};
this.updatePlaygroundState(modelId, state);
}
}

async selectImage(connection: ProviderContainerConnection, image: string): Promise<ImageInfo | undefined> {
async selectImage(image: string): Promise<ImageInfo | undefined> {
const images = (await containerEngine.listImages()).filter(im => im.RepoTags?.some(tag => tag === image));
return images.length > 0 ? images[0] : undefined;
}
Expand Down Expand Up @@ -145,16 +130,16 @@ export class PlayGroundManager {

this.setPlaygroundStatus(modelId, 'starting');

const connection = findFirstProvider();
const connection = this.podmanConnection.getConnection();
if (!connection) {
this.setPlaygroundStatus(modelId, 'error');
throw new Error('Unable to find an engine to start playground');
}

let image = await this.selectImage(connection, PLAYGROUND_IMAGE);
let image = await this.selectImage(PLAYGROUND_IMAGE);
if (!image) {
await containerEngine.pullImage(connection.connection, PLAYGROUND_IMAGE, () => {});
image = await this.selectImage(connection, PLAYGROUND_IMAGE);
await containerEngine.pullImage(connection, PLAYGROUND_IMAGE, () => {});
image = await this.selectImage(PLAYGROUND_IMAGE);
if (!image) {
this.setPlaygroundStatus(modelId, 'error');
throw new Error(`Unable to find ${PLAYGROUND_IMAGE} image`);
Expand Down
73 changes: 73 additions & 0 deletions packages/backend/src/managers/podmanConnection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/**********************************************************************
* Copyright (C) 2024 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

import { type ContainerProviderConnection, type RegisterContainerConnectionEvent, provider } from '@podman-desktop/api';

type startupHandle = () => void;

export class PodmanConnection {
#firstFound = false;
#connection: ContainerProviderConnection | undefined = undefined;

#toExecuteAtStartup: startupHandle[] = [];

init(): void {
// In case the extension has not yet registered, we listen for new registrations
// and retain the first started podman provider
const disposable = provider.onDidRegisterContainerConnection((e: RegisterContainerConnectionEvent) => {
if (e.connection.type !== 'podman' || e.connection.status() !== 'started') {
return;
}
if (this.#firstFound) {
return;
}
this.#firstFound = true;
this.#connection = e.connection;
for (const f of this.#toExecuteAtStartup) {
f();
}
this.#toExecuteAtStartup = [];
disposable.dispose();
});

// In case at least one extension has already registered, we get one started podman provider
const engines = provider
.getContainerConnections()
.filter(connection => connection.connection.type === 'podman')
.filter(connection => connection.connection.status() === 'started');
if (engines.length > 0) {
disposable.dispose();
this.#firstFound = true;
this.#connection = engines[0].connection;
}
}

// startupSubscribe registers f to be executed when a podman container provider
// registers, or immediately if already registered
startupSubscribe(f: startupHandle): void {
if (this.#firstFound) {
f();
} else {
this.#toExecuteAtStartup.push(f);
}
}

getConnection(): ContainerProviderConnection | undefined {
return this.#connection;
}
}
3 changes: 3 additions & 0 deletions packages/backend/src/studio.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const studio = new Studio(mockedExtensionContext);

const mocks = vi.hoisted(() => ({
listContainers: vi.fn(),
getContainerConnections: vi.fn(),
}));

vi.mock('@podman-desktop/api', async () => {
Expand All @@ -56,6 +57,7 @@ vi.mock('@podman-desktop/api', async () => {
},
provider: {
onDidRegisterContainerConnection: vi.fn(),
getContainerConnections: mocks.getContainerConnections,
},
};
});
Expand All @@ -75,6 +77,7 @@ afterEach(() => {

test('check activate ', async () => {
mocks.listContainers.mockReturnValue([]);
mocks.getContainerConnections.mockReturnValue([]);
vi.spyOn(fs.promises, 'readFile').mockImplementation(() => {
return Promise.resolve('<html></html>');
});
Expand Down
8 changes: 6 additions & 2 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import path from 'node:path';
import os from 'os';
import fs from 'node:fs';
import { ContainerRegistry } from './registries/ContainerRegistry';
import { PodmanConnection } from './managers/podmanConnection';

// TODO: Need to be configured
export const AI_STUDIO_FOLDER = path.join('podman-desktop', 'ai-studio');
Expand Down Expand Up @@ -104,9 +105,11 @@ export class Studio {

this.rpcExtension = new RpcExtension(this.#panel.webview);
const gitManager = new GitManager();

const podmanConnection = new PodmanConnection();
const taskRegistry = new TaskRegistry();
const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview);
this.playgroundManager = new PlayGroundManager(this.#panel.webview, containerRegistry);
this.playgroundManager = new PlayGroundManager(this.#panel.webview, containerRegistry, podmanConnection);
// Create catalog manager, responsible for loading the catalog files and watching for changes
this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview);
this.modelsManager = new ModelsManager(appUserDirectory, this.#panel.webview, this.catalogManager);
Expand All @@ -128,7 +131,8 @@ export class Studio {

await this.catalogManager.loadCatalog();
await this.modelsManager.loadLocalModels();
await this.playgroundManager.adoptRunningPlaygrounds();
podmanConnection.init();
this.playgroundManager.adoptRunningPlaygrounds();

// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down

0 comments on commit 2ebd657

Please sign in to comment.