diff --git a/packages/backend/src/managers/modelsManager.ts b/packages/backend/src/managers/modelsManager.ts index b46327bb9..d6429d23b 100644 --- a/packages/backend/src/managers/modelsManager.ts +++ b/packages/backend/src/managers/modelsManager.ts @@ -3,15 +3,22 @@ import fs from 'fs'; import * as path from 'node:path'; export class ModelsManager { - constructor(private appUserDirectory: string) {} + private readonly modelsDir: string; + + constructor(private appUserDirectory: string) { + this.modelsDir = path.join(this.appUserDirectory, 'models'); + } + + getModelsDirectory(): string { + return this.modelsDir; + } getLocalModels(): LocalModelInfo[] { const result: LocalModelInfo[] = []; - const modelsDir = path.join(this.appUserDirectory, 'models'); - if (!fs.existsSync(modelsDir)) { + if (!fs.existsSync(this.modelsDir)) { return []; } - const entries = fs.readdirSync(modelsDir, { withFileTypes: true }); + const entries = fs.readdirSync(this.modelsDir, { withFileTypes: true }); const dirs = entries.filter(dir => dir.isDirectory()); for (const d of dirs) { const modelEntries = fs.readdirSync(path.resolve(d.path, d.name)); diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 105fed58f..4494681b5 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -71,7 +71,6 @@ beforeEach(async () => { // Creating StudioApiImpl studioApiImpl = new StudioApiImpl( - appUserDirectory, {} as unknown as ApplicationManager, {} as unknown as RecipeStatusRegistry, {} as unknown as PlayGroundManager, diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 122a6a361..40c05a3ac 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -34,7 +34,6 @@ import type { LocalModelInfo } from '@shared/src/models/ILocalModelInfo'; export class StudioApiImpl implements StudioAPI { constructor( - private appUserDirectory: string, private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, private playgroundManager: PlayGroundManager, @@ -101,7 +100,7 @@ export class StudioApiImpl implements StudioAPI { } // TODO: we need to stop doing that. - const modelPath = path.resolve(this.appUserDirectory, 'models', modelId, localModelInfo[0].file); + const modelPath = path.resolve(this.modelsManager.getModelsDirectory(), modelId, localModelInfo[0].file); await this.playgroundManager.startPlayground(modelId, modelPath); } @@ -129,4 +128,11 @@ export class StudioApiImpl implements StudioAPI { async getCatalog(): Promise { return this.catalogManager.getCatalog(); } + + async getModelsDirectory(): Promise { + return this.modelsManager.getModelsDirectory(); + } + async openDirectory(path: string): Promise { + void podmanDesktopApi.env.openExternal(podmanDesktopApi.Uri.file(path)); + } } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index ff4e39026..9d750c5bc 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -114,7 +114,6 @@ export class Studio { // Creating StudioApiImpl this.studioApi = new StudioApiImpl( - appUserDirectory, applicationManager, recipeStatusRegistry, this.playgroundManager, diff --git a/packages/frontend/src/pages/Models.spec.ts b/packages/frontend/src/pages/Models.spec.ts new file mode 100644 index 000000000..571775f5b --- /dev/null +++ b/packages/frontend/src/pages/Models.spec.ts @@ -0,0 +1,68 @@ +import '@testing-library/jest-dom/vitest'; +import { vi, test, expect } from 'vitest'; +import { screen, fireEvent, render, waitFor } from '@testing-library/svelte'; +import Models from '/@/pages/Models.svelte'; + +const mocks = vi.hoisted(() => { + return { + getModelsDirectoryMock: vi.fn(), + openDirectoryMock: vi.fn(), + localModelsMock: { + subscribe: (_f: (msg: any) => void) => { + return () => {}; + }, + }, + modelsPullingSubscribeMock: vi.fn(), + modelsPullingMock: { + subscribe: (f: (msg: any) => void) => { + f(mocks.modelsPullingSubscribeMock()); + return () => {}; + }, + }, + }; +}); + +vi.mock('../utils/client', async () => { + return { + studioClient: { + getModelsDirectory: mocks.getModelsDirectoryMock, + openDirectory: mocks.openDirectoryMock, + }, + rpcBrowser: { + subscribe: () => { + return { + unsubscribe: () => {}, + }; + }, + }, + }; +}); + +vi.mock('../stores/local-models', async () => { + return { + localModels: mocks.localModelsMock, + }; +}); + +vi.mock('../stores/recipe', async () => { + return { + modelsPulling: mocks.modelsPullingMock, + }; +}); + +test('open models directory should call the api', async () => { + mocks.getModelsDirectoryMock.mockResolvedValue('fake'); + mocks.modelsPullingSubscribeMock.mockReturnValue([]); + render(Models); + + await waitFor(async () => { + const open = screen.getByTitle('open-models-directory'); + expect(open).toBeDefined(); + + await fireEvent.click(open); + }); + + await waitFor(() => { + expect(mocks.openDirectoryMock).toHaveBeenNthCalledWith(1, 'fake'); + }); +}); diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index 4420ec43b..b7e523d70 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -15,7 +15,10 @@ import Card from '/@/lib/Card.svelte'; import { modelsPulling } from '../stores/recipe'; import { onMount } from 'svelte'; import ModelColumnSize from '../lib/table/model/ModelColumnSize.svelte'; - import ModelColumnCreation from '../lib/table/model/ModelColumnCreation.svelte'; +import ModelColumnCreation from '../lib/table/model/ModelColumnCreation.svelte'; +import { studioClient } from '/@/utils/client'; +import { faFolderOpen } from '@fortawesome/free-solid-svg-icons'; +import Button from '/@/lib/button/Button.svelte'; const columns: Column[] = [ new Column('Name', { width: '3fr', renderer: ModelColumnName }), @@ -33,6 +36,7 @@ let loading: boolean = true; let tasks: Task[] = []; let models: ModelInfo[] = []; let filteredModels: ModelInfo[] = []; +let modelsDir: string | undefined = undefined; function filterModels(): void { // Let's collect the models we do not want to show (loading, error). @@ -49,7 +53,10 @@ function filterModels(): void { } onMount(() => { - // Pulling update + studioClient.getModelsDirectory().then((modelsDirectory) => { + modelsDir = modelsDirectory; + }); + const modelsPullingUnsubscribe = modelsPulling.subscribe(runningTasks => { tasks = runningTasks; loading = false; @@ -67,12 +74,31 @@ onMount(() => { localModelsUnsubscribe(); } }); + +const openModelsDir = () => { + if(modelsDir === undefined) + return; + studioClient.openDirectory(modelsDir); +}
+
+ {#if modelsDir} + +
+
+ Models are stored in + {modelsDir} +
+
+
+ {/if} +
{#if !loading} {#if tasks.length > 0}
diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index 5ea1990eb..45216229a 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -11,6 +11,7 @@ export abstract class StudioAPI { abstract getPullingStatuses(): Promise>; abstract pullApplication(recipeId: string): Promise; abstract openURL(url: string): Promise; + abstract openDirectory(path: string): Promise; /** * Get the information of models saved locally into the extension's storage directory */ @@ -21,4 +22,5 @@ export abstract class StudioAPI { abstract askPlayground(modelId: string, prompt: string): Promise; abstract getPlaygroundQueriesState(): Promise; abstract getPlaygroundsState(): Promise; + abstract getModelsDirectory(): Promise; }