Skip to content

Commit

Permalink
Merge pull request #4159 from janhq/fix/correct-token-speed-calculating
Browse files Browse the repository at this point in the history
fix: token speed should not be calculated based on state updates
  • Loading branch information
louis-jan authored Nov 29, 2024
2 parents a60022c + 7b1a1be commit 0f834a6
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 31 deletions.
27 changes: 27 additions & 0 deletions web/containers/Providers/ModelHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
getCurrentChatMessagesAtom,
addNewMessageAtom,
updateMessageAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import {
Expand Down Expand Up @@ -62,6 +63,7 @@ export default function ModelHandler() {
const activeModelRef = useRef(activeModel)
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const activeModelParamsRef = useRef(activeModelParams)
const setTokenSpeed = useSetAtom(tokenSpeedAtom)

useEffect(() => {
threadsRef.current = threads
Expand Down Expand Up @@ -179,6 +181,31 @@ export default function ModelHandler() {
if (message.content.length) {
setIsGeneratingResponse(false)
}

setTokenSpeed((prev) => {
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!prev) {
// If this is the first update, just set the lastTimestamp and return
return {
lastTimestamp: currentTimestamp,
tokenSpeed: 0,
tokenCount: 1,
message: message.id,
}
}

const timeDiffInSeconds =
(currentTimestamp - prev.lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = prev.tokenCount + 1
const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
return {
...prev,
tokenSpeed: averageTokenSpeed,
tokenCount: totalTokenCount,
message: message.id,
}
})
return
} else if (
message.status === MessageStatus.Error &&
Expand Down
9 changes: 9 additions & 0 deletions web/helpers/atoms/ChatMessage.atom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@ import {
updateThreadStateLastMessageAtom,
} from './Thread.atom'

import { TokenSpeed } from '@/types/token'

/**
* Stores all chat messages for all threads
*/
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})

/**
* Stores the status of the messages load for each thread
*/
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})

/**
* Store the token speed for current message
*/
export const tokenSpeedAtom = atom<TokenSpeed | undefined>(undefined)
/**
* Return the chat messages for the current active conversation
*/
Expand Down
7 changes: 3 additions & 4 deletions web/hooks/useSendChatMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
addNewMessageAtom,
deleteMessageAtom,
getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
Expand All @@ -45,7 +46,6 @@ import {
updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Thread.atom'

export const queuedMessageAtom = atom(false)
export const reloadModelAtom = atom(false)

export default function useSendChatMessage() {
Expand All @@ -70,7 +70,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setTokenSpeed = useSetAtom(tokenSpeedAtom)

const selectedModelRef = useRef<Model | undefined>()

Expand Down Expand Up @@ -147,6 +147,7 @@ export default function useSendChatMessage() {
}

if (engineParamsUpdate) setReloadModel(true)
setTokenSpeed(undefined)

const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = extractModelLoadParams(activeModelParams)
Expand Down Expand Up @@ -231,9 +232,7 @@ export default function useSendChatMessage() {
}

if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
const error = await startModel(modelId).catch((error: Error) => error)
setQueuedMessage(false)
if (error) {
updateThreadWaiting(activeThreadRef.current.id, false)
return
Expand Down
33 changes: 6 additions & 27 deletions web/screens/Thread/ThreadCenterPanel/SimpleTextMessage/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import { RelativeImage } from './RelativeImage'
import {
editMessageAtom,
getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'

Expand Down Expand Up @@ -233,32 +234,9 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
}

const { onViewFile, onViewFileContainer } = usePath()
const [tokenCount, setTokenCount] = useState(0)
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
const [tokenSpeed, setTokenSpeed] = useState(0)
const tokenSpeed = useAtomValue(tokenSpeedAtom)
const messages = useAtomValue(getCurrentChatMessagesAtom)

useEffect(() => {
if (props.status !== MessageStatus.Pending) {
return
}
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!lastTimestamp) {
// If this is the first update, just set the lastTimestamp and return
if (props.content[0]?.text?.value !== '')
setLastTimestamp(currentTimestamp)
return
}

const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = tokenCount + 1
const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed

setTokenSpeed(averageTokenSpeed)
setTokenCount(totalTokenCount)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [props.content])

return (
<div className="group relative mx-auto max-w-[700px] p-4">
<div
Expand Down Expand Up @@ -308,10 +286,11 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
>
<MessageToolbar message={props} />
</div>
{messages[messages.length - 1]?.id === props.id &&
(props.status === MessageStatus.Pending || tokenSpeed > 0) && (
{tokenSpeed &&
tokenSpeed.message === props.id &&
tokenSpeed.tokenSpeed > 0 && (
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
Token Speed: {Number(tokenSpeed).toFixed(2)}t/s
Token Speed: {Number(tokenSpeed.tokenSpeed).toFixed(2)}t/s
</p>
)}
</div>
Expand Down
6 changes: 6 additions & 0 deletions web/types/token.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export type TokenSpeed = {
message: string
tokenSpeed: number
tokenCount: number
lastTimestamp: number
}

0 comments on commit 0f834a6

Please sign in to comment.