Skip to content

Commit

Permalink
feat: chat operation for all sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
Echo-minn committed Apr 9, 2024
1 parent 0062d97 commit 4f84bd6
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 33 deletions.
3 changes: 3 additions & 0 deletions openaoe/frontend/src/pages/chat/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import Chat from '@pages/chat/components/chat/chat.tsx';
import ModelList from '@pages/chat/components/model-list/model-list.tsx';
import PromptInput from '@pages/chat/components/prompt-input/prompt-input.tsx';
import Loading from '@components/loading/loading.tsx';
import ChatOperation from '@pages/chat/components/chat-operations/chat-operation.tsx';
import styles from './chat.module.less';

const useHasHydrated = () => {
Expand All @@ -28,6 +29,8 @@ const ChatPage: React.FC = () => {
<div className={styles.homeChats}>
<Chat />
</div>
{/* Operation for all sessions */}
<ChatOperation modelName="gpt-3.5-turbo" />
{/* Input editor */}
<PromptInput />
{/* Model selector */}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Tooltip } from 'sea-lion-ui';
import { getNeedEventCallback, scrollToBottom } from '@utils/utils.ts';
import { BASE_IMG_URL, CLEAR_CONTEXT, SERIAL_SESSION } from '@constants/models.ts';
import {
BASE_IMG_URL, CLEAR_CONTEXT, PARALLEL_MODE, SERIAL_MODE, SERIAL_SESSION
} from '@constants/models.ts';
import React, { useContext, useEffect } from 'react';
import { GlobalConfigContext } from '@components/global-config/global-config-context.tsx';
import styles from './chat-operation.module.less';
Expand Down Expand Up @@ -122,51 +124,112 @@ const RetryIcon = () => {
};
const ChatOperation = (props: ChatOperationProps) => {
const { models, streamModels } = useContext(GlobalConfigContext);
const { modelName } = props;
const chatStore = useChatStore();
const configStore = useConfigStore();
const { sessions } = chatStore;
const currSession = sessions.find((session) => session.name === modelName);
const botStore = useBotStore();
const hasMessage = sessions.some((session) => {
if (configStore.mode === PARALLEL_MODE && session.name !== SERIAL_SESSION) {
return session.messages.length > 0;
}
if (configStore.mode === SERIAL_MODE && session.name === SERIAL_SESSION) {
return session.messages.length > 0;
}
return false;
});

const hasStreamSession = sessions.some((session) => {
if (configStore.mode === PARALLEL_MODE && session.name !== SERIAL_SESSION) {
return chatStore.lastMessage(session.name).stream;
}
if (configStore.mode === SERIAL_MODE && session.name === SERIAL_SESSION) {
return chatStore.lastMessage(SERIAL_SESSION).stream;
}
return false;
});

/**
* clear context for current session, then scroll to bottom automatically
*/
const handleClearContext = () => {
const sessionIdx = chatStore.sessions.findIndex((session) => session.name === modelName);
const newSession = chatStore.sessions[sessionIdx];
if (chatStore.lastMessage(newSession.name).sender_type === CLEAR_CONTEXT) return;
if (configStore.mode === PARALLEL_MODE) {
chatStore.sessions.forEach((session, sessionIdx) => {
if (session.name === SERIAL_SESSION) return;
const newSession = session;
if (chatStore.lastMessage(session.name).sender_type === CLEAR_CONTEXT) return;

newSession.clearContextIndex = session.messages.length || 0;
chatStore.updateSession(sessionIdx, newSession);
chatStore.onNewMessage(createMessage({
model: 'admin',
text: '',
sender_type: CLEAR_CONTEXT,
id: Date.now(),
stream: false,
isError: true,
}), sessionIdx);
scrollToBottom(`chat-wrapper-${newSession.id}`);
});
} else if (configStore.mode === SERIAL_MODE) {
const sessionIdx = chatStore.sessions.findIndex((session) => session.name === SERIAL_SESSION);
const newSession = chatStore.sessions[sessionIdx];
if (chatStore.lastMessage(newSession.name).sender_type === CLEAR_CONTEXT) return;

newSession.clearContextIndex = newSession.messages.length || 0;
chatStore.updateSession(sessionIdx, newSession);
chatStore.onNewMessage(createMessage({
model: 'admin',
text: '',
sender_type: CLEAR_CONTEXT,
id: Date.now(),
stream: false,
isError: true,
}), sessionIdx);
scrollToBottom(`chat-wrapper-${newSession.id}`);
newSession.clearContextIndex = newSession.messages.length || 0;
chatStore.updateSession(sessionIdx, newSession);
chatStore.onNewMessage(createMessage({
model: 'admin',
text: '',
sender_type: CLEAR_CONTEXT,
id: Date.now(),
stream: false,
isError: true,
}), sessionIdx);
scrollToBottom(`chat-wrapper-${newSession.id}`);
}
};

/**
* clear history for current session
*/
const handleClearHistory = () => {
const sessionIdx = sessions.findIndex((session) => session.name === modelName);

chatStore.updateSession(sessionIdx, { messages: [], clearContextIndex: 0 });
if (configStore.mode === PARALLEL_MODE) {
chatStore.sessions.forEach((session, sessionIdx) => {
if (session.name === SERIAL_SESSION) return;
chatStore.updateSession(sessionIdx, { messages: [], clearContextIndex: 0 });
});
} else if (configStore.mode === SERIAL_MODE) {
const sessionIdx = sessions.findIndex((session) => session.name === SERIAL_SESSION);
chatStore.updateSession(sessionIdx, { messages: [], clearContextIndex: 0 });
}
};

const handleStopStream = () => {
chatStore.closeController(modelName);
if (configStore.mode === PARALLEL_MODE) {
chatStore.sessions.forEach((session) => {
if (session.name === SERIAL_SESSION) return;
chatStore.closeController(session.name);
});
} else if (configStore.mode === SERIAL_MODE) {
chatStore.closeController(SERIAL_SESSION);
}
};

const handleRetry = () => {
const model = currSession.name === SERIAL_SESSION ? botStore.currentBot : chatStore.lastBotMessage(currSession.name).model;
const provider = models[model]?.provider || '';
const isStream = streamModels.includes(model);
chatStore.retry(currSession.name, provider, model, isStream);
if (configStore.mode === PARALLEL_MODE) {
chatStore.sessions.forEach((session) => {
if (session.name === SERIAL_SESSION) return;
const model = chatStore.lastBotMessage(session.name)?.model;
const provider = models[model]?.provider || '';
const isStream = streamModels.includes(model);
chatStore.retry(session.name, provider, model, isStream);
});
} else if (configStore.mode === SERIAL_MODE) {
const model = botStore.currentBot;
const provider = models[model]?.provider || '';
const isStream = streamModels.includes(model);
chatStore.retry(SERIAL_SESSION, provider, model, isStream);
}
};

useEffect(() => {
Expand All @@ -175,10 +238,11 @@ const ChatOperation = (props: ChatOperationProps) => {
handleStopStream();
};
}, []);

return (
<div className={styles.homeOperation}>
{/** if the last message is generating,which means: stream=true */}
{chatStore.lastMessage(modelName).stream && (
{hasMessage && hasStreamSession && (
<Tooltip title="Stop generating" className={styles.opBtn}>
<div {...getNeedEventCallback(handleStopStream)}>
<img
Expand All @@ -191,7 +255,7 @@ const ChatOperation = (props: ChatOperationProps) => {
</Tooltip>
)}
{/** if message list is not empty && stream=false */}
{!!currSession.messages?.length && !chatStore.lastMessage(modelName).stream && (
{hasMessage && !hasStreamSession && (
<>
<Tooltip title="Clear context" className={styles.opBtn}>
<div {...getNeedEventCallback(handleClearContext)}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
height: 100%;
width: 100%;
gap: 48px;
padding: 12px 0;
padding: 0;
position: relative;
}
.preview-sessions {
Expand Down
2 changes: 1 addition & 1 deletion openaoe/frontend/src/pages/chat/components/chat/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ function ChatSession(props: { session: ChatSessionProps }) {
);
})}
</div>
<ChatOperation modelName={session.name} />
{/* <ChatOperation modelName={session.name} /> */}
</div>
);
}
Expand Down
13 changes: 9 additions & 4 deletions openaoe/frontend/src/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { persist } from 'zustand/middleware';
import { scrollToBottom } from '@utils/utils.ts';
import { getHeaders, getPayload, getUrl } from '@services/fetch.ts';
import { fetchBotAnswer } from '@services/home.ts';
import { DEFAULT_BOT, SERIAL_SESSION, STREAM_BOT } from '@constants/models.ts';
import { DEFAULT_BOT, SERIAL_SESSION } from '@constants/models.ts';

export interface ChatMessage {
text: string;
Expand Down Expand Up @@ -126,7 +126,10 @@ export const useChatStore = create<ChatStore>()(
}
},
retry(bot: '', provider: '', model: '', isStreamApi = true) {
const text = get().lastUserMessage(bot).text;
const text = get().lastUserMessage(bot)?.text;
if (!text) {
return;
}
if ((get().lastMessage(bot).id === get().lastBotMessage(bot).id) && get().getSession(bot).clearContextIndex !== get().getSession(bot).messages.length) {
// If the last message is a reply from the bot and the context is not cleared yet, replace the last two messages,
// Otherwise resend the last conversation
Expand Down Expand Up @@ -319,7 +322,8 @@ export const useChatStore = create<ChatStore>()(
return createMessage({ text: '' });
},
lastUserMessage(sessionName) {
const session = get().sessions.find((session) => session.name === sessionName);
const session = get().sessions
.find((session) => session.name === sessionName);
if (!session) return createMessage({ text: '' });
const messages = [...session.messages];
if (Array.isArray(messages) && messages.length > 0) {
Expand All @@ -328,7 +332,8 @@ export const useChatStore = create<ChatStore>()(
return createMessage({ text: '' });
},
lastBotMessage(sessionName) {
const session = get().sessions.find((session) => session.name === sessionName);
const session = get().sessions
.find((session) => session.name === sessionName);
if (!session) return createMessage({ text: '' });
const messages = [...session.messages];
if (Array.isArray(messages) && messages.length > 0) {
Expand Down

0 comments on commit 4f84bd6

Please sign in to comment.