From 61dd1d867efd1be539c5a2287cbe0bac45d293d4 Mon Sep 17 00:00:00 2001 From: titaniumcloudwalk Date: Thu, 7 Nov 2024 15:07:30 -0300 Subject: [PATCH 1/2] Add LM Studio support. self --- src/complete/completers.sass | 1 + src/complete/completers.ts | 4 + .../completers/lmstudio/lmstudio.sass | 4 + src/complete/completers/lmstudio/lmstudio.tsx | 191 ++++++++++++++++++ .../completers/lmstudio/model_settings.tsx | 126 ++++++++++++ .../completers/lmstudio/provider_settings.tsx | 48 +++++ 6 files changed, 374 insertions(+) create mode 100644 src/complete/completers/lmstudio/lmstudio.sass create mode 100644 src/complete/completers/lmstudio/lmstudio.tsx create mode 100644 src/complete/completers/lmstudio/model_settings.tsx create mode 100644 src/complete/completers/lmstudio/provider_settings.tsx diff --git a/src/complete/completers.sass b/src/complete/completers.sass index a1c9d7c..b8151ec 100644 --- a/src/complete/completers.sass +++ b/src/complete/completers.sass @@ -1,4 +1,5 @@ @import "completers/chatgpt/chatgpt.sass" @import "completers/ai21/ai21.sass" @import "completers/ollama/ollama.sass" +@import "completers/lmstudio/lmstudio.sass" @import "completers/groq/groq.sass" diff --git a/src/complete/completers.ts b/src/complete/completers.ts index 9a23e06..a5f0951 100644 --- a/src/complete/completers.ts +++ b/src/complete/completers.ts @@ -6,6 +6,7 @@ import { GooseAIComplete } from "./completers/gooseai/gooseai"; import { OobaboogaComplete } from "./completers/oobabooga/oobabooga"; import { OllamaComplete } from "./completers/ollama/ollama"; import { GroqComplete } from "./completers/groq/groq"; +import { LMStudioComplete } from "./completers/lmstudio/lmstudio"; export const available: Completer[] = [ new ChatGPTComplete(), @@ -15,4 +16,7 @@ export const available: Completer[] = [ new OobaboogaComplete(), new OllamaComplete(), new GroqComplete(), + new LMStudioComplete(), ]; + +console.log("Available completers:", available.map(c => c.id)); diff --git a/src/complete/completers/lmstudio/lmstudio.sass b/src/complete/completers/lmstudio/lmstudio.sass new file mode 100644 index 0000000..9c5f8d7 --- /dev/null +++ b/src/complete/completers/lmstudio/lmstudio.sass @@ -0,0 +1,4 @@ +.ai-complete-lmstudio-full-width + width: 100% + min-height: 120px + resize: none diff --git a/src/complete/completers/lmstudio/lmstudio.tsx b/src/complete/completers/lmstudio/lmstudio.tsx new file mode 100644 index 0000000..d385547 --- /dev/null +++ b/src/complete/completers/lmstudio/lmstudio.tsx @@ -0,0 +1,191 @@ +import { Completer, Model, Prompt } from "../../complete"; +import { + SettingsUI as ProviderSettingsUI, + Settings as ProviderSettings, + parse_settings as parse_provider_settings, +} from "./provider_settings"; +import { + SettingsUI as ModelSettingsUI, + parse_settings as parse_model_settings, + Settings as ModelSettings, +} from "./model_settings"; +import OpenAI from "openai"; +import { Notice } from "obsidian"; +import Mustache from "mustache"; + +export default class LMStudioModel implements Model { + id: string; + name: string; + description: string; + rate_limit_notice: Notice | null = null; + rate_limit_notice_timeout: number | null = null; + Settings = ModelSettingsUI; + + provider_settings: ProviderSettings; + + constructor( + provider_settings: string, + id: string, + name: string, + description: string + ) { + this.id = id; + this.name = name; + this.description = description; + this.provider_settings = parse_provider_settings(provider_settings); + } + + get_api() { + return new OpenAI({ + baseURL: this.provider_settings.endpoint + "/v1", + apiKey: "lm-studio", // LM Studio doesn't require a real API key + dangerouslyAllowBrowser: true, + }); + } + + async prepare( + prompt: Prompt, + settings: ModelSettings + ): Promise<{ + prefix: string; + suffix: string; + last_line: string; + context: string; + }> { + const cropped = { + prefix: prompt.prefix.slice(-(settings.prompt_length || 6000)), + suffix: prompt.suffix.slice(0, settings.prompt_length || 6000), + }; + const last_line = cropped.prefix + .split("\n") + .filter((x) => x.length > 0) + .pop(); + return { + ...cropped, + last_line: last_line || "", + context: cropped.prefix + .split("\n") + .filter((x) => x !== last_line) + .join("\n"), + }; + } + + async complete(prompt: Prompt, settings: string): Promise { + const model_settings = parse_model_settings(settings); + + try { + const response = await this.get_api().chat.completions.create({ + model: this.id, + messages: [ + { + role: "system", + content: model_settings.system_prompt, + }, + { + role: "user", + content: Mustache.render( + model_settings.user_prompt, + await this.prepare(prompt, model_settings) + ), + }, + ], + temperature: model_settings.temperature, + max_tokens: model_settings.max_tokens, + }); + + return this.interpret( + prompt, + response.choices[0]?.message?.content || "" + ); + } catch (e) { + throw new Error(`LM Studio API error: ${e.message}`); + } + } + + async *iterate(prompt: Prompt, settings: string): AsyncGenerator { + const model_settings = parse_model_settings(settings); + + try { + const completion = await this.get_api().chat.completions.create({ + model: this.id, + messages: [ + { + role: "system", + content: model_settings.system_prompt, + }, + { + role: "user", + content: Mustache.render( + model_settings.user_prompt, + await this.prepare(prompt, model_settings) + ), + }, + ], + temperature: model_settings.temperature, + max_tokens: model_settings.max_tokens, + stream: true, + }); + + let initialized = false; + for await (const chunk of completion) { + const token = chunk.choices[0]?.delta?.content || ""; + if (!initialized) { + yield this.interpret(prompt, token); + initialized = true; + } else { + yield token; + } + } + } catch (e) { + throw new Error(`LM Studio API error: ${e.message}`); + } + } + + interpret(prompt: Prompt, completion: string) { + const response_punctuation = " \n.,?!:;"; + const prompt_punctuation = " \n"; + + if ( + prompt.prefix.length !== 0 && + !prompt_punctuation.includes( + prompt.prefix[prompt.prefix.length - 1] + ) && + !response_punctuation.includes(completion[0]) + ) { + completion = " " + completion; + } + + return completion; + } +} + +export class LMStudioComplete implements Completer { + id: string = "lmstudio"; + name: string = "LM Studio"; + description: string = "Local LM Studio server for running local models"; + + async get_models(settings: string) { + const provider_settings = parse_provider_settings(settings); + const api = new OpenAI({ + baseURL: provider_settings.endpoint + "/v1", + apiKey: "lm-studio", + dangerouslyAllowBrowser: true, + }); + + try { + const models = await api.models.list(); + return models.data.map((model: any) => { + return new LMStudioModel( + settings, + model.id, + model.id, + `LM Studio model: ${model.id}` + ); + }); + } catch (e) { + throw new Error(`Failed to fetch LM Studio models: ${e.message}`); + } + } + + Settings = ProviderSettingsUI; +} diff --git a/src/complete/completers/lmstudio/model_settings.tsx b/src/complete/completers/lmstudio/model_settings.tsx new file mode 100644 index 0000000..b9bdaa3 --- /dev/null +++ b/src/complete/completers/lmstudio/model_settings.tsx @@ -0,0 +1,126 @@ +import * as React from "react"; +import SettingsItem from "../../../components/SettingsItem"; +import { z } from "zod"; + +export const settings_schema = z.object({ + system_prompt: z.string(), + user_prompt: z.string(), + temperature: z.number().optional(), + max_tokens: z.number().optional(), + prompt_length: z.number().optional(), +}); + +export type Settings = z.infer; + +const default_settings: Settings = { + system_prompt: "", + user_prompt: '{{#context}}Context:\n\n{{context}}\n\n=================================\n{{/context}}Do not start with "...". Continue the following paragraph:\n\n{{last_line}}', + max_tokens: 100, +}; + +export const parse_settings = (data: string | null): Settings => { + if (data == null) { + return default_settings; + } + try { + const settings: unknown = JSON.parse(data); + return settings_schema.parse(settings); + } catch (e) { + return default_settings; + } +}; + +export function SettingsUI({ + settings, + saveSettings, +}: { + settings: string | null; + saveSettings: (settings: string) => void; +}) { + const parsed_settings = parse_settings(settings); + + return ( + <> + +