Skip to content

Commit

Permalink
watch local models
Browse files Browse the repository at this point in the history
  • Loading branch information
feloy committed Jan 24, 2024
1 parent 60eca54 commit fb335cc
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 61 deletions.
16 changes: 5 additions & 11 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ describe('pullApplication', () => {

const setStatusMock = vi.fn();
const cloneRepositoryMock = vi.fn();
const getLocalModelsMock = vi.fn();
const isModelOnDiskMock = vi.fn();
let manager: ApplicationManager;
let downloadModelMainSpy: MockInstance<
[modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string],
Expand Down Expand Up @@ -93,7 +93,7 @@ describe('pullApplication', () => {
setStatus: setStatusMock,
} as unknown as RecipeStatusRegistry,
{
getLocalModels: getLocalModelsMock,
isModelOnDisk: isModelOnDiskMock,
} as unknown as ModelsManager,
);

Expand All @@ -105,7 +105,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: false,
});
getLocalModelsMock.mockReturnValue([]);
isModelOnDiskMock.mockReturnValue(false);

const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -140,7 +140,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
getLocalModelsMock.mockReturnValue([]);
isModelOnDiskMock.mockReturnValue(false);

const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -169,13 +169,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
getLocalModelsMock.mockReturnValue([
{
id: 'model1',
file: 'model1.file',
},
]);

isModelOnDiskMock.mockReturnValue(true);
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down
3 changes: 1 addition & 2 deletions packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ export class ApplicationManager {
container => container.arch === undefined || container.arch === arch(),
);

const localModels = this.modelsManager.getLocalModels();
if (!localModels.map(m => m.id).includes(model.id)) {
if (!this.modelsManager.isModelOnDisk(model.id)) {
// Download model
taskUtil.setTask({
id: model.id,
Expand Down
18 changes: 10 additions & 8 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ import os from 'os';
import fs from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import type { Webview } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';

beforeEach(() => {
vi.resetAllMocks();
});

test('getLocalModels should return models in local directory', () => {
test('getLocalModelsFromDisk should get models in local directory', () => {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
const existsSyncSpy = vi.spyOn(fs, 'existsSync');
existsSyncSpy.mockImplementation((path: string) => {
Expand Down Expand Up @@ -53,9 +55,9 @@ test('getLocalModels should return models in local directory', () => {
] as fs.Dirent[];
}
});
const manager = new ModelsManager('/home/user/aistudio');
const models = manager.getLocalModels();
expect(models).toEqual([
const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager);
manager.getLocalModelsFromDisk();
expect(manager.getLocalModels()).toEqual([
{
id: 'model-id-1',
file: 'model-id-1-model',
Expand All @@ -71,13 +73,13 @@ test('getLocalModels should return models in local directory', () => {
]);
});

test('getLocalModels should return an empty array if the models folder does not exist', () => {
test('getLocalModelsFromDisk should return an empty array if the models folder does not exist', () => {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
const existsSyncSpy = vi.spyOn(fs, 'existsSync');
existsSyncSpy.mockReturnValue(false);
const manager = new ModelsManager('/home/user/aistudio');
const models = manager.getLocalModels();
expect(models).toEqual([]);
const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager);
manager.getLocalModelsFromDisk();
expect(manager.getLocalModels()).toEqual([]);
if (process.platform === 'win32') {
expect(existsSyncSpy).toHaveBeenCalledWith('\\home\\user\\aistudio\\models');
} else {
Expand Down
74 changes: 65 additions & 9 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,53 @@
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';
import fs from 'fs';
import * as path from 'node:path';
import { type Webview, fs as apiFs } from '@podman-desktop/api';
import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages';
import type { CatalogManager } from './catalogManager';

export class ModelsManager {
constructor(private appUserDirectory: string) {}
#modelsDir: string;
#localModels: Map<string, LocalModelInfo>;

getLocalModels(): LocalModelInfo[] {
const result: LocalModelInfo[] = [];
const modelsDir = path.join(this.appUserDirectory, 'models');
if (!fs.existsSync(modelsDir)) {
return [];
constructor(
private appUserDirectory: string,
private webview: Webview,
private catalogManager: CatalogManager,
) {
this.#modelsDir = path.join(this.appUserDirectory, 'models');
this.#localModels = new Map();
}

async loadLocalModels() {
const reloadLocalModels = async () => {
this.getLocalModelsFromDisk();
const models = this.getModelsInfo();
await this.webview.postMessage({
id: MSG_NEW_LOCAL_MODELS_STATE,
body: models,
});
};
const watcher = apiFs.createFileSystemWatcher(this.#modelsDir);
watcher.onDidCreate(reloadLocalModels);
watcher.onDidDelete(reloadLocalModels);
watcher.onDidChange(reloadLocalModels);
// Initialize the local models manually
await reloadLocalModels();
}

getModelsInfo() {
return this.catalogManager
.getModels()
.filter(m => this.#localModels.has(m.id))
.map(m => ({ ...m, file: this.#localModels.get(m.id) }));
}

getLocalModelsFromDisk(): void {
if (!fs.existsSync(this.#modelsDir)) {
return;
}
const entries = fs.readdirSync(modelsDir, { withFileTypes: true });
const result = new Map<string, LocalModelInfo>();
const entries = fs.readdirSync(this.#modelsDir, { withFileTypes: true });
const dirs = entries.filter(dir => dir.isDirectory());
for (const d of dirs) {
const modelEntries = fs.readdirSync(path.resolve(d.path, d.name));
Expand All @@ -21,13 +57,33 @@ export class ModelsManager {
}
const modelFile = modelEntries[0];
const info = fs.statSync(path.resolve(d.path, d.name, modelFile));
result.push({
result.set(d.name, {
id: d.name,
file: modelFile,
size: info.size,
creation: info.mtime,
});
}
return result;
this.#localModels = result;
}

isModelOnDisk(modelId: string) {
return this.#localModels.has(modelId);
}

getLocalModelInfo(modelId: string): LocalModelInfo {
if (!this.isModelOnDisk(modelId)) {
throw new Error('model is not on disk');
}
return this.#localModels.get(modelId);
}

getLocalModelPath(modelId: string): string {
const info = this.getLocalModelInfo(modelId);
return path.resolve(this.#modelsDir, modelId, info.file);
}

getLocalModels(): LocalModelInfo[] {
return Array.from(this.#localModels.values());
}
}
30 changes: 4 additions & 26 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ import type { PlayGroundManager } from './managers/playground';
import * as podmanDesktopApi from '@podman-desktop/api';
import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';

import * as path from 'node:path';
import type { CatalogManager } from './managers/catalogManager';
import type { Catalog } from '@shared/src/models/ICatalog';
import type { PlaygroundState } from '@shared/src/models/IPlaygroundState';
import type { ModelsManager } from './managers/modelsManager';
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';

export class StudioApiImpl implements StudioAPI {
constructor(
Expand Down Expand Up @@ -81,28 +79,11 @@ export class StudioApiImpl implements StudioAPI {
}

async getLocalModels(): Promise<ModelInfo[]> {
const local = this.modelsManager.getLocalModels();
const localMap = new Map<string, LocalModelInfo>();
for (const l of local) {
localMap.set(l.id, l);
}
const localIds = local.map(l => l.id);
return this.catalogManager
.getModels()
.filter(m => localIds.includes(m.id))
.map(m => ({ ...m, file: localMap.get(m.id) }));
return this.modelsManager.getModelsInfo();
}

async startPlayground(modelId: string): Promise<void> {
// TODO: improve the following
const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}

// TODO: we need to stop doing that.
const modelPath = path.resolve(this.appUserDirectory, 'models', modelId, localModelInfo[0].file);

const modelPath = this.modelsManager.getLocalModelPath(modelId);
await this.playgroundManager.startPlayground(modelId, modelPath);
}

Expand All @@ -111,11 +92,8 @@ export class StudioApiImpl implements StudioAPI {
}

askPlayground(modelId: string, prompt: string): Promise<number> {
const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
return this.playgroundManager.askPlayground(localModelInfo[0], prompt);
const localModelInfo = this.modelsManager.getLocalModelInfo(modelId);
return this.playgroundManager.askPlayground(localModelInfo, prompt);
}

async getPlaygroundQueriesState(): Promise<QueryState[]> {
Expand Down
2 changes: 2 additions & 0 deletions packages/backend/src/studio.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import type { ExtensionContext } from '@podman-desktop/api';

import * as fs from 'node:fs';

vi.mock('./managers/modelsManager');

const mockedExtensionContext = {
subscriptions: [],
} as unknown as ExtensionContext;
Expand Down
9 changes: 5 additions & 4 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ export class Studio {
const gitManager = new GitManager();
const taskRegistry = new TaskRegistry();
const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview);
this.modelsManager = new ModelsManager(appUserDirectory);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
// Create catalog manager, responsible for loading the catalog files and watching for changes
this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview);
this.modelsManager = new ModelsManager(appUserDirectory, this.#panel.webview, this.catalogManager);
const applicationManager = new ApplicationManager(
appUserDirectory,
gitManager,
recipeStatusRegistry,
this.modelsManager,
);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
// Create catalog manager, responsible for loading the catalog files and watching for changes
this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview);

// Creating StudioApiImpl
this.studioApi = new StudioApiImpl(
Expand All @@ -123,6 +123,7 @@ export class Studio {
);

await this.catalogManager.loadCatalog();
await this.modelsManager.loadLocalModels();

// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down
10 changes: 9 additions & 1 deletion packages/frontend/src/stores/local-models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { Readable } from 'svelte/store';
import { readable } from 'svelte/store';
import { studioClient } from '/@/utils/client';
import { rpcBrowser, studioClient } from '/@/utils/client';
import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages';

export const localModels: Readable<ModelInfo[]> = readable<ModelInfo[]>([], set => {
const sub = rpcBrowser.subscribe(MSG_NEW_LOCAL_MODELS_STATE, msg => {
set(msg);
});
// Initialize the store manually
studioClient.getLocalModels().then(v => {
set(v);
});
return () => {
sub.unsubscribe();
};
});
1 change: 1 addition & 0 deletions packages/shared/Messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export const MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update';
export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state';
export const MSG_NEW_CATALOG_STATE = 'new-catalog-state';
export const MSG_NEW_RECIPE_STATE = 'new-recipe-state';
export const MSG_NEW_LOCAL_MODELS_STATE = 'new-local-models-state';

0 comments on commit fb335cc

Please sign in to comment.