Skip to content

Commit

Permalink
load catalog from user's directory
Browse files Browse the repository at this point in the history
  • Loading branch information
feloy committed Jan 18, 2024
1 parent 6d66dcb commit 941123f
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 48 deletions.
49 changes: 49 additions & 0 deletions packages/backend/src/ai-user-test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"recipes": [
{
"id": "recipe 1",
"description" : "Recipe 1",
"name" : "Recipe 1",
"repository": "https://recipe1.example.com",
"icon": "natural-language-processing",
"categories": [
"category1"
],
"config": "chatbot/ai-studio.yaml",
"readme": "Readme for recipe 1",
"models": [
"model1",
"model2"
]
}
],
"models": [
{
"id": "model1",
"name": "Model 1",
"description": "Readme for model 1",
"hw": "CPU",
"registry": "Hugging Face",
"popularity": 3,
"license": "?",
"url": "https://model1.example.com"
},
{
"id": "model2",
"name": "Model 2",
"description": "Readme for model 2",
"hw": "CPU",
"registry": "Civital",
"popularity": 3,
"license": "?",
"url": ""
}
],
"categories": [
{
"id": "category1",
"name": "Category 1",
"description" : "Readme for category 1"
}
]
}
137 changes: 98 additions & 39 deletions packages/backend/src/studio-api-impl.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,122 @@

/* eslint-disable @typescript-eslint/no-explicit-any */

import { expect, test, vi } from 'vitest';
import { beforeEach, describe, expect, test, vi } from 'vitest';
import content from './ai-test.json';
import userContent from './ai-user-test.json';
import type { ApplicationManager } from './managers/applicationManager';
import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import { StudioApiImpl } from './studio-api-impl';
import type { PlayGroundManager } from './playground';
import type { TaskRegistry } from './registries/TaskRegistry';

import * as fs from 'node:fs';

vi.mock('./ai.json', () => {
return {
default: content,
};
});

const studioApiImpl = new StudioApiImpl(
{} as unknown as ApplicationManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as TaskRegistry,
{} as unknown as PlayGroundManager,
);

test('expect correct model is returned with valid id', async () => {
const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
let studioApiImpl: StudioApiImpl;

beforeEach(async () => {
studioApiImpl = new StudioApiImpl(
{
appUserDirectory: '.',
} as unknown as ApplicationManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as TaskRegistry,
{} as unknown as PlayGroundManager,
);
});

test('expect error if id does not correspond to any model', async () => {
await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown');
});
describe('no valid user catalog', () => {
beforeEach(async () => {
vi.spyOn(fs.promises, 'readFile').mockResolvedValue('invalid json');
await studioApiImpl.loadCatalog();
});

test('expect array of models based on list of ids', async () => {
const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']);
expect(models).toBeDefined();
expect(models.length).toBe(2);
expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF');
expect(models[0].registry).toEqual('Hugging Face');
expect(models[0].url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
expect(models[1].name).toEqual('AlbedoBase XL 1.3');
expect(models[1].registry).toEqual('Civital');
expect(models[1].url).toEqual('');
});
test('expect correct model is returned with valid id', async () => {
const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

test('expect error if id does not correspond to any model', async () => {
await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown');
});

test('expect array of models based on list of ids', async () => {
const models = await studioApiImpl.getModelsByIds(['llama-2-7b-chat.Q5_K_S', 'albedobase-xl-1.3']);
expect(models).toBeDefined();
expect(models.length).toBe(2);
expect(models[0].name).toEqual('Llama-2-7B-Chat-GGUF');
expect(models[0].registry).toEqual('Hugging Face');
expect(models[0].url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
expect(models[1].name).toEqual('AlbedoBase XL 1.3');
expect(models[1].registry).toEqual('Civital');
expect(models[1].url).toEqual('');
});

test('expect empty array if input list is empty', async () => {
const models = await studioApiImpl.getModelsByIds([]);
expect(models).toBeDefined();
expect(models.length).toBe(0);
test('expect empty array if input list is empty', async () => {
const models = await studioApiImpl.getModelsByIds([]);
expect(models).toBeDefined();
expect(models.length).toBe(0);
});

test('expect empty array if input list has ids that are not in the catalog', async () => {
const models = await studioApiImpl.getModelsByIds(['1', '2']);
expect(models).toBeDefined();
expect(models.length).toBe(0);
});
});

test('expect empty array if input list has ids that are not in the catalog', async () => {
const models = await studioApiImpl.getModelsByIds(['1', '2']);
expect(models).toBeDefined();
expect(models.length).toBe(0);
describe('valid user catalog', () => {
beforeEach(async () => {
vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));
await studioApiImpl.loadCatalog();
});

test('expect correct model is returned with valid id', async () => {
const model = await studioApiImpl.getModelById('model1');
expect(model).toBeDefined();
expect(model.name).toEqual('Model 1');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual('https://model1.example.com');
});

test('expect error if id does not correspond to any model', async () => {
await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown');
});

test('expect array of models based on list of ids', async () => {
const models = await studioApiImpl.getModelsByIds(['model1', 'model2']);
expect(models).toBeDefined();
expect(models.length).toBe(2);
expect(models[0].name).toEqual('Model 1');
expect(models[0].registry).toEqual('Hugging Face');
expect(models[0].url).toEqual('https://model1.example.com');
expect(models[1].name).toEqual('Model 2');
expect(models[1].registry).toEqual('Civital');
expect(models[1].url).toEqual('');
});

test('expect empty array if input list is empty', async () => {
const models = await studioApiImpl.getModelsByIds([]);
expect(models).toBeDefined();
expect(models.length).toBe(0);
});

test('expect empty array if input list has ids that are not in the catalog', async () => {
const models = await studioApiImpl.getModelsByIds(['1', '2']);
expect(models).toBeDefined();
expect(models.length).toBe(0);
});
});
64 changes: 55 additions & 9 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,73 @@
import type { StudioAPI } from '@shared/src/StudioAPI';
import type { Category } from '@shared/src/models/ICategory';
import type { Recipe } from '@shared/src/models/IRecipe';
import content from './ai.json';
import defaultCatalog from './ai.json';
import type { ApplicationManager } from './managers/applicationManager';
import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import type { RecipeStatus } from '@shared/src/models/IRecipeStatus';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { TaskRegistry } from './registries/TaskRegistry';
import type { Task } from '@shared/src/models/ITask';
import * as path from 'node:path';
import type { PlayGroundManager } from './playground';
import * as podmanDesktopApi from '@podman-desktop/api';
import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';
import type { Catalog } from '@shared/src/models/ICatalog';

import * as path from 'node:path';
import * as fs from 'node:fs';

export const RECENT_CATEGORY_ID = 'recent-category';

export class StudioApiImpl implements StudioAPI {
private catalog: Catalog;

constructor(
private applicationManager: ApplicationManager,
private recipeStatusRegistry: RecipeStatusRegistry,
private taskRegistry: TaskRegistry,
private playgroundManager: PlayGroundManager,
) {}
) {
// We start with an empty catalog, for the methods to work before the catalog is loaded
this.catalog = {
categories: [],
models: [],
recipes: [],
};
}

async loadCatalog() {
const catalogPath = path.resolve(this.applicationManager.appUserDirectory, 'catalog.json');
try {
// TODO(feloy): watch catalog file and update catalog with new content
await fs.promises
.readFile(catalogPath, 'utf-8')
.then((data: string) => {
try {
const cat = JSON.parse(data) as Catalog;
// TODO(feloy): check version and format
console.log('using user catalog');
this.setNewCatalog(cat);
} catch (err: unknown) {
console.error('unable to parse catalog file, reverting to default catalog', err);
this.setNewCatalog(defaultCatalog);
}
})
.catch((err: unknown) => {
console.error('got err', err);
console.error('unable to read catalog file, reverting to default catalog', err);
this.setNewCatalog(defaultCatalog);
return;

Check failure on line 59 in packages/backend/src/studio-api-impl.ts

View workflow job for this annotation

GitHub Actions / linter, formatters and unit tests / windows-2022

Remove this redundant jump

Check failure on line 59 in packages/backend/src/studio-api-impl.ts

View workflow job for this annotation

GitHub Actions / linter, formatters and unit tests / ubuntu-22.04

Remove this redundant jump

Check failure on line 59 in packages/backend/src/studio-api-impl.ts

View workflow job for this annotation

GitHub Actions / linter, formatters and unit tests / macos-12

Remove this redundant jump
});
} catch (err: unknown) {
console.error('unable to read catalog file, reverting to default catalog', err);
this.setNewCatalog(defaultCatalog);
}
}

setNewCatalog(newCatalog: Catalog) {
// TODO(feloy): send message to frontend with new catalog
this.catalog = newCatalog;
}

async openURL(url: string): Promise<boolean> {
return await podmanDesktopApi.env.openExternal(podmanDesktopApi.Uri.parse(url));
Expand All @@ -40,31 +86,31 @@ export class StudioApiImpl implements StudioAPI {
}

async getCategories(): Promise<Category[]> {
return content.categories;
return this.catalog.categories;
}

async getRecipesByCategory(categoryId: string): Promise<Recipe[]> {
if (categoryId === RECENT_CATEGORY_ID) return this.getRecentRecipes();

return content.recipes.filter(recipe => recipe.categories.includes(categoryId));
return this.catalog.recipes.filter(recipe => recipe.categories.includes(categoryId));
}

async getRecipeById(recipeId: string): Promise<Recipe> {
const recipe = (content.recipes as Recipe[]).find(recipe => recipe.id === recipeId);
const recipe = (this.catalog.recipes as Recipe[]).find(recipe => recipe.id === recipeId);
if (recipe) return recipe;
throw new Error('Not found');
}

async getModelById(modelId: string): Promise<ModelInfo> {
const model = content.models.find(m => modelId === m.id);
const model = this.catalog.models.find(m => modelId === m.id);
if (!model) {
throw new Error(`No model found having id ${modelId}`);
}
return model;
}

async getModelsByIds(ids: string[]): Promise<ModelInfo[]> {
return content.models.filter(m => ids.includes(m.id)) ?? [];
return this.catalog.models.filter(m => ids.includes(m.id)) ?? [];
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand All @@ -90,7 +136,7 @@ export class StudioApiImpl implements StudioAPI {
async getLocalModels(): Promise<ModelInfo[]> {
const local = this.applicationManager.getLocalModels();
const localIds = local.map(l => l.id);
return content.models.filter(m => localIds.includes(m.id));
return this.catalog.models.filter(m => localIds.includes(m.id));
}

async getTasksByLabel(label: string): Promise<Task[]> {
Expand Down
1 change: 1 addition & 0 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export class Studio {
const applicationManager = new ApplicationManager(gitManager, recipeStatusRegistry, this.#extensionContext);
this.playgroundManager = new PlayGroundManager(this.#panel.webview);
this.studioApi = new StudioApiImpl(applicationManager, recipeStatusRegistry, taskRegistry, this.playgroundManager);
await this.studioApi.loadCatalog();
// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
}
Expand Down

0 comments on commit 941123f

Please sign in to comment.