diff --git a/agenta-web/cypress/tsconfig.json b/agenta-web/cypress/tsconfig.json index dc618360a0..1235d8c20f 100644 --- a/agenta-web/cypress/tsconfig.json +++ b/agenta-web/cypress/tsconfig.json @@ -2,7 +2,7 @@ "compilerOptions": { "target": "es5", "lib": ["es5", "dom"], - "types": ["cypress", "node"] + "types": ["cypress", "node"], }, - "include": ["**/*.ts"] + "include": ["**/*.ts"], } diff --git a/agenta-web/src/components/EvaluationTable/ABTestingEvaluationTable.tsx b/agenta-web/src/components/EvaluationTable/ABTestingEvaluationTable.tsx index b1610a4aa9..e05d6b8965 100644 --- a/agenta-web/src/components/EvaluationTable/ABTestingEvaluationTable.tsx +++ b/agenta-web/src/components/EvaluationTable/ABTestingEvaluationTable.tsx @@ -1,6 +1,18 @@ -import {useState, useEffect} from "react" +import {useState, useEffect, useCallback} from "react" import type {ColumnType} from "antd/es/table" -import {Button, Card, Col, Radio, Row, Space, Statistic, Table, Typography, message} from "antd" +import { + Button, + Card, + Col, + Input, + Radio, + Row, + Space, + Statistic, + Table, + Typography, + message, +} from "antd" import { updateEvaluationScenario, callVariant, @@ -22,6 +34,7 @@ import EvaluationVotePanel from "../Evaluations/EvaluationCardView/EvaluationVot import VariantAlphabet from "../Evaluations/EvaluationCardView/VariantAlphabet" import {ParamsFormWithRun} from "./SingleModelEvaluationTable" import {PassThrough} from "stream" +import {debounce} from "lodash" const {Title} = Typography @@ -90,6 +103,22 @@ const useStyles = createUseStyles({ top: 36, zIndex: 1, }, + sideBar: { + marginTop: "1rem", + display: "flex", + flexDirection: "column", + gap: "2rem", + border: "1px solid #d9d9d9", + borderRadius: 6, + padding: "1rem", + alignSelf: "flex-start", + "&>h4.ant-typography": { + margin: 0, + }, + flex: 0.35, + minWidth: 240, + maxWidth: 500, + }, }) const ABTestingEvaluationTable: React.FC = ({ @@ -117,6 +146,13 @@ const ABTestingEvaluationTable: React.FC = ({ evaluationResults?.votes_data?.variants_votes_data?.[evaluation.variants[1]?.variantId] ?.number_of_votes || 0 + const depouncedUpdateEvaluationScenario = useCallback( + debounce((data: Partial, scenarioId) => { + updateEvaluationScenarioData(scenarioId, data) + }, 800), + [evaluationScenarios], + ) + useEffect(() => { if (evaluationScenarios) { const obj = [...evaluationScenarios] @@ -297,7 +333,7 @@ const ABTestingEvaluationTable: React.FC = ({ ), dataIndex: columnKey, key: columnKey, - width: "25%", + width: "20%", render: (text: any, record: ABTestingEvaluationTableRow, rowIndex: number) => { if (text) return text if (record.outputs && record.outputs.length > 0) { @@ -345,29 +381,74 @@ const ABTestingEvaluationTable: React.FC = ({ }, }, { - key: "correctAnswer", title: "Expected Output", - dataIndex: "correctAnswer", + dataIndex: "expectedOutput", + key: "expectedOutput", width: "25%", + render: (text: any, record: any, rowIndex: number) => { + let correctAnswer = + record.correctAnswer || evaluation.testset.csvdata[rowIndex].correct_answer + + return ( + <> + + depouncedUpdateEvaluationScenario( + { + correctAnswer: e.target.value, + }, + record.id, + ) + } + key={record.id} + /> + + ) + }, }, ...dynamicColumns, { - title: "Evaluate", - dataIndex: "evaluate", - key: "evaluate", - width: 200, - // fixed: 'right', - render: (text: any, record: any, rowIndex: number) => ( - handleVoteClick(record.id, vote)} - loading={record.vote === "loading"} - vertical - key={record.id} - /> - ), + title: "Score", + dataIndex: "score", + key: "score", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + { + handleVoteClick(record.id, vote)} + loading={record.vote === "loading"} + vertical + key={record.id} + /> + } + + ) + }, + }, + { + title: "Additional Note", + dataIndex: "additionalNote", + key: "additionalNote", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + + depouncedUpdateEvaluationScenario({note: e.target.value}, record.id) + } + key={record.id} + /> + + ) + }, }, ] diff --git a/agenta-web/src/components/EvaluationTable/SingleModelEvaluationTable.tsx b/agenta-web/src/components/EvaluationTable/SingleModelEvaluationTable.tsx index 8b731dad74..ae12c869fe 100644 --- a/agenta-web/src/components/EvaluationTable/SingleModelEvaluationTable.tsx +++ b/agenta-web/src/components/EvaluationTable/SingleModelEvaluationTable.tsx @@ -1,4 +1,4 @@ -import {useState, useEffect, useCallback} from "react" +import {useState, useEffect, useCallback, useMemo} from "react" import type {ColumnType} from "antd/es/table" import {CaretRightOutlined} from "@ant-design/icons" import { @@ -6,6 +6,7 @@ import { Card, Col, Form, + Input, Radio, Row, Space, @@ -98,6 +99,22 @@ const useStyles = createUseStyles({ top: 36, zIndex: 1, }, + sideBar: { + marginTop: "1rem", + display: "flex", + flexDirection: "column", + gap: "2rem", + border: "1px solid #d9d9d9", + borderRadius: 6, + padding: "1rem", + alignSelf: "flex-start", + "&>h4.ant-typography": { + margin: 0, + }, + flex: 0.35, + minWidth: 240, + maxWidth: 500, + }, }) export const ParamsFormWithRun = ({ @@ -166,6 +183,13 @@ const SingleModelEvaluationTable: React.FC = ({ const [viewMode, setViewMode] = useQueryParam("viewMode", "card") const [accuracy, setAccuracy] = useState(0) + const depouncedUpdateEvaluationScenario = useCallback( + debounce((data: Partial, scenarioId) => { + updateEvaluationScenarioData(scenarioId, data) + }, 800), + [evaluationScenarios], + ) + useEffect(() => { if (evaluationScenarios) { const obj = [...evaluationScenarios] @@ -403,36 +427,79 @@ const SingleModelEvaluationTable: React.FC = ({ }, }, { - key: "correctAnswer", title: "Expected Output", - dataIndex: "correctAnswer", + dataIndex: "expectedOutput", + key: "expectedOutput", width: "25%", + render: (text: any, record: any, rowIndex: number) => { + let correctAnswer = + record.correctAnswer || evaluation.testset.csvdata[rowIndex].correct_answer + + return ( + <> + + depouncedUpdateEvaluationScenario( + { + correctAnswer: e.target.value, + }, + record.id, + ) + } + key={record.id} + /> + + ) + }, }, ...dynamicColumns, { - title: "Evaluate", - dataIndex: "evaluate", - key: "evaluate", - width: 200, - // fixed: 'right', + title: "Score", + dataIndex: "score", + key: "score", render: (text: any, record: any, rowIndex: number) => { return ( - - depouncedHandleScoreChange(record.id, val[0].score as number) + <> + { + + depouncedHandleScoreChange(record.id, val[0].score as number) + } + loading={record.score === "loading"} + showVariantName={false} + key={record.id} + /> } - loading={record.score === "loading"} - showVariantName={false} - key={record.id} - /> + + ) + }, + }, + { + title: "Additional Note", + dataIndex: "additionalNote", + key: "additionalNote", + render: (text: any, record: any, rowIndex: number) => { + return ( + <> + + depouncedUpdateEvaluationScenario({note: e.target.value}, record.id) + } + key={record.id} + /> + ) }, }, diff --git a/agenta-web/src/lib/helpers/evaluate.ts b/agenta-web/src/lib/helpers/evaluate.ts index d5902915f9..934ee400d1 100644 --- a/agenta-web/src/lib/helpers/evaluate.ts +++ b/agenta-web/src/lib/helpers/evaluate.ts @@ -99,7 +99,7 @@ export const exportABTestingEvaluationData = ( ["Vote"]: evaluation.variants.find((v: Variant) => v.variantId === data.vote)?.variantName || data.vote, - ["Expected answer"]: + ["Expected Output"]: scenarios[ix]?.correctAnswer || evaluation.testset.csvdata[ix].correct_answer, ["Additional notes"]: scenarios[ix]?.note, } @@ -133,7 +133,7 @@ export const exportSingleModelEvaluationData = ( ? data?.columnData0 : data.outputs[0]?.variant_output, ["Score"]: isNaN(numericScore) ? "-" : numericScore, - ["Expected answer"]: + ["Expected Output"]: scenarios[ix]?.correctAnswer || evaluation.testset.csvdata[ix].correct_answer, ["Additional notes"]: scenarios[ix]?.note, } diff --git a/agenta-web/tsconfig.json b/agenta-web/tsconfig.json index 73b324e040..db62be51b5 100644 --- a/agenta-web/tsconfig.json +++ b/agenta-web/tsconfig.json @@ -15,9 +15,9 @@ "jsx": "preserve", "incremental": true, "paths": { - "@/*": ["./src/*"] - } + "@/*": ["./src/*"], + }, }, "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx"], - "exclude": ["node_modules", "cypress.config.ts", "cypress/**/*"] + "exclude": ["node_modules", "cypress.config.ts", "cypress/**/*"], }