Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

watch local models #119

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions packages/backend/src/managers/applicationManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ describe('pullApplication', () => {

const setStatusMock = vi.fn();
const cloneRepositoryMock = vi.fn();
const getLocalModelsMock = vi.fn();
const isModelOnDiskMock = vi.fn();
let manager: ApplicationManager;
let downloadModelMainSpy: MockInstance<
[modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string],
Expand Down Expand Up @@ -93,7 +93,7 @@ describe('pullApplication', () => {
setStatus: setStatusMock,
} as unknown as RecipeStatusRegistry,
{
getLocalModels: getLocalModelsMock,
isModelOnDisk: isModelOnDiskMock,
} as unknown as ModelsManager,
);

Expand All @@ -105,7 +105,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: false,
});
getLocalModelsMock.mockReturnValue([]);
isModelOnDiskMock.mockReturnValue(false);

const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -140,7 +140,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
getLocalModelsMock.mockReturnValue([]);
isModelOnDiskMock.mockReturnValue(false);

const recipe: Recipe = {
id: 'recipe1',
Expand Down Expand Up @@ -169,13 +169,7 @@ describe('pullApplication', () => {
mockForPullApplication({
recipeFolderExists: true,
});
getLocalModelsMock.mockReturnValue([
{
id: 'model1',
file: 'model1.file',
},
]);

isModelOnDiskMock.mockReturnValue(true);
const recipe: Recipe = {
id: 'recipe1',
name: 'Recipe 1',
Expand Down
3 changes: 1 addition & 2 deletions packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ export class ApplicationManager {
container => container.arch === undefined || container.arch === arch(),
);

const localModels = this.modelsManager.getLocalModels();
if (!localModels.map(m => m.id).includes(model.id)) {
if (!this.modelsManager.isModelOnDisk(model.id)) {
// Download model
taskUtil.setTask({
id: model.id,
Expand Down
73 changes: 64 additions & 9 deletions packages/backend/src/managers/modelsManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ import os from 'os';
import fs from 'node:fs';
import path from 'node:path';
import { ModelsManager } from './modelsManager';
import type { Webview } from '@podman-desktop/api';
import type { CatalogManager } from './catalogManager';
import type { ModelInfo } from '@shared/src/models/IModelInfo';

beforeEach(() => {
vi.resetAllMocks();
});

test('getLocalModels should return models in local directory', () => {
function mockFiles(now: Date) {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
const existsSyncSpy = vi.spyOn(fs, 'existsSync');
existsSyncSpy.mockImplementation((path: string) => {
Expand All @@ -21,7 +24,6 @@ test('getLocalModels should return models in local directory', () => {
});
const statSyncSpy = vi.spyOn(fs, 'statSync');
const info = new fs.Stats();
const now = new Date();
info.size = 32000;
info.mtime = now;
statSyncSpy.mockReturnValue(info);
Expand Down Expand Up @@ -53,9 +55,14 @@ test('getLocalModels should return models in local directory', () => {
] as fs.Dirent[];
}
});
const manager = new ModelsManager('/home/user/aistudio');
const models = manager.getLocalModels();
expect(models).toEqual([
}

test('getLocalModelsFromDisk should get models in local directory', () => {
const now = new Date();
mockFiles(now);
const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager);
manager.getLocalModelsFromDisk();
expect(manager.getLocalModels()).toEqual([
{
id: 'model-id-1',
file: 'model-id-1-model',
Expand All @@ -71,16 +78,64 @@ test('getLocalModels should return models in local directory', () => {
]);
});

test('getLocalModels should return an empty array if the models folder does not exist', () => {
test('getLocalModelsFromDisk should return an empty array if the models folder does not exist', () => {
vi.spyOn(os, 'homedir').mockReturnValue('/home/user');
const existsSyncSpy = vi.spyOn(fs, 'existsSync');
existsSyncSpy.mockReturnValue(false);
const manager = new ModelsManager('/home/user/aistudio');
const models = manager.getLocalModels();
expect(models).toEqual([]);
const manager = new ModelsManager('/home/user/aistudio', {} as Webview, {} as CatalogManager);
manager.getLocalModelsFromDisk();
expect(manager.getLocalModels()).toEqual([]);
if (process.platform === 'win32') {
expect(existsSyncSpy).toHaveBeenCalledWith('\\home\\user\\aistudio\\models');
} else {
expect(existsSyncSpy).toHaveBeenCalledWith('/home/user/aistudio/models');
}
});

test('loadLocalModels should post a message with the message on disk and on catalog', async () => {
const now = new Date();
mockFiles(now);

vi.mock('@podman-desktop/api', () => {
return {
fs: {
createFileSystemWatcher: () => ({
onDidCreate: vi.fn(),
onDidDelete: vi.fn(),
onDidChange: vi.fn(),
}),
},
};
});
const postMessageMock = vi.fn();
const manager = new ModelsManager(
'/home/user/aistudio',
{
postMessage: postMessageMock,
} as unknown as Webview,
{
getModels: () => {
return [
{
id: 'model-id-1',
},
] as ModelInfo[];
},
} as CatalogManager,
);
await manager.loadLocalModels();
expect(postMessageMock).toHaveBeenNthCalledWith(1, {
id: 'new-local-models-state',
body: [
{
file: {
creation: now,
file: 'model-id-1-model',
id: 'model-id-1',
size: 32000,
},
id: 'model-id-1',
},
],
});
});
74 changes: 65 additions & 9 deletions packages/backend/src/managers/modelsManager.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,53 @@
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';
import fs from 'fs';
import * as path from 'node:path';
import { type Webview, fs as apiFs } from '@podman-desktop/api';
import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages';
import type { CatalogManager } from './catalogManager';

export class ModelsManager {
constructor(private appUserDirectory: string) {}
#modelsDir: string;
#localModels: Map<string, LocalModelInfo>;

getLocalModels(): LocalModelInfo[] {
const result: LocalModelInfo[] = [];
const modelsDir = path.join(this.appUserDirectory, 'models');
if (!fs.existsSync(modelsDir)) {
return [];
constructor(
private appUserDirectory: string,
private webview: Webview,
private catalogManager: CatalogManager,
) {
this.#modelsDir = path.join(this.appUserDirectory, 'models');
this.#localModels = new Map();
}

async loadLocalModels() {
const reloadLocalModels = async () => {
this.getLocalModelsFromDisk();
const models = this.getModelsInfo();
await this.webview.postMessage({
id: MSG_NEW_LOCAL_MODELS_STATE,
body: models,
});
};
const watcher = apiFs.createFileSystemWatcher(this.#modelsDir);
watcher.onDidCreate(reloadLocalModels);
watcher.onDidDelete(reloadLocalModels);
watcher.onDidChange(reloadLocalModels);
// Initialize the local models manually
await reloadLocalModels();
}

getModelsInfo() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do not display models that are not in the catalog ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we are displaying in the Models list only the models on the disk AND on the catalog, as the catalog is the source for many information (model name, license, etc).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather still list the models, and show some "?" or information/warning icon the missing field rather than ignoring them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slemeur what do you think?

return this.catalogManager
.getModels()
.filter(m => this.#localModels.has(m.id))
.map(m => ({ ...m, file: this.#localModels.get(m.id) }));
}

getLocalModelsFromDisk(): void {
if (!fs.existsSync(this.#modelsDir)) {
return;
}
const entries = fs.readdirSync(modelsDir, { withFileTypes: true });
const result = new Map<string, LocalModelInfo>();
const entries = fs.readdirSync(this.#modelsDir, { withFileTypes: true });
const dirs = entries.filter(dir => dir.isDirectory());
for (const d of dirs) {
const modelEntries = fs.readdirSync(path.resolve(d.path, d.name));
Expand All @@ -21,13 +57,33 @@ export class ModelsManager {
}
const modelFile = modelEntries[0];
const info = fs.statSync(path.resolve(d.path, d.name, modelFile));
result.push({
result.set(d.name, {
id: d.name,
file: modelFile,
size: info.size,
creation: info.mtime,
});
}
return result;
this.#localModels = result;
}

isModelOnDisk(modelId: string) {
return this.#localModels.has(modelId);
}

getLocalModelInfo(modelId: string): LocalModelInfo {
if (!this.isModelOnDisk(modelId)) {
throw new Error('model is not on disk');
}
return this.#localModels.get(modelId);
}

getLocalModelPath(modelId: string): string {
const info = this.getLocalModelInfo(modelId);
return path.resolve(this.#modelsDir, modelId, info.file);
}

getLocalModels(): LocalModelInfo[] {
return Array.from(this.#localModels.values());
}
}
30 changes: 4 additions & 26 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ import type { PlayGroundManager } from './managers/playground';
import * as podmanDesktopApi from '@podman-desktop/api';
import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';

import * as path from 'node:path';
import type { CatalogManager } from './managers/catalogManager';
import type { Catalog } from '@shared/src/models/ICatalog';
import type { PlaygroundState } from '@shared/src/models/IPlaygroundState';
import type { ModelsManager } from './managers/modelsManager';
import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo';

export class StudioApiImpl implements StudioAPI {
constructor(
Expand Down Expand Up @@ -81,28 +79,11 @@ export class StudioApiImpl implements StudioAPI {
}

async getLocalModels(): Promise<ModelInfo[]> {
const local = this.modelsManager.getLocalModels();
const localMap = new Map<string, LocalModelInfo>();
for (const l of local) {
localMap.set(l.id, l);
}
const localIds = local.map(l => l.id);
return this.catalogManager
.getModels()
.filter(m => localIds.includes(m.id))
.map(m => ({ ...m, file: localMap.get(m.id) }));
return this.modelsManager.getModelsInfo();
}

async startPlayground(modelId: string): Promise<void> {
// TODO: improve the following
const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}

// TODO: we need to stop doing that.
const modelPath = path.resolve(this.appUserDirectory, 'models', modelId, localModelInfo[0].file);

const modelPath = this.modelsManager.getLocalModelPath(modelId);
await this.playgroundManager.startPlayground(modelId, modelPath);
}

Expand All @@ -111,11 +92,8 @@ export class StudioApiImpl implements StudioAPI {
}

askPlayground(modelId: string, prompt: string): Promise<number> {
const localModelInfo = this.modelsManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
return this.playgroundManager.askPlayground(localModelInfo[0], prompt);
const localModelInfo = this.modelsManager.getLocalModelInfo(modelId);
return this.playgroundManager.askPlayground(localModelInfo, prompt);
}

async getPlaygroundQueriesState(): Promise<QueryState[]> {
Expand Down
2 changes: 2 additions & 0 deletions packages/backend/src/studio.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import type { ExtensionContext } from '@podman-desktop/api';

import * as fs from 'node:fs';

vi.mock('./managers/modelsManager');

const mockedExtensionContext = {
subscriptions: [],
} as unknown as ExtensionContext;
Expand Down
9 changes: 5 additions & 4 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ export class Studio {
const gitManager = new GitManager();
const taskRegistry = new TaskRegistry();
const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry, this.#panel.webview);
this.modelsManager = new ModelsManager(appUserDirectory);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
// Create catalog manager, responsible for loading the catalog files and watching for changes
this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview);
this.modelsManager = new ModelsManager(appUserDirectory, this.#panel.webview, this.catalogManager);
const applicationManager = new ApplicationManager(
appUserDirectory,
gitManager,
recipeStatusRegistry,
this.modelsManager,
);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
// Create catalog manager, responsible for loading the catalog files and watching for changes
this.catalogManager = new CatalogManager(appUserDirectory, this.#panel.webview);

// Creating StudioApiImpl
this.studioApi = new StudioApiImpl(
Expand All @@ -123,6 +123,7 @@ export class Studio {
);

await this.catalogManager.loadCatalog();
await this.modelsManager.loadLocalModels();

// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down
10 changes: 9 additions & 1 deletion packages/frontend/src/stores/local-models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { Readable } from 'svelte/store';
import { readable } from 'svelte/store';
import { studioClient } from '/@/utils/client';
import { rpcBrowser, studioClient } from '/@/utils/client';
import { MSG_NEW_LOCAL_MODELS_STATE } from '@shared/Messages';

export const localModels: Readable<ModelInfo[]> = readable<ModelInfo[]>([], set => {
const sub = rpcBrowser.subscribe(MSG_NEW_LOCAL_MODELS_STATE, msg => {
set(msg);
});
// Initialize the store manually
studioClient.getLocalModels().then(v => {
set(v);
});
return () => {
sub.unsubscribe();
};
});
1 change: 1 addition & 0 deletions packages/shared/Messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ export const MSG_PLAYGROUNDS_STATE_UPDATE = 'playgrounds-state-update';
export const MSG_NEW_PLAYGROUND_QUERIES_STATE = 'new-playground-queries-state';
export const MSG_NEW_CATALOG_STATE = 'new-catalog-state';
export const MSG_NEW_RECIPE_STATE = 'new-recipe-state';
export const MSG_NEW_LOCAL_MODELS_STATE = 'new-local-models-state';
axel7083 marked this conversation as resolved.
Show resolved Hide resolved
Loading