Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: user catalog #916

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,17 @@ For the time being, please consider the following actions:
4. Remove the containers related to AI
5. Cleanup your local clone of the recipes: `$HOME/podman-desktop/ai-lab`

### Providing a custom catalog
### 📖 Providing a custom catalog

The extension provides a default catalog, but you can build your own catalog by creating a file `$HOME/podman-desktop/ai-lab/catalog.json`.
The extension provides by default a curated list of recipes, models and categories. However, this system is extensible and you can define your own.

The catalog provides lists of categories, recipes, and models.
To enhance the existing catalog, you can create a file located in the extension storage folder `$HOME/.local/share/containers/podman-desktop/extensions-storage/redhat.ai-lab/user-catalog.json`.

Each recipe can belong to one or several categories. Each model can be used by one or several recipes.
It must follow the same format as the default catalog [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json).

The format of the catalog is not stable nor versioned yet, you can see the current catalog's format [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json).
> :information_source: The default behaviour is to append the items of the user's catalog to the default one.

> :warning: Each item (recipes, models or categories) has a unique id, when conflict between the default catalog and the user one are found, the user's items overwrite the defaults.

### Packaging sample applications

Expand Down
78 changes: 68 additions & 10 deletions packages/backend/src/managers/catalogManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { promises, existsSync } from 'node:fs';
import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog';
import path from 'node:path';

vi.mock('./ai.json', () => {
vi.mock('../assets/ai.json', () => {
return {
default: content,
};
Expand Down Expand Up @@ -104,12 +104,12 @@ describe('invalid user catalog', () => {
});

test('expect correct model is returned with valid id', () => {
const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M');
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF');
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf',
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

Expand All @@ -123,12 +123,12 @@ test('expect correct model is returned from default catalog with valid id when n
catalogManager.init();
await vi.waitUntil(() => catalogManager.getRecipes().length > 0);

const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M');
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF');
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf',
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

Expand All @@ -137,7 +137,7 @@ test('expect correct model is returned with valid id when the user catalog is va
vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getRecipes().length > 0);
await vi.waitUntil(() => catalogManager.getModels().some(model => model.id === 'model1'));

const model = catalogManager.getModelById('model1');
expect(model).toBeDefined();
Expand Down Expand Up @@ -176,7 +176,7 @@ test('expect to call writeFile in addLocalModelsToCatalog with catalog updated',
});
const writeFileMock = vi.spyOn(promises, 'writeFile').mockResolvedValue();

await catalogManager.addLocalModelsToCatalog([
await catalogManager.importUserModels([
{
name: 'custom-model',
path: '/root/path/file.gguf',
Expand All @@ -199,7 +199,65 @@ test('expect to call writeFile in removeLocalModelFromCatalog with catalog updat
const updatedCatalog: ApplicationCatalog = Object.assign({}, userContent);
updatedCatalog.models = updatedCatalog.models.filter(m => m.id !== 'model1');

await catalogManager.removeLocalModelFromCatalog('model1');
await catalogManager.removeUserModel('model1');

expect(writeFileMock).toBeCalledWith('path', JSON.stringify(updatedCatalog, undefined, 2), 'utf-8');
});

test('catalog should be the combination of user catalog and default catalog', async () => {
vi.mocked(existsSync).mockReturnValue(true);
vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getModels().length > userContent.models.length);

const mtimeDate = new Date('2024-04-03T09:51:15.766Z');
vi.spyOn(promises, 'stat').mockResolvedValue({
size: 1,
mtime: mtimeDate,
} as Stats);
vi.spyOn(path, 'resolve').mockReturnValue('path');

const catalog = catalogManager.getCatalog();

expect(catalog).toEqual({
recipes: [...content.recipes, ...userContent.recipes],
models: [...content.models, ...userContent.models],
categories: [...content.categories, ...userContent.categories],
});
});

test('catalog should use user items in favour of default', async () => {
vi.mocked(existsSync).mockReturnValue(true);

const overwriteFullCatalog: ApplicationCatalog = {
recipes: content.recipes.map(recipe => ({
...recipe,
name: 'user-recipe-overwrite',
})),
models: content.models.map(model => ({
...model,
name: 'user-model-overwrite',
})),
categories: content.categories.map(category => ({
...category,
name: 'user-model-overwrite',
})),
};

vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(overwriteFullCatalog));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getModels().length > userContent.models.length);

const mtimeDate = new Date('2024-04-03T09:51:15.766Z');
vi.spyOn(promises, 'stat').mockResolvedValue({
size: 1,
mtime: mtimeDate,
} as Stats);
vi.spyOn(path, 'resolve').mockReturnValue('path');

const catalog = catalogManager.getCatalog();

expect(catalog).toEqual(overwriteFullCatalog);
});
192 changes: 150 additions & 42 deletions packages/backend/src/managers/catalogManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
***********************************************************************/

import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog';
import { promises } from 'node:fs';
import fs, { promises } from 'node:fs';
import path from 'node:path';
import defaultCatalog from '../assets/ai.json';
import type { Recipe } from '@shared/src/models/IRecipe';
Expand All @@ -30,6 +30,8 @@ import type { LocalModelImportInfo } from '@shared/src/models/ILocalModelInfo';

export type catalogUpdateHandle = () => void;

export const USER_CATALOG = 'user-catalog.json';

export class CatalogManager extends Publisher<ApplicationCatalog> implements Disposable {
private catalog: ApplicationCatalog;
#catalogUpdateListeners: catalogUpdateHandle[];
Expand All @@ -51,31 +53,91 @@ export class CatalogManager extends Publisher<ApplicationCatalog> implements Dis
this.#disposables = [];
}

/**
* The init method will start a watcher on the user catalog.json
*/
init(): void {
// Creating a json watcher
const jsonWatcher: JsonWatcher<ApplicationCatalog> = new JsonWatcher(
path.resolve(this.appUserDirectory, 'catalog.json'),
defaultCatalog,
);
const jsonWatcher: JsonWatcher<ApplicationCatalog> = new JsonWatcher(this.getUserCatalogPath(), {
recipes: [],
models: [],
categories: [],
});
jsonWatcher.onContentUpdated(content => this.onCatalogUpdated(content));
jsonWatcher.init();

this.#disposables.push(jsonWatcher);
}

private loadDefaultCatalog(): void {
this.catalog = defaultCatalog;
this.notify();
}

private onCatalogUpdated(content: ApplicationCatalog): void {
// when reading the content on the catalog, the creation is just a string and we need to convert it back to a Date object
content.models
.filter(m => m.file?.creation)
.forEach(m => {
if (m.file?.creation) {
m.file.creation = new Date(m.file.creation);
if (typeof content !== 'object' || !('models' in content) || typeof content.models !== 'object') {
this.loadDefaultCatalog();
return;
}

const sanitize = this.sanitize(content);
this.catalog = {
models: [...defaultCatalog.models.filter(a => !sanitize.models.some(b => a.id === b.id)), ...sanitize.models],
recipes: [...defaultCatalog.recipes.filter(a => !sanitize.recipes.some(b => a.id === b.id)), ...sanitize.recipes],
categories: [
...defaultCatalog.categories.filter(a => !sanitize.categories.some(b => a.id === b.id)),
...sanitize.categories,
],
};

this.notify();
}

private sanitize(content: unknown): ApplicationCatalog {
const output: ApplicationCatalog = {
recipes: [],
models: [],
categories: [],
};

if (!content || typeof content !== 'object') {
console.warn('malformed application catalog content');
return output;
}

// ensure user's models are properly formatted
if ('models' in content && typeof content.models === 'object' && Array.isArray(content.models)) {
output.models = content.models.map(model => {
// parse the creation date properly
if (model.file?.creation) {
return {
...model,
file: {
...model.file,
creation: new Date(model.file.creation),
},
};
}
return model;
});
this.catalog = content;
}

// ensure user's recipes are properly formatted
if ('recipes' in content && typeof content.recipes === 'object' && Array.isArray(content.recipes)) {
output.recipes = content.recipes;
}

// ensure user's categories are properly formatted
if ('categories' in content && typeof content.categories === 'object' && Array.isArray(content.categories)) {
output.categories = content.categories;
}

return output;
}

override notify() {
super.notify();
this.#catalogUpdateListeners.forEach(listener => listener());
this.notify();
}

onCatalogUpdate(listener: catalogUpdateHandle): Disposable {
Expand Down Expand Up @@ -117,39 +179,85 @@ export class CatalogManager extends Publisher<ApplicationCatalog> implements Dis
return recipe;
}

async addLocalModelsToCatalog(models: LocalModelImportInfo[]): Promise<void> {
// we copy the current catalog in another object and update it with the model
// then write it to the custom catalog path. If it exists it will be overwritten by default
const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog);

for (const model of models) {
const statFile = await promises.stat(model.path);
tmpCatalog.models.push({
id: model.path,
name: model.name,
description: `Model imported from ${model.path}`,
hw: 'CPU',
file: {
path: path.dirname(model.path),
file: path.basename(model.path),
size: statFile.size,
creation: statFile.mtime,
},
memory: statFile.size,
});
/**
* This method is used to imports user's local models.
* @param localModels the models to imports
*/
async importUserModels(localModels: LocalModelImportInfo[]): Promise<void> {
const userCatalogPath = this.getUserCatalogPath();
let content: ApplicationCatalog;

// check if we already have an existing user's catalog
if (fs.existsSync(userCatalogPath)) {
const raw = await promises.readFile(userCatalogPath, 'utf-8');
content = this.sanitize(JSON.parse(raw));
} else {
content = {
recipes: [],
models: [],
categories: [],
};
}

const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json');
return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8');
// Transform local models into ModelInfo
const models: ModelInfo[] = await Promise.all(
localModels.map(async local => {
const statFile = await promises.stat(local.path);
return {
id: local.path,
name: local.name,
description: `Model imported from ${local.path}`,
hw: 'CPU',
file: {
path: path.dirname(local.path),
file: path.basename(local.path),
size: statFile.size,
creation: statFile.mtime,
},
memory: statFile.size,
};
}),
);

// Add all our models infos to the user's models catalog
content.models.push(...models);

// overwrite the existing catalog
return promises.writeFile(userCatalogPath, JSON.stringify(content, undefined, 2), 'utf-8');
}

async removeLocalModelFromCatalog(modelId: string): Promise<void> {
// we copy the current catalog in another object and remove from it the model with modelId
// then write it to the custom catalog path.
const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog);
tmpCatalog.models = tmpCatalog.models.filter(m => m.url !== '' && m.id !== modelId);
/**
* Remove a model from the user's catalog.
* @param modelId
*/
async removeUserModel(modelId: string): Promise<void> {
const userCatalogPath = this.getUserCatalogPath();
if (!fs.existsSync(userCatalogPath)) {
throw new Error('User catalog does not exist.');
}

const raw = await promises.readFile(userCatalogPath, 'utf-8');
const content = this.sanitize(JSON.parse(raw));

return promises.writeFile(
userCatalogPath,
JSON.stringify(
{
recipes: content.recipes,
models: content.models.filter(model => model.id !== modelId),
categories: content.categories,
},
undefined,
2,
),
'utf-8',
);
}

const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json');
return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8');
/**
* Return the path to the user catalog
*/
private getUserCatalogPath(): string {
return path.resolve(this.appUserDirectory, USER_CATALOG);
}
}
Loading
Loading