Skip to content

Commit

Permalink
refacto: move getModelById implementation to catalogManager (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
feloy authored Feb 12, 2024
1 parent f7cf100 commit 97fb6d5
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 55 deletions.
134 changes: 134 additions & 0 deletions packages/backend/src/managers/catalogManager.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/**********************************************************************
* Copyright (C) 2024 Red Hat, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
***********************************************************************/

/* eslint-disable @typescript-eslint/no-explicit-any */

import { beforeEach, describe, expect, test, vi } from 'vitest';
import content from '../ai-test.json';
import userContent from '../ai-user-test.json';
import type { Webview } from '@podman-desktop/api';
import { CatalogManager } from '../managers/catalogManager';

import * as fs from 'node:fs';

vi.mock('./ai.json', () => {
return {
default: content,
};
});

vi.mock('node:fs', () => {
return {
existsSync: vi.fn(),
promises: {
readFile: vi.fn(),
},
};
});

vi.mock('@podman-desktop/api', () => {
return {
fs: {
createFileSystemWatcher: () => ({
onDidCreate: vi.fn(),
onDidDelete: vi.fn(),
onDidChange: vi.fn(),
}),
},
};
});

const mocks = vi.hoisted(() => ({
withProgressMock: vi.fn(),
}));

vi.mock('@podman-desktop/api', async () => {
return {
window: {
withProgress: mocks.withProgressMock,
},
ProgressLocation: {
TASK_WIDGET: 'TASK_WIDGET',
},
fs: {
createFileSystemWatcher: () => ({
onDidCreate: vi.fn(),
onDidDelete: vi.fn(),
onDidChange: vi.fn(),
}),
},
};
});

let catalogManager: CatalogManager;

beforeEach(async () => {
const appUserDirectory = '.';

// Creating CatalogManager
catalogManager = new CatalogManager(appUserDirectory, {
postMessage: vi.fn(),
} as unknown as Webview);
vi.resetAllMocks();
vi.mock('node:fs');
});

describe('invalid user catalog', () => {
beforeEach(async () => {
vi.spyOn(fs.promises, 'readFile').mockResolvedValue('invalid json');
await catalogManager.loadCatalog();
});

test('expect correct model is returned with valid id', () => {
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

test('expect error if id does not correspond to any model', () => {
expect(() => catalogManager.getModelById('unknown')).toThrowError('No model found having id unknown');
});
});

test('expect correct model is returned from default catalog with valid id when no user catalog exists', async () => {
vi.spyOn(fs, 'existsSync').mockReturnValue(false);
await catalogManager.loadCatalog();
const model = catalogManager.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

test('expect correct model is returned with valid id when the user catalog is valid', async () => {
vi.spyOn(fs, 'existsSync').mockReturnValue(true);
vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

await catalogManager.loadCatalog();
const model = catalogManager.getModelById('model1');
expect(model).toBeDefined();
expect(model.name).toEqual('Model 1');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual('https://model1.example.com');
});
9 changes: 9 additions & 0 deletions packages/backend/src/managers/catalogManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ export class CatalogManager {
public getModels(): ModelInfo[] {
return this.catalog.models;
}

public getModelById(modelId: string): ModelInfo {
const model = this.getModels().find(m => modelId === m.id);
if (!model) {
throw new Error(`No model found having id ${modelId}`);
}
return model;
}

public getRecipes(): Recipe[] {
return this.catalog.recipes;
}
Expand Down
47 changes: 1 addition & 46 deletions packages/backend/src/studio-api-impl.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

/* eslint-disable @typescript-eslint/no-explicit-any */

import { beforeEach, describe, expect, test, vi } from 'vitest';
import { beforeEach, expect, test, vi } from 'vitest';
import content from './ai-test.json';
import userContent from './ai-user-test.json';
import type { ApplicationManager } from './managers/applicationManager';
Expand Down Expand Up @@ -106,51 +106,6 @@ beforeEach(async () => {
vi.mock('node:fs');
});

describe('invalid user catalog', () => {
beforeEach(async () => {
vi.spyOn(fs.promises, 'readFile').mockResolvedValue('invalid json');
await catalogManager.loadCatalog();
});

test('expect correct model is returned with valid id', async () => {
const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

test('expect error if id does not correspond to any model', async () => {
await expect(() => studioApiImpl.getModelById('unknown')).rejects.toThrowError('No model found having id unknown');
});
});

test('expect correct model is returned from default catalog with valid id when no user catalog exists', async () => {
vi.spyOn(fs, 'existsSync').mockReturnValue(false);
await catalogManager.loadCatalog();
const model = await studioApiImpl.getModelById('llama-2-7b-chat.Q5_K_S');
expect(model).toBeDefined();
expect(model.name).toEqual('Llama-2-7B-Chat-GGUF');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual(
'https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf',
);
});

test('expect correct model is returned with valid id when the user catalog is valid', async () => {
vi.spyOn(fs, 'existsSync').mockReturnValue(true);
vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));

await catalogManager.loadCatalog();
const model = await studioApiImpl.getModelById('model1');
expect(model).toBeDefined();
expect(model.name).toEqual('Model 1');
expect(model.registry).toEqual('Hugging Face');
expect(model.url).toEqual('https://model1.example.com');
});

test('expect pull application to call the withProgress api method', async () => {
vi.spyOn(fs, 'existsSync').mockReturnValue(true);
vi.spyOn(fs.promises, 'readFile').mockResolvedValue(JSON.stringify(userContent));
Expand Down
10 changes: 1 addition & 9 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,13 @@ export class StudioApiImpl implements StudioAPI {
return this.recipeStatusRegistry.getStatuses();
}

async getModelById(modelId: string): Promise<ModelInfo> {
const model = this.catalogManager.getModels().find(m => modelId === m.id);
if (!model) {
throw new Error(`No model found having id ${modelId}`);
}
return model;
}

async pullApplication(recipeId: string): Promise<void> {
const recipe = this.catalogManager.getRecipes().find(recipe => recipe.id === recipeId);
if (!recipe) throw new Error('Not found');

// the user should have selected one model, we use the first one for the moment
const modelId = recipe.models[0];
const model = await this.getModelById(modelId);
const model = this.catalogManager.getModelById(modelId);

// Do not wait for the pull application, run it separately
podmanDesktopApi.window
Expand Down

0 comments on commit 97fb6d5

Please sign in to comment.