Skip to content

Commit

Permalink
Models page also uses the store
Browse files Browse the repository at this point in the history
  • Loading branch information
feloy committed Jan 22, 2024
1 parent be175f9 commit a956b16
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 56 deletions.
2 changes: 1 addition & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ export class ApplicationManager {
progress += chunk.length;
const progressValue = (progress * 100) / totalFileSize;

if ((progressValue - previousProgressValue) > 1) {
if (progressValue === 100 || progressValue - previousProgressValue > 1) {
previousProgressValue = progressValue;
taskUtil.setTaskProgress(modelId, progressValue);
}
Expand Down
9 changes: 7 additions & 2 deletions packages/backend/src/registries/RecipeStatusRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,20 @@ import { MSG_NEW_RECIPE_STATE } from '@shared/Messages';
export class RecipeStatusRegistry {
private statuses: Map<string, RecipeStatus> = new Map<string, RecipeStatus>();

constructor(private taskRegistry: TaskRegistry, private webview: Webview) {}
constructor(
private taskRegistry: TaskRegistry,
private webview: Webview,
) {}

setStatus(recipeId: string, status: RecipeStatus) {
// Update the TaskRegistry
if (status.tasks && status.tasks.length > 0) {
status.tasks.map(task => this.taskRegistry.set(task));
}
this.statuses.set(recipeId, status);
this.dispatchState(); // we don't want to wait
this.dispatchState().catch((err: unknown) => {
console.error('error dispatching recipe statuses', err);
}); // we don't want to wait
}

getStatus(recipeId: string): RecipeStatus | undefined {
Expand Down
4 changes: 0 additions & 4 deletions packages/backend/src/registries/TaskRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ import type { Task } from '@shared/src/models/ITask';
export class TaskRegistry {
private tasks: Map<string, Task> = new Map<string, Task>();

getTasksByLabel(label: string): Task[] {
return Array.from(this.tasks.values()).filter(task => label in (task.labels || {}));
}

set(task: Task) {
this.tasks.set(task.id, task);
}
Expand Down
1 change: 0 additions & 1 deletion packages/backend/src/studio-api-impl.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ beforeEach(async () => {
appUserDirectory,
} as unknown as ApplicationManager,
{} as unknown as RecipeStatusRegistry,
{} as unknown as TaskRegistry,
{} as unknown as PlayGroundManager,
catalogManager,
);
Expand Down
7 changes: 0 additions & 7 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import type { ApplicationManager } from './managers/applicationManager';
import type { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import type { RecipeStatus } from '@shared/src/models/IRecipeStatus';
import type { ModelInfo } from '@shared/src/models/IModelInfo';
import type { TaskRegistry } from './registries/TaskRegistry';
import type { Task } from '@shared/src/models/ITask';
import type { PlayGroundManager } from './managers/playground';
import * as podmanDesktopApi from '@podman-desktop/api';
import type { QueryState } from '@shared/src/models/IPlaygroundQueryState';
Expand All @@ -36,7 +34,6 @@ export class StudioApiImpl implements StudioAPI {
constructor(
private applicationManager: ApplicationManager,
private recipeStatusRegistry: RecipeStatusRegistry,
private taskRegistry: TaskRegistry,
private playgroundManager: PlayGroundManager,
private catalogManager: CatalogManager,
) {}
Expand Down Expand Up @@ -87,10 +84,6 @@ export class StudioApiImpl implements StudioAPI {
return this.catalogManager.getModels().filter(m => localIds.includes(m.id));
}

async getTasksByLabel(label: string): Promise<Task[]> {
return this.taskRegistry.getTasksByLabel(label);
}

async startPlayground(modelId: string): Promise<void> {
// TODO: improve the following
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
Expand Down
1 change: 0 additions & 1 deletion packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ export class Studio {
this.studioApi = new StudioApiImpl(
applicationManager,
recipeStatusRegistry,
taskRegistry,
this.playgroundManager,
this.catalogManager,
);
Expand Down
30 changes: 11 additions & 19 deletions packages/frontend/src/pages/Models.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@ import ModelColumnRegistry from '../lib/table/model/ModelColumnRegistry.svelte';
import ModelColumnPopularity from '../lib/table/model/ModelColumnPopularity.svelte';
import ModelColumnLicense from '../lib/table/model/ModelColumnLicense.svelte';
import ModelColumnHw from '../lib/table/model/ModelColumnHW.svelte';
import { onDestroy, onMount } from 'svelte';
import { studioClient } from '/@/utils/client';
import type { Category } from '@shared/models/ICategory';
import type { Task } from '@shared/models/ITask';
import type { Task } from '@shared/src/models/ITask';
import TasksProgress from '/@/lib/progress/TasksProgress.svelte';
import { faRefresh } from '@fortawesome/free-solid-svg-icons';
import Card from '/@/lib/Card.svelte';
import Button from '/@/lib/button/Button.svelte';
import LinearProgress from '/@/lib/progress/LinearProgress.svelte';
import { modelsPulling } from '../stores/recipe';
import { onMount } from 'svelte';
const columns: Column<ModelInfo>[] = [
new Column<ModelInfo>('Name', { width: '4fr', renderer: ModelColumnName }),
Expand All @@ -29,7 +25,6 @@ const columns: Column<ModelInfo>[] = [
const row = new Row<ModelInfo>({});
let loading: boolean = true;
let intervalId: ReturnType<typeof setInterval> | undefined = undefined;
let tasks: Task[] = [];
let models: ModelInfo[] = [];
Expand All @@ -46,31 +41,28 @@ function filterModels(): void {
}
return previousValue;
}, [] as string[]);
filteredModels = models.filter((model) => !(model.id in modelsId));
filteredModels = models.filter((model) => !modelsId.includes(model.id));
}
onMount(() => {
// Pulling update
intervalId = setInterval(async () => {
tasks = await studioClient.getTasksByLabel("model-pulling");
const modelsPullingUnsubscribe = modelsPulling.subscribe(runningTasks => {
tasks = runningTasks;
loading = false;
filterModels();
}, 1000);
});
// Subscribe to the models store
return localModels.subscribe((value) => {
const localModelsUnsubscribe = localModels.subscribe((value) => {
models = value;
filterModels();
})
});
onDestroy(() => {
if(intervalId !== undefined) {
clearInterval(intervalId);
intervalId = undefined;
return () => {
modelsPullingUnsubscribe();
localModelsUnsubscribe();
}
});
</script>

<NavPage title="Models on disk" searchEnabled="{false}" loading="{loading}">
Expand Down
3 changes: 3 additions & 0 deletions packages/frontend/src/pages/Recipe.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import Recipe from './Recipe.svelte';
const mocks = vi.hoisted(() => {
return {
getCatalogMock: vi.fn(),
getPullingStatusesMock: vi.fn(),
};
});

vi.mock('../utils/client', async () => {
return {
studioClient: {
getCatalog: mocks.getCatalogMock,
getPullingStatuses: mocks.getPullingStatusesMock,
},
rpcBrowser: {
subscribe: () => {
Expand All @@ -29,6 +31,7 @@ test('should display recipe information', async () => {
expect(recipe).not.toBeUndefined();

mocks.getCatalogMock.mockResolvedValue(catalog);
mocks.getPullingStatusesMock.mockResolvedValue(new Map());
render(Recipe, {
recipeId: 'recipe 1',
});
Expand Down
33 changes: 21 additions & 12 deletions packages/frontend/src/stores/recipe.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import type { Readable } from 'svelte/store';
import { readable } from 'svelte/store';
import { derived, readable } from 'svelte/store';
import { MSG_NEW_RECIPE_STATE } from '@shared/Messages';
import { rpcBrowser, studioClient } from '/@/utils/client';
import type { RecipeStatus } from '@shared/src/models/IRecipeStatus';

export const recipes: Readable<Map<string, RecipeStatus>> = readable<Map<string, RecipeStatus>>(new Map<string, RecipeStatus>(), set => {
const sub = rpcBrowser.subscribe(MSG_NEW_RECIPE_STATE, msg => {
set(msg);
});
// Initialize the store manually
studioClient.getPullingStatuses().then(state => {
set(state);
});
return () => {
sub.unsubscribe();
};
export const recipes: Readable<Map<string, RecipeStatus>> = readable<Map<string, RecipeStatus>>(
new Map<string, RecipeStatus>(),
set => {
const sub = rpcBrowser.subscribe(MSG_NEW_RECIPE_STATE, msg => {
set(msg);
});
// Initialize the store manually
studioClient.getPullingStatuses().then(state => {
set(state);
});
return () => {
sub.unsubscribe();
};
},
);

export const modelsPulling = derived(recipes, $recipes => {
return Array.from($recipes.values())
.flatMap(recipe => recipe.tasks)
.filter(task => 'model-pulling' in (task.labels || {}));
});
9 changes: 0 additions & 9 deletions packages/shared/src/StudioAPI.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import type { RecipeStatus } from './models/IRecipeStatus';
import type { ModelInfo } from './models/IModelInfo';
import type { Task } from './models/ITask';
import type { QueryState } from './models/IPlaygroundQueryState';
import type { Catalog } from './models/ICatalog';
import type { PlaygroundState } from './models/IPlaygroundState';
Expand All @@ -20,14 +19,6 @@ export abstract class StudioAPI {
abstract startPlayground(modelId: string): Promise<void>;
abstract stopPlayground(modelId: string): Promise<void>;
abstract askPlayground(modelId: string, prompt: string): Promise<number>;

/**
* Get task by label
* @param label
*/
abstract getTasksByLabel(label: string): Promise<Task[]>;

abstract getPlaygroundQueriesState(): Promise<QueryState[]>;

abstract getPlaygroundsState(): Promise<PlaygroundState[]>;
}

0 comments on commit a956b16

Please sign in to comment.