Skip to content

Commit

Permalink
fix: cleaning a thread should just clear out messages
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Dec 23, 2024
1 parent 5163e12 commit 5b42df7
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 65 deletions.
6 changes: 4 additions & 2 deletions extensions/conversational-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ export default class CortexConversationalExtension extends ConversationalExtensi
*/
async getThreadAssistant(threadId: string): Promise<ThreadAssistantInfo> {
return this.queue.add(() =>
ky.get(`${API_URL}/v1/assistants/${threadId}?limit=-1`).json<ThreadAssistantInfo>()
ky
.get(`${API_URL}/v1/assistants/${threadId}?limit=-1`)
.json<ThreadAssistantInfo>()
) as Promise<ThreadAssistantInfo>
}
/**
Expand Down Expand Up @@ -188,7 +190,7 @@ export default class CortexConversationalExtension extends ConversationalExtensi
* Do health check on cortex.cpp
* @returns
*/
healthz(): Promise<void> {
async healthz(): Promise<void> {
return ky
.get(`${API_URL}/healthz`, {
retry: { limit: 20, delay: () => 500, methods: ['get'] },
Expand Down
20 changes: 20 additions & 0 deletions web/helpers/atoms/Thread.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ export const waitingToSendMessage = atom<boolean | undefined>(undefined)
*/
export const isGeneratingResponseAtom = atom<boolean | undefined>(undefined)

/**
* Create a new thread and add it to the thread list
*/
export const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
// create thread state for this new thread
const currentState = { ...get(threadStatesAtom) }

Check warning on line 133 in web/helpers/atoms/Thread.atom.ts

View workflow job for this annotation

GitHub Actions / coverage-check

133 line is not covered with tests

const threadState: ThreadState = {

Check warning on line 135 in web/helpers/atoms/Thread.atom.ts

View workflow job for this annotation

GitHub Actions / coverage-check

135 line is not covered with tests
hasMore: false,
waitingForResponse: false,
lastMessage: undefined,
}
currentState[newThread.id] = threadState
set(threadStatesAtom, currentState)

Check warning on line 141 in web/helpers/atoms/Thread.atom.ts

View workflow job for this annotation

GitHub Actions / coverage-check

140-141 lines are not covered with tests

// add the new thread on top of the thread list to the state
const threads = get(threadsAtom)
set(threadsAtom, [newThread, ...threads])

Check warning on line 145 in web/helpers/atoms/Thread.atom.ts

View workflow job for this annotation

GitHub Actions / coverage-check

144-145 lines are not covered with tests
})

/**
* Remove a thread state from the atom
*/
Expand Down
19 changes: 1 addition & 18 deletions web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,12 @@ import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
threadStatesAtom,
updateThreadAtom,
setThreadModelParamsAtom,
isGeneratingResponseAtom,
createNewThreadAtom,
} from '@/helpers/atoms/Thread.atom'

const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
// create thread state for this new thread
const currentState = { ...get(threadStatesAtom) }

const threadState: ThreadState = {
hasMore: false,
waitingForResponse: false,
lastMessage: undefined,
}
currentState[newThread.id] = threadState
set(threadStatesAtom, currentState)

// add the new thread on top of the thread list to the state
const threads = get(threadsAtom)
set(threadsAtom, [newThread, ...threads])
})

export const useCreateNewThread = () => {
const createNewThread = useSetAtom(createNewThreadAtom)
const { setActiveThread } = useSetActiveThread()
Expand Down
20 changes: 12 additions & 8 deletions web/hooks/useDeleteThread.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,21 @@ describe('useDeleteThread', () => {
const mockCleanMessages = jest.fn()
;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
const mockCreateNewThread = jest.fn()
;(useCreateNewThread as jest.Mock).mockReturnValue({
requestCreateNewThread: mockCreateNewThread,
})

const mockSaveThread = jest.fn()
const mockDeleteThread = jest.fn().mockResolvedValue({})
const mockDeleteMessage = jest.fn().mockResolvedValue({})
const mockModifyThread = jest.fn().mockResolvedValue({})
extensionManager.get = jest.fn().mockReturnValue({
saveThread: mockSaveThread,
getThreadAssistant: jest.fn().mockResolvedValue({}),
deleteThread: mockDeleteThread,
listMessages: jest.fn().mockResolvedValue([
{
id: 'message1',
text: 'Message 1',
},
]),
deleteMessage: mockDeleteMessage,
modifyThread: mockModifyThread,
})

const { result } = renderHook(() => useDeleteThread())
Expand All @@ -74,8 +78,8 @@ describe('useDeleteThread', () => {
await result.current.cleanThread('thread1')
})

expect(mockDeleteThread).toHaveBeenCalled()
expect(mockCreateNewThread).toHaveBeenCalled()
expect(mockDeleteMessage).toHaveBeenCalled()
expect(mockModifyThread).toHaveBeenCalled()
})

it('should handle errors when deleting a thread', async () => {
Expand Down
67 changes: 30 additions & 37 deletions web/hooks/useDeleteThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ import {
threadsAtom,
setActiveThreadIdAtom,
deleteThreadStateAtom,
updateThreadAtom,
} from '@/helpers/atoms/Thread.atom'

export default function useDeleteThread() {
const [threads, setThreads] = useAtom(threadsAtom)
const { requestCreateNewThread } = useCreateNewThread()
const assistants = useAtomValue(assistantsAtom)
const models = useAtomValue(downloadedModelsAtom)
const updateThread = useSetAtom(updateThreadAtom)

const setCurrentPrompt = useSetAtom(currentPromptAtom)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
Expand All @@ -35,43 +34,37 @@ export default function useDeleteThread() {

const cleanThread = useCallback(
async (threadId: string) => {
const thread = threads.find((c) => c.id === threadId)
if (!thread) return
const availableThreads = threads.filter((c) => c.id !== threadId)
setThreads(availableThreads)

// delete the thread state
deleteThreadState(threadId)

const assistantInfo = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.getThreadAssistant(thread.id)
.catch(console.error)

if (!assistantInfo) return
const model = models.find((c) => c.id === assistantInfo?.model?.id)

requestCreateNewThread(
{
...assistantInfo,
id: assistants[0].id,
name: assistants[0].name,
},
model
? {
...model,
parameters: assistantInfo?.model?.parameters ?? {},
settings: assistantInfo?.model?.settings ?? {},
}
: undefined
)
// Delete this thread
await extensionManager
const messages = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteThread(threadId)
?.listMessages(threadId)
.catch(console.error)
if (messages) {
messages.forEach((message) => {
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(threadId, message.id)
.catch(console.error)
})
const thread = threads.find((e) => e.id === threadId)
if (thread) {
const updatedThread = {
...thread,
metadata: {
...thread.metadata,
title: 'New Thread',
lastMessage: '',
},
}
extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyThread(updatedThread)
.catch(console.error)
updateThread(updatedThread)
}
}
deleteMessages(threadId)
},
[assistants, models, requestCreateNewThread, threads]
[deleteMessages, threads, updateThread]
)

const deleteThread = async (threadId: string) => {
Expand Down

0 comments on commit 5b42df7

Please sign in to comment.