diff --git a/packages/backend/src/managers/applicationManager.ts b/packages/backend/src/managers/applicationManager.ts index 8b336add3..39effa785 100644 --- a/packages/backend/src/managers/applicationManager.ts +++ b/packages/backend/src/managers/applicationManager.ts @@ -9,9 +9,8 @@ import { containerEngine, ExtensionContext, provider } from '@podman-desktop/api import { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; import { AIConfig, parseYaml } from '../models/AIConfig'; import { Task } from '@shared/models/ITask'; -import { TaskUtils } from '../utils/taskUtils'; +import { RecipeStatusUtils } from '../utils/recipeStatusUtils'; import { getParentDirectory } from '../utils/pathUtils'; -import { a } from 'vitest/dist/suite-dF4WyktM'; import type { LocalModelInfo } from '@shared/models/ILocalModelInfo'; // TODO: Need to be configured @@ -32,7 +31,7 @@ export class ApplicationManager { async pullApplication(recipe: Recipe) { // Create a TaskUtils object to help us - const taskUtil = new TaskUtils(recipe.id, this.recipeStatusRegistry); + const taskUtil = new RecipeStatusUtils(recipe.id, this.recipeStatusRegistry); const localFolder = path.join(this.homeDirectory, AI_STUDIO_FOLDER, recipe.id); @@ -41,6 +40,9 @@ export class ApplicationManager { id: 'checkout', name: 'Checkout repository', state: 'loading', + labels: { + 'git': 'checkout', + }, } taskUtil.setTask(checkoutTask); @@ -125,6 +127,9 @@ export class ApplicationManager { id: model.id, state: 'loading', name: `Downloading model ${model.name}`, + labels: { + "model-pulling": model.id, + } }); await this.downloadModelMain(model.id, model.url, taskUtil) @@ -181,7 +186,7 @@ export class ApplicationManager { } - downloadModelMain(modelId: string, url: string, taskUtil: TaskUtils, destFileName?: string): Promise { + downloadModelMain(modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string): Promise { return new Promise((resolve, reject) => { const downloadCallback = (result: DownloadModelResult) => { if (result.result) { @@ -193,11 +198,17 @@ export class ApplicationManager { } } + if(fs.existsSync(destFileName)) { + taskUtil.setTaskState(modelId, 'success'); + taskUtil.setTaskProgress(modelId, 100); + return; + } + this.downloadModel(modelId, url, taskUtil, downloadCallback, destFileName) }) } - downloadModel(modelId: string, url: string, taskUtil: TaskUtils, callback: (message: DownloadModelResult) => void, destFileName?: string) { + private downloadModel(modelId: string, url: string, taskUtil: RecipeStatusUtils, callback: (message: DownloadModelResult) => void, destFileName?: string) { const destDir = path.join(this.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId); if (!fs.existsSync(destDir)) { fs.mkdirSync(destDir, { recursive: true }); diff --git a/packages/backend/src/registries/RecipeStatusRegistry.ts b/packages/backend/src/registries/RecipeStatusRegistry.ts index ef2a46da8..8f636a6f3 100644 --- a/packages/backend/src/registries/RecipeStatusRegistry.ts +++ b/packages/backend/src/registries/RecipeStatusRegistry.ts @@ -1,9 +1,16 @@ import { RecipeStatus } from '@shared/models/IRecipeStatus'; +import { TaskRegistry } from './TaskRegistry'; export class RecipeStatusRegistry { private statuses: Map = new Map(); + constructor(private taskRegistry: TaskRegistry) { } + setStatus(recipeId: string, status: RecipeStatus) { + // Update the TaskRegistry + if(status.tasks && status.tasks.length > 0) { + status.tasks.map((task) => this.taskRegistry.set(task)); + } this.statuses.set(recipeId, status); } diff --git a/packages/backend/src/registries/TaskRegistry.ts b/packages/backend/src/registries/TaskRegistry.ts new file mode 100644 index 000000000..699fa3a4f --- /dev/null +++ b/packages/backend/src/registries/TaskRegistry.ts @@ -0,0 +1,18 @@ +import { Task } from '@shared/models/ITask'; + +export class TaskRegistry { + private tasks: Map = new Map(); + + getTasksByLabel(label: string): Task[] { + return Array.from(this.tasks.values()) + .filter((task) => label in (task.labels || {})) + } + + set(task: Task) { + this.tasks.set(task.id, task); + } + + delete(taskId: string) { + this.tasks.delete(taskId); + } +} diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index 5940bd6fd..1ae2fec62 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -6,9 +6,12 @@ import { AI_STUDIO_FOLDER, ApplicationManager } from './managers/applicationMana import { RecipeStatusRegistry } from './registries/RecipeStatusRegistry'; import { RecipeStatus } from '@shared/models/IRecipeStatus'; import { ModelInfo } from '@shared/models/IModelInfo'; +import { TaskRegistry } from './registries/TaskRegistry'; +import { Task } from '@shared/models/ITask'; import { Studio } from './studio'; import * as path from 'node:path'; import { ModelResponse } from '@shared/models/IModelResponse'; +import { PlayGroundManager } from './playground'; export const RECENT_CATEGORY_ID = 'recent-category'; @@ -16,7 +19,8 @@ export class StudioApiImpl implements StudioAPI { constructor( private applicationManager: ApplicationManager, private recipeStatusRegistry: RecipeStatusRegistry, - private studio: Studio, + private taskRegistry: TaskRegistry, + private playgroundManager: PlayGroundManager, ) {} async openURL(url: string): Promise { @@ -82,14 +86,19 @@ export class StudioApiImpl implements StudioAPI { return content.recipes.flatMap(r => r.models.filter(m => localIds.includes(m.id))); } + async getTasksByLabel(label: string): Promise { + return this.taskRegistry.getTasksByLabel(label); + } + async startPlayground(modelId: string): Promise { const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId); if (localModelInfo.length !== 1) { throw new Error('model not found'); } - const destDir = path.join(); + const modelPath = path.resolve(this.applicationManager.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId, localModelInfo[0].file); - this.studio.playgroundManager.startPlayground(modelId, modelPath); + + await this.playgroundManager.startPlayground(modelId, modelPath); } askPlayground(modelId: string, prompt: string): Promise { @@ -97,6 +106,6 @@ export class StudioApiImpl implements StudioAPI { if (localModelInfo.length !== 1) { throw new Error('model not found'); } - return this.studio.playgroundManager.askPlayground(localModelInfo[0], prompt); + return this.playgroundManager.askPlayground(localModelInfo[0], prompt); } } diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index efaff9e48..eb13c103b 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -28,6 +28,7 @@ import * as fs from 'node:fs'; import * as https from 'node:https'; import * as path from 'node:path'; import type { LocalModelInfo } from '@shared/models/ILocalModelInfo'; +import { TaskRegistry } from './registries/TaskRegistry'; import { PlayGroundManager } from './playground'; export class Studio { @@ -92,7 +93,8 @@ export class Studio { // Let's create the api that the front will be able to call this.rpcExtension = new RpcExtension(this.#panel.webview); const gitManager = new GitManager(); - const recipeStatusRegistry = new RecipeStatusRegistry(); + const taskRegistry = new TaskRegistry(); + const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry); const applicationManager = new ApplicationManager( gitManager, recipeStatusRegistry, @@ -101,7 +103,8 @@ export class Studio { this.studioApi = new StudioApiImpl( applicationManager, recipeStatusRegistry, - this, + taskRegistry, + this.playgroundManager, ); // Register the instance this.rpcExtension.registerInstance(StudioApiImpl, this.studioApi); diff --git a/packages/backend/src/utils/taskUtils.ts b/packages/backend/src/utils/recipeStatusUtils.ts similarity index 97% rename from packages/backend/src/utils/taskUtils.ts rename to packages/backend/src/utils/recipeStatusUtils.ts index 8269a79f7..93c3ddc46 100644 --- a/packages/backend/src/utils/taskUtils.ts +++ b/packages/backend/src/utils/recipeStatusUtils.ts @@ -3,7 +3,7 @@ import type { Task, TaskState } from '@shared/models/ITask'; import { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry'; -export class TaskUtils { +export class RecipeStatusUtils { private tasks: Map = new Map(); private state: RecipeStatusState = 'loading'; diff --git a/packages/frontend/src/pages/Models.svelte b/packages/frontend/src/pages/Models.svelte index 90fa0a951..49a0e5ebc 100644 --- a/packages/frontend/src/pages/Models.svelte +++ b/packages/frontend/src/pages/Models.svelte @@ -1,39 +1,106 @@
+ {#if loading} + + {/if}
- {#if $localModels && $localModels.length} - -
- {:else} -
There is no model yet
+ {#if !loading} + {#if tasks.length > 0} +
+ +
+
Downloading models
+ +
+
+
+ {/if} + {#if filteredModels.length > 0} + +
+ {:else} +
There is no model yet
+ {/if} {/if}
diff --git a/packages/frontend/src/pages/Recipe.svelte b/packages/frontend/src/pages/Recipe.svelte index 98865371e..bdca5e59b 100644 --- a/packages/frontend/src/pages/Recipe.svelte +++ b/packages/frontend/src/pages/Recipe.svelte @@ -73,7 +73,7 @@ const onClickRepository = () => {
-
+
Repository
diff --git a/packages/shared/StudioAPI.ts b/packages/shared/StudioAPI.ts index 32dd909af..b76fa4f1c 100644 --- a/packages/shared/StudioAPI.ts +++ b/packages/shared/StudioAPI.ts @@ -3,6 +3,7 @@ import type { Category } from '@shared/models/ICategory'; import { RecipeStatus } from '@shared/models/IRecipeStatus'; import { ModelInfo } from '@shared/models/IModelInfo'; import { ModelResponse } from '@shared/models/IModelResponse'; +import { Task } from './models/ITask'; export abstract class StudioAPI { abstract ping(): Promise; @@ -16,11 +17,17 @@ export abstract class StudioAPI { abstract pullApplication(recipeId: string): Promise; abstract openURL(url: string): Promise; /** - * Get the information of models saved locally into the extension's storage directory + * Get the information of models saved locally into the extension's storage directory */ abstract getLocalModels(): Promise; abstract startPlayground(modelId: string): Promise; abstract askPlayground(modelId: string, prompt: string): Promise; + + /** + * Get task by label + * @param label + */ + abstract getTasksByLabel(label: string): Promise; } diff --git a/packages/shared/models/ITask.ts b/packages/shared/models/ITask.ts index a6aec82e7..c26df28d0 100644 --- a/packages/shared/models/ITask.ts +++ b/packages/shared/models/ITask.ts @@ -1,8 +1,9 @@ export type TaskState = 'loading' | 'error' | 'success' export interface Task { - id: string, + id: string; state: TaskState; - progress?: number + progress?: number; name: string; + labels?: {[id: string]: string} }