From 941123f25c646fd601e8cbae183a0dd20c0d6294 Mon Sep 17 00:00:00 2001 From: Philippe Martin Date: Thu, 18 Jan 2024 11:44:33 +0100 Subject: [PATCH] load catalog from user's directory --- packages/backend/src/ai-user-test.json | 49 +++++++ packages/backend/src/studio-api-impl.spec.ts | 137 +++++++++++++------ packages/backend/src/studio-api-impl.ts | 64 +++++++-- packages/backend/src/studio.ts | 1 + 4 files changed, 203 insertions(+), 48 deletions(-) create mode 100644 packages/backend/src/ai-user-test.json diff --git a/packages/backend/src/ai-user-test.json b/packages/backend/src/ai-user-test.json new file mode 100644 index 000000000..cc44b1aab --- /dev/null +++ b/packages/backend/src/ai-user-test.json @@ -0,0 +1,49 @@ +{ + "recipes": [ + { + "id": "recipe 1", + "description" : "Recipe 1", + "name" : "Recipe 1", + "repository": "https://recipe1.example.com", + "icon": "natural-language-processing", + "categories": [ + "category1" + ], + "config": "chatbot/ai-studio.yaml", + "readme": "Readme for recipe 1", + "models": [ + "model1", + "model2" + ] + } + ], + "models": [ + { + "id": "model1", + "name": "Model 1", + "description": "Readme for model 1", + "hw": "CPU", + "registry": "Hugging Face", + "popularity": 3, + "license": "?", + "url": "https://model1.example.com" + }, + { + "id": "model2", + "name": "Model 2", + "description": "Readme for model 2", + "hw": "CPU", + "registry": "Civital", + "popularity": 3, + "license": "?", + "url": "" + } + ], + "categories": [ + { + "id": "category1", + "name": "Category 1", + "description" : "Readme for category 1" + } + ] +} diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 5bf8606fc..e9d8f79cf 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -18,63 +18,122 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { expect, test, vi } from 'vitest'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; import content from './ai-test.json'; +import userContent from './ai-user-test.json'; import type { ApplicationManager } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import { StudioApiImpl } from './studio-api-impl'; import type { PlayGroundManager } from './playground'; import type { TaskRegistry } from './registries/TaskRegistry'; +import * as fs from 'node:fs'; + vi.mock('./ai.json', () => { return { default: content, }; }); -const studioApiImpl = new StudioApiImpl( - {} as unknown as ApplicationManager, - {} as unknown as RecipeStatusRegistry, - {} as unknown as TaskRegistry, - {} as unknown as PlayGroundManager, -); - -test('expect correct model is returned with valid id', async () => { - const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S'); - expect(model).toBeDefined(); - expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); - expect(model.registry).toEqual('Hugging Face'); - expect(model.url).toEqual( - 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', +let studioApiImpl: StudioApiImpl; + +beforeEach(async () => { + studioApiImpl = new StudioApiImpl( + { + appUserDirectory: '.', + } as unknown as ApplicationManager, + {} as unknown as RecipeStatusRegistry, + {} as unknown as TaskRegistry, + {} as unknown as PlayGroundManager, ); }); -test('expect error if id does not correspond to any model', async () => { - await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown'); -}); +describe('no valid user catalog', () => { + beforeEach(async () => { + vi.spyOn(fs.promises, 'readFile').mockResolvedValue('invalid json'); + await studioApiImpl.loadCatalog(); + }); -test('expect array of models based on list of ids', async () => { - const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']); - expect(models).toBeDefined(); - expect(models.length).toBe(2); - expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF'); - expect(models[0].registry).toEqual('Hugging Face'); - expect(models[0].url).toEqual( - 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', - ); - expect(models[1].name).toEqual('AlbedoBase XL 1.3'); - expect(models[1].registry).toEqual('Civital'); - expect(models[1].url).toEqual(''); -}); + test('expect correct model is returned with valid id', async () => { + const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S'); + expect(model).toBeDefined(); + expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); + expect(model.registry).toEqual('Hugging Face'); + expect(model.url).toEqual( + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + ); + }); + + test('expect error if id does not correspond to any model', async () => { + await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown'); + }); + + test('expect array of models based on list of ids', async () => { + const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']); + expect(models).toBeDefined(); + expect(models.length).toBe(2); + expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF'); + expect(models[0].registry).toEqual('Hugging Face'); + expect(models[0].url).toEqual( + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + ); + expect(models[1].name).toEqual('AlbedoBase XL 1.3'); + expect(models[1].registry).toEqual('Civital'); + expect(models[1].url).toEqual(''); + }); -test('expect empty array if input list is empty', async () => { - const models = await studioApiImpl.getModelsByIds([]); - expect(models).toBeDefined(); - expect(models.length).toBe(0); + test('expect empty array if input list is empty', async () => { + const models = await studioApiImpl.getModelsByIds([]); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); + + test('expect empty array if input list has ids that are not in the catalog', async () => { + const models = await studioApiImpl.getModelsByIds(['1', '2']); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); }); -test('expect empty array if input list has ids that are not in the catalog', async () => { - const models = await studioApiImpl.getModelsByIds(['1', '2']); - expect(models).toBeDefined(); - expect(models.length).toBe(0); +describe('valid user catalog', () => { + beforeEach(async () => { + vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent)); + await studioApiImpl.loadCatalog(); + }); + + test('expect correct model is returned with valid id', async () => { + const model = await studioApiImpl.getModelById('model1'); + expect(model).toBeDefined(); + expect(model.name).toEqual('Model 1'); + expect(model.registry).toEqual('Hugging Face'); + expect(model.url).toEqual('https://model1.example.com'); + }); + + test('expect error if id does not correspond to any model', async () => { + await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown'); + }); + + test('expect array of models based on list of ids', async () => { + const models = await studioApiImpl.getModelsByIds(['model1', 'model2']); + expect(models).toBeDefined(); + expect(models.length).toBe(2); + expect(models[0].name).toEqual('Model 1'); + expect(models[0].registry).toEqual('Hugging Face'); + expect(models[0].url).toEqual('https://model1.example.com'); + expect(models[1].name).toEqual('Model 2'); + expect(models[1].registry).toEqual('Civital'); + expect(models[1].url).toEqual(''); + }); + + test('expect empty array if input list is empty', async () => { + const models = await studioApiImpl.getModelsByIds([]); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); + + test('expect empty array if input list has ids that are not in the catalog', async () => { + const models = await studioApiImpl.getModelsByIds(['1', '2']); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); }); diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 73bb6a389..55fd461db 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -1,27 +1,73 @@ import type { StudioAPI } from '@shared/src/StudioAPI'; import type { Category } from '@shared/src/models/ICategory'; import type { Recipe } from '@shared/src/models/IRecipe'; -import content from './ai.json'; +import defaultCatalog from './ai.json'; import type { ApplicationManager } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import type { RecipeStatus } from '@shared/src/models/IRecipeStatus'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { TaskRegistry } from './registries/TaskRegistry'; import type { Task } from '@shared/src/models/ITask'; -import * as path from 'node:path'; import type { PlayGroundManager } from './playground'; import * as podmanDesktopApi from '@podman-desktop/api'; import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; +import type { Catalog } from '@shared/src/models/ICatalog'; + +import * as path from 'node:path'; +import * as fs from 'node:fs'; export const RECENT_CATEGORY_ID = 'recent-category'; export class StudioApiImpl implements StudioAPI { + private catalog: Catalog; + constructor( private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, private taskRegistry: TaskRegistry, private playgroundManager: PlayGroundManager, - ) {} + ) { + // We start with an empty catalog, for the methods to work before the catalog is loaded + this.catalog = { + categories: [], + models: [], + recipes: [], + }; + } + + async loadCatalog() { + const catalogPath = path.resolve(this.applicationManager.appUserDirectory, 'catalog.json'); + try { + // TODO(feloy): watch catalog file and update catalog with new content + await fs.promises + .readFile(catalogPath, 'utf-8') + .then((data: string) => { + try { + const cat = JSON.parse(data) as Catalog; + // TODO(feloy): check version and format + console.log('using user catalog'); + this.setNewCatalog(cat); + } catch (err: unknown) { + console.error('unable to parse catalog file, reverting to default catalog', err); + this.setNewCatalog(defaultCatalog); + } + }) + .catch((err: unknown) => { + console.error('got err', err); + console.error('unable to read catalog file, reverting to default catalog', err); + this.setNewCatalog(defaultCatalog); + return; + }); + } catch (err: unknown) { + console.error('unable to read catalog file, reverting to default catalog', err); + this.setNewCatalog(defaultCatalog); + } + } + + setNewCatalog(newCatalog: Catalog) { + // TODO(feloy): send message to frontend with new catalog + this.catalog = newCatalog; + } async openURL(url: string): Promise { return await podmanDesktopApi.env.openExternal(podmanDesktopApi.Uri.parse(url)); @@ -40,23 +86,23 @@ export class StudioApiImpl implements StudioAPI { } async getCategories(): Promise { - return content.categories; + return this.catalog.categories; } async getRecipesByCategory(categoryId: string): Promise { if (categoryId === RECENT_CATEGORY_ID) return this.getRecentRecipes(); - return content.recipes.filter(recipe => recipe.categories.includes(categoryId)); + return this.catalog.recipes.filter(recipe => recipe.categories.includes(categoryId)); } async getRecipeById(recipeId: string): Promise { - const recipe = (content.recipes as Recipe[]).find(recipe => recipe.id === recipeId); + const recipe = (this.catalog.recipes as Recipe[]).find(recipe => recipe.id === recipeId); if (recipe) return recipe; throw new Error('Not found'); } async getModelById(modelId: string): Promise { - const model = content.models.find(m => modelId === m.id); + const model = this.catalog.models.find(m => modelId === m.id); if (!model) { throw new Error(`No model found having id ${modelId}`); } @@ -64,7 +110,7 @@ export class StudioApiImpl implements StudioAPI { } async getModelsByIds(ids: string[]): Promise { - return content.models.filter(m => ids.includes(m.id)) ?? []; + return this.catalog.models.filter(m => ids.includes(m.id)) ?? []; } // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -90,7 +136,7 @@ export class StudioApiImpl implements StudioAPI { async getLocalModels(): Promise { const local = this.applicationManager.getLocalModels(); const localIds = local.map(l => l.id); - return content.models.filter(m => localIds.includes(m.id)); + return this.catalog.models.filter(m => localIds.includes(m.id)); } async getTasksByLabel(label: string): Promise { diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index a03c47dfb..7ec079d86 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -94,6 +94,7 @@ export class Studio { const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext); this.playgroundManager = new PlayGroundManager(this.#panel.webview); this.studioApi = new StudioApiImpl(applicationManager, recipeStatusRegistry, taskRegistry, this.playgroundManager); + await this.studioApi.loadCatalog(); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi); }