Skip to content

Commit

Permalink
feat: adding model params
Browse files Browse the repository at this point in the history
Signed-off-by: James <[email protected]>
  • Loading branch information
James committed Dec 6, 2023
1 parent 3bfe32a commit a639d30
Show file tree
Hide file tree
Showing 18 changed files with 244 additions and 83 deletions.
3 changes: 3 additions & 0 deletions core/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ export type ModelRuntimeParam = {
top_p?: number;
stream?: boolean;
max_tokens?: number;
stop?: string[];
frequency_penalty?: number;
presence_penalty?: number;
};

/**
Expand Down
2 changes: 1 addition & 1 deletion web/containers/DropdownListSidebar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'

import { toGigabytes } from '@/utils/converter'

import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'

export const selectedModelAtom = atom<Model | undefined>(undefined)

Expand Down
2 changes: 1 addition & 1 deletion web/containers/Layout/TopBar/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { useMainViewState } from '@/hooks/useMainViewState'

import { showRightSideBarAtom } from '@/screens/Chat/Sidebar'

import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'

const TopBar = () => {
const activeThread = useAtomValue(activeThreadAtom)
Expand Down
2 changes: 1 addition & 1 deletion web/containers/Providers/EventHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
import {
updateThreadWaitingForResponseAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'

export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom)
Expand Down
27 changes: 27 additions & 0 deletions web/containers/Slider/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
type Props = {
min: number
max: number
step: number
value: number
onValueUpdated: (value: number) => void
}

const Slider: React.FC<Props> = ({ min, max, step, value, onValueUpdated }) => {
const onChange = (e: React.ChangeEvent<HTMLInputElement>) => {
e.preventDefault()
onValueUpdated(Number(e.target.value))
}

return (
<input
value={value}
onChange={onChange}
type="range"
min={min}
max={max}
step={step}
/>
)
}

export default Slider
19 changes: 11 additions & 8 deletions web/helpers/atoms/ChatMessage.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { atom } from 'jotai'
import {
getActiveThreadIdAtom,
updateThreadStateLastMessageAtom,
} from './Conversation.atom'
} from './Thread.atom'

/**
* Stores all chat messages for all threads
Expand Down Expand Up @@ -76,15 +76,18 @@ export const addNewMessageAtom = atom(
}
)

export const deleteConversationMessage = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
export const deleteChatMessageAtom = atom(
null,
(get, set, threadId: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}
newData[threadId] = []
set(chatMessages, newData)
}
newData[id] = []
set(chatMessages, newData)
})
)

export const cleanConversationMessages = atom(null, (get, set, id: string) => {
export const cleanChatMessageAtom = atom(null, (get, set, id: string) => {
const newData: Record<string, ThreadMessage[]> = {
...get(chatMessages),
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import { Thread, ThreadContent, ThreadState } from '@janhq/core'
import {
ModelRuntimeParam,
Thread,
ThreadContent,
ThreadState,
} from '@janhq/core'
import { atom } from 'jotai'

/**
* Stores the current active conversation id.
* Stores the current active thread id.
*/
const activeThreadIdAtom = atom<string | undefined>(undefined)

export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom))

export const setActiveThreadIdAtom = atom(
null,
(_get, set, convoId: string | undefined) => set(activeThreadIdAtom, convoId)
(_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId)
)

export const waitingToSendMessage = atom<boolean | undefined>(undefined)
Expand All @@ -20,47 +25,27 @@ export const waitingToSendMessage = atom<boolean | undefined>(undefined)
*/
export const threadStatesAtom = atom<Record<string, ThreadState>>({})
export const activeThreadStateAtom = atom<ThreadState | undefined>((get) => {
const activeConvoId = get(activeThreadIdAtom)
if (!activeConvoId) {
console.debug('Active convo id is undefined')
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}

return get(threadStatesAtom)[activeConvoId]
return get(threadStatesAtom)[threadId]
})

export const updateThreadWaitingForResponseAtom = atom(
null,
(get, set, conversationId: string, waitingForResponse: boolean) => {
(get, set, threadId: string, waitingForResponse: boolean) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
currentState[threadId] = {
...currentState[threadId],
waitingForResponse,
error: undefined,
}
set(threadStatesAtom, currentState)
}
)
export const updateConversationErrorAtom = atom(
null,
(get, set, conversationId: string, error?: Error) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = {
...currentState[conversationId],
error,
}
set(threadStatesAtom, currentState)
}
)
export const updateConversationHasMoreAtom = atom(
null,
(get, set, conversationId: string, hasMore: boolean) => {
const currentState = { ...get(threadStatesAtom) }
currentState[conversationId] = { ...currentState[conversationId], hasMore }
set(threadStatesAtom, currentState)
}
)

export const updateThreadStateLastMessageAtom = atom(
null,
(get, set, threadId: string, lastContent?: ThreadContent[]) => {
Expand Down Expand Up @@ -100,3 +85,42 @@ export const threadsAtom = atom<Thread[]>([])
export const activeThreadAtom = atom<Thread | undefined>((get) =>
get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom))
)

/**
* Store model params at thread level settings
*/
export const threadModelRuntimeParamsAtom = atom<
Record<string, ModelRuntimeParam>
>({})

export const getActiveThreadModelRuntimeParamsAtom = atom<
ModelRuntimeParam | undefined
>((get) => {
const threadId = get(activeThreadIdAtom)
if (!threadId) {
console.debug('Active thread id is undefined')
return undefined
}

return get(threadModelRuntimeParamsAtom)[threadId]
})

export const getThreadModelRuntimeParamsAtom = atom(
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
)

export const setThreadModelRuntimeParamsAtom = atom(
null,
(get, set, threadId: string, params: ModelRuntimeParam) => {
const currentState = { ...get(threadModelRuntimeParamsAtom) }
currentState[threadId] = params
console.debug(
`Update model params for thread ${threadId}, ${JSON.stringify(
params,
null,
2
)}`
)
set(threadModelRuntimeParamsAtom, currentState)
}
)
19 changes: 11 additions & 8 deletions web/hooks/useCreateNewThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ import {
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'

import { generateThreadId } from '@/utils/conversation'
import { generateThreadId } from '@/utils/thread'

import { extensionManager } from '@/extension'
import {
threadsAtom,
setActiveThreadIdAtom,
threadStatesAtom,
activeThreadAtom,
updateThreadAtom,
} from '@/helpers/atoms/Conversation.atom'
setThreadModelRuntimeParamsAtom,
} from '@/helpers/atoms/Thread.atom'

const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
// create thread state for this new thread
Expand All @@ -42,6 +42,10 @@ export const useCreateNewThread = () => {
const threads = useAtomValue(threadsAtom)
const updateThread = useSetAtom(updateThreadAtom)

const setThreadModelRuntimeParams = useSetAtom(
setThreadModelRuntimeParamsAtom
)

const requestCreateNewThread = async (assistant: Assistant) => {
const unfinishedThreads = threads.filter((t) => t.isFinishInit === false)
if (unfinishedThreads.length > 0) {
Expand Down Expand Up @@ -88,19 +92,18 @@ export const useCreateNewThread = () => {
lastMessage: undefined,
}
setThreadStates({ ...threadStates, [threadId]: threadState })
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)

// add the new thread on top of the thread list to the state
createNewThread(thread)
setActiveThreadId(thread.id)
}

function updateThreadMetadata(thread: Thread) {
const updatedThread: Thread = {
...thread,
}
updateThread(updatedThread)
updateThread(thread)
extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.saveThread(updatedThread)
?.saveThread(thread)
}

return {
Expand Down
30 changes: 14 additions & 16 deletions web/hooks/useDeleteConversation.ts → web/hooks/useDeleteThread.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,32 @@ import { useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager'

import {
cleanConversationMessages,
deleteConversationMessage,
cleanChatMessageAtom as cleanChatMessagesAtom,
deleteChatMessageAtom as deleteChatMessagesAtom,
getCurrentChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import {
threadsAtom,
setActiveThreadIdAtom,
} from '@/helpers/atoms/Conversation.atom'
import { threadsAtom, setActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'

export default function useDeleteThread() {
const { activeModel } = useActiveModel()
const [threads, setThreads] = useAtom(threadsAtom)
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const messages = useAtomValue(getCurrentChatMessagesAtom)

const setActiveConvoId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteConversationMessage)
const cleanMessages = useSetAtom(cleanConversationMessages)
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
const deleteMessages = useSetAtom(deleteChatMessagesAtom)
const cleanMessages = useSetAtom(cleanChatMessagesAtom)

const cleanThread = async (threadId: string) => {
if (threadId) {
const thread = threads.filter((c) => c.id === threadId)[0]
cleanMessages(threadId)

const cleanThread = async (activeThreadId: string) => {
if (activeThreadId) {
const thread = threads.filter((c) => c.id === activeThreadId)[0]
cleanMessages(activeThreadId)
if (thread)
await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.writeMessages(
activeThreadId,
threadId,
messages.filter((msg) => msg.role === ChatCompletionRole.System)
)
}
Expand All @@ -62,9 +60,9 @@ export default function useDeleteThread() {
description: `Thread with ${activeModel?.name} has been successfully deleted.`,
})
if (availableThreads.length > 0) {
setActiveConvoId(availableThreads[0].id)
setActiveThreadId(availableThreads[0].id)
} else {
setActiveConvoId(undefined)
setActiveThreadId(undefined)
}
} catch (err) {
console.error(err)
Expand Down
34 changes: 24 additions & 10 deletions web/hooks/useGetAllThreads.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,49 @@
import { ExtensionType, ThreadState } from '@janhq/core'
import { ExtensionType, ModelRuntimeParam, ThreadState } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useSetAtom } from 'jotai'

import { extensionManager } from '@/extension/ExtensionManager'
import {
threadModelRuntimeParamsAtom,
threadStatesAtom,
threadsAtom,
} from '@/helpers/atoms/Conversation.atom'
} from '@/helpers/atoms/Thread.atom'

const useGetAllThreads = () => {
const setConversationStates = useSetAtom(threadStatesAtom)
const setConversations = useSetAtom(threadsAtom)
const setThreadStates = useSetAtom(threadStatesAtom)
const setThreads = useSetAtom(threadsAtom)
const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom)

const getAllThreads = async () => {
try {
const threads = await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()
const threads =
(await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.getThreads()) ?? []

const threadStates: Record<string, ThreadState> = {}
threads?.forEach((thread) => {
const threadModelParams: Record<string, ModelRuntimeParam> = {}

threads.forEach((thread) => {
if (thread.id != null) {
const lastMessage = (thread.metadata?.lastMessage as string) ?? ''

threadStates[thread.id] = {
hasMore: true,
waitingForResponse: false,
lastMessage,
}

// model params
const modelParams = thread.assistants?.[0]?.model?.parameters
threadModelParams[thread.id] = modelParams
}
})
setConversationStates(threadStates)
setConversations(threads ?? [])

// updating app states
setThreadStates(threadStates)
setThreads(threads)
setThreadModelRuntimeParams(threadModelParams)
} catch (error) {
console.error(error)
}
Expand Down
Loading

0 comments on commit a639d30

Please sign in to comment.