diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index 854773556f..30766be45f 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -147,7 +147,9 @@ export default class CortexConversationalExtension extends ConversationalExtensi */ async getThreadAssistant(threadId: string): Promise { return this.queue.add(() => - ky.get(`${API_URL}/v1/assistants/${threadId}?limit=-1`).json() + ky + .get(`${API_URL}/v1/assistants/${threadId}?limit=-1`) + .json() ) as Promise } /** @@ -188,7 +190,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi * Do health check on cortex.cpp * @returns */ - healthz(): Promise { + async healthz(): Promise { return ky .get(`${API_URL}/healthz`, { retry: { limit: 20, delay: () => 500, methods: ['get'] }, diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts index 7fb6f3c600..55527115fe 100644 --- a/web/helpers/atoms/Thread.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -125,6 +125,26 @@ export const waitingToSendMessage = atom(undefined) */ export const isGeneratingResponseAtom = atom(undefined) +/** + * Create a new thread and add it to the thread list + */ +export const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { + // create thread state for this new thread + const currentState = { ...get(threadStatesAtom) } + + const threadState: ThreadState = { + hasMore: false, + waitingForResponse: false, + lastMessage: undefined, + } + currentState[newThread.id] = threadState + set(threadStatesAtom, currentState) + + // add the new thread on top of the thread list to the state + const threads = get(threadsAtom) + set(threadsAtom, [newThread, ...threads]) +}) + /** * Remove a thread state from the atom */ diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 1e1f7a8492..40e554945e 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -33,29 +33,12 @@ import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, - threadStatesAtom, updateThreadAtom, setThreadModelParamsAtom, isGeneratingResponseAtom, + createNewThreadAtom, } from '@/helpers/atoms/Thread.atom' -const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { - // create thread state for this new thread - const currentState = { ...get(threadStatesAtom) } - - const threadState: ThreadState = { - hasMore: false, - waitingForResponse: false, - lastMessage: undefined, - } - currentState[newThread.id] = threadState - set(threadStatesAtom, currentState) - - // add the new thread on top of the thread list to the state - const threads = get(threadsAtom) - set(threadsAtom, [newThread, ...threads]) -}) - export const useCreateNewThread = () => { const createNewThread = useSetAtom(createNewThreadAtom) const { setActiveThread } = useSetActiveThread() diff --git a/web/hooks/useDeleteThread.test.ts b/web/hooks/useDeleteThread.test.ts index 50b0c7511b..e9e4341648 100644 --- a/web/hooks/useDeleteThread.test.ts +++ b/web/hooks/useDeleteThread.test.ts @@ -55,17 +55,21 @@ describe('useDeleteThread', () => { const mockCleanMessages = jest.fn() ;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages) ;(useAtomValue as jest.Mock).mockReturnValue(['thread 1']) - const mockCreateNewThread = jest.fn() - ;(useCreateNewThread as jest.Mock).mockReturnValue({ - requestCreateNewThread: mockCreateNewThread, - }) const mockSaveThread = jest.fn() - const mockDeleteThread = jest.fn().mockResolvedValue({}) + const mockDeleteMessage = jest.fn().mockResolvedValue({}) + const mockModifyThread = jest.fn().mockResolvedValue({}) extensionManager.get = jest.fn().mockReturnValue({ saveThread: mockSaveThread, getThreadAssistant: jest.fn().mockResolvedValue({}), - deleteThread: mockDeleteThread, + listMessages: jest.fn().mockResolvedValue([ + { + id: 'message1', + text: 'Message 1', + }, + ]), + deleteMessage: mockDeleteMessage, + modifyThread: mockModifyThread, }) const { result } = renderHook(() => useDeleteThread()) @@ -74,8 +78,8 @@ describe('useDeleteThread', () => { await result.current.cleanThread('thread1') }) - expect(mockDeleteThread).toHaveBeenCalled() - expect(mockCreateNewThread).toHaveBeenCalled() + expect(mockDeleteMessage).toHaveBeenCalled() + expect(mockModifyThread).toHaveBeenCalled() }) it('should handle errors when deleting a thread', async () => { diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index f69ccd47ea..2d14460ffa 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -2,30 +2,25 @@ import { useCallback } from 'react' import { ExtensionTypeEnum, ConversationalExtension } from '@janhq/core' -import { useAtom, useAtomValue, useSetAtom } from 'jotai' +import { useAtom, useSetAtom } from 'jotai' import { currentPromptAtom } from '@/containers/Providers/Jotai' import { toaster } from '@/containers/Toast' -import { useCreateNewThread } from './useCreateNewThread' - import { extensionManager } from '@/extension/ExtensionManager' -import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import { deleteChatMessageAtom as deleteChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, setActiveThreadIdAtom, deleteThreadStateAtom, + updateThreadAtom, } from '@/helpers/atoms/Thread.atom' export default function useDeleteThread() { const [threads, setThreads] = useAtom(threadsAtom) - const { requestCreateNewThread } = useCreateNewThread() - const assistants = useAtomValue(assistantsAtom) - const models = useAtomValue(downloadedModelsAtom) + const updateThread = useSetAtom(updateThreadAtom) const setCurrentPrompt = useSetAtom(currentPromptAtom) const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) @@ -35,43 +30,37 @@ export default function useDeleteThread() { const cleanThread = useCallback( async (threadId: string) => { - const thread = threads.find((c) => c.id === threadId) - if (!thread) return - const availableThreads = threads.filter((c) => c.id !== threadId) - setThreads(availableThreads) - - // delete the thread state - deleteThreadState(threadId) - - const assistantInfo = await extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.getThreadAssistant(thread.id) - .catch(console.error) - - if (!assistantInfo) return - const model = models.find((c) => c.id === assistantInfo?.model?.id) - - requestCreateNewThread( - { - ...assistantInfo, - id: assistants[0].id, - name: assistants[0].name, - }, - model - ? { - ...model, - parameters: assistantInfo?.model?.parameters ?? {}, - settings: assistantInfo?.model?.settings ?? {}, - } - : undefined - ) - // Delete this thread - await extensionManager + const messages = await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.deleteThread(threadId) + ?.listMessages(threadId) .catch(console.error) + if (messages) { + messages.forEach((message) => { + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.deleteMessage(threadId, message.id) + .catch(console.error) + }) + const thread = threads.find((e) => e.id === threadId) + if (thread) { + const updatedThread = { + ...thread, + metadata: { + ...thread.metadata, + title: 'New Thread', + lastMessage: '', + }, + } + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.modifyThread(updatedThread) + .catch(console.error) + updateThread(updatedThread) + } + } + deleteMessages(threadId) }, - [assistants, models, requestCreateNewThread, threads] + [deleteMessages, threads, updateThread] ) const deleteThread = async (threadId: string) => {