diff --git a/agenta-web/src/components/Playground/AddToTestSetDrawer/AddToTestSetDrawer.tsx b/agenta-web/src/components/Playground/AddToTestSetDrawer/AddToTestSetDrawer.tsx index 740be0e668..6a03a069a8 100644 --- a/agenta-web/src/components/Playground/AddToTestSetDrawer/AddToTestSetDrawer.tsx +++ b/agenta-web/src/components/Playground/AddToTestSetDrawer/AddToTestSetDrawer.tsx @@ -1,11 +1,23 @@ import AlertPopup from "@/components/AlertPopup/AlertPopup" import {useAppTheme} from "../../Layout/ThemeContextProvider" -import {ChatMessage, GenericObject, testset} from "@/lib/Types" +import {ChatMessage, ChatRole, GenericObject, testset} from "@/lib/Types" import {removeKeys, renameVariables} from "@/lib/helpers/utils" import {createNewTestset, loadTestset, updateTestset, useLoadTestsetsList} from "@/lib/services/api" -import {Button, Divider, Drawer, Form, Input, Modal, Select, Typography, message} from "antd" +import { + Button, + Divider, + Drawer, + Form, + Input, + Modal, + Select, + Space, + Switch, + Typography, + message, +} from "antd" import {useRouter} from "next/router" -import React, {useCallback, useRef, useState} from "react" +import React, {useCallback, useLayoutEffect, useRef, useState} from "react" import {createUseStyles} from "react-jss" import {useUpdateEffect} from "usehooks-ts" import ChatInputs from "@/components/ChatInputs/ChatInputs" @@ -32,9 +44,43 @@ const useStyles = createUseStyles({ display: "flex", flexDirection: "column", gap: "0.75rem", + marginBottom: "1rem", }, }) +function flatToTurn({ + chat, + correct_answer, +}: { + chat?: ChatMessage[] + correct_answer?: ChatMessage | string +}) { + const flatChat = _.cloneDeep(chat || []) + if (correct_answer && typeof correct_answer !== "string") + flatChat.push(_.cloneDeep(correct_answer)) + + const turns: {chat: ChatMessage[]; correct_answer: ChatMessage}[] = [] + let currentTurn: ChatMessage[] = [] + flatChat.forEach((item) => { + if (item.role !== ChatRole.User) { + turns.push({ + chat: _.clone(currentTurn || []), + correct_answer: item, + }) + } + currentTurn.push(item) + }) + return turns +} + +function turnToFlat(turns: {chat: ChatMessage[]; correct_answer: ChatMessage}[]) { + const flat = _.cloneDeep(turns.at(-1)) + return { + chat: flat?.chat || [], + correct_answer: flat?.correct_answer || "", + } +} + type Props = React.ComponentProps & { params: GenericObject isChatVariant: boolean @@ -47,15 +93,19 @@ const AddToTestSetDrawer: React.FC = ({params, isChatVariant, ...props}) const [selectedTestset, setSelectedTestset] = useState() const [newTesetModalOpen, setNewTestsetModalOpen] = useState(false) const [loading, setLoading] = useState(false) + const [turnModeChat, setTurnModeChat] = useState< + {chat: ChatMessage[]; correct_answer: ChatMessage}[] | null + >(null) + const [shouldRender, setShouldRender] = useState(false) const dirty = useRef(false) const router = useRouter() const appId = router.query.app_id as string const isNew = selectedTestset === "-1" const {testsets, mutate, isTestsetsLoading, isTestsetsLoadingError} = useLoadTestsetsList(appId) - const chatParams = useRef({ - chat: params.chat || [], - correct_answer: params.correct_answer || "", + const chatParams = useRef<{chat: ChatMessage[]; correct_answer: ChatMessage | string}>({ + chat: [], + correct_answer: "", }).current // reset the form to load latest initialValues on drawer open @@ -65,9 +115,14 @@ const AddToTestSetDrawer: React.FC = ({params, isChatVariant, ...props}) //reset to defaults form.resetFields() - chatParams.chat = params.chat || [] - chatParams.correct_answer = params.correct_answer || "" - } else dirty.current = false + chatParams.chat = _.cloneDeep(params.chat || []) + chatParams.correct_answer = _.cloneDeep(params.correct_answer || "") + setTurnModeChat(null) + setShouldRender(true) + } else { + dirty.current = false + setShouldRender(false) + } }, [props.open]) const onClose = useCallback(() => { @@ -84,27 +139,30 @@ const AddToTestSetDrawer: React.FC = ({params, isChatVariant, ...props}) }, [props.onClose]) const addToTestSet = useCallback( - (name: string, csvdata: Record[], rowData: GenericObject) => { - rowData = {...rowData} - if (isChatVariant) { - rowData.chat = JSON.stringify( - rowData.chat.map((item: ChatMessage) => removeKeys(item, ["id"])), - ) - rowData.correct_answer = JSON.stringify(removeKeys(rowData.correct_answer, ["id"])) - } + (name: string, csvdata: Record[], rows: GenericObject[]) => { + const newRows = rows.map((item) => { + const row = {...item} + if (isChatVariant) { + row.chat = JSON.stringify( + row.chat.map((item: ChatMessage) => removeKeys(item, ["id"])), + ) + row.correct_answer = JSON.stringify(removeKeys(row.correct_answer, ["id"])) + } - setLoading(true) + setLoading(true) - const newRow: (typeof csvdata)[0] = {} - if (!isNew) { - Object.keys(csvdata?.[0] || {}).forEach((col) => { - newRow[col] = rowData[col] || "" - }) - } + const newRow: (typeof csvdata)[0] = {} + if (!isNew) { + Object.keys(csvdata?.[0] || {}).forEach((col) => { + newRow[col] = row[col] || "" + }) + } + return isNew ? row : newRow + }) const promise = isNew - ? createNewTestset(appId, name, [rowData]) - : updateTestset(selectedTestset!, name, [...csvdata, newRow]) + ? createNewTestset(appId, name, newRows) + : updateTestset(selectedTestset!, name, [...csvdata, ...newRows]) promise .then(() => { message.success(`Row added to the "${name}" test set!`) @@ -116,13 +174,13 @@ const AddToTestSetDrawer: React.FC = ({params, isChatVariant, ...props}) ) const onFinish = useCallback( - (values: any) => { + (values: GenericObject[]) => { if (isNew) { setNewTestsetModalOpen(true) } else { loadTestset(selectedTestset!).then((data) => { const testsetCols = Object.keys(data.csvdata?.[0] || {}) - const playgroundCols = Object.keys(values) + const playgroundCols = Object.keys(values[0]) const missingColsTestset = testsetCols.filter( (col) => !playgroundCols.includes(col), ) @@ -187,6 +245,22 @@ const AddToTestSetDrawer: React.FC = ({params, isChatVariant, ...props}) size="large" footer={
+ {isChatVariant && ( + + Turn by Turn: + { + setTurnModeChat(checked ? flatToTurn(chatParams) : null) + if (!checked && Array.isArray(turnModeChat)) { + const {chat, correct_answer} = turnToFlat(turnModeChat) + chatParams.chat = chat + chatParams.correct_answer = correct_answer + } + }} + /> + + )}