diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 9066a2d1f7..09427018ca 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -1,4 +1,4 @@ -import { Fragment, use, useCallback, useEffect, useRef } from 'react' +import { Fragment, useCallback, useEffect, useRef } from 'react' import { ChatCompletionMessage, diff --git a/web/hooks/useCreateNewThread.test.ts b/web/hooks/useCreateNewThread.test.ts index 25589c0988..d98983830d 100644 --- a/web/hooks/useCreateNewThread.test.ts +++ b/web/hooks/useCreateNewThread.test.ts @@ -67,7 +67,7 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) expect(extensionManager.get).toHaveBeenCalled() }) @@ -104,7 +104,7 @@ describe('useCreateNewThread', () => { await result.current.requestCreateNewThread({ id: 'assistant1', name: 'Assistant 1', - instructions: "Hello Jan Assistant", + instructions: 'Hello Jan Assistant', model: { id: 'model1', parameters: [], @@ -113,16 +113,8 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() - expect(mockSetAtom).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ - assistants: expect.arrayContaining([ - expect.objectContaining({ instructions: 'Hello Jan Assistant' }), - ]), - }) - ) }) it('should create a new thread with previous instructions', async () => { @@ -166,16 +158,8 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set + expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() - expect(mockSetAtom).toHaveBeenNthCalledWith( - 2, - expect.objectContaining({ - assistants: expect.arrayContaining([ - expect.objectContaining({ instructions: 'Hello Jan' }), - ]), - }) - ) }) it('should show a warning toast if trying to create an empty thread', async () => { @@ -212,13 +196,12 @@ describe('useCreateNewThread', () => { const { result } = renderHook(() => useCreateNewThread()) - const mockThread = { id: 'thread1', title: 'Test Thread' } + const mockThread = { id: 'thread1', title: 'Test Thread', assistants: [{}] } await act(async () => { await result.current.updateThreadMetadata(mockThread as any) }) expect(mockUpdateThread).toHaveBeenCalledWith(mockThread) - expect(extensionManager.get).toHaveBeenCalled() }) }) diff --git a/web/hooks/useDeleteThread.test.ts b/web/hooks/useDeleteThread.test.ts index d3a6138d07..50b0c7511b 100644 --- a/web/hooks/useDeleteThread.test.ts +++ b/web/hooks/useDeleteThread.test.ts @@ -2,8 +2,7 @@ import { renderHook, act } from '@testing-library/react' import { useAtom, useAtomValue, useSetAtom } from 'jotai' import useDeleteThread from './useDeleteThread' import { extensionManager } from '@/extension/ExtensionManager' -import { toaster } from '@/containers/Toast' - +import { useCreateNewThread } from './useCreateNewThread' // Mock the necessary dependencies // Mock dependencies jest.mock('jotai', () => ({ @@ -12,6 +11,7 @@ jest.mock('jotai', () => ({ useAtom: jest.fn(), atom: jest.fn(), })) +jest.mock('./useCreateNewThread') jest.mock('@/extension/ExtensionManager') jest.mock('@/containers/Toast') @@ -27,8 +27,13 @@ describe('useDeleteThread', () => { ] const mockSetThreads = jest.fn() ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) + ;(useSetAtom as jest.Mock).mockReturnValue(() => {}) + ;(useCreateNewThread as jest.Mock).mockReturnValue({}) + + const mockDeleteThread = jest.fn().mockImplementation(() => ({ + catch: () => jest.fn, + })) - const mockDeleteThread = jest.fn() extensionManager.get = jest.fn().mockReturnValue({ deleteThread: mockDeleteThread, }) @@ -50,12 +55,17 @@ describe('useDeleteThread', () => { const mockCleanMessages = jest.fn() ;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages) ;(useAtomValue as jest.Mock).mockReturnValue(['thread 1']) + const mockCreateNewThread = jest.fn() + ;(useCreateNewThread as jest.Mock).mockReturnValue({ + requestCreateNewThread: mockCreateNewThread, + }) - const mockWriteMessages = jest.fn() const mockSaveThread = jest.fn() + const mockDeleteThread = jest.fn().mockResolvedValue({}) extensionManager.get = jest.fn().mockReturnValue({ - writeMessages: mockWriteMessages, saveThread: mockSaveThread, + getThreadAssistant: jest.fn().mockResolvedValue({}), + deleteThread: mockDeleteThread, }) const { result } = renderHook(() => useDeleteThread()) @@ -64,20 +74,18 @@ describe('useDeleteThread', () => { await result.current.cleanThread('thread1') }) - expect(mockWriteMessages).toHaveBeenCalled() - expect(mockSaveThread).toHaveBeenCalledWith( - expect.objectContaining({ - id: 'thread1', - title: 'New Thread', - metadata: expect.objectContaining({ lastMessage: undefined }), - }) - ) + expect(mockDeleteThread).toHaveBeenCalled() + expect(mockCreateNewThread).toHaveBeenCalled() }) it('should handle errors when deleting a thread', async () => { const mockThreads = [{ id: 'thread1', title: 'Thread 1' }] const mockSetThreads = jest.fn() ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) + const mockCreateNewThread = jest.fn() + ;(useCreateNewThread as jest.Mock).mockReturnValue({ + requestCreateNewThread: mockCreateNewThread, + }) const mockDeleteThread = jest .fn() @@ -98,8 +106,6 @@ describe('useDeleteThread', () => { expect(mockDeleteThread).toHaveBeenCalledWith('thread1') expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error)) - expect(mockSetThreads).not.toHaveBeenCalled() - expect(toaster).not.toHaveBeenCalled() consoleErrorSpy.mockRestore() }) diff --git a/web/hooks/useThread.test.ts b/web/hooks/useThread.test.ts index a40c709be6..4db7f87aca 100644 --- a/web/hooks/useThread.test.ts +++ b/web/hooks/useThread.test.ts @@ -78,7 +78,7 @@ describe('useThreads', () => { // Mock extensionManager const mockGetThreads = jest.fn().mockResolvedValue(mockThreads) ;(extensionManager.get as jest.Mock).mockReturnValue({ - getThreads: mockGetThreads, + listThreads: mockGetThreads, }) const { result } = renderHook(() => useThreads()) @@ -119,7 +119,7 @@ describe('useThreads', () => { it('should handle empty threads', async () => { // Mock empty threads ;(extensionManager.get as jest.Mock).mockReturnValue({ - getThreads: jest.fn().mockResolvedValue([]), + listThreads: jest.fn().mockResolvedValue([]), }) const mockSetThreadStates = jest.fn() diff --git a/web/hooks/useUpdateModelParameters.test.ts b/web/hooks/useUpdateModelParameters.test.ts index bc60aa631c..6c7ceb8b03 100644 --- a/web/hooks/useUpdateModelParameters.test.ts +++ b/web/hooks/useUpdateModelParameters.test.ts @@ -1,7 +1,12 @@ import { renderHook, act } from '@testing-library/react' +import { useAtom } from 'jotai' // Mock dependencies jest.mock('ulidx') jest.mock('@/extension') +jest.mock('jotai', () => ({ + ...jest.requireActual('jotai'), + useAtom: jest.fn(), +})) import useUpdateModelParameters from './useUpdateModelParameters' import { extensionManager } from '@/extension' @@ -13,7 +18,8 @@ let model: any = { } let extension: any = { - saveThread: jest.fn(), + modifyThread: jest.fn(), + modifyThreadAssistant: jest.fn(), } const mockThread: any = { @@ -35,6 +41,7 @@ const mockThread: any = { describe('useUpdateModelParameters', () => { beforeAll(() => { jest.clearAllMocks() + jest.useFakeTimers() jest.mock('./useRecommendedModel', () => ({ useRecommendedModel: () => ({ recommendedModel: model, @@ -45,6 +52,12 @@ describe('useUpdateModelParameters', () => { }) it('should update model parameters and save thread when params are valid', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockValidParameters: any = { params: { // Inference @@ -76,7 +89,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) @@ -84,44 +98,46 @@ describe('useUpdateModelParameters', () => { await result.current.updateModelParameter(mockThread, mockValidParameters) }) + jest.runAllTimers() + // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: ['', ''], - temperature: 0.5, - token_limit: 1000, - top_k: 0.7, - top_p: 0.1, - stream: true, - max_tokens: 1000, - frequency_penalty: 0.3, - presence_penalty: 0.2, - }, - settings: { - ctx_len: 1024, - ngl: 12, - embedding: true, - n_parallel: 2, - cpu_threads: 4, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + parameters: { + stop: ['', ''], + temperature: 0.5, + token_limit: 1000, + top_k: 0.7, + top_p: 0.1, + stream: true, + max_tokens: 1000, + frequency_penalty: 0.3, + presence_penalty: 0.2, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + settings: { + ctx_len: 1024, + ngl: 12, + embedding: true, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + }, + id: 'model-1', + engine: 'nitro', + }, }) }) it('should not update invalid model parameters', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockInvalidParameters: any = { params: { // Inference @@ -153,7 +169,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) @@ -164,36 +181,38 @@ describe('useUpdateModelParameters', () => { ) }) + jest.runAllTimers() + // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - max_tokens: 1000, - token_limit: 1000, - }, - settings: { - cpu_threads: 4, - ctx_len: 1024, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - n_parallel: 2, - ngl: 12, - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + engine: 'nitro', + id: 'model-1', + parameters: { + token_limit: 1000, + max_tokens: 1000, + }, + settings: { + cpu_threads: 4, + ctx_len: 1024, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', + n_parallel: 2, + ngl: 12, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + }, }) }) it('should update valid model parameters only', async () => { + ;(useAtom as jest.Mock).mockReturnValue([ + { + id: 'assistant-1', + }, + jest.fn(), + ]) const mockInvalidParameters: any = { params: { // Inference @@ -225,8 +244,8 @@ describe('useUpdateModelParameters', () => { // Spy functions jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) - + jest.spyOn(extension, 'modifyThread').mockReturnValue({}) + jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({}) const { result } = renderHook(() => useUpdateModelParameters()) await act(async () => { @@ -235,80 +254,33 @@ describe('useUpdateModelParameters', () => { mockInvalidParameters ) }) + jest.runAllTimers() // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: [''], - top_k: 0.7, - top_p: 0.1, - stream: true, - token_limit: 100, - max_tokens: 1000, - presence_penalty: 0.2, - }, - settings: { - ctx_len: 1024, - ngl: 0, - n_parallel: 2, - cpu_threads: 4, - prompt_template: 'template', - llama_model_path: 'path', - mmproj: 'mmproj', - }, - }, + expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', { + id: 'assistant-1', + model: { + engine: 'nitro', + id: 'model-1', + parameters: { + stop: [''], + top_k: 0.7, + top_p: 0.1, + stream: true, + token_limit: 100, + max_tokens: 1000, + presence_penalty: 0.2, }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, - }) - }) - - it('should handle missing modelId and engine gracefully', async () => { - const mockParametersWithoutModelIdAndEngine: any = { - params: { - stop: ['', ''], - temperature: 0.5, - }, - } - - // Spy functions - jest.spyOn(extensionManager, 'get').mockReturnValue(extension) - jest.spyOn(extension, 'saveThread').mockReturnValue({}) - - const { result } = renderHook(() => useUpdateModelParameters()) - - await act(async () => { - await result.current.updateModelParameter( - mockThread, - mockParametersWithoutModelIdAndEngine - ) - }) - - // Check if the model parameters are valid before persisting - expect(extension.saveThread).toHaveBeenCalledWith({ - assistants: [ - { - model: { - parameters: { - stop: ['', ''], - temperature: 0.5, - }, - settings: {}, - }, + settings: { + ctx_len: 1024, + ngl: 0, + n_parallel: 2, + cpu_threads: 4, + prompt_template: 'template', + llama_model_path: 'path', + mmproj: 'mmproj', }, - ], - created: 0, - id: 'thread-1', - object: 'thread', - title: 'New Thread', - updated: 0, + }, }) }) }) diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 977ebd10be..dab2f6e284 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -82,6 +82,7 @@ export default function useUpdateModelParameters() { }, } setActiveAssistant(assistantInfo) + updateAssistantCallback(thread.id, assistantInfo) }, [ diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx index 96ff6f559a..9b4e67ffbf 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.test.tsx @@ -7,6 +7,8 @@ import { useAtomValue, useSetAtom } from 'jotai' import { useActiveModel } from '@/hooks/useActiveModel' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import AssistantSetting from './index' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' +import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' jest.mock('jotai', () => { const originalModule = jest.requireActual('jotai') @@ -68,6 +70,7 @@ describe('AssistantSetting Component', () => { beforeEach(() => { jest.clearAllMocks() + jest.useFakeTimers() }) test('renders AssistantSetting component with proper data', async () => { @@ -75,7 +78,14 @@ describe('AssistantSetting Component', () => { ;(useSetAtom as jest.Mock).mockImplementationOnce( () => setEngineParamsUpdate ) - ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) + ;(useAtomValue as jest.Mock).mockImplementation((atom) => { + switch (atom) { + case activeThreadAtom: + return mockActiveThread + case activeAssistantAtom: + return {} + } + }) const updateThreadMetadata = jest.fn() ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() }) ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({ @@ -98,7 +108,14 @@ describe('AssistantSetting Component', () => { const setEngineParamsUpdate = jest.fn() const updateThreadMetadata = jest.fn() const stopModel = jest.fn() - ;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) + ;(useAtomValue as jest.Mock).mockImplementation((atom) => { + switch (atom) { + case activeThreadAtom: + return mockActiveThread + case activeAssistantAtom: + return {} + } + }) ;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate) ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel }) ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({