diff --git a/README.md b/README.md index cafe0ab20..9fe2959b1 100644 --- a/README.md +++ b/README.md @@ -129,15 +129,17 @@ For the time being, please consider the following actions: 4. Remove the containers related to AI 5. Cleanup your local clone of the recipes: `$HOME/podman-desktop/ai-lab` -### Providing a custom catalog +### 📖 Providing a custom catalog -The extension provides a default catalog, but you can build your own catalog by creating a file `$HOME/podman-desktop/ai-lab/catalog.json`. +The extension provides by default a curated list of recipes, models and categories. However, this system is extensible and you can define your own. -The catalog provides lists of categories, recipes, and models. +To enhance the existing catalog, you can create a file located in the extension storage folder `$HOME/.local/share/containers/podman-desktop/extensions-storage/redhat.ai-lab/user-catalog.json`. -Each recipe can belong to one or several categories. Each model can be used by one or several recipes. +It must follow the same format as the default catalog [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json). -The format of the catalog is not stable nor versioned yet, you can see the current catalog's format [in the sources of the extension](https://github.com/containers/podman-desktop-extension-ai-lab/blob/main/packages/backend/src/assets/ai.json). +> :information_source: The default behaviour is to append the items of the user's catalog to the default one. + +> :warning: Each item (recipes, models or categories) has a unique id, when conflict between the default catalog and the user one are found, the user's items overwrite the defaults. ### Packaging sample applications diff --git a/packages/backend/src/managers/catalogManager.spec.ts b/packages/backend/src/managers/catalogManager.spec.ts index 84e1300e1..4eb5aba0f 100644 --- a/packages/backend/src/managers/catalogManager.spec.ts +++ b/packages/backend/src/managers/catalogManager.spec.ts @@ -29,7 +29,7 @@ import { promises, existsSync } from 'node:fs'; import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog'; import path from 'node:path'; -vi.mock('./ai.json', () => { +vi.mock('../assets/ai.json', () => { return { default: content, }; @@ -104,12 +104,12 @@ describe('invalid user catalog', () => { }); test('expect correct model is returned with valid id', () => { - const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M'); + const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S'); expect(model).toBeDefined(); - expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF'); + expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); expect(model.registry).toEqual('Hugging Face'); expect(model.url).toEqual( - 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', ); }); @@ -123,12 +123,12 @@ test('expect correct model is returned from default catalog with valid id when n catalogManager.init(); await vi.waitUntil(() => catalogManager.getRecipes().length > 0); - const model = catalogManager.getModelById('hf.TheBloke.mistral-7b-instruct-v0.1.Q4_K_M'); + const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S'); expect(model).toBeDefined(); - expect(model.name).toEqual('TheBloke/Mistral-7B-Instruct-v0.1-GGUF'); + expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); expect(model.registry).toEqual('Hugging Face'); expect(model.url).toEqual( - 'https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/resolve/main/mistral-7b-instruct-v0.1.Q4_K_M.gguf', + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', ); }); @@ -137,7 +137,7 @@ test('expect correct model is returned with valid id when the user catalog is va vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent)); catalogManager.init(); - await vi.waitUntil(() => catalogManager.getRecipes().length > 0); + await vi.waitUntil(() => catalogManager.getModels().some(model => model.id === 'model1')); const model = catalogManager.getModelById('model1'); expect(model).toBeDefined(); @@ -176,7 +176,7 @@ test('expect to call writeFile in addLocalModelsToCatalog with catalog updated', }); const writeFileMock = vi.spyOn(promises, 'writeFile').mockResolvedValue(); - await catalogManager.addLocalModelsToCatalog([ + await catalogManager.importUserModels([ { name: 'custom-model', path: '/root/path/file.gguf', @@ -199,7 +199,65 @@ test('expect to call writeFile in removeLocalModelFromCatalog with catalog updat const updatedCatalog: ApplicationCatalog = Object.assign({}, userContent); updatedCatalog.models = updatedCatalog.models.filter(m => m.id !== 'model1'); - await catalogManager.removeLocalModelFromCatalog('model1'); + await catalogManager.removeUserModel('model1'); expect(writeFileMock).toBeCalledWith('path', JSON.stringify(updatedCatalog, undefined, 2), 'utf-8'); }); + +test('catalog should be the combination of user catalog and default catalog', async () => { + vi.mocked(existsSync).mockReturnValue(true); + vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(userContent)); + + catalogManager.init(); + await vi.waitUntil(() => catalogManager.getModels().length > userContent.models.length); + + const mtimeDate = new Date('2024-04-03T09:51:15.766Z'); + vi.spyOn(promises, 'stat').mockResolvedValue({ + size: 1, + mtime: mtimeDate, + } as Stats); + vi.spyOn(path, 'resolve').mockReturnValue('path'); + + const catalog = catalogManager.getCatalog(); + + expect(catalog).toEqual({ + recipes: [...content.recipes, ...userContent.recipes], + models: [...content.models, ...userContent.models], + categories: [...content.categories, ...userContent.categories], + }); +}); + +test('catalog should use user items in favour of default', async () => { + vi.mocked(existsSync).mockReturnValue(true); + + const overwriteFullCatalog: ApplicationCatalog = { + recipes: content.recipes.map(recipe => ({ + ...recipe, + name: 'user-recipe-overwrite', + })), + models: content.models.map(model => ({ + ...model, + name: 'user-model-overwrite', + })), + categories: content.categories.map(category => ({ + ...category, + name: 'user-model-overwrite', + })), + }; + + vi.spyOn(promises, 'readFile').mockResolvedValue(JSON.stringify(overwriteFullCatalog)); + + catalogManager.init(); + await vi.waitUntil(() => catalogManager.getModels().length > 0); + + const mtimeDate = new Date('2024-04-03T09:51:15.766Z'); + vi.spyOn(promises, 'stat').mockResolvedValue({ + size: 1, + mtime: mtimeDate, + } as Stats); + vi.spyOn(path, 'resolve').mockReturnValue('path'); + + const catalog = catalogManager.getCatalog(); + + expect(catalog).toEqual(overwriteFullCatalog); +}); diff --git a/packages/backend/src/managers/catalogManager.ts b/packages/backend/src/managers/catalogManager.ts index 7c11af2b9..10daf9ec8 100644 --- a/packages/backend/src/managers/catalogManager.ts +++ b/packages/backend/src/managers/catalogManager.ts @@ -17,7 +17,7 @@ ***********************************************************************/ import type { ApplicationCatalog } from '@shared/src/models/IApplicationCatalog'; -import { promises } from 'node:fs'; +import fs, { promises } from 'node:fs'; import path from 'node:path'; import defaultCatalog from '../assets/ai.json'; import type { Recipe } from '@shared/src/models/IRecipe'; @@ -30,6 +30,8 @@ import type { LocalModelImportInfo } from '@shared/src/models/ILocalModelInfo'; export type catalogUpdateHandle = () => void; +export const USER_CATALOG = 'user-catalog.json'; + export class CatalogManager extends Publisher implements Disposable { private catalog: ApplicationCatalog; #catalogUpdateListeners: catalogUpdateHandle[]; @@ -51,31 +53,91 @@ export class CatalogManager extends Publisher implements Dis this.#disposables = []; } + /** + * The init method will start a watcher on the user catalog.json + */ init(): void { // Creating a json watcher - const jsonWatcher: JsonWatcher = new JsonWatcher( - path.resolve(this.appUserDirectory, 'catalog.json'), - defaultCatalog, - ); + const jsonWatcher: JsonWatcher = new JsonWatcher(this.getUserCatalogPath(), { + recipes: [], + models: [], + categories: [], + }); jsonWatcher.onContentUpdated(content => this.onCatalogUpdated(content)); jsonWatcher.init(); this.#disposables.push(jsonWatcher); } + private loadDefaultCatalog(): void { + this.catalog = defaultCatalog; + this.notify(); + } + private onCatalogUpdated(content: ApplicationCatalog): void { - // when reading the content on the catalog, the creation is just a string and we need to convert it back to a Date object - content.models - .filter(m => m.file?.creation) - .forEach(m => { - if (m.file?.creation) { - m.file.creation = new Date(m.file.creation); + if (typeof content !== 'object' || !('models' in content) || typeof content.models !== 'object') { + this.loadDefaultCatalog(); + return; + } + + const sanitize = this.sanitize(content); + this.catalog = { + models: [...defaultCatalog.models.filter(a => !sanitize.models.some(b => a.id === b.id)), ...sanitize.models], + recipes: [...defaultCatalog.recipes.filter(a => !sanitize.recipes.some(b => a.id === b.id)), ...sanitize.recipes], + categories: [ + ...defaultCatalog.categories.filter(a => !sanitize.categories.some(b => a.id === b.id)), + ...sanitize.categories, + ], + }; + + this.notify(); + } + + private sanitize(content: unknown): ApplicationCatalog { + const output: ApplicationCatalog = { + recipes: [], + models: [], + categories: [], + }; + + if (!content || typeof content !== 'object') { + console.warn('malformed application catalog content'); + return output; + } + + // ensure user's models are properly formatted + if ('models' in content && typeof content.models === 'object' && Array.isArray(content.models)) { + output.models = content.models.map(model => { + // parse the creation date properly + if (model.file?.creation) { + return { + ...model, + file: { + ...model.file, + creation: new Date(model.file.creation), + }, + }; } + return model; }); - this.catalog = content; + } + // ensure user's recipes are properly formatted + if ('recipes' in content && typeof content.recipes === 'object' && Array.isArray(content.recipes)) { + output.recipes = content.recipes; + } + + // ensure user's categories are properly formatted + if ('categories' in content && typeof content.categories === 'object' && Array.isArray(content.categories)) { + output.categories = content.categories; + } + + return output; + } + + override notify() { + super.notify(); this.#catalogUpdateListeners.forEach(listener => listener()); - this.notify(); } onCatalogUpdate(listener: catalogUpdateHandle): Disposable { @@ -117,39 +179,85 @@ export class CatalogManager extends Publisher implements Dis return recipe; } - async addLocalModelsToCatalog(models: LocalModelImportInfo[]): Promise { - // we copy the current catalog in another object and update it with the model - // then write it to the custom catalog path. If it exists it will be overwritten by default - const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog); - - for (const model of models) { - const statFile = await promises.stat(model.path); - tmpCatalog.models.push({ - id: model.path, - name: model.name, - description: `Model imported from ${model.path}`, - hw: 'CPU', - file: { - path: path.dirname(model.path), - file: path.basename(model.path), - size: statFile.size, - creation: statFile.mtime, - }, - memory: statFile.size, - }); + /** + * This method is used to imports user's local models. + * @param localModels the models to imports + */ + async importUserModels(localModels: LocalModelImportInfo[]): Promise { + const userCatalogPath = this.getUserCatalogPath(); + let content: ApplicationCatalog; + + // check if we already have an existing user's catalog + if (fs.existsSync(userCatalogPath)) { + const raw = await promises.readFile(userCatalogPath, 'utf-8'); + content = this.sanitize(JSON.parse(raw)); + } else { + content = { + recipes: [], + models: [], + categories: [], + }; } - const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json'); - return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8'); + // Transform local models into ModelInfo + const models: ModelInfo[] = await Promise.all( + localModels.map(async local => { + const statFile = await promises.stat(local.path); + return { + id: local.path, + name: local.name, + description: `Model imported from ${local.path}`, + hw: 'CPU', + file: { + path: path.dirname(local.path), + file: path.basename(local.path), + size: statFile.size, + creation: statFile.mtime, + }, + memory: statFile.size, + }; + }), + ); + + // Add all our models infos to the user's models catalog + content.models.push(...models); + + // overwrite the existing catalog + return promises.writeFile(userCatalogPath, JSON.stringify(content, undefined, 2), 'utf-8'); } - async removeLocalModelFromCatalog(modelId: string): Promise { - // we copy the current catalog in another object and remove from it the model with modelId - // then write it to the custom catalog path. - const tmpCatalog: ApplicationCatalog = Object.assign({}, this.catalog); - tmpCatalog.models = tmpCatalog.models.filter(m => m.url !== '' && m.id !== modelId); + /** + * Remove a model from the user's catalog. + * @param modelId + */ + async removeUserModel(modelId: string): Promise { + const userCatalogPath = this.getUserCatalogPath(); + if (!fs.existsSync(userCatalogPath)) { + throw new Error('User catalog does not exist.'); + } + + const raw = await promises.readFile(userCatalogPath, 'utf-8'); + const content = this.sanitize(JSON.parse(raw)); + + return promises.writeFile( + userCatalogPath, + JSON.stringify( + { + recipes: content.recipes, + models: content.models.filter(model => model.id !== modelId), + categories: content.categories, + }, + undefined, + 2, + ), + 'utf-8', + ); + } - const customCatalog = path.resolve(this.appUserDirectory, 'catalog.json'); - return promises.writeFile(customCatalog, JSON.stringify(tmpCatalog, undefined, 2), 'utf-8'); + /** + * Return the path to the user catalog + */ + private getUserCatalogPath(): string { + return path.resolve(this.appUserDirectory, USER_CATALOG); } } diff --git a/packages/backend/src/managers/modelsManager.spec.ts b/packages/backend/src/managers/modelsManager.spec.ts index a22f82853..63647ae80 100644 --- a/packages/backend/src/managers/modelsManager.spec.ts +++ b/packages/backend/src/managers/modelsManager.spec.ts @@ -503,7 +503,7 @@ describe('deleting models', () => { test('delete local model should call catalogManager', async () => { vi.mocked(env).isWindows = false; const postMessageMock = vi.fn(); - const removeLocalModelFromCatalogMock = vi.fn(); + const removeUserModelMock = vi.fn(); const manager = new ModelsManager( 'appdir', { @@ -522,7 +522,7 @@ describe('deleting models', () => { }, ] as ModelInfo[]; }, - removeLocalModelFromCatalog: removeLocalModelFromCatalogMock, + removeUserModel: removeUserModelMock, } as unknown as CatalogManager, telemetryLogger, taskRegistry, @@ -531,7 +531,7 @@ describe('deleting models', () => { await manager.loadLocalModels(); await manager.deleteModel('model-id-1'); - expect(removeLocalModelFromCatalogMock).toBeCalledWith('model-id-1'); + expect(removeUserModelMock).toBeCalledWith('model-id-1'); }); test('deleting on windows should check if models is uploaded', async () => { diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index 535ecf42c..9c23e3b7d 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -180,7 +180,7 @@ export class ModelsManager implements Disposable { if (!model.url) { modelPath = path.join(model.file.path, model.file.file); // remove it from the catalog as it cannot be downloaded anymore - await this.catalogManager.removeLocalModelFromCatalog(modelId); + await this.catalogManager.removeUserModel(modelId); } else { modelPath = this.getLocalModelFolder(modelId); } diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 0a308777b..3ea73130c 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -240,7 +240,7 @@ test('openDialog should call podmanDesktopAPi showOpenDialog', async () => { test('importModels should call catalogManager', async () => { const addLocalModelsMock = vi - .spyOn(catalogManager, 'addLocalModelsToCatalog') + .spyOn(catalogManager, 'importUserModels') .mockImplementation((_models: LocalModelImportInfo[]) => Promise.resolve()); const models: LocalModelImportInfo[] = [ { diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index b8b0e78c8..64bad96e4 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -459,7 +459,7 @@ export class StudioApiImpl implements StudioAPI { } async importModels(models: LocalModelImportInfo[]): Promise { - return this.catalogManager.addLocalModelsToCatalog(models); + return this.catalogManager.importUserModels(models); } async checkInvalidModels(models: string[]): Promise {