Skip to content

Commit

Permalink
fix: user catalog (#916)
Browse files Browse the repository at this point in the history
* fix: catalog user

Signed-off-by: axel7083 <[email protected]>

* fix: models manager tests

Signed-off-by: axel7083 <[email protected]>

* fix: README

Signed-off-by: axel7083 <[email protected]>

* Update README.md

Co-authored-by: Luca Stocchi <[email protected]>
Signed-off-by: axel7083 <[email protected]>

* fix: allow the user catalog to overwrite default items

Signed-off-by: axel7083 <[email protected]>

* fix: tests

Signed-off-by: axel7083 <[email protected]>

---------

Signed-off-by: axel7083 <[email protected]>
Co-authored-by: Luca Stocchi <[email protected]>
  • Loading branch information
axel7083 and lstocchi authored Apr 18, 2024
1 parent c97f233 commit cf01470
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 63 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,17 @@ For the time being, please consider the following actions:
4. Remove the containers related to AI
5. Cleanup your local clone of the recipes: `$HOME/podman-desktop/ai-lab`

### Providing a custom catalog
### 📖 Providing a custom catalog

The extension provides a default catalog, but you can build your own catalog by creating a file `$HOME/podman-desktop/ai-lab/catalog.json`.
The extension provides by default a curated list of recipes, models and categories. However, this system is extensible and you can define your own.

The catalog provides lists of categories, recipes, and models.
To enhance the existing catalog, you can create a file located in the extension storage folder `$HOME/.local/share/containers/podman-desktop/extensions-storage/redhat.ai-lab/user-catalog.json`.

Each recipe can belong to one or several categories. Each model can be used by one or several recipes.
It must follow the same format as the default catalog [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json).

The format of the catalog is not stable nor versioned yet, you can see the current catalog's format [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json).
> :information_source: The default behaviour is to append the items of the user's catalog to the default one.
> :warning: Each item (recipes, models or categories) has a unique id, when conflict between the default catalog and the user one are found, the user's items overwrite the defaults.
### Packaging sample applications

Expand Down
78 changes: 68 additions & 10 deletions packages/backend/src/managers/catalogManager.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import { promises, existsSync } from 'node:fs';
import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog';
import path from 'node:path';

vi.mock('./ai.json', () => {
vi.mock('../assets/ai.json', () => {
return {
default: content,
};
Expand Down Expand Up @@ -104,12 +104,12 @@ describe('invalid user catalog', () => {
});

test('expect correct model is returned with valid id', () => {
const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M');
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF');
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf',
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

Expand All @@ -123,12 +123,12 @@ test('expect correct model is returned from default catalog with valid id when n
catalogManager.init();
await vi.waitUntil(() => catalogManager.getRecipes().length > 0);

const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M');
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF');
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf',
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

Expand All @@ -137,7 +137,7 @@ test('expect correct model is returned with valid id when the user catalog is va
vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getRecipes().length > 0);
await vi.waitUntil(() => catalogManager.getModels().some(model => model.id === 'model1'));

const model = catalogManager.getModelById('model1');
expect(model).toBeDefined();
Expand Down Expand Up @@ -176,7 +176,7 @@ test('expect to call writeFile in addLocalModelsToCatalog with catalog updated',
});
const writeFileMock = vi.spyOn(promises, 'writeFile').mockResolvedValue();

await catalogManager.addLocalModelsToCatalog([
await catalogManager.importUserModels([
{
name: 'custom-model',
path: '/root/path/file.gguf',
Expand All @@ -199,7 +199,65 @@ test('expect to call writeFile in removeLocalModelFromCatalog with catalog updat
const updatedCatalog: ApplicationCatalog = Object.assign({}, userContent);
updatedCatalog.models = updatedCatalog.models.filter(m => m.id !== 'model1');

await catalogManager.removeLocalModelFromCatalog('model1');
await catalogManager.removeUserModel('model1');

expect(writeFileMock).toBeCalledWith('path', JSON.stringify(updatedCatalog, undefined, 2), 'utf-8');
});

test('catalog should be the combination of user catalog and default catalog', async () => {
vi.mocked(existsSync).mockReturnValue(true);
vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getModels().length > userContent.models.length);

const mtimeDate = new Date('2024-04-03T09:51:15.766Z');
vi.spyOn(promises, 'stat').mockResolvedValue({
size: 1,
mtime: mtimeDate,
} as Stats);
vi.spyOn(path, 'resolve').mockReturnValue('path');

const catalog = catalogManager.getCatalog();

expect(catalog).toEqual({
recipes: [...content.recipes, ...userContent.recipes],
models: [...content.models, ...userContent.models],
categories: [...content.categories, ...userContent.categories],
});
});

test('catalog should use user items in favour of default', async () => {
vi.mocked(existsSync).mockReturnValue(true);

const overwriteFullCatalog: ApplicationCatalog = {
recipes: content.recipes.map(recipe => ({
...recipe,
name: 'user-recipe-overwrite',
})),
models: content.models.map(model => ({
...model,
name: 'user-model-overwrite',
})),
categories: content.categories.map(category => ({
...category,
name: 'user-model-overwrite',
})),
};

vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(overwriteFullCatalog));

catalogManager.init();
await vi.waitUntil(() => catalogManager.getModels().length > 0);

const mtimeDate = new Date('2024-04-03T09:51:15.766Z');
vi.spyOn(promises, 'stat').mockResolvedValue({
size: 1,
mtime: mtimeDate,
} as Stats);
vi.spyOn(path, 'resolve').mockReturnValue('path');

const catalog = catalogManager.getCatalog();

expect(catalog).toEqual(overwriteFullCatalog);
});
192 changes: 150 additions & 42 deletions packages/backend/src/managers/catalogManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
***********************************************************************/

import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog';
import { promises } from 'node:fs';
import fs, { promises } from 'node:fs';
import path from 'node:path';
import defaultCatalog from '../assets/ai.json';
import type { Recipe } from '@shared/src/models/IRecipe';
Expand All @@ -30,6 +30,8 @@ import type { LocalModelImportInfo } from '@shared/src/models/ILocalModelInfo';

export type catalogUpdateHandle = () => void;

export const USER_CATALOG = 'user-catalog.json';

export class CatalogManager extends Publisher<ApplicationCatalog> implements Disposable {
private catalog: ApplicationCatalog;
#catalogUpdateListeners: catalogUpdateHandle[];
Expand All @@ -51,31 +53,91 @@ export class CatalogManager extends Publisher<ApplicationCatalog> implements Dis
this.#disposables = [];
}

/**
* The init method will start a watcher on the user catalog.json
*/
init(): void {
// Creating a json watcher
const jsonWatcher: JsonWatcher<ApplicationCatalog> = new JsonWatcher(
path.resolve(this.appUserDirectory, 'catalog.json'),
defaultCatalog,
);
const jsonWatcher: JsonWatcher<ApplicationCatalog> = new JsonWatcher(this.getUserCatalogPath(), {
recipes: [],
models: [],
categories: [],
});
jsonWatcher.onContentUpdated(content => this.onCatalogUpdated(content));
jsonWatcher.init();

this.#disposables.push(jsonWatcher);
}

private loadDefaultCatalog(): void {
this.catalog = defaultCatalog;
this.notify();
}

private onCatalogUpdated(content: ApplicationCatalog): void {
// when reading the content on the catalog, the creation is just a string and we need to convert it back to a Date object
content.models
.filter(m => m.file?.creation)
.forEach(m => {
if (m.file?.creation) {
m.file.creation = new Date(m.file.creation);
if (typeof content !== 'object' || !('models' in content) || typeof content.models !== 'object') {
this.loadDefaultCatalog();
return;
}

const sanitize = this.sanitize(content);
this.catalog = {
models: [...defaultCatalog.models.filter(a => !sanitize.models.some(b => a.id === b.id)), ...sanitize.models],
recipes: [...defaultCatalog.recipes.filter(a => !sanitize.recipes.some(b => a.id === b.id)), ...sanitize.recipes],
categories: [
...defaultCatalog.categories.filter(a => !sanitize.categories.some(b => a.id === b.id)),
...sanitize.categories,
],
};

this.notify();
}

private sanitize(content: unknown): ApplicationCatalog {
const output: ApplicationCatalog = {
recipes: [],
models: [],
categories: [],
};

if (!content || typeof content !== 'object') {
console.warn('malformed application catalog content');
return output;
}

// ensure user's models are properly formatted
if ('models' in content && typeof content.models === 'object' && Array.isArray(content.models)) {
output.models = content.models.map(model => {
// parse the creation date properly
if (model.file?.creation) {
return {
...model,
file: {
...model.file,
creation: new Date(model.file.creation),
},
};
}
return model;
});
this.catalog = content;
}

// ensure user's recipes are properly formatted
if ('recipes' in content && typeof content.recipes === 'object' && Array.isArray(content.recipes)) {
output.recipes = content.recipes;
}

// ensure user's categories are properly formatted
if ('categories' in content && typeof content.categories === 'object' && Array.isArray(content.categories)) {
output.categories = content.categories;
}

return output;
}

override notify() {
super.notify();
this.#catalogUpdateListeners.forEach(listener => listener());
this.notify();
}

onCatalogUpdate(listener: catalogUpdateHandle): Disposable {
Expand Down Expand Up @@ -117,39 +179,85 @@ export class CatalogManager extends Publisher<ApplicationCatalog> implements Dis
return recipe;
}

async addLocalModelsToCatalog(models: LocalModelImportInfo[]): Promise<void> {
// we copy the current catalog in another object and update it with the model
// then write it to the custom catalog path. If it exists it will be overwritten by default
const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog);

for (const model of models) {
const statFile = await promises.stat(model.path);
tmpCatalog.models.push({
id: model.path,
name: model.name,
description: `Model imported from ${model.path}`,
hw: 'CPU',
file: {
path: path.dirname(model.path),
file: path.basename(model.path),
size: statFile.size,
creation: statFile.mtime,
},
memory: statFile.size,
});
/**
* This method is used to imports user's local models.
* @param localModels the models to imports
*/
async importUserModels(localModels: LocalModelImportInfo[]): Promise<void> {
const userCatalogPath = this.getUserCatalogPath();
let content: ApplicationCatalog;

// check if we already have an existing user's catalog
if (fs.existsSync(userCatalogPath)) {
const raw = await promises.readFile(userCatalogPath, 'utf-8');
content = this.sanitize(JSON.parse(raw));
} else {
content = {
recipes: [],
models: [],
categories: [],
};
}

const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json');
return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8');
// Transform local models into ModelInfo
const models: ModelInfo[] = await Promise.all(
localModels.map(async local => {
const statFile = await promises.stat(local.path);
return {
id: local.path,
name: local.name,
description: `Model imported from ${local.path}`,
hw: 'CPU',
file: {
path: path.dirname(local.path),
file: path.basename(local.path),
size: statFile.size,
creation: statFile.mtime,
},
memory: statFile.size,
};
}),
);

// Add all our models infos to the user's models catalog
content.models.push(...models);

// overwrite the existing catalog
return promises.writeFile(userCatalogPath, JSON.stringify(content, undefined, 2), 'utf-8');
}

async removeLocalModelFromCatalog(modelId: string): Promise<void> {
// we copy the current catalog in another object and remove from it the model with modelId
// then write it to the custom catalog path.
const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog);
tmpCatalog.models = tmpCatalog.models.filter(m => m.url !== '' && m.id !== modelId);
/**
* Remove a model from the user's catalog.
* @param modelId
*/
async removeUserModel(modelId: string): Promise<void> {
const userCatalogPath = this.getUserCatalogPath();
if (!fs.existsSync(userCatalogPath)) {
throw new Error('User catalog does not exist.');
}

const raw = await promises.readFile(userCatalogPath, 'utf-8');
const content = this.sanitize(JSON.parse(raw));

return promises.writeFile(
userCatalogPath,
JSON.stringify(
{
recipes: content.recipes,
models: content.models.filter(model => model.id !== modelId),
categories: content.categories,
},
undefined,
2,
),
'utf-8',
);
}

const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json');
return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8');
/**
* Return the path to the user catalog
*/
private getUserCatalogPath(): string {
return path.resolve(this.appUserDirectory, USER_CATALOG);
}
}
Loading

0 comments on commit cf01470

Please sign in to comment.