diff --git a/agenta-web/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx b/agenta-web/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx index bfa922fe01..050c6d091d 100644 --- a/agenta-web/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx +++ b/agenta-web/src/components/pages/evaluations/evaluationCompare/EvaluationCompare.tsx @@ -11,12 +11,18 @@ import { import {fetchAllComparisonResults} from "@/services/evaluations" import {ColDef} from "ag-grid-community" import {AgGridReact} from "ag-grid-react" -import {Button, Space, Spin, Switch, Tag, Tooltip, Typography} from "antd" +import {Button, Dropdown, DropdownProps, Space, Spin, Switch, Tag, Tooltip, Typography} from "antd" import React, {useEffect, useMemo, useRef, useState} from "react" import {createUseStyles} from "react-jss" import {getFilterParams, getTypedValue} from "@/lib/helpers/evaluate" import {getColorFromStr, getRandomColors} from "@/lib/helpers/colors" -import {CloseCircleOutlined, DownloadOutlined, UndoOutlined} from "@ant-design/icons" +import { + CheckOutlined, + CloseCircleOutlined, + DownOutlined, + DownloadOutlined, + UndoOutlined, +} from "@ant-design/icons" import {getAppValues} from "@/contexts/app.context" import {useQueryParam} from "@/hooks/useQuery" import {LongTextCellRenderer} from "../cellRenderers/cellRenderers" @@ -49,6 +55,21 @@ const useStyles = createUseStyles((theme: JSSTheme) => ({ }, }, }, + dropdownMenu: { + "&>.ant-dropdown-menu-item": { + "& .anticon-check": { + display: "none", + }, + }, + "&>.ant-dropdown-menu-item-selected": { + "&:not(:hover)": { + backgroundColor: "transparent !important", + }, + "& .anticon-check": { + display: "inline-flex !important", + }, + }, + }, })) interface Props {} @@ -58,7 +79,8 @@ const EvaluationCompareMode: React.FC = () => { const classes = useStyles() const {appTheme} = useAppTheme() const [evaluationIdsStr = ""] = useQueryParam("evaluations") - const [evalIds, setEvalIds] = useState(evaluationIdsStr.split(",").filter((item) => !!item)) + const evaluationIdsArray = evaluationIdsStr.split(",").filter((item) => !!item) + const [evalIds, setEvalIds] = useState(evaluationIdsArray) const [hiddenVariants, setHiddenVariants] = useState([]) const [showDiff, setShowDiff] = useLocalStorage("showDiff", "show") const [fetching, setFetching] = useState(false) @@ -66,6 +88,13 @@ const EvaluationCompareMode: React.FC = () => { const [testset, setTestset] = useState() const [evaluators] = useAtom(evaluatorsAtom) const gridRef = useRef>() + const [filterColsDropdown, setFilterColsDropdown] = useState(false) + + const handleOpenChange: DropdownProps["onOpenChange"] = (nextOpen, info) => { + if (info.source === "trigger" || nextOpen) { + setFilterColsDropdown(nextOpen) + } + } const variants = useMemo(() => { return rows[0]?.variants || [] @@ -116,8 +145,6 @@ const EvaluationCompareMode: React.FC = () => { }) variants.forEach((variant, vi) => { - const isHidden = evalIds.includes(variant.evaluationId) - colDefs.push({ headerComponent: (props: any) => ( @@ -127,11 +154,12 @@ const EvaluationCompareMode: React.FC = () => { ), + headerName: "Output", minWidth: 280, flex: 1, field: `variants.${vi}.output` as any, ...getFilterParams("text"), - hide: !isHidden, + hide: !evalIds.includes(variant.evaluationId) || hiddenVariants.includes("Output"), cellRenderer: (params: any) => { return ( <> @@ -201,11 +229,9 @@ const EvaluationCompareMode: React.FC = () => { Object.entries(confgisMap).forEach(([_, configs]) => { configs.forEach(({config, variant, color}) => { - const isHidden = evalIds.includes(variant.evaluationId) colDefs.push({ flex: 1, minWidth: 200, - headerName: config.name, headerComponent: (props: any) => { const evaluator = evaluators.find( (item) => item.key === config.evaluator_key, @@ -222,9 +248,12 @@ const EvaluationCompareMode: React.FC = () => { ) }, + headerName: config.name, field: "variants.0.evaluatorConfigs.0.result" as any, ...getFilterParams("text"), - hide: !isHidden, + hide: + !evalIds.includes(variant.evaluationId) || + hiddenVariants.includes(config.name), valueGetter: (params) => { return getTypedValue( params.data?.variants @@ -248,7 +277,9 @@ const EvaluationCompareMode: React.FC = () => { ), + hide: !evalIds.includes(variant.evaluationId) || hiddenVariants.includes("Latency"), minWidth: 120, + headerName: "Latency", flex: 1, valueGetter: (params) => { const latency = params.data?.variants.find( @@ -270,7 +301,9 @@ const EvaluationCompareMode: React.FC = () => { ), + headerName: "Cost", minWidth: 120, + hide: !evalIds.includes(variant.evaluationId) || hiddenVariants.includes("Cost"), flex: 1, valueGetter: (params) => { const cost = params.data?.variants.find( @@ -283,7 +316,7 @@ const EvaluationCompareMode: React.FC = () => { }) return colDefs - }, [rows, showDiff, evalIds]) + }, [rows, showDiff, hiddenVariants, evalIds]) const fetcher = () => { setFetching(true) @@ -311,15 +344,25 @@ const EvaluationCompareMode: React.FC = () => { }, [appId, evaluationIdsStr]) const handleToggleVariantVisibility = (evalId: string) => { - if (hiddenVariants.includes(evalId)) { - setHiddenVariants(hiddenVariants.filter((item) => item !== evalId)) - setEvalIds([...evalIds, evalId]) - } else { + if (!hiddenVariants.includes(evalId)) { setHiddenVariants([...hiddenVariants, evalId]) setEvalIds(evalIds.filter((val) => val !== evalId)) + } else { + setHiddenVariants(hiddenVariants.filter((item) => item !== evalId)) + if (evaluationIdsArray.includes(evalId)) { + setEvalIds([...evalIds, evalId]) + } } } + const shownCols = useMemo( + () => + colDefs + .map((item) => item.headerName) + .filter((item) => item !== undefined && !hiddenVariants.includes(item)) as string[], + [colDefs], + ) + const onExport = () => { if (!gridRef.current) return const {currentApp} = getAppValues() @@ -399,6 +442,44 @@ const EvaluationCompareMode: React.FC = () => { onClick={() => setShowDiff(showDiff === "show" ? "hide" : "show")} /> + + !item.headerName?.startsWith("Input") && + !item.headerName?.includes("Expected Output"), + ) + .reduce((acc, curr) => { + if (curr.headerName && !acc.includes(curr.headerName)) { + acc.push(curr.headerName) + } + return acc + }, [] as string[]) + .map((configs) => ({ + key: configs as string, + label: ( + + + <>{configs} + + ), + })), + onClick: ({key}) => { + handleToggleVariantVisibility(key) + setFilterColsDropdown(true) + }, + className: classes.dropdownMenu, + }} + > + + +