diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index de7819d6c..85bf17003 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -261,7 +261,7 @@ export class ApplicationManager { progress += chunk.length; const progressValue = (progress * 100) / totalFileSize; - if ((progressValue - previousProgressValue) > 1) { + if (progressValue === 100 || progressValue - previousProgressValue > 1) { previousProgressValue = progressValue; taskUtil.setTaskProgress(modelId, progressValue); } diff --git a/packages/backend/src/registries/RecipeStatusRegistry.ts b/packages/backend/src/registries/RecipeStatusRegistry.ts index ac96899e7..e614bec27 100644 --- a/packages/backend/src/registries/RecipeStatusRegistry.ts +++ b/packages/backend/src/registries/RecipeStatusRegistry.ts @@ -24,7 +24,10 @@ import { MSG_NEW_RECIPE_STATE } from '@shared/Messages'; export class RecipeStatusRegistry { private statuses: Map = new Map(); - constructor(private taskRegistry: TaskRegistry, private webview: Webview) {} + constructor( + private taskRegistry: TaskRegistry, + private webview: Webview, + ) {} setStatus(recipeId: string, status: RecipeStatus) { // Update the TaskRegistry @@ -32,7 +35,9 @@ export class RecipeStatusRegistry { status.tasks.map(task => this.taskRegistry.set(task)); } this.statuses.set(recipeId, status); - this.dispatchState(); // we don't want to wait + this.dispatchState().catch((err: unknown) => { + console.error('error dispatching recipe statuses', err); + }); // we don't want to wait } getStatus(recipeId: string): RecipeStatus | undefined { diff --git a/packages/backend/src/registries/TaskRegistry.ts b/packages/backend/src/registries/TaskRegistry.ts index 7c7f1cf6e..31bea2a72 100644 --- a/packages/backend/src/registries/TaskRegistry.ts +++ b/packages/backend/src/registries/TaskRegistry.ts @@ -21,10 +21,6 @@ import type { Task } from '@shared/src/models/ITask'; export class TaskRegistry { private tasks: Map = new Map(); - getTasksByLabel(label: string): Task[] { - return Array.from(this.tasks.values()).filter(task => label in (task.labels || {})); - } - set(task: Task) { this.tasks.set(task.id, task); } diff --git a/packages/backend/src/studio-api-impl.spec.ts b/packages/backend/src/studio-api-impl.spec.ts index 71c31bbd6..a2a8f62ab 100644 --- a/packages/backend/src/studio-api-impl.spec.ts +++ b/packages/backend/src/studio-api-impl.spec.ts @@ -75,7 +75,6 @@ beforeEach(async () => { appUserDirectory, } as unknown as ApplicationManager, {} as unknown as RecipeStatusRegistry, - {} as unknown as TaskRegistry, {} as unknown as PlayGroundManager, catalogManager, ); diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index f9ee00255..63b481aee 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -21,8 +21,6 @@ import type { ApplicationManager } from './managers/applicationManager'; import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import type { RecipeStatus } from '@shared/src/models/IRecipeStatus'; import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import type { TaskRegistry } from './registries/TaskRegistry'; -import type { Task } from '@shared/src/models/ITask'; import type { PlayGroundManager } from './managers/playground'; import * as podmanDesktopApi from '@podman-desktop/api'; import type { QueryState } from '@shared/src/models/IPlaygroundQueryState'; @@ -36,7 +34,6 @@ export class StudioApiImpl implements StudioAPI { constructor( private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, - private taskRegistry: TaskRegistry, private playgroundManager: PlayGroundManager, private catalogManager: CatalogManager, ) {} @@ -87,10 +84,6 @@ export class StudioApiImpl implements StudioAPI { return this.catalogManager.getModels().filter(m => localIds.includes(m.id)); } - async getTasksByLabel(label: string): Promise { - return this.taskRegistry.getTasksByLabel(label); - } - async startPlayground(modelId: string): Promise { // TODO: improve the following const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 4f61f260c..87bb917bd 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -102,7 +102,6 @@ export class Studio { this.studioApi = new StudioApiImpl( applicationManager, recipeStatusRegistry, - taskRegistry, this.playgroundManager, this.catalogManager, ); diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index 63a69b44d..af6eca362 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -9,15 +9,11 @@ import ModelColumnRegistry from '../lib/table/model/ModelColumnRegistry.svelte'; import ModelColumnPopularity from '../lib/table/model/ModelColumnPopularity.svelte'; import ModelColumnLicense from '../lib/table/model/ModelColumnLicense.svelte'; import ModelColumnHw from '../lib/table/model/ModelColumnHW.svelte'; -import { onDestroy, onMount } from 'svelte'; -import { studioClient } from '/@/utils/client'; -import type { Category } from '@shared/models/ICategory'; -import type { Task } from '@shared/models/ITask'; +import type { Task } from '@shared/src/models/ITask'; import TasksProgress from '/@/lib/progress/TasksProgress.svelte'; -import { faRefresh } from '@fortawesome/free-solid-svg-icons'; import Card from '/@/lib/Card.svelte'; -import Button from '/@/lib/button/Button.svelte'; -import LinearProgress from '/@/lib/progress/LinearProgress.svelte'; + import { modelsPulling } from '../stores/recipe'; + import { onMount } from 'svelte'; const columns: Column[] = [ new Column('Name', { width: '4fr', renderer: ModelColumnName }), @@ -29,7 +25,6 @@ const columns: Column[] = [ const row = new Row({}); let loading: boolean = true; -let intervalId: ReturnType | undefined = undefined; let tasks: Task[] = []; let models: ModelInfo[] = []; @@ -46,31 +41,28 @@ function filterModels(): void { } return previousValue; }, [] as string[]); - filteredModels = models.filter((model) => !(model.id in modelsId)); + filteredModels = models.filter((model) => !modelsId.includes(model.id)); } onMount(() => { // Pulling update - intervalId = setInterval(async () => { - tasks = await studioClient.getTasksByLabel("model-pulling"); + const modelsPullingUnsubscribe = modelsPulling.subscribe(runningTasks => { + tasks = runningTasks; loading = false; filterModels(); - }, 1000); + }); // Subscribe to the models store - return localModels.subscribe((value) => { + const localModelsUnsubscribe = localModels.subscribe((value) => { models = value; filterModels(); }) -}); -onDestroy(() => { - if(intervalId !== undefined) { - clearInterval(intervalId); - intervalId = undefined; + return () => { + modelsPullingUnsubscribe(); + localModelsUnsubscribe(); } }); - diff --git a/packages/frontend/src/pages/Recipe.spec.ts b/packages/frontend/src/pages/Recipe.spec.ts index a12e52b3b..fdbf2cd3d 100644 --- a/packages/frontend/src/pages/Recipe.spec.ts +++ b/packages/frontend/src/pages/Recipe.spec.ts @@ -6,6 +6,7 @@ import Recipe from './Recipe.svelte'; const mocks = vi.hoisted(() => { return { getCatalogMock: vi.fn(), + getPullingStatusesMock: vi.fn(), }; }); @@ -13,6 +14,7 @@ vi.mock('../utils/client', async () => { return { studioClient: { getCatalog: mocks.getCatalogMock, + getPullingStatuses: mocks.getPullingStatusesMock, }, rpcBrowser: { subscribe: () => { @@ -29,6 +31,7 @@ test('should display recipe information', async () => { expect(recipe).not.toBeUndefined(); mocks.getCatalogMock.mockResolvedValue(catalog); + mocks.getPullingStatusesMock.mockResolvedValue(new Map()); render(Recipe, { recipeId: 'recipe 1', }); diff --git a/packages/frontend/src/stores/recipe.ts b/packages/frontend/src/stores/recipe.ts index 888251376..b4d7d400c 100644 --- a/packages/frontend/src/stores/recipe.ts +++ b/packages/frontend/src/stores/recipe.ts @@ -1,18 +1,27 @@ import type { Readable } from 'svelte/store'; -import { readable } from 'svelte/store'; +import { derived, readable } from 'svelte/store'; import { MSG_NEW_RECIPE_STATE } from '@shared/Messages'; import { rpcBrowser, studioClient } from '/@/utils/client'; import type { RecipeStatus } from '@shared/src/models/IRecipeStatus'; -export const recipes: Readable> = readable>(new Map(), set => { - const sub = rpcBrowser.subscribe(MSG_NEW_RECIPE_STATE, msg => { - set(msg); - }); - // Initialize the store manually - studioClient.getPullingStatuses().then(state => { - set(state); - }); - return () => { - sub.unsubscribe(); - }; +export const recipes: Readable> = readable>( + new Map(), + set => { + const sub = rpcBrowser.subscribe(MSG_NEW_RECIPE_STATE, msg => { + set(msg); + }); + // Initialize the store manually + studioClient.getPullingStatuses().then(state => { + set(state); + }); + return () => { + sub.unsubscribe(); + }; + }, +); + +export const modelsPulling = derived(recipes, $recipes => { + return Array.from($recipes.values()) + .flatMap(recipe => recipe.tasks) + .filter(task => 'model-pulling' in (task.labels || {})); }); diff --git a/packages/shared/src/StudioAPI.ts b/packages/shared/src/StudioAPI.ts index 2ec821058..5ea1990eb 100644 --- a/packages/shared/src/StudioAPI.ts +++ b/packages/shared/src/StudioAPI.ts @@ -1,6 +1,5 @@ import type { RecipeStatus } from './models/IRecipeStatus'; import type { ModelInfo } from './models/IModelInfo'; -import type { Task } from './models/ITask'; import type { QueryState } from './models/IPlaygroundQueryState'; import type { Catalog } from './models/ICatalog'; import type { PlaygroundState } from './models/IPlaygroundState'; @@ -20,14 +19,6 @@ export abstract class StudioAPI { abstract startPlayground(modelId: string): Promise; abstract stopPlayground(modelId: string): Promise; abstract askPlayground(modelId: string, prompt: string): Promise; - - /** - * Get task by label - * @param label - */ - abstract getTasksByLabel(label: string): Promise; - abstract getPlaygroundQueriesState(): Promise; - abstract getPlaygroundsState(): Promise; }