From 4f41dab9b58b93e1034b19d36436ae1086cfc002 Mon Sep 17 00:00:00 2001 From: Faisal Amir Date: Thu, 14 Dec 2023 16:33:42 +0700 Subject: [PATCH] feat: move stop inference button into the send button --- .../inference-nitro-extension/src/index.ts | 5 ++ .../inference-openai-extension/src/index.ts | 5 ++ web/screens/Chat/MessageToolbar/index.tsx | 48 +++++-------------- web/screens/Chat/index.tsx | 36 ++++++++++---- 4 files changed, 51 insertions(+), 43 deletions(-) diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index f2fbf0d345..4bcfe05b08 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -227,6 +227,11 @@ export default class JanInferenceNitroExtension implements InferenceExtension { events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { + if (instance.isCancelled) { + message.status = MessageStatus.Ready; + events.emit(EventName.OnMessageUpdate, message); + return; + } const messageContent: ThreadContent = { type: ContentType.Text, text: { diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index c719e405fa..42c9c37980 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -217,6 +217,11 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { + if (instance.isCancelled) { + message.status = MessageStatus.Ready; + events.emit(EventName.OnMessageUpdate, message); + return; + } const messageContent: ThreadContent = { type: ContentType.Text, text: { diff --git a/web/screens/Chat/MessageToolbar/index.tsx b/web/screens/Chat/MessageToolbar/index.tsx index 7f8e5ca7eb..c6214aef37 100644 --- a/web/screens/Chat/MessageToolbar/index.tsx +++ b/web/screens/Chat/MessageToolbar/index.tsx @@ -1,14 +1,12 @@ import { - EventName, MessageStatus, ExtensionType, ThreadMessage, - events, ChatCompletionRole, } from '@janhq/core' -import { ConversationalExtension, InferenceExtension } from '@janhq/core' +import { ConversationalExtension } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' -import { RefreshCcw, Copy, Trash2Icon, StopCircle } from 'lucide-react' +import { RefreshCcw, Copy, Trash2Icon } from 'lucide-react' import { twMerge } from 'tailwind-merge' @@ -29,17 +27,6 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { const messages = useAtomValue(getCurrentChatMessagesAtom) const { resendChatMessage } = useSendChatMessage() - const onStopInferenceClick = async () => { - events.emit(EventName.OnInferenceStopped, {}) - - setTimeout(() => { - events.emit(EventName.OnMessageUpdate, { - ...message, - status: MessageStatus.Ready, - }) - }, 300) - } - const onDeleteClick = async () => { deleteMessage(message.id ?? '') if (thread) { @@ -60,26 +47,19 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { resendChatMessage(message) } + if (message.status !== MessageStatus.Ready) return null + return (
- {message.status === MessageStatus.Pending && ( + {message.id === messages[messages.length - 1]?.id && (
- +
)} - {message.status !== MessageStatus.Pending && - message.id === messages[messages.length - 1]?.id && ( -
- -
- )}
{ @@ -91,14 +71,12 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { >
- {message.status === MessageStatus.Ready && ( -
- -
- )} +
+ +
) diff --git a/web/screens/Chat/index.tsx b/web/screens/Chat/index.tsx index 88a97db71d..9f6dba38e6 100644 --- a/web/screens/Chat/index.tsx +++ b/web/screens/Chat/index.tsx @@ -1,9 +1,11 @@ import { ChangeEvent, Fragment, KeyboardEvent, useEffect, useRef } from 'react' +import { EventName, MessageStatus, events } from '@janhq/core' import { Button, Textarea } from '@janhq/uikit' import { useAtom, useAtomValue } from 'jotai' +import { StopCircle } from 'lucide-react' import { twMerge } from 'tailwind-merge' import LogoMark from '@/containers/Brand/Logo/Mark' @@ -26,6 +28,7 @@ import ThreadList from '@/screens/Chat/ThreadList' import Sidebar, { showRightSideBarAtom } from './Sidebar' +import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { activeThreadAtom, getActiveThreadIdAtom, @@ -40,6 +43,7 @@ const ChatScreen = () => { const { activeModel, stateModel } = useActiveModel() const { setMainViewState } = useMainViewState() + const messages = useAtomValue(getCurrentChatMessagesAtom) const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom) const activeThreadState = useAtomValue(activeThreadStateAtom) @@ -94,6 +98,10 @@ const ChatScreen = () => { } } + const onStopInferenceClick = async () => { + events.emit(EventName.OnInferenceStopped, {}) + } + return (
@@ -159,14 +167,26 @@ const ChatScreen = () => { onPromptChange(e) } /> - + {messages[messages.length - 1]?.status !== MessageStatus.Pending ? ( + + ) : ( + + )}