From 5a6d88432f705d38098e7199e091271c02f4f555 Mon Sep 17 00:00:00 2001 From: James Date: Thu, 19 Oct 2023 17:11:19 -0700 Subject: [PATCH] update the temperature progress bar Signed-off-by: James --- plugins/data-plugin/index.ts | 200 +++++++++++++++--- plugins/data-plugin/package.json | 2 +- plugins/inference-plugin/index.ts | 39 ++-- web/app/_components/BotInfo/index.tsx | 18 +- web/app/_components/BotSetting/index.tsx | 108 +++++++++- .../_components/CreateBotContainer/index.tsx | 21 +- .../_components/CreateBotInAdvance/index.tsx | 15 +- .../CustomBotTemperature/index.tsx | 2 +- .../DraggableProgressBar/index.tsx | 13 +- web/app/_components/RightContainer/index.tsx | 2 +- web/app/_hooks/useUpdateBot.ts | 3 + web/app/_models/Bot.ts | 6 + 12 files changed, 356 insertions(+), 73 deletions(-) diff --git a/plugins/data-plugin/index.ts b/plugins/data-plugin/index.ts index d4fb29e42c..4dbd95d35c 100644 --- a/plugins/data-plugin/index.ts +++ b/plugins/data-plugin/index.ts @@ -1,4 +1,11 @@ -import { core, store, RegisterExtensionPoint, StoreService, DataService, PluginService } from "@janhq/core"; +import { + core, + store, + RegisterExtensionPoint, + StoreService, + DataService, + PluginService, +} from "@janhq/core"; /** * Create a collection on data store @@ -8,7 +15,13 @@ import { core, store, RegisterExtensionPoint, StoreService, DataService, PluginS * @returns Promise * */ -function createCollection({ name, schema }: { name: string; schema?: { [key: string]: any } }): Promise { +function createCollection({ + name, + schema, +}: { + name: string; + schema?: { [key: string]: any }; +}): Promise { return core.invokePluginFunc(MODULE_PATH, "createCollection", name, schema); } @@ -31,7 +44,13 @@ function deleteCollection(name: string): Promise { * @returns Promise * */ -function insertOne({ collectionName, value }: { collectionName: string; value: any }): Promise { +function insertOne({ + collectionName, + value, +}: { + collectionName: string; + value: any; +}): Promise { return core.invokePluginFunc(MODULE_PATH, "insertOne", collectionName, value); } @@ -44,8 +63,22 @@ function insertOne({ collectionName, value }: { collectionName: string; value: a * @returns Promise * */ -function updateOne({ collectionName, key, value }: { collectionName: string; key: string; value: any }): Promise { - return core.invokePluginFunc(MODULE_PATH, "updateOne", collectionName, key, value); +function updateOne({ + collectionName, + key, + value, +}: { + collectionName: string; + key: string; + value: any; +}): Promise { + return core.invokePluginFunc( + MODULE_PATH, + "updateOne", + collectionName, + key, + value + ); } /** @@ -64,7 +97,13 @@ function updateMany({ value: any; selector?: { [key: string]: any }; }): Promise { - return core.invokePluginFunc(MODULE_PATH, "updateMany", collectionName, value, selector); + return core.invokePluginFunc( + MODULE_PATH, + "updateMany", + collectionName, + value, + selector + ); } /** @@ -75,7 +114,13 @@ function updateMany({ * @returns Promise * */ -function deleteOne({ collectionName, key }: { collectionName: string; key: string }): Promise { +function deleteOne({ + collectionName, + key, +}: { + collectionName: string; + key: string; +}): Promise { return core.invokePluginFunc(MODULE_PATH, "deleteOne", collectionName, key); } @@ -94,7 +139,12 @@ function deleteMany({ collectionName: string; selector?: { [key: string]: any }; }): Promise { - return core.invokePluginFunc(MODULE_PATH, "deleteMany", collectionName, selector); + return core.invokePluginFunc( + MODULE_PATH, + "deleteMany", + collectionName, + selector + ); } /** @@ -103,7 +153,13 @@ function deleteMany({ * @param {string} key - The key of the record to retrieve. * @returns {Promise} A promise that resolves when the record is retrieved. */ -function findOne({ collectionName, key }: { collectionName: string; key: string }): Promise { +function findOne({ + collectionName, + key, +}: { + collectionName: string; + key: string; +}): Promise { return core.invokePluginFunc(MODULE_PATH, "findOne", collectionName, key); } @@ -123,7 +179,13 @@ function findMany({ selector: { [key: string]: any }; sort?: [{ [key: string]: any }]; }): Promise { - return core.invokePluginFunc(MODULE_PATH, "findMany", collectionName, selector, sort); + return core.invokePluginFunc( + MODULE_PATH, + "findMany", + collectionName, + selector, + sort + ); } function onStart() { @@ -135,8 +197,16 @@ function onStart() { // Register all the above functions and objects with the relevant extension points export function init({ register }: { register: RegisterExtensionPoint }) { register(PluginService.OnStart, PLUGIN_NAME, onStart); - register(StoreService.CreateCollection, createCollection.name, createCollection); - register(StoreService.DeleteCollection, deleteCollection.name, deleteCollection); + register( + StoreService.CreateCollection, + createCollection.name, + createCollection + ); + register( + StoreService.DeleteCollection, + deleteCollection.name, + deleteCollection + ); register(StoreService.InsertOne, insertOne.name, insertOne); register(StoreService.UpdateOne, updateOne.name, updateOne); register(StoreService.UpdateMany, updateMany.name, updateMany); @@ -145,15 +215,39 @@ export function init({ register }: { register: RegisterExtensionPoint }) { register(StoreService.FindOne, findOne.name, findOne); register(StoreService.FindMany, findMany.name, findMany); - register(DataService.GetConversations, getConversations.name, getConversations); - register(DataService.CreateConversation, createConversation.name, createConversation); - register(DataService.UpdateConversation, updateConversation.name, updateConversation); + register( + DataService.GetConversations, + getConversations.name, + getConversations + ); + register( + DataService.CreateConversation, + createConversation.name, + createConversation + ); + register( + DataService.UpdateConversation, + updateConversation.name, + updateConversation + ); register(DataService.UpdateMessage, updateMessage.name, updateMessage); - register(DataService.DeleteConversation, deleteConversation.name, deleteConversation); + register( + DataService.DeleteConversation, + deleteConversation.name, + deleteConversation + ); register(DataService.CreateMessage, createMessage.name, createMessage); - register(DataService.GetConversationMessages, getConversationMessages.name, getConversationMessages); + register( + DataService.GetConversationMessages, + getConversationMessages.name, + getConversationMessages + ); - register("getConversationById", getConversationById.name, getConversationById); + register( + "getConversationById", + getConversationById.name, + getConversationById + ); register("createBot", createBot.name, createBot); register("getBots", getBots.name, getBots); register("getBotById", getBotById.name, getBotById); @@ -186,29 +280,83 @@ function updateMessage(message: any): Promise { } function deleteConversation(id: any) { - return store.deleteOne("conversations", id).then(() => store.deleteMany("messages", { conversationId: id })); + return store + .deleteOne("conversations", id) + .then(() => store.deleteMany("messages", { conversationId: id })); } function getConversationMessages(conversationId: any) { - return store.findMany("messages", { conversationId }, [{ createdAt: "desc" }]); + return store.findMany("messages", { conversationId }, [ + { createdAt: "desc" }, + ]); } function createBot(bot: any): Promise { - return store.insertOne("bots", bot); + console.debug("Creating bot", JSON.stringify(bot, null, 2)); + return store + .insertOne("bots", bot) + .then(() => { + console.debug("Bot created", JSON.stringify(bot, null, 2)); + return Promise.resolve(); + }) + .catch((err) => { + console.error("Error creating bot", err); + return Promise.reject(err); + }); } function getBots(): Promise { - return store.findMany("bots", {}); + console.debug("Getting bots"); + return store + .findMany("bots", {name: { $gt: null }}) + .then((bots) => { + console.debug("Bots retrieved", JSON.stringify(bots, null, 2)); + return Promise.resolve(bots); + }) + .catch((err) => { + console.error("Error getting bots", err); + return Promise.reject(err); + }); } -function deleteBot(id: any): Promise { - return store.deleteOne("bots", id); +function deleteBot(id: string): Promise { + console.debug("Deleting bot", id); + return store + .deleteOne("bots", id) + .then(() => { + console.debug("Bot deleted", id); + return Promise.resolve(); + }) + .catch((err) => { + console.error("Error deleting bot", err); + return Promise.reject(err); + }); } function updateBot(bot: any): Promise { - return store.updateOne("bots", bot._id, bot); + console.debug("Updating bot", JSON.stringify(bot, null, 2)); + return store + .updateOne("bots", bot._id, bot) + .then(() => { + console.debug("Bot updated"); + return Promise.resolve(); + }) + .catch((err) => { + console.error("Error updating bot", err); + return Promise.reject(err); + }); } function getBotById(botId: string): Promise { - return store.findOne("bots", botId); + console.debug("Getting bot", botId); + return store + .findOne("bots", botId) + .then((bot) => { + console.debug("Bot retrieved", JSON.stringify(bot, null, 2)); + return Promise.resolve(bot); + }) + .catch((err) => { + console.error("Error getting bot", err); + return Promise.reject(err); + }); } diff --git a/plugins/data-plugin/package.json b/plugins/data-plugin/package.json index 5957d9aadd..07e5abdbeb 100644 --- a/plugins/data-plugin/package.json +++ b/plugins/data-plugin/package.json @@ -40,7 +40,7 @@ "node_modules" ], "dependencies": { - "@janhq/core": "file:../../core", // revert back this line to the original one + "@janhq/core": "file:../../core", "pouchdb-find": "^8.0.1", "pouchdb-node": "^8.0.1" }, diff --git a/plugins/inference-plugin/index.ts b/plugins/inference-plugin/index.ts index 92a7ffd063..d8f73a2181 100644 --- a/plugins/inference-plugin/index.ts +++ b/plugins/inference-plugin/index.ts @@ -15,8 +15,18 @@ const stopModel = () => { invokePluginFunc(MODULE_PATH, "killSubprocess"); }; -function requestInference(recentMessages: any[]): Observable { +function requestInference(recentMessages: any[], bot?: any): Observable { return new Observable((subscriber) => { + const requestBody = JSON.stringify({ + messages: recentMessages, + stream: true, + model: "gpt-3.5-turbo", + max_tokens: bot?.maxTokens ?? 2048, + frequency_penalty: bot?.frequencyPenalty ?? 0, + presence_penalty: bot?.presencePenalty ?? 0, + temperature: bot?.customTemperature ?? 0, + }); + console.debug(`Request body: ${requestBody}`); fetch(INFERENCE_URL, { method: "POST", headers: { @@ -24,12 +34,7 @@ function requestInference(recentMessages: any[]): Observable { Accept: "text/event-stream", "Access-Control-Allow-Origin": "*", }, - body: JSON.stringify({ - messages: recentMessages, - stream: true, - model: "gpt-3.5-turbo", - max_tokens: 500, - }), + body: requestBody, }) .then(async (response) => { const stream = response.body; @@ -62,14 +67,8 @@ function requestInference(recentMessages: any[]): Observable { }); } -async function retrieveLastTenMessages(conversationId: string) { +async function retrieveLastTenMessages(conversationId: string, bot?: any) { // TODO: Common collections should be able to access via core functions instead of store - const conversation = await store.findOne("conversations", conversationId); - let bot = undefined; - if (conversation.botId != null) { - bot = await store.findOne("bots", conversation.botId); - } - const messageHistory = (await store.findMany("messages", { conversationId }, [{ createdAt: "asc" }])) ?? []; let recentMessages = messageHistory @@ -88,13 +87,19 @@ async function retrieveLastTenMessages(conversationId: string) { },...recentMessages]; } - console.debug(`Sending: ${JSON.stringify(recentMessages)}`); + console.debug(`Last 10 messages: ${JSON.stringify(recentMessages, null, 2)}`); return recentMessages; } async function handleMessageRequest(data: NewMessageRequest) { - const recentMessages = await retrieveLastTenMessages(data.conversationId); + const conversation = await store.findOne("conversations", data.conversationId); + let bot = undefined; + if (conversation.botId != null) { + bot = await store.findOne("bots", conversation.botId); + } + + const recentMessages = await retrieveLastTenMessages(data.conversationId, bot); const message = { ...data, message: "", @@ -108,7 +113,7 @@ async function handleMessageRequest(data: NewMessageRequest) { message._id = id; events.emit(EventName.OnNewMessageResponse, message); - requestInference(recentMessages).subscribe({ + requestInference(recentMessages, bot).subscribe({ next: (content) => { message.message = content; events.emit(EventName.OnMessageResponseUpdate, message); diff --git a/web/app/_components/BotInfo/index.tsx b/web/app/_components/BotInfo/index.tsx index 78c6d847da..9085536c09 100644 --- a/web/app/_components/BotInfo/index.tsx +++ b/web/app/_components/BotInfo/index.tsx @@ -7,6 +7,8 @@ import useCreateConversation from "@/_hooks/useCreateConversation"; import useDeleteBot from "@/_hooks/useDeleteBot"; import { useAtomValue, useSetAtom } from "jotai"; import React from "react"; +import PrimaryButton from "../PrimaryButton"; +import ExpandableHeader from "../ExpandableHeader"; const BotInfo: React.FC = () => { const { deleteBot } = useDeleteBot(); @@ -33,20 +35,20 @@ const BotInfo: React.FC = () => { }; return ( -
- {/* Header */} -
Bot Info
+
+ {}} /> - {/* Body */}
- + {botInfo.description}
-
- Delete bot -
+
); }; diff --git a/web/app/_components/BotSetting/index.tsx b/web/app/_components/BotSetting/index.tsx index cbb1a80b10..5f39e6a1c1 100644 --- a/web/app/_components/BotSetting/index.tsx +++ b/web/app/_components/BotSetting/index.tsx @@ -1,6 +1,6 @@ import { activeBotAtom } from "@/_helpers/atoms/Bot.atom"; import { useAtomValue } from "jotai"; -import React from "react"; +import React, { useState } from "react"; import ExpandableHeader from "../ExpandableHeader"; import { useDebouncedCallback } from "use-debounce"; import useUpdateBot from "@/_hooks/useUpdateBot"; @@ -10,6 +10,18 @@ const delayBeforeUpdateInMs = 1000; const BotSetting: React.FC = () => { const activeBot = useAtomValue(activeBotAtom); + const [temperature, setTemperature] = useState( + activeBot?.customTemperature ?? 0 + ); + + const [maxTokens, setMaxTokens] = useState(activeBot?.maxTokens ?? 0); + const [frequencyPenalty, setFrequencyPenalty] = useState( + activeBot?.frequencyPenalty ?? 0 + ); + const [presencePenalty, setPresencePenalty] = useState( + activeBot?.presencePenalty ?? 0 + ); + const { updateBot } = useUpdateBot(); const debouncedTemperature = useDebouncedCallback((value) => { @@ -18,6 +30,24 @@ const BotSetting: React.FC = () => { updateBot(activeBot, { customTemperature: value }); }, delayBeforeUpdateInMs); + const debouncedMaxToken = useDebouncedCallback((value) => { + if (!activeBot) return; + if (activeBot.maxTokens === value) return; + updateBot(activeBot, { maxTokens: value }); + }, delayBeforeUpdateInMs); + + const debouncedFreqPenalty = useDebouncedCallback((value) => { + if (!activeBot) return; + if (activeBot.frequencyPenalty === value) return; + updateBot(activeBot, { frequencyPenalty: value }); + }, delayBeforeUpdateInMs); + + const debouncedPresencePenalty = useDebouncedCallback((value) => { + if (!activeBot) return; + if (activeBot.presencePenalty === value) return; + updateBot(activeBot, { presencePenalty: value }); + }, delayBeforeUpdateInMs); + const debouncedSystemPrompt = useDebouncedCallback((value) => { if (!activeBot) return; if (activeBot.systemPrompt === value) return; @@ -55,21 +85,89 @@ const BotSetting: React.FC = () => {
+ {/* TODO: clean up this code */} + {/* Max temp */} +

Max tokens

+
+ { + const value = Number(e.target.value); + setMaxTokens(value); + debouncedMaxToken(value); + }} + /> + + {formatTwoDigits(maxTokens)} + +
+ +

Frequency penalty

+
+ { + const value = Number(e.target.value); + setFrequencyPenalty(value); + debouncedFreqPenalty(value); + }} + /> + + {formatTwoDigits(frequencyPenalty)} + +
+ +

Presence penalty

+
+ { + const value = Number(e.target.value); + setPresencePenalty(value); + debouncedPresencePenalty(value); + }} + /> + + {formatTwoDigits(presencePenalty)} + +
+ {/* Custom temp */} +

Temperature

debouncedTemperature(e.target.value)} + onChange={(e) => { + const newTemp = Number(e.target.value); + setTemperature(newTemp); + debouncedTemperature(Number(e.target.value)); + }} /> - {/* - {formatTwoDigits(value)} - */} + + {formatTwoDigits(temperature)} +
diff --git a/web/app/_components/CreateBotContainer/index.tsx b/web/app/_components/CreateBotContainer/index.tsx index 82e7d69788..0f9c502013 100644 --- a/web/app/_components/CreateBotContainer/index.tsx +++ b/web/app/_components/CreateBotContainer/index.tsx @@ -11,7 +11,6 @@ import useCreateBot from "@/_hooks/useCreateBot"; import { Bot } from "@/_models/Bot"; import { SubmitHandler, useForm } from "react-hook-form"; import Avatar from "../Avatar"; -import SecondaryButton from "../SecondaryButton"; import { v4 as uuidv4 } from "uuid"; const CreateBotContainer: React.FC = () => { @@ -32,6 +31,9 @@ const CreateBotContainer: React.FC = () => { renderMarkdownContent: true, customTemperature: 0.7, enableCustomTemperature: false, + maxTokens: 2048, + frequencyPenalty: 0, + presencePenalty: 0, }, mode: "onChange", }); @@ -42,7 +44,15 @@ const CreateBotContainer: React.FC = () => { alert("Please select a model"); return; } - createBot({ ...data, name: data._id }); + const bot: Bot = { + ...data, + customTemperature: Number(data.customTemperature), + maxTokens: Number(data.maxTokens), + frequencyPenalty: Number(data.frequencyPenalty), + presencePenalty: Number(data.presencePenalty), + name: data._id, + }; + createBot(bot); }; const models = downloadedModels.map((model) => { @@ -57,7 +67,6 @@ const CreateBotContainer: React.FC = () => {
Create Bot
-
@@ -95,13 +104,13 @@ const CreateBotContainer: React.FC = () => { required /> - + /> */}
- {showAdvanced && ( + +

Max tokens

+ +

Custom temperature

+ +

Frequency penalty

+ +

Presence penalty

+ + + {/* {showAdvanced && ( = ({ control }) => { /> - )} + )} */} ); }; diff --git a/web/app/_components/CustomBotTemperature/index.tsx b/web/app/_components/CustomBotTemperature/index.tsx index 5f5fb22c7e..987419b3e2 100644 --- a/web/app/_components/CustomBotTemperature/index.tsx +++ b/web/app/_components/CustomBotTemperature/index.tsx @@ -25,7 +25,7 @@ const CutomBotTemperature: React.FC = ({ control }) => ( render={({ field: { value } }) => { if (!value) return
; return ( - + ); }} /> diff --git a/web/app/_components/DraggableProgressBar/index.tsx b/web/app/_components/DraggableProgressBar/index.tsx index 76150728ab..1c9edf191d 100644 --- a/web/app/_components/DraggableProgressBar/index.tsx +++ b/web/app/_components/DraggableProgressBar/index.tsx @@ -5,9 +5,12 @@ import { Controller, useController } from "react-hook-form"; type Props = { id: string; control: any; + min: number; + max: number; + step: number; }; -const DraggableProgressBar: React.FC = ({ id, control }) => { +const DraggableProgressBar: React.FC = ({ id, control, min, max, step }) => { const { field } = useController({ name: id, control: control, @@ -19,11 +22,9 @@ const DraggableProgressBar: React.FC = ({ id, control }) => { {...field} className="flex-1" type="range" - id="volume" - name="volume" - min="0" - max="1" - step="0.01" + min={min} + max={max} + step={step} /> { initial={false} animate={isVisible ? "show" : "hide"} variants={variants} - className="flex flex-col w-80 flex-shrink-0 py-3 border-l border-gray-200" + className="flex flex-col w-80 flex-shrink-0 py-3 border-l border-gray-200 overflow-y-auto scroll" > {isVisible && ( diff --git a/web/app/_hooks/useUpdateBot.ts b/web/app/_hooks/useUpdateBot.ts index 505f135dff..b4c541d3f0 100644 --- a/web/app/_hooks/useUpdateBot.ts +++ b/web/app/_hooks/useUpdateBot.ts @@ -29,6 +29,9 @@ export default function useUpdateBot() { } export type UpdatableField = { + presencePenalty?: number; + frequencyPenalty?: number; + maxTokens?: number; customTemperature?: number; systemPrompt?: number; }; diff --git a/web/app/_models/Bot.ts b/web/app/_models/Bot.ts index 31acc5f1c3..6d5f46cdcf 100644 --- a/web/app/_models/Bot.ts +++ b/web/app/_models/Bot.ts @@ -20,6 +20,12 @@ export type Bot = { */ customTemperature: number; + maxTokens: number; + + frequencyPenalty: number; + + presencePenalty: number; + modelId: string; createdAt?: number; updatedAt?: number;