diff --git a/packages/backend/src/ai-user-test.json b/packages/backend/src/ai-user-test.json new file mode 100644 index 000000000..cc44b1aab --- /dev/null +++ b/packages/backend/src/ai-user-test.json @@ -0,0 +1,49 @@ +{ + "recipes": [ + { + "id": "recipe 1", + "description" : "Recipe 1", + "name" : "Recipe 1", + "repository": "https://recipe1.example.com", + "icon": "natural-language-processing", + "categories": [ + "category1" + ], + "config": "chatbot/ai-studio.yaml", + "readme": "Readme for recipe 1", + "models": [ + "model1", + "model2" + ] + } + ], + "models": [ + { + "id": "model1", + "name": "Model 1", + "description": "Readme for model 1", + "hw": "CPU", + "registry": "Hugging Face", + "popularity": 3, + "license": "?", + "url": "https://model1.example.com" + }, + { + "id": "model2", + "name": "Model 2", + "description": "Readme for model 2", + "hw": "CPU", + "registry": "Civital", + "popularity": 3, + "license": "?", + "url": "" + } + ], + "categories": [ + { + "id": "category1", + "name": "Category 1", + "description" : "Readme for category 1" + } + ] +} diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 1fc93d1ec..971bfe1fd 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -25,21 +25,21 @@ interface DownloadModelResult { } export class ApplicationManager { - readonly homeDirectory: string; // todo: make configurable + readonly appUserDirectory: string; // todo: make configurable constructor( private git: GitManager, private recipeStatusRegistry: RecipeStatusRegistry, private extensionContext: ExtensionContext, ) { - this.homeDirectory = os.homedir(); + this.appUserDirectory = path.join(os.homedir(), AI_STUDIO_FOLDER); } async pullApplication(recipe: Recipe, model: ModelInfo) { // Create a TaskUtils object to help us const taskUtil = new RecipeStatusUtils(recipe.id, this.recipeStatusRegistry); - const localFolder = path.join(this.homeDirectory, AI_STUDIO_FOLDER, recipe.id); + const localFolder = path.join(this.appUserDirectory, recipe.id); // Adding checkout task const checkoutTask: Task = { @@ -218,7 +218,7 @@ export class ApplicationManager { callback: (message: DownloadModelResult) => void, destFileName?: string, ) { - const destDir = path.join(this.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId); + const destDir = path.join(this.appUserDirectory, 'models', modelId); if (!fs.existsSync(destDir)) { fs.mkdirSync(destDir, { recursive: true }); } @@ -269,7 +269,7 @@ export class ApplicationManager { // todo: move somewhere else (dedicated to models) getLocalModels(): LocalModelInfo[] { const result: LocalModelInfo[] = []; - const modelsDir = path.join(this.homeDirectory, AI_STUDIO_FOLDER, 'models'); + const modelsDir = path.join(this.appUserDirectory, 'models'); const entries = fs.readdirSync(modelsDir, { withFileTypes: true }); const dirs = entries.filter(dir => dir.isDirectory()); for (const d of dirs) { diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 5bf8606fc..ab3423765 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -18,63 +18,104 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { expect, test, vi } from 'vitest'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; import content from './ai-test.json'; +import userContent from './ai-user-test.json'; import type { ApplicationManager } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import { StudioApiImpl } from './studio-api-impl'; import type { PlayGroundManager } from './playground'; import type { TaskRegistry } from './registries/TaskRegistry'; +import * as fs from 'node:fs'; + vi.mock('./ai.json', () => { return { default: content, }; }); -const studioApiImpl = new StudioApiImpl( - {} as unknown as ApplicationManager, - {} as unknown as RecipeStatusRegistry, - {} as unknown as TaskRegistry, - {} as unknown as PlayGroundManager, -); +let studioApiImpl: StudioApiImpl; -test('expect correct model is returned with valid id', async () => { - const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S'); - expect(model).toBeDefined(); - expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); - expect(model.registry).toEqual('Hugging Face'); - expect(model.url).toEqual( - 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', +beforeEach(async () => { + studioApiImpl = new StudioApiImpl( + { + appUserDirectory: '.', + } as unknown as ApplicationManager, + {} as unknown as RecipeStatusRegistry, + {} as unknown as TaskRegistry, + {} as unknown as PlayGroundManager, ); + vi.resetAllMocks(); + vi.mock('node:fs'); }); -test('expect error if id does not correspond to any model', async () => { - await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown'); +describe('invalid user catalog', () => { + beforeEach(async () => { + vi.spyOn(fs.promises, 'readFile').mockResolvedValue('invalid json'); + await studioApiImpl.loadCatalog(); + }); + + test('expect correct model is returned with valid id', async () => { + const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S'); + expect(model).toBeDefined(); + expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); + expect(model.registry).toEqual('Hugging Face'); + expect(model.url).toEqual( + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + ); + }); + + test('expect error if id does not correspond to any model', async () => { + await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown'); + }); + + test('expect array of models based on list of ids', async () => { + const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']); + expect(models).toBeDefined(); + expect(models.length).toBe(2); + expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF'); + expect(models[0].registry).toEqual('Hugging Face'); + expect(models[0].url).toEqual( + 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', + ); + expect(models[1].name).toEqual('AlbedoBase XL 1.3'); + expect(models[1].registry).toEqual('Civital'); + expect(models[1].url).toEqual(''); + }); + + test('expect empty array if input list is empty', async () => { + const models = await studioApiImpl.getModelsByIds([]); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); + + test('expect empty array if input list has ids that are not in the catalog', async () => { + const models = await studioApiImpl.getModelsByIds(['1', '2']); + expect(models).toBeDefined(); + expect(models.length).toBe(0); + }); }); -test('expect array of models based on list of ids', async () => { - const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']); - expect(models).toBeDefined(); - expect(models.length).toBe(2); - expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF'); - expect(models[0].registry).toEqual('Hugging Face'); - expect(models[0].url).toEqual( +test('expect correct model is returned from default catalog with valid id when no user catalog exists', async () => { + vi.spyOn(fs, 'existsSync').mockReturnValue(false); + await studioApiImpl.loadCatalog(); + const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S'); + expect(model).toBeDefined(); + expect(model.name).toEqual('Llama-2-7B-Chat-GGUF'); + expect(model.registry).toEqual('Hugging Face'); + expect(model.url).toEqual( 'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf', ); - expect(models[1].name).toEqual('AlbedoBase XL 1.3'); - expect(models[1].registry).toEqual('Civital'); - expect(models[1].url).toEqual(''); -}); - -test('expect empty array if input list is empty', async () => { - const models = await studioApiImpl.getModelsByIds([]); - expect(models).toBeDefined(); - expect(models.length).toBe(0); }); -test('expect empty array if input list has ids that are not in the catalog', async () => { - const models = await studioApiImpl.getModelsByIds(['1', '2']); - expect(models).toBeDefined(); - expect(models.length).toBe(0); +test('expect correct model is returned with valid id when the user catalog is valid', async () => { + vi.spyOn(fs, 'existsSync').mockReturnValue(true); + vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent)); + await studioApiImpl.loadCatalog(); + const model = await studioApiImpl.getModelById('model1'); + expect(model).toBeDefined(); + expect(model.name).toEqual('Model 1'); + expect(model.registry).toEqual('Hugging Face'); + expect(model.url).toEqual('https://model1.example.com'); }); diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 79e8a576a..30f0c5ab6 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -1,28 +1,61 @@ import type { StudioAPI } from '@shared/src/StudioAPI'; import type { Category } from '@shared/src/models/ICategory'; import type { Recipe } from '@shared/src/models/IRecipe'; -import content from './ai.json'; +import defaultCatalog from './ai.json'; import type { ApplicationManager } from './managers/applicationManager'; -import { AI_STUDIO_FOLDER } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import type { RecipeStatus } from '@shared/src/models/IRecipeStatus'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; import type { TaskRegistry } from './registries/TaskRegistry'; import type { Task } from '@shared/src/models/ITask'; -import * as path from 'node:path'; import type { PlayGroundManager } from './playground'; import * as podmanDesktopApi from '@podman-desktop/api'; import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; +import type { Catalog } from '@shared/src/models/ICatalog'; + +import * as path from 'node:path'; +import * as fs from 'node:fs'; export const RECENT_CATEGORY_ID = 'recent-category'; export class StudioApiImpl implements StudioAPI { + private catalog: Catalog; + constructor( private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, private taskRegistry: TaskRegistry, private playgroundManager: PlayGroundManager, - ) {} + ) { + // We start with an empty catalog, for the methods to work before the catalog is loaded + this.catalog = { + categories: [], + models: [], + recipes: [], + }; + } + + async loadCatalog() { + const catalogPath = path.resolve(this.applicationManager.appUserDirectory, 'catalog.json'); + try { + if (!fs.existsSync(catalogPath)) { + this.setCatalog(defaultCatalog); + return; + } + // TODO(feloy): watch catalog file and update catalog with new content + const data = await fs.promises.readFile(catalogPath, 'utf-8'); + const cat = JSON.parse(data) as Catalog; + this.setCatalog(cat); + } catch (err: unknown) { + console.error('unable to read catalog file, reverting to default catalog', err); + this.setCatalog(defaultCatalog); + } + } + + setCatalog(newCatalog: Catalog) { + // TODO(feloy): send message to frontend with new catalog + this.catalog = newCatalog; + } async openURL(url: string): Promise { return await podmanDesktopApi.env.openExternal(podmanDesktopApi.Uri.parse(url)); @@ -41,23 +74,23 @@ export class StudioApiImpl implements StudioAPI { } async getCategories(): Promise { - return content.categories; + return this.catalog.categories; } async getRecipesByCategory(categoryId: string): Promise { if (categoryId === RECENT_CATEGORY_ID) return this.getRecentRecipes(); - return content.recipes.filter(recipe => recipe.categories.includes(categoryId)); + return this.catalog.recipes.filter(recipe => recipe.categories.includes(categoryId)); } async getRecipeById(recipeId: string): Promise { - const recipe = (content.recipes as Recipe[]).find(recipe => recipe.id === recipeId); + const recipe = (this.catalog.recipes as Recipe[]).find(recipe => recipe.id === recipeId); if (recipe) return recipe; throw new Error('Not found'); } async getModelById(modelId: string): Promise { - const model = content.models.find(m => modelId === m.id); + const model = this.catalog.models.find(m => modelId === m.id); if (!model) { throw new Error(`No model found having id ${modelId}`); } @@ -65,7 +98,7 @@ export class StudioApiImpl implements StudioAPI { } async getModelsByIds(ids: string[]): Promise { - return content.models.filter(m => ids.includes(m.id)) ?? []; + return this.catalog.models.filter(m => ids.includes(m.id)) ?? []; } // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -91,7 +124,7 @@ export class StudioApiImpl implements StudioAPI { async getLocalModels(): Promise { const local = this.applicationManager.getLocalModels(); const localIds = local.map(l => l.id); - return content.models.filter(m => localIds.includes(m.id)); + return this.catalog.models.filter(m => localIds.includes(m.id)); } async getTasksByLabel(label: string): Promise { @@ -104,13 +137,7 @@ export class StudioApiImpl implements StudioAPI { throw new Error('model not found'); } - const modelPath = path.resolve( - this.applicationManager.homeDirectory, - AI_STUDIO_FOLDER, - 'models', - modelId, - localModelInfo[0].file, - ); + const modelPath = path.resolve(this.applicationManager.appUserDirectory, 'models', modelId, localModelInfo[0].file); await this.playgroundManager.startPlayground(modelId, modelPath); } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index a03c47dfb..7ec079d86 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -94,6 +94,7 @@ export class Studio { const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext); this.playgroundManager = new PlayGroundManager(this.#panel.webview); this.studioApi = new StudioApiImpl(applicationManager, recipeStatusRegistry, taskRegistry, this.playgroundManager); + await this.studioApi.loadCatalog(); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi); } diff --git a/packages/shared/src/models/ICatalog.ts b/packages/shared/src/models/ICatalog.ts new file mode 100644 index 000000000..5d3e64343 --- /dev/null +++ b/packages/shared/src/models/ICatalog.ts @@ -0,0 +1,9 @@ +import type { Category } from './ICategory'; +import type { ModelInfo } from './IModelInfo'; +import type { Recipe } from './IRecipe'; + +export interface Catalog { + recipes: Recipe[]; + models: ModelInfo[]; + categories: Category[]; +}