Skip to content

Commit

Permalink
turn by turn switch while add to teset for chat variants:
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammedMaaz committed Dec 9, 2023
1 parent f180e1c commit e2ab1ae
Showing 1 changed file with 174 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import AlertPopup from "@/components/AlertPopup/AlertPopup"
import {useAppTheme} from "../../Layout/ThemeContextProvider"
import {ChatMessage, GenericObject, testset} from "@/lib/Types"
import {ChatMessage, ChatRole, GenericObject, testset} from "@/lib/Types"
import {removeKeys, renameVariables} from "@/lib/helpers/utils"
import {createNewTestset, loadTestset, updateTestset, useLoadTestsetsList} from "@/lib/services/api"
import {Button, Divider, Drawer, Form, Input, Modal, Select, Typography, message} from "antd"
import {
Button,
Divider,
Drawer,
Form,
Input,
Modal,
Select,
Space,
Switch,
Typography,
message,
} from "antd"
import {useRouter} from "next/router"
import React, {useCallback, useRef, useState} from "react"
import React, {useCallback, useLayoutEffect, useRef, useState} from "react"
import {createUseStyles} from "react-jss"
import {useUpdateEffect} from "usehooks-ts"
import ChatInputs from "@/components/ChatInputs/ChatInputs"
Expand All @@ -32,9 +44,43 @@ const useStyles = createUseStyles({
display: "flex",
flexDirection: "column",
gap: "0.75rem",
marginBottom: "1rem",
},
})

function flatToTurn({
chat,
correct_answer,
}: {
chat?: ChatMessage[]
correct_answer?: ChatMessage | string
}) {
const flatChat = _.cloneDeep(chat || [])
if (correct_answer && typeof correct_answer !== "string")
flatChat.push(_.cloneDeep(correct_answer))

const turns: {chat: ChatMessage[]; correct_answer: ChatMessage}[] = []
let currentTurn: ChatMessage[] = []
flatChat.forEach((item) => {
if (item.role !== ChatRole.User) {
turns.push({
chat: _.clone(currentTurn || []),
correct_answer: item,
})
}
currentTurn.push(item)
})
return turns
}

function turnToFlat(turns: {chat: ChatMessage[]; correct_answer: ChatMessage}[]) {
const flat = _.cloneDeep(turns.at(-1))
return {
chat: flat?.chat || [],
correct_answer: flat?.correct_answer || "",
}
}

type Props = React.ComponentProps<typeof Drawer> & {
params: GenericObject
isChatVariant: boolean
Expand All @@ -47,15 +93,19 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
const [selectedTestset, setSelectedTestset] = useState<string>()
const [newTesetModalOpen, setNewTestsetModalOpen] = useState(false)
const [loading, setLoading] = useState(false)
const [turnModeChat, setTurnModeChat] = useState<
{chat: ChatMessage[]; correct_answer: ChatMessage}[] | null
>(null)
const [shouldRender, setShouldRender] = useState(false)
const dirty = useRef(false)
const router = useRouter()
const appId = router.query.app_id as string
const isNew = selectedTestset === "-1"

const {testsets, mutate, isTestsetsLoading, isTestsetsLoadingError} = useLoadTestsetsList(appId)
const chatParams = useRef({
chat: params.chat || [],
correct_answer: params.correct_answer || "",
const chatParams = useRef<{chat: ChatMessage[]; correct_answer: ChatMessage | string}>({
chat: [],
correct_answer: "",
}).current

// reset the form to load latest initialValues on drawer open
Expand All @@ -65,9 +115,14 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})

//reset to defaults
form.resetFields()
chatParams.chat = params.chat || []
chatParams.correct_answer = params.correct_answer || ""
} else dirty.current = false
chatParams.chat = _.cloneDeep(params.chat || [])
chatParams.correct_answer = _.cloneDeep(params.correct_answer || "")
setTurnModeChat(null)
setShouldRender(true)
} else {
dirty.current = false
setShouldRender(false)
}
}, [props.open])

const onClose = useCallback(() => {
Expand All @@ -84,27 +139,30 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
}, [props.onClose])

const addToTestSet = useCallback(
(name: string, csvdata: Record<string, string>[], rowData: GenericObject) => {
rowData = {...rowData}
if (isChatVariant) {
rowData.chat = JSON.stringify(
rowData.chat.map((item: ChatMessage) => removeKeys(item, ["id"])),
)
rowData.correct_answer = JSON.stringify(removeKeys(rowData.correct_answer, ["id"]))
}
(name: string, csvdata: Record<string, string>[], rows: GenericObject[]) => {
const newRows = rows.map((item) => {
const row = {...item}
if (isChatVariant) {
row.chat = JSON.stringify(
row.chat.map((item: ChatMessage) => removeKeys(item, ["id"])),
)
row.correct_answer = JSON.stringify(removeKeys(row.correct_answer, ["id"]))
}

setLoading(true)
setLoading(true)

const newRow: (typeof csvdata)[0] = {}
if (!isNew) {
Object.keys(csvdata?.[0] || {}).forEach((col) => {
newRow[col] = rowData[col] || ""
})
}
const newRow: (typeof csvdata)[0] = {}
if (!isNew) {
Object.keys(csvdata?.[0] || {}).forEach((col) => {
newRow[col] = row[col] || ""
})
}
return isNew ? row : newRow
})

const promise = isNew
? createNewTestset(appId, name, [rowData])
: updateTestset(selectedTestset!, name, [...csvdata, newRow])
? createNewTestset(appId, name, newRows)
: updateTestset(selectedTestset!, name, [...csvdata, ...newRows])
promise
.then(() => {
message.success(`Row added to the "${name}" test set!`)
Expand All @@ -116,13 +174,13 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
)

const onFinish = useCallback(
(values: any) => {
(values: GenericObject[]) => {
if (isNew) {
setNewTestsetModalOpen(true)
} else {
loadTestset(selectedTestset!).then((data) => {
const testsetCols = Object.keys(data.csvdata?.[0] || {})
const playgroundCols = Object.keys(values)
const playgroundCols = Object.keys(values[0])
const missingColsTestset = testsetCols.filter(
(col) => !playgroundCols.includes(col),
)
Expand Down Expand Up @@ -187,6 +245,22 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
size="large"
footer={
<div className={classes.footer}>
{isChatVariant && (
<Space align="center">
<Typography.Text>Turn by Turn:</Typography.Text>
<Switch
checked={Array.isArray(turnModeChat)}
onChange={(checked) => {
setTurnModeChat(checked ? flatToTurn(chatParams) : null)
if (!checked && Array.isArray(turnModeChat)) {
const {chat, correct_answer} = turnToFlat(turnModeChat)
chatParams.chat = chat
chatParams.correct_answer = correct_answer
}
}}
/>
</Space>
)}
<Select
placeholder="Select Test Set"
className={classes.selector}
Expand All @@ -211,7 +285,9 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
onClick={
isChatVariant
? () => {
onFinish(chatParams)
onFinish(
Array.isArray(turnModeChat) ? turnModeChat : [chatParams],
)
}
: form.submit
}
Expand All @@ -223,47 +299,75 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
{...props}
onClose={onClose}
>
{isChatVariant ? (
<div>
<div className={classes.chatContainer}>
<Typography.Text strong>Chat</Typography.Text>
<ChatInputs
defaultValue={
params.chat?.length ? _.cloneDeep(params.chat) : undefined
}
onChange={(val) => {
chatParams.chat = val
dirty.current = true
}}
/>
</div>
{!shouldRender ? null : isChatVariant ? (
Array.isArray(turnModeChat) ? (
turnModeChat.map((turn, index) => (
<div key={index}>
<div className={classes.chatContainer}>
<Typography.Text strong>Chat</Typography.Text>
<ChatInputs
defaultValue={turn.chat}
onChange={(val) => {
turn.chat = val
dirty.current = true
}}
/>
</div>

<div className={classes.chatContainer}>
<Typography.Text strong>Correct Answer</Typography.Text>
<ChatInputs
defaultValue={[turn.correct_answer]}
onChange={(val) => {
turn.correct_answer = val[0]
dirty.current = true
}}
disableAdd
disableRemove
/>
</div>

<Divider />
<Divider />
</div>
))
) : (
<div>
<div className={classes.chatContainer}>
<Typography.Text strong>Chat</Typography.Text>
<ChatInputs
defaultValue={chatParams.chat}
onChange={(val) => {
chatParams.chat = val
dirty.current = true
}}
/>
</div>

<div className={classes.chatContainer}>
<Typography.Text strong>Correct Answer</Typography.Text>
<ChatInputs
defaultValue={
params.correct_answer
? [_.cloneDeep(params.correct_answer)]
: undefined
}
onChange={(val) => {
chatParams.correct_answer = val[0]
dirty.current = true
}}
disableAdd
disableRemove
/>
<div className={classes.chatContainer}>
<Typography.Text strong>Correct Answer</Typography.Text>
<ChatInputs
defaultValue={
chatParams.correct_answer
? [chatParams.correct_answer as ChatMessage]
: undefined
}
onChange={(val) => {
chatParams.correct_answer = val[0]
dirty.current = true
}}
disableAdd
disableRemove
/>
</div>
</div>
</div>
)
) : (
<Form
onValuesChange={() => (dirty.current = true)}
form={form}
initialValues={params}
layout="vertical"
onFinish={onFinish}
onFinish={(values) => onFinish([values])}
>
{Object.keys(params).map((name) => (
<Form.Item key={name} label={renameVariables(name)} name={name}>
Expand All @@ -277,7 +381,15 @@ const AddToTestSetDrawer: React.FC<Props> = ({params, isChatVariant, ...props})
onCancel={() => setNewTestsetModalOpen(false)}
destroyOnClose
onSubmit={(name) =>
addToTestSet(name, [], isChatVariant ? chatParams : form.getFieldsValue())
addToTestSet(
name,
[],
isChatVariant
? Array.isArray(turnModeChat)
? turnModeChat
: [chatParams]
: [form.getFieldsValue()],
)
}
/>
</Drawer>
Expand Down

0 comments on commit e2ab1ae

Please sign in to comment.