Skip to content

Commit

Permalink
Merge pull request #16 from projectatomic/feat/model-playground
Browse files Browse the repository at this point in the history
feat: model playground
  • Loading branch information
feloy authored Jan 12, 2024
2 parents 20288d7 + bbaaf51 commit fe7ad06
Show file tree
Hide file tree
Showing 13 changed files with 239 additions and 9 deletions.
21 changes: 21 additions & 0 deletions packages/backend/src/ai.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,32 @@
{
"id": "llama-2-7b-chat.Q5_K_S",
"name": "Llama-2-7B-Chat-GGUF",
"description": "Llama 2 is a family of state-of-the-art open-access large language models released by Meta today, and we’re excited to fully support the launch with comprehensive integration in Hugging Face. Llama 2 is being released with a very permissive community license and is available for commercial use. The code, pretrained models, and fine-tuned models are all being released today 🔥",
"hw": "CPU",
"registry": "Hugging Face",
"popularity": 3,
"license": "?",
"url": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q5_K_S.gguf"
},
{
"id": "albedobase-xl-1.3",
"name": "AlbedoBase XL 1.3",
"description": "Stable Diffusion XL has 6.6 billion parameters, which is about 6.6 times more than the SD v1.5 version. I believe that this is not just a number, but a number that can lead to a significant improvement in performance. It has been a while since we realized that the overall performance of SD v1.5 has improved beyond imagination thanks to the explosive contributions of our community. Therefore, I am working on completing this AlbedoBase XL model in order to optimally reproduce the performance improvement that occurred in v1.5 in this XL version as well. My goal is to directly test the performance of all Checkpoints and LoRAs that are publicly uploaded to Civitai, and merge only the resources that are judged to be optimal after passing through several filters. This will surpass the performance of image-generating AI of companies such as Midjourney. As of now, AlbedoBase XL v0.4 has merged exactly 55 selected checkpoints and 138 LoRAs.",
"hw": "CPU",
"registry": "Civital",
"popularity": 3,
"license": "openrail++",
"url": ""
},
{
"id": "sdxl-turbo",
"name": "SDXL Turbo",
"description": "SDXL Turbo achieves state-of-the-art performance with a new distillation technology, enabling single-step image generation with unprecedented quality, reducing the required step count from 50 to just one.",
"hw": "CPU",
"registry": "Hugging Face",
"popularity": 3,
"license": "sai-c-community",
"url": ""
}
]
}
Expand Down
2 changes: 1 addition & 1 deletion packages/backend/src/managers/applicationManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ interface DownloadModelResult {
}

export class ApplicationManager {
private readonly homeDirectory: string; // todo: make configurable
readonly homeDirectory: string; // todo: make configurable

constructor(private git: GitManager, private recipeStatusRegistry: RecipeStatusRegistry, private extensionContext: ExtensionContext) {
this.homeDirectory = os.homedir();
Expand Down
54 changes: 52 additions & 2 deletions packages/backend/src/playground.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { provider, containerEngine, type ProviderContainerConnection, type ImageInfo } from '@podman-desktop/api';
import { LocalModelInfo } from '@shared/models/ILocalModelInfo';
import { ModelResponse } from '@shared/models/IModelResponse';

import path from 'node:path';
import * as http from 'node:http';

const LOCALAI_IMAGE = 'quay.io/go-skynet/local-ai:v2.5.1';

Expand All @@ -13,7 +17,7 @@ function findFirstProvider(): ProviderContainerConnection | undefined {

export class PlayGroundManager {
async selectImage(connection: ProviderContainerConnection, image: string): Promise<ImageInfo | undefined> {
const images = (await containerEngine.listImages()).filter(im => im.RepoTags.some(tag => tag === image));
const images = (await containerEngine.listImages()).filter(im => im.RepoTags && im.RepoTags.some(tag => tag === image));
return images.length > 0 ? images[0] : undefined;
}

Expand All @@ -34,7 +38,7 @@ export class PlayGroundManager {
const result = await containerEngine.createContainer(image.engineId, {
Image: image.Id,
Detach: true,
ExposedPorts: { '9000': '8080' },
ExposedPorts: { '9000': {} },
HostConfig: {
AutoRemove: true,
Mounts: [
Expand All @@ -44,6 +48,13 @@ export class PlayGroundManager {
Type: 'bind',
},
],
PortBindings: {
'8080/tcp': [
{
HostPort: '9000'
}
]
}
},
Cmd: ['--models-path', '/models', '--context-size', '700', '--threads', '4'],
});
Expand All @@ -57,4 +68,43 @@ export class PlayGroundManager {
}
return containerEngine.stopContainer(connection.providerId, playgroundId);
}

async askPlayground(modelInfo: LocalModelInfo, prompt: string): Promise<ModelResponse> {
return new Promise(resolve => {
let post_data = JSON.stringify({
"model": modelInfo.file,
"prompt": prompt,
"temperature": 0.7
});

let post_options: http.RequestOptions = {
host: 'localhost',
port: '9000',
path: '/v1/completions',
method: 'POST',
headers: {
'Content-Type': 'application/json'
}
};

let post_req = http.request(post_options, function (res) {
res.setEncoding('utf8');
const chunks = [];
res.on('data', (data) => chunks.push(data));
res.on('end', () => {
let resBody = chunks.join();
switch (res.headers['content-type']) {
case 'application/json':
const result = JSON.parse(resBody);
console.log('result', result);
resolve(result as ModelResponse);
break;
}
});
});
// post the data
post_req.write(post_data);
post_req.end();
});
}
}
30 changes: 29 additions & 1 deletion packages/backend/src/studio-api-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ import type { StudioAPI } from '@shared/StudioAPI';
import { Category } from '@shared/models/ICategory';
import { Recipe } from '@shared/models/IRecipe';
import content from './ai.json';
import { ApplicationManager } from './managers/applicationManager';
import { AI_STUDIO_FOLDER, ApplicationManager } from './managers/applicationManager';
import { RecipeStatusRegistry } from './registries/RecipeStatusRegistry';
import { RecipeStatus } from '@shared/models/IRecipeStatus';
import { ModelInfo } from '@shared/models/IModelInfo';
import { Studio } from './studio';
import * as path from 'node:path';
import { ModelResponse } from '@shared/models/IModelResponse';

export const RECENT_CATEGORY_ID = 'recent-category';

export class StudioApiImpl implements StudioAPI {
constructor(
private applicationManager: ApplicationManager,
private recipeStatusRegistry: RecipeStatusRegistry,
private studio: Studio,
) {}

async openURL(url: string): Promise<void> {
Expand Down Expand Up @@ -48,6 +52,13 @@ export class StudioApiImpl implements StudioAPI {
throw new Error('Not found');
}

async getModelById(modelId: string): Promise<ModelInfo> {
const model = content.recipes.flatMap(r => (r.models as ModelInfo[]).filter(m => modelId === m.id));
if (model.length === 1) return model[0];
if (model.length === 0) throw new Error('Not found');
throw new Error('several models with same id');
}

async searchRecipes(query: string): Promise<Recipe[]> {
return []; // todo: not implemented
}
Expand All @@ -71,4 +82,21 @@ export class StudioApiImpl implements StudioAPI {
return content.recipes.flatMap(r => r.models.filter(m => localIds.includes(m.id)));
}

async startPlayground(modelId: string): Promise<void> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
const destDir = path.join();
const modelPath = path.resolve(this.applicationManager.homeDirectory, AI_STUDIO_FOLDER, 'models', modelId, localModelInfo[0].file);
this.studio.playgroundManager.startPlayground(modelId, modelPath);
}

askPlayground(modelId: string, prompt: string): Promise<ModelResponse> {
const localModelInfo = this.applicationManager.getLocalModels().filter(m => m.id === modelId);
if (localModelInfo.length !== 1) {
throw new Error('model not found');
}
return this.studio.playgroundManager.askPlayground(localModelInfo[0], prompt);
}
}
4 changes: 4 additions & 0 deletions packages/backend/src/studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import * as fs from 'node:fs';
import * as https from 'node:https';
import * as path from 'node:path';
import type { LocalModelInfo } from '@shared/models/ILocalModelInfo';
import { PlayGroundManager } from './playground';

export class Studio {
readonly #extensionContext: ExtensionContext;
Expand All @@ -36,9 +37,11 @@ export class Studio {

rpcExtension: RpcExtension;
studioApi: StudioApiImpl;
playgroundManager: PlayGroundManager;

constructor(readonly extensionContext: ExtensionContext) {
this.#extensionContext = extensionContext;
this.playgroundManager = new PlayGroundManager();
}

public async activate(): Promise<void> {
Expand Down Expand Up @@ -98,6 +101,7 @@ export class Studio {
this.studioApi = new StudioApiImpl(
applicationManager,
recipeStatusRegistry,
this,
);
// Register the instance
this.rpcExtension.registerInstance<StudioApiImpl>(StudioApiImpl, this.studioApi);
Expand Down
5 changes: 5 additions & 0 deletions packages/frontend/src/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Preferences from '/@/pages/Preferences.svelte';
import Registries from '/@/pages/Registries.svelte';
import Models from '/@/pages/Models.svelte';
import Recipe from '/@/pages/Recipe.svelte';
import Model from './pages/Model.svelte';
router.mode.hash();
</script>
Expand Down Expand Up @@ -54,6 +55,10 @@ router.mode.hash();
<Route path="/recipes/:id/*" breadcrumb="Recipe Details" let:meta>
<Recipe recipeId="{meta.params.id}"/>
</Route>

<Route path="/models/:id/*" breadcrumb="Model Details" let:meta>
<Model modelId="{meta.params.id}"/>
</Route>
</div>
</main>
</Route>
37 changes: 37 additions & 0 deletions packages/frontend/src/pages/Model.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<script lang="ts">
import NavPage from '/@/lib/NavPage.svelte';
import Tab from '/@/lib/Tab.svelte';
import Route from '/@/Route.svelte';
import MarkdownRenderer from '/@/lib/MarkdownRenderer.svelte';
import type { ModelInfo } from '@shared/models/IModelInfo';
import { studioClient } from '../utils/client';
import { onMount } from 'svelte';
import ModelPlayground from './ModelPlayground.svelte';
export let modelId: string;
let model: ModelInfo | undefined = undefined;
onMount(async () => {
model = await studioClient.getModelById(modelId);
})
</script>

<NavPage title="{model?.name || ''}">
<svelte:fragment slot="tabs">
<Tab title="Summary" url="{modelId}" />
<Tab title="Playground" url="{modelId}/playground" />
</svelte:fragment>
<svelte:fragment slot="content">
<Route path="/" breadcrumb="Summary" >
<div class="flex flex-row w-full">
<div class="flex-grow p-5">
<MarkdownRenderer source="{model?.description}"/>
</div>
</div>
</Route>
<Route path="/playground" breadcrumb="Playground">
<ModelPlayground model={model} />
</Route>

</svelte:fragment>
</NavPage>
9 changes: 7 additions & 2 deletions packages/frontend/src/pages/ModelColumnName.svelte
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
<script lang="ts">
import type { ModelInfo } from "@shared/models/IModelInfo";
import { router } from "tinro";
export let object: ModelInfo;
function openDetails() {
router.goto(`/models/${object.id}`);
}
</script>

<div class="text-sm text-gray-700">
<button class="text-sm text-gray-700" on:click="{() => openDetails()}">
{object.name}
</div>
</button>
55 changes: 55 additions & 0 deletions packages/frontend/src/pages/ModelPlayground.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<script lang="ts">
import type { ModelInfo } from '@shared/models/IModelInfo';
import type { ModelResponseChoice } from '@shared/models/IModelResponse';
import Button from '../lib/button/Button.svelte';
import { onMount } from 'svelte';
import { studioClient } from '../utils/client';
export let model: ModelInfo | undefined;
let prompt = '';
let result: ModelResponseChoice | undefined = undefined;
let inProgress = false;
onMount(() => {
if (!model) {
return;
}
studioClient.startPlayground(model.id);
});
async function askPlayground() {
if (!model) {
return;
}
inProgress = true;
result = undefined;
const res = await studioClient.askPlayground(model.id, prompt)
inProgress = false;
if (res.choices.length) {
result = res.choices[0];
}
}
</script>

<div class="m-4 w-full flew flex-col">
<div class="mb-2">Prompt</div>
<textarea
bind:value={prompt}
rows="4"
class="w-full p-2 outline-none text-sm bg-charcoal-800 rounded-sm text-gray-700 placeholder-gray-700"
placeholder="Type your prompt here"></textarea>

<div class="mt-4 text-right">
<Button inProgress={inProgress} on:click={() => askPlayground()}>Send Request</Button>
</div>

{#if result}
<div class="mt-4 mb-2">Output</div>
<textarea
readonly
disabled
rows="20"
bind:value={result.text}
class="w-full p-2 outline-none text-sm bg-charcoal-800 rounded-sm text-gray-700 placeholder-gray-700"></textarea>
{/if}
</div>
2 changes: 1 addition & 1 deletion packages/shared/MessageProxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ export class RpcBrowser {
return;
reject(new Error('Timeout'));
this.promises.delete(requestId);
}, 10000);
}, 10000000);

// Create a Promise
return new Promise((resolve, reject) => {
Expand Down
8 changes: 6 additions & 2 deletions packages/shared/StudioAPI.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import type { Recipe } from '@shared/models/IRecipe';
import type { Category } from '@shared/models/ICategory';
import { RecipeStatus } from '@shared/models/IRecipeStatus';
import { Task } from '@shared/models/ITask';
import { ModelInfo } from './models/IModelInfo';
import { ModelInfo } from '@shared/models/IModelInfo';
import { ModelResponse } from '@shared/models/IModelResponse';

export abstract class StudioAPI {
abstract ping(): Promise<string>;
abstract getRecentRecipes(): Promise<Recipe[]>;
abstract getCategories(): Promise<Category[]>;
abstract getRecipesByCategory(categoryId: string): Promise<Recipe[]>;
abstract getRecipeById(recipeId: string): Promise<Recipe>;
abstract getModelById(modelId: string): Promise<ModelInfo>;
abstract searchRecipes(query: string): Promise<Recipe[]>;
abstract getPullingStatus(recipeId: string): Promise<RecipeStatus>
abstract pullApplication(recipeId: string): Promise<void>;
Expand All @@ -18,5 +19,8 @@ export abstract class StudioAPI {
* Get the information of models saved locally into the extension's storage directory
*/
abstract getLocalModels(): Promise<ModelInfo[]>;

abstract startPlayground(modelId: string): Promise<void>;
abstract askPlayground(modelId: string, prompt: string): Promise<ModelResponse>;
}

1 change: 1 addition & 0 deletions packages/shared/models/IModelInfo.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export interface ModelInfo {
id: string;
name: string;
description: string;
hw:string;
registry: string;
popularity: number;
Expand Down
Loading

0 comments on commit fe7ad06

Please sign in to comment.