diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts index c94d287b56..4bf5a855e6 100644 --- a/web/helpers/atoms/Thread.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -173,6 +173,21 @@ export const updateThreadWaitingForResponseAtom = atom( } ) +/** + * Reset the thread waiting for response state + */ +export const resetThreadWaitingForResponseAtom = atom(null, (get, set) => { + const currentState = { ...get(threadStatesAtom) } + Object.keys(currentState).forEach((threadId) => { + currentState[threadId] = { + ...currentState[threadId], + waitingForResponse: false, + error: undefined, + } + }) + set(threadStatesAtom, currentState) +}) + /** * Update the thread last message */ diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index ed704dd612..14d8819776 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -10,6 +10,10 @@ import { LAST_USED_MODEL_ID } from './useRecommendedModel' import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { + isGeneratingResponseAtom, + resetThreadWaitingForResponseAtom, +} from '@/helpers/atoms/Thread.atom' export const activeModelAtom = atom(undefined) export const loadModelErrorAtom = atom(undefined) @@ -34,6 +38,10 @@ export function useActiveModel() { const pendingModelLoad = useRef(false) const isVulkanEnabled = useAtomValue(vulkanEnabledAtom) const activeAssistant = useAtomValue(activeAssistantAtom) + const setGeneratingResponse = useSetAtom(isGeneratingResponseAtom) + const resetThreadWaitingForResponseState = useSetAtom( + resetThreadWaitingForResponseAtom + ) const downloadedModelsRef = useRef([]) @@ -139,6 +147,8 @@ export function useActiveModel() { return const engine = EngineManager.instance().get(stoppingModel.engine) + setGeneratingResponse(false) + resetThreadWaitingForResponseState() return engine ?.unloadModel(stoppingModel) .catch((e) => console.error(e)) @@ -148,7 +158,14 @@ export function useActiveModel() { pendingModelLoad.current = false }) }, - [activeModel, setStateModel, setActiveModel, stateModel] + [ + activeModel, + setStateModel, + setActiveModel, + stateModel, + setGeneratingResponse, + resetThreadWaitingForResponseState, + ] ) const stopInference = useCallback(async () => { diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index 0ba50880b6..24499bd305 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -302,9 +302,7 @@ const ChatInput = () => { )} - {messages[messages.length - 1]?.status !== MessageStatus.Pending && - !isGeneratingResponse && - !isStreamingResponse ? ( + {!isGeneratingResponse && !isStreamingResponse ? ( <> {currentPrompt.length !== 0 && (