diff --git a/agenta-web/src/components/Playground/Views/TestView.tsx b/agenta-web/src/components/Playground/Views/TestView.tsx index 31e7fc4b77..0e63e18b99 100644 --- a/agenta-web/src/components/Playground/Views/TestView.tsx +++ b/agenta-web/src/components/Playground/Views/TestView.tsx @@ -1,6 +1,6 @@ -import React, {useContext, useEffect, useState} from "react" +import React, {useContext, useEffect, useRef, useState} from "react" import {Button, Input, Card, Row, Col, Space, Form} from "antd" -import {CaretRightOutlined, PlusOutlined} from "@ant-design/icons" +import {CaretRightOutlined, CloseCircleOutlined, PlusOutlined} from "@ant-design/icons" import {callVariant} from "@/lib/services/api" import {ChatMessage, ChatRole, GenericObject, Parameter, Variant} from "@/lib/Types" import {batchExecute, randString, removeKeys} from "@/lib/helpers/utils" @@ -109,6 +109,7 @@ interface BoxComponentProps { onDelete?: () => void isChatVariant?: boolean variant: Variant + onCancel: () => void } const BoxComponent: React.FC = ({ @@ -122,6 +123,7 @@ const BoxComponent: React.FC = ({ onDelete, isChatVariant = false, variant, + onCancel, }) => { const {appTheme} = useAppTheme() const classes = useStylesBox() @@ -208,17 +210,29 @@ const BoxComponent: React.FC = ({ disabled={loading || !result} shape="round" /> - + {loading ? ( + + ) : ( + + )} {!isChatVariant && ( @@ -276,6 +290,15 @@ const App: React.FC = ({ }> >(testList.map(() => ({cost: null, latency: null, usage: null}))) + const abortControllersRef = useRef([]) + const [isRunningAll, setIsRunningAll] = useState(false) + + useEffect(() => { + return () => { + abortControllersRef.current.forEach((controller) => controller.abort()) + } + }, []) + useEffect(() => { setResultsList((prevResultsList) => { const newResultsList = testList.map((_, index) => { @@ -326,25 +349,30 @@ const App: React.FC = ({ } const handleRun = async (index: number) => { + const controller = new AbortController() + abortControllersRef.current[index] = controller try { const testItem = testList[index] if (compareMode && !isRunning[index]) { - setIsRunning( - (prevState) => { - const newState = [...prevState] - newState[index] = true - return newState - }, - () => { - document - .querySelectorAll(`.testview-run-button-${testItem._id}`) - .forEach((btn) => { - if (btn.parentElement?.id !== variant.variantId) { - ;(btn as HTMLButtonElement).click() - } - }) - }, - ) + let called = false + const callback = () => { + if (called) return + called = true + document + .querySelectorAll(`.testview-run-button-${testItem._id}`) + .forEach((btn) => { + if (btn.parentElement?.id !== variant.variantId) { + ;(btn as HTMLButtonElement).click() + } + }) + } + + setIsRunning((prevState) => { + const newState = [...prevState] + newState[index] = true + return newState + }, callback) + setTimeout(callback, 300) } setResultForIndex(LOADING_TEXT, index) @@ -355,7 +383,10 @@ const App: React.FC = ({ appId || "", variant.baseId || "", isChatVariant ? testItem.chat : [], + controller.signal, + true, ) + // check if res is an object or string if (typeof res === "string") { setResultForIndex(res, index) @@ -368,7 +399,16 @@ const App: React.FC = ({ }) } } catch (e) { - setResultForIndex(`❌ ${getErrorMessage(e)}`, index) + if (!controller.signal.aborted) { + setResultForIndex(`❌ ${getErrorMessage(e)}`, index) + } else { + setResultForIndex("", index) + setAdditionalDataList((prev) => { + const newDataList = [...prev] + newDataList[index] = {cost: null, latency: null, usage: null} + return newDataList + }) + } } finally { setIsRunning((prevState) => { const newState = [...prevState] @@ -378,13 +418,38 @@ const App: React.FC = ({ } } - const handleRunAll = () => { + const handleCancel = (index: number) => { + if (abortControllersRef.current[index]) { + abortControllersRef.current[index].abort() + } + if (compareMode && isRunning[index]) { + const testItem = testList[index] + + document.querySelectorAll(`.testview-cancel-button-${testItem._id}`).forEach((btn) => { + if (btn.parentElement?.id !== variant.variantId) { + ;(btn as HTMLButtonElement).click() + } + }) + } + } + + const handleCancelAll = () => { + const funcs: Function[] = [] + rootRef.current + ?.querySelectorAll("[class*=testview-cancel-button-]") + .forEach((btn) => funcs.push(() => (btn as HTMLButtonElement).click())) + batchExecute(funcs) + } + + const handleRunAll = async () => { const funcs: Function[] = [] rootRef.current ?.querySelectorAll("[data-cy=testview-input-parameters-run-button]") .forEach((btn) => funcs.push(() => (btn as HTMLButtonElement).click())) - batchExecute(funcs) + setIsRunningAll(true) + await batchExecute(funcs) + setIsRunningAll(false) } const handleAddRow = () => { @@ -432,14 +497,25 @@ const App: React.FC = ({ - + {!isRunningAll ? ( + + ) : ( + + )} @@ -463,6 +539,7 @@ const App: React.FC = ({ onDelete={testList.length >= 2 ? () => handleDeleteRow(index) : undefined} isChatVariant={isChatVariant} variant={variant} + onCancel={() => handleCancel(index)} /> ))}