Skip to content

Commit

Permalink
Merge pull request #17 from projectatomic/feature/adding-tasks-labels
Browse files Browse the repository at this point in the history
feat: adding TaskRegistry
  • Loading branch information
axel7083 authored Jan 16, 2024
2 parents 280c18b + 6045ed5 commit faab64f
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 44 deletions.
21 changes: 16 additions & 5 deletions packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ import { containerEngine, ExtensionContext, provider } from '@podman-desktop/api
import { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';
import { AIConfig, parseYaml } from '../models/AIConfig';
import { Task } from '@shared/models/ITask';
import { TaskUtils } from '../utils/taskUtils';
import { RecipeStatusUtils } from '../utils/recipeStatusUtils';
import { getParentDirectory } from '../utils/pathUtils';
import { a } from 'vitest/dist/suite-dF4WyktM';
import type { LocalModelInfo } from '@shared/models/ILocalModelInfo';

// TODO: Need to be configured
Expand All @@ -32,7 +31,7 @@ export class ApplicationManager {

async pullApplication(recipe: Recipe) {
// Create a TaskUtils object to help us
const taskUtil = new TaskUtils(recipe.id, this.recipeStatusRegistry);
const taskUtil = new RecipeStatusUtils(recipe.id, this.recipeStatusRegistry);

const localFolder = path.join(this.homeDirectory, AI_STUDIO_FOLDER, recipe.id);

Expand All @@ -41,6 +40,9 @@ export class ApplicationManager {
id: 'checkout',
name: 'Checkout repository',
state: 'loading',
labels: {
'git': 'checkout',
},
}
taskUtil.setTask(checkoutTask);

Expand Down Expand Up @@ -125,6 +127,9 @@ export class ApplicationManager {
id: model.id,
state: 'loading',
name: `Downloading model ${model.name}`,
labels: {
"model-pulling": model.id,
}
});

await this.downloadModelMain(model.id, model.url, taskUtil)
Expand Down Expand Up @@ -181,7 +186,7 @@ export class ApplicationManager {
}


downloadModelMain(modelId: string, url: string, taskUtil: TaskUtils, destFileName?: string): Promise<string> {
downloadModelMain(modelId: string, url: string, taskUtil: RecipeStatusUtils, destFileName?: string): Promise<string> {
return new Promise((resolve, reject) => {
const downloadCallback = (result: DownloadModelResult) => {
if (result.result) {
Expand All @@ -193,11 +198,17 @@ export class ApplicationManager {
}
}

if(fs.existsSync(destFileName)) {
taskUtil.setTaskState(modelId, 'success');
taskUtil.setTaskProgress(modelId, 100);
return;
}

this.downloadModel(modelId, url, taskUtil, downloadCallback, destFileName)
})
}

downloadModel(modelId: string, url: string, taskUtil: TaskUtils, callback: (message: DownloadModelResult) => void, destFileName?: string) {
private downloadModel(modelId: string, url: string, taskUtil: RecipeStatusUtils, callback: (message: DownloadModelResult) => void, destFileName?: string) {
const destDir = path.join(this.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId);
if (!fs.existsSync(destDir)) {
fs.mkdirSync(destDir, { recursive: true });
Expand Down
7 changes: 7 additions & 0 deletions packages/backend/src/registries/RecipeStatusRegistry.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import { RecipeStatus } from '@shared/models/IRecipeStatus';
import { TaskRegistry } from './TaskRegistry';

export class RecipeStatusRegistry {
private statuses: Map<string, RecipeStatus> = new Map<string, RecipeStatus>();

constructor(private taskRegistry: TaskRegistry) { }

setStatus(recipeId: string, status: RecipeStatus) {
// Update the TaskRegistry
if(status.tasks && status.tasks.length > 0) {
status.tasks.map((task) => this.taskRegistry.set(task));
}
this.statuses.set(recipeId, status);
}

Expand Down
18 changes: 18 additions & 0 deletions packages/backend/src/registries/TaskRegistry.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { Task } from '@shared/models/ITask';

export class TaskRegistry {
private tasks: Map<string, Task> = new Map<string, Task>();

getTasksByLabel(label: string): Task[] {
return Array.from(this.tasks.values())
.filter((task) => label in (task.labels || {}))
}

set(task: Task) {
this.tasks.set(task.id, task);
}

delete(taskId: string) {
this.tasks.delete(taskId);
}
}
17 changes: 13 additions & 4 deletions packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@ import { AI_STUDIO_FOLDER, ApplicationManager } from './managers/applicationMana
import { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import { RecipeStatus } from '@shared/models/IRecipeStatus';
import { ModelInfo } from '@shared/models/IModelInfo';
import { TaskRegistry } from './registries/TaskRegistry';
import { Task } from '@shared/models/ITask';
import { Studio } from './studio';
import * as path from 'node:path';
import { ModelResponse } from '@shared/models/IModelResponse';
import { PlayGroundManager } from './playground';

export const RECENT_CATEGORY_ID = 'recent-category';

export class StudioApiImpl implements StudioAPI {
constructor(
private applicationManager: ApplicationManager,
private recipeStatusRegistry: RecipeStatusRegistry,
private studio: Studio,
private taskRegistry: TaskRegistry,
private playgroundManager: PlayGroundManager,
) {}

async openURL(url: string): Promise<void> {
Expand Down Expand Up @@ -82,21 +86,26 @@ export class StudioApiImpl implements StudioAPI {
return content.recipes.flatMap(r => r.models.filter(m => localIds.includes(m.id)));
}

async getTasksByLabel(label: string): Promise<Task[]> {
return this.taskRegistry.getTasksByLabel(label);
}

async startPlayground(modelId: string): Promise<void> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
const destDir = path.join();

const modelPath = path.resolve(this.applicationManager.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId, localModelInfo[0].file);
this.studio.playgroundManager.startPlayground(modelId, modelPath);

await this.playgroundManager.startPlayground(modelId, modelPath);
}

askPlayground(modelId: string, prompt: string): Promise<ModelResponse> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
return this.studio.playgroundManager.askPlayground(localModelInfo[0], prompt);
return this.playgroundManager.askPlayground(localModelInfo[0], prompt);
}
}
7 changes: 5 additions & 2 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import * as fs from 'node:fs';
import * as https from 'node:https';
import * as path from 'node:path';
import type { LocalModelInfo } from '@shared/models/ILocalModelInfo';
import { TaskRegistry } from './registries/TaskRegistry';
import { PlayGroundManager } from './playground';

export class Studio {
Expand Down Expand Up @@ -92,7 +93,8 @@ export class Studio {
// Let's create the api that the front will be able to call
this.rpcExtension = new RpcExtension(this.#panel.webview);
const gitManager = new GitManager();
const recipeStatusRegistry = new RecipeStatusRegistry();
const taskRegistry = new TaskRegistry();
const recipeStatusRegistry = new RecipeStatusRegistry(taskRegistry);
const applicationManager = new ApplicationManager(
gitManager,
recipeStatusRegistry,
Expand All @@ -101,7 +103,8 @@ export class Studio {
this.studioApi = new StudioApiImpl(
applicationManager,
recipeStatusRegistry,
this,
taskRegistry,
this.playgroundManager,
);
// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { Task, TaskState } from '@shared/models/ITask';
import { RecipeStatusRegistry } from '../registries/RecipeStatusRegistry';


export class TaskUtils {
export class RecipeStatusUtils {
private tasks: Map<string, Task> = new Map<string, Task>();
private state: RecipeStatusState = 'loading';

Expand Down
123 changes: 95 additions & 28 deletions packages/frontend/src/pages/Models.svelte
Original file line number Diff line number Diff line change
@@ -1,39 +1,106 @@
<script lang="ts">
import type { ModelInfo } from '@shared/models/IModelInfo';
import NavPage from '../lib/NavPage.svelte';
import Table from '../lib/table/Table.svelte';
import { Column, Row } from '../lib/table/table';
import { localModels } from '../stores/local-models';
import ModelColumnName from './ModelColumnName.svelte';
import ModelColumnRegistry from './ModelColumnRegistry.svelte';
import ModelColumnPopularity from './ModelColumnPopularity.svelte';
import ModelColumnLicense from './ModelColumnLicense.svelte';
import ModelColumnHw from './ModelColumnHW.svelte';
const columns: Column<ModelInfo>[] = [
new Column<ModelInfo>('Name', { width: '4fr', renderer: ModelColumnName }),
new Column<ModelInfo>('HW Compat', { width: '1fr', renderer: ModelColumnHw }),
new Column<ModelInfo>('Registry', { width: '1fr', renderer: ModelColumnRegistry }),
new Column<ModelInfo>('Popularity', { width: '1fr', renderer: ModelColumnPopularity }),
new Column<ModelInfo>('License', { width: '1fr', renderer: ModelColumnLicense }),
];
const row = new Row<ModelInfo>({});
import type { ModelInfo } from '@shared/models/IModelInfo';
import NavPage from '../lib/NavPage.svelte';
import Table from '../lib/table/Table.svelte';
import { Column, Row } from '../lib/table/table';
import { localModels } from '../stores/local-models';
import ModelColumnName from './ModelColumnName.svelte';
import ModelColumnRegistry from './ModelColumnRegistry.svelte';
import ModelColumnPopularity from './ModelColumnPopularity.svelte';
import ModelColumnLicense from './ModelColumnLicense.svelte';
import ModelColumnHw from './ModelColumnHW.svelte';
import { onDestroy, onMount } from 'svelte';
import { studioClient } from '/@/utils/client';
import type { Category } from '@shared/models/ICategory';
import type { Task } from '@shared/models/ITask';
import TasksProgress from '/@/lib/progress/TasksProgress.svelte';
import { faRefresh } from '@fortawesome/free-solid-svg-icons';
import Card from '/@/lib/Card.svelte';
import Button from '/@/lib/button/Button.svelte';
import LinearProgress from '/@/lib/progress/LinearProgress.svelte';
const columns: Column<ModelInfo>[] = [
new Column<ModelInfo>('Name', { width: '4fr', renderer: ModelColumnName }),
new Column<ModelInfo>('HW Compat', { width: '1fr', renderer: ModelColumnHw }),
new Column<ModelInfo>('Registry', { width: '1fr', renderer: ModelColumnRegistry }),
new Column<ModelInfo>('Popularity', { width: '1fr', renderer: ModelColumnPopularity }),
new Column<ModelInfo>('License', { width: '1fr', renderer: ModelColumnLicense }),
];
const row = new Row<ModelInfo>({});
let loading: boolean = true;
let intervalId: ReturnType<typeof setInterval> | undefined = undefined;
let tasks: Task[] = [];
let models: ModelInfo[] = [];
let filteredModels: ModelInfo[] = [];
function filterModels(): void {
// Let's collect the models we do not want to show (loading, error).
const modelsId: string[] = tasks.reduce((previousValue, currentValue) => {
if(currentValue.state === 'success')
return previousValue;
if(currentValue.labels !== undefined) {
previousValue.push(currentValue.labels["model-pulling"]);
}
return previousValue;
}, [] as string[]);
filteredModels = models.filter((model) => !(model.id in modelsId));
}
onMount(() => {
// Pulling update
intervalId = setInterval(async () => {
tasks = await studioClient.getTasksByLabel("model-pulling");
loading = false;
filterModels();
}, 1000);
// Subscribe to the models store
return localModels.subscribe((value) => {
models = value;
filterModels();
})
});
onDestroy(() => {
if(intervalId !== undefined) {
clearInterval(intervalId);
intervalId = undefined;
}
});
</script>

<NavPage title="Models on disk">
<div slot="content" class="flex flex-col min-w-full min-h-full">
<div class="min-w-full min-h-full flex-1">
{#if loading}
<LinearProgress/>
{/if}
<div class="mt-4 px-5 space-y-5 h-full">
{#if $localModels && $localModels.length}
<Table
kind="model"
data="{$localModels}"
columns="{columns}"
row={row}>
</Table>
{:else}
<div>There is no model yet</div>
{#if !loading}
{#if tasks.length > 0}
<div class="mx-4">
<Card classes="bg-charcoal-800 mt-4">
<div slot="content" class="text-base font-normal p-2">
<div class="text-base mb-2">Downloading models</div>
<TasksProgress tasks="{tasks}"/>
</div>
</Card>
</div>
{/if}
{#if filteredModels.length > 0}
<Table
kind="model"
data="{filteredModels}"
columns="{columns}"
row={row}>
</Table>
{:else}
<div>There is no model yet</div>
{/if}
{/if}
</div>
</div>
Expand Down
2 changes: 1 addition & 1 deletion packages/frontend/src/pages/Recipe.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ const onClickRepository = () => {
<MarkdownRenderer source="{recipe?.readme}"/>
</div>
<!-- Right column -->
<div class="border-l border-l-charcoal-400 px-5 min-w-80">
<div class="border-l border-l-charcoal-400 px-5 max-w-80 min-w-80">
<Card classes="bg-charcoal-800 mt-5">
<div slot="content" class="text-base font-normal p-2">
<div class="text-base mb-2">Repository</div>
Expand Down
9 changes: 8 additions & 1 deletion packages/shared/StudioAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { Category } from '@shared/models/ICategory';
import { RecipeStatus } from '@shared/models/IRecipeStatus';
import { ModelInfo } from '@shared/models/IModelInfo';
import { ModelResponse } from '@shared/models/IModelResponse';
import { Task } from './models/ITask';

export abstract class StudioAPI {
abstract ping(): Promise<string>;
Expand All @@ -16,11 +17,17 @@ export abstract class StudioAPI {
abstract pullApplication(recipeId: string): Promise<void>;
abstract openURL(url: string): Promise<void>;
/**
* Get the information of models saved locally into the extension's storage directory
* Get the information of models saved locally into the extension's storage directory
*/
abstract getLocalModels(): Promise<ModelInfo[]>;

abstract startPlayground(modelId: string): Promise<void>;
abstract askPlayground(modelId: string, prompt: string): Promise<ModelResponse>;

/**
* Get task by label
* @param label
*/
abstract getTasksByLabel(label: string): Promise<Task[]>;
}

5 changes: 3 additions & 2 deletions packages/shared/models/ITask.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
export type TaskState = 'loading' | 'error' | 'success'

export interface Task {
id: string,
id: string;
state: TaskState;
progress?: number
progress?: number;
name: string;
labels?: {[id: string]: string}
}

0 comments on commit faab64f

Please sign in to comment.