From c52db12181ce70dc4f4c46aba8fdb86efc54ac04 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Thu, 14 Nov 2024 11:16:48 -0800 Subject: [PATCH] feat: add code operations to the ui (#169) --- website/package-lock.json | 40 +++ website/package.json | 1 + website/src/app/types.ts | 5 +- website/src/components/AIChatPanel.tsx | 8 +- website/src/components/PipelineGui.tsx | 35 +++ website/src/components/ResizableDataTable.tsx | 238 ++++++++++++------ website/src/components/operations/args.tsx | 107 ++++++++ .../src/components/operations/components.tsx | 188 +++++++++++++- 8 files changed, 541 insertions(+), 81 deletions(-) diff --git a/website/package-lock.json b/website/package-lock.json index 61ac8ede..87950fde 100644 --- a/website/package-lock.json +++ b/website/package-lock.json @@ -11,6 +11,7 @@ "@agbishop/react-ansi-18": "^4.0.6", "@ai-sdk/openai": "^0.0.70", "@hookform/resolvers": "^3.9.0", + "@monaco-editor/react": "^4.6.0", "@next/third-parties": "^14.2.11", "@radix-ui/react-accordion": "^1.2.0", "@radix-ui/react-alert-dialog": "^1.1.2", @@ -1366,6 +1367,32 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@monaco-editor/loader": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@monaco-editor/loader/-/loader-1.4.0.tgz", + "integrity": "sha512-00ioBig0x642hytVspPl7DbQyaSWRaolYie/UFNjoTdvoKPzo6xrXLhTk9ixgIKcLH5b5vDOjVNiGyY+uDCUlg==", + "license": "MIT", + "dependencies": { + "state-local": "^1.0.6" + }, + "peerDependencies": { + "monaco-editor": ">= 0.21.0 < 1" + } + }, + "node_modules/@monaco-editor/react": { + "version": "4.6.0", + "resolved": "https://registry.npmjs.org/@monaco-editor/react/-/react-4.6.0.tgz", + "integrity": "sha512-RFkU9/i7cN2bsq/iTkurMWOEErmYcY6JiQI3Jn+WeR/FGISH8JbHERjpS9oRuSOPvDMJI0Z8nJeKkbOs9sBYQw==", + "license": "MIT", + "dependencies": { + "@monaco-editor/loader": "^1.4.0" + }, + "peerDependencies": { + "monaco-editor": ">= 0.25.0 < 1", + "react": "^16.8.0 || ^17.0.0 || ^18.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" + } + }, "node_modules/@mui/core-downloads-tracker": { "version": "5.16.7", "resolved": "https://registry.npmjs.org/@mui/core-downloads-tracker/-/core-downloads-tracker-5.16.7.tgz", @@ -9251,6 +9278,13 @@ "node": ">=16 || 14 >=14.17" } }, + "node_modules/monaco-editor": { + "version": "0.52.0", + "resolved": "https://registry.npmjs.org/monaco-editor/-/monaco-editor-0.52.0.tgz", + "integrity": "sha512-OeWhNpABLCeTqubfqLMXGsqf6OmPU6pHM85kF3dhy6kq5hnhuVS1p3VrEW/XhWHc71P2tHyS5JFySD8mgs1crw==", + "license": "MIT", + "peer": true + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", @@ -10801,6 +10835,12 @@ "svelte": "^4.0.0 || ^5.0.0-next.0" } }, + "node_modules/state-local": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/state-local/-/state-local-1.0.7.tgz", + "integrity": "sha512-HTEHMNieakEnoe33shBYcZ7NX83ACUjCu8c40iOGEZsngj9zRnkqS9j1pqQPXwobB0ZcVTk27REb7COQ0UR59w==", + "license": "MIT" + }, "node_modules/stop-iteration-iterator": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/stop-iteration-iterator/-/stop-iteration-iterator-1.0.0.tgz", diff --git a/website/package.json b/website/package.json index 497d4e4c..79434f1b 100644 --- a/website/package.json +++ b/website/package.json @@ -12,6 +12,7 @@ "@agbishop/react-ansi-18": "^4.0.6", "@ai-sdk/openai": "^0.0.70", "@hookform/resolvers": "^3.9.0", + "@monaco-editor/react": "^4.6.0", "@next/third-parties": "^14.2.11", "@radix-ui/react-accordion": "^1.2.0", "@radix-ui/react-alert-dialog": "^1.1.2", diff --git a/website/src/app/types.ts b/website/src/app/types.ts index 994b3013..eb56b6df 100644 --- a/website/src/app/types.ts +++ b/website/src/app/types.ts @@ -15,7 +15,10 @@ export type Operation = { | "unnest" | "split" | "gather" - | "sample"; + | "sample" + | "code_map" + | "code_reduce" + | "code_filter"; name: string; prompt?: string; output?: { schema: SchemaItem[] }; diff --git a/website/src/components/AIChatPanel.tsx b/website/src/components/AIChatPanel.tsx index 8f8a79b6..a1c24adf 100644 --- a/website/src/components/AIChatPanel.tsx +++ b/website/src/components/AIChatPanel.tsx @@ -106,7 +106,7 @@ const AIChatPanel: React.FC = ({ onClose }) => { Core Capabilities: - DocETL enables users to create sophisticated data processing workflows with LLM calls, like crowdsourcing pipelines - Each pipeline processes documents through a sequence of operations -- Operations can be LLM-based (map, reduce, resolve, filter) or utility-based (unnest, split, gather, sample) +- Operations can be LLM-based (map, reduce, resolve, filter) or utility-based (unnest, split, gather, sample) or code-based (python for map, reduce, and filter) Operation Details: - Every LLM operation has: @@ -116,6 +116,12 @@ Operation Details: - Map/Filter: Access current doc with '{{ input.keyname }}' - Reduce: Loop through docs with '{% for doc in inputs %}...{% endfor %}' - Resolve: Compare docs with '{{ input1 }}/{{ input2 }}' and canonicalize with '{{ inputs }}' +- Code-based operations: + - Map: Define a transform function (def transform(doc: dict) -> dict), where the returned dict will have key-value pairs that will be added to the output document + - Filter: Define a transform function (def transform(doc: dict) -> bool), where the function should return true if the document should be included in the output + - Reduce: Define a transform function (def transform(docs: list[dict]) -> dict), where the returned dict will have key-value pairs that will *be* the output document (unless "pass_through" is set to true, then the first original doc for every group will also be returned) + - Only do imports of common libraries, inside the function definition + - Only suggest code-based operations if the task is one that is easily expressed in code, and LLMs or crowd workers are incapable of doing it correctly (e.g., word count, simple regex, etc.) Your Role: - Help users optimize pipelines and overcome challenges diff --git a/website/src/components/PipelineGui.tsx b/website/src/components/PipelineGui.tsx index 8af768d9..f9c89e45 100644 --- a/website/src/components/PipelineGui.tsx +++ b/website/src/components/PipelineGui.tsx @@ -735,6 +735,41 @@ const PipelineGUI: React.FC = () => { > Sample + + Code Operations + + handleAddOperation( + "non-LLM", + "code_map", + "Untitled Code Map" + ) + } + > + Code Map + + + handleAddOperation( + "non-LLM", + "code_reduce", + "Untitled Code Reduce" + ) + } + > + Code Reduce + + + handleAddOperation( + "non-LLM", + "code_filter", + "Untitled Code Filter" + ) + } + > + Code Filter +
diff --git a/website/src/components/ResizableDataTable.tsx b/website/src/components/ResizableDataTable.tsx index 73561406..fccc4f0f 100644 --- a/website/src/components/ResizableDataTable.tsx +++ b/website/src/components/ResizableDataTable.tsx @@ -1,10 +1,4 @@ -import React, { - useState, - useEffect, - useCallback, - useRef, - useMemo, -} from "react"; +import React, { useState, useEffect, useCallback, useMemo } from "react"; import { flexRender, getCoreRowModel, @@ -16,6 +10,8 @@ import { Row, getPaginationRowModel, VisibilityState, + SortingState, + getSortedRowModel, } from "@tanstack/react-table"; import { Table, @@ -26,7 +22,14 @@ import { TableRow, } from "@/components/ui/table"; import { Button } from "@/components/ui/button"; -import { ChevronLeft, ChevronRight, ChevronDown } from "lucide-react"; +import { + ChevronLeft, + ChevronRight, + ChevronDown, + ArrowUpDown, + ArrowUp, + ArrowDown, +} from "lucide-react"; import { DropdownMenu, DropdownMenuCheckboxItem, @@ -36,12 +39,12 @@ import { import { TABLE_SETTINGS_KEY } from "@/app/localStorageKeys"; import ReactMarkdown from "react-markdown"; import debounce from "lodash/debounce"; -import { Progress } from "@/components/ui/progress"; import { BarChart, Bar, XAxis, Tooltip, ResponsiveContainer } from "recharts"; export type DataType = Record; export type ColumnType = ColumnDef & { initialWidth?: number; + accessorKey?: string; }; interface ColumnStats { @@ -95,6 +98,18 @@ function calculateColumnStats( const max = Math.max(...values); const avg = values.reduce((sum, val) => sum + val, 0) / values.length; + // Special handling for single distinct value + if (min === max) { + return { + min, + max, + avg, + distribution: [values.length], // Put all values in a single bucket + bucketSize: 1, + type: type, + }; + } + // For numbers, use more precise bucketing const bucketSize = type === "number" ? (max - min) / 7 : Math.ceil((max - min) / 7); @@ -124,57 +139,64 @@ const WordCountHistogram = React.memo( histogramData, }: { histogramData: { range: string; count: number; fullRange: string }[]; - }) => ( - - - - [value.toLocaleString(), "Count"]} - labelFormatter={(label: string) => label} - contentStyle={{ - backgroundColor: "hsl(var(--popover))", - border: "1px solid hsl(var(--border))", - borderRadius: "var(--radius)", - color: "hsl(var(--popover-foreground))", - padding: "8px 12px", - boxShadow: "0 2px 4px rgba(0,0,0,0.1)", - }} - /> - - - - ) + }) => { + // Calculate total count for fractions + const totalCount = useMemo( + () => histogramData.reduce((sum, item) => sum + item.count, 0), + [histogramData] + ); + + return ( + + + + [ + `${value.toLocaleString()} (${( + (value / totalCount) * + 100 + ).toFixed(1)}%)`, + "Count", + ]} + labelFormatter={(label: string) => label} + contentStyle={{ + backgroundColor: "hsl(var(--popover))", + border: "1px solid hsl(var(--border))", + borderRadius: "var(--radius)", + color: "hsl(var(--popover-foreground))", + padding: "8px 12px", + boxShadow: "0 2px 4px rgba(0,0,0,0.1)", + }} + /> + + + + ); + } ); WordCountHistogram.displayName = "WordCountHistogram"; -interface ColumnStats { - min: number; - max: number; - avg: number; - distribution: number[]; - bucketSize: number; - type: "number" | "string" | "array"; -} - interface ColumnHeaderProps { header: string; stats: ColumnStats | null; isBold: boolean; + onSort: () => void; + sortDirection: "asc" | "desc" | false; } const ColumnHeader = React.memo( - ({ header, stats, isBold }: ColumnHeaderProps) => { + ({ header, stats, isBold, onSort, sortDirection }: ColumnHeaderProps) => { const histogramData = useMemo(() => { if (!stats) return []; @@ -189,6 +211,17 @@ const ColumnHeader = React.memo( } }; + // Special handling for single distinct value + if (stats.min === stats.max) { + return [ + { + range: `${Math.round(stats.min)}`, + count: stats.distribution[0], + fullRange: `${Math.round(stats.min)}${getUnitLabel()}`, + }, + ]; + } + return stats.distribution.map((count, i) => ({ range: `${Math.round(stats.min + i * stats.bucketSize)}`, count, @@ -202,27 +235,53 @@ const ColumnHeader = React.memo( return (
-
{header}
+
+ {header} + +
{stats && (
- - {stats.min} - {stats.type === "array" - ? " items" - : stats.type === "string" - ? " words" - : ""} - - avg: {Math.round(stats.avg)} - - {stats.max} - {stats.type === "array" - ? " items" - : stats.type === "string" - ? " words" - : ""} - + {stats.min === stats.max ? ( + + Single value: {stats.min} + {stats.type === "array" + ? " items" + : stats.type === "string" + ? " words" + : ""} + + ) : ( + <> + + {stats.min} + {stats.type === "array" + ? " items" + : stats.type === "string" + ? " words" + : ""} + + avg: {Math.round(stats.avg)} + + {stats.max} + {stats.type === "array" + ? " items" + : stats.type === "string" + ? " words" + : ""} + + + )}
@@ -418,6 +477,7 @@ function ResizableDataTable({ return {}; }); const [columnVisibility, setColumnVisibility] = useState({}); + const [sorting, setSorting] = useState([]); const saveSettings = useCallback(() => { localStorage.setItem( @@ -473,11 +533,37 @@ function ResizableDataTable({ const table = useReactTable({ data, - // Replace columns with sortedColumns here - columns: sortedColumns, + columns: sortedColumns.map((col) => ({ + ...col, + enableSorting: true, + sortingFn: (rowA: any, rowB: any) => { + const accessor = col.accessorKey; + if (!accessor) return 0; + + const a = rowA.getValue(accessor); + const b = rowB.getValue(accessor); + + // Handle null/undefined values + if (a == null) return -1; + if (b == null) return 1; + + // Sort based on type + if (typeof a === "number" && typeof b === "number") { + return a - b; + } + + if (Array.isArray(a) && Array.isArray(b)) { + return a.length - b.length; + } + + // For strings, do alphabetical comparison + return String(a).localeCompare(String(b)); + }, + })), columnResizeMode: "onChange" as ColumnResizeMode, getCoreRowModel: getCoreRowModel(), getPaginationRowModel: getPaginationRowModel(), + getSortedRowModel: getSortedRowModel(), onColumnSizingChange: (newColumnSizing) => { setColumnSizing(newColumnSizing); setIsResizing(true); @@ -485,10 +571,13 @@ function ResizableDataTable({ saveSettings(); }, onColumnVisibilityChange: setColumnVisibility, + onSortingChange: setSorting, state: { columnSizing, columnVisibility, + sorting, }, + enableSorting: true, enableColumnResizing: true, defaultColumn: { minSize: 30, @@ -585,6 +674,10 @@ function ResizableDataTable({ isBold={boldedColumns.includes( header.column.columnDef.header as string )} + onSort={() => header.column.toggleSorting()} + sortDirection={ + header.column.getIsSorted() as false | "asc" | "desc" + } /> )} @@ -605,10 +698,7 @@ function ResizableDataTable({ }} > - {table.getState().pagination.pageIndex * - table.getState().pagination.pageSize + - index + - 1} + {row.index + 1} {row.getVisibleCells().map((cell) => ( diff --git a/website/src/components/operations/args.tsx b/website/src/components/operations/args.tsx index 659e8196..fd59d708 100644 --- a/website/src/components/operations/args.tsx +++ b/website/src/components/operations/args.tsx @@ -19,6 +19,7 @@ import { } from "../ui/tooltip"; import { Switch } from "../ui/switch"; import { Label } from "../ui/label"; +import Editor from "@monaco-editor/react"; interface PromptInputProps { prompt: string; @@ -398,3 +399,109 @@ export const Guardrails: React.FC = React.memo( ); Guardrails.displayName = "Guardrails"; + +interface CodeInputProps { + code: string; + operationType: "code_map" | "code_reduce" | "code_filter"; + onChange: (value: string) => void; +} + +export const CodeInput: React.FC = React.memo( + ({ code, operationType, onChange }) => { + const getPlaceholder = () => { + switch (operationType) { + case "code_map": + return `def transform(doc: dict) -> dict: + # Transform a single document + # Return a dictionary with new key-value pairs + return { + 'new_key': process(doc['existing_key']) + }`; + case "code_filter": + return `def transform(doc: dict) -> bool: + # Return True to keep the document, False to filter it out + return doc['score'] >= 0.5`; + case "code_reduce": + return `def transform(items: list) -> dict: + # Aggregate multiple items into a single result + # Return a dictionary with aggregated values + return { + 'total': sum(item['value'] for item in items), + 'count': len(items) + }`; + } + }; + + const validatePythonCode = (value: string) => { + return value.includes("def transform") && value.includes("return"); + }; + + const getTooltipContent = () => { + switch (operationType) { + case "code_map": + return "Transform each document independently using Python code. The transform function takes a single document as input and returns a dictionary with new key-value pairs."; + case "code_filter": + return "Filter documents using Python code. The transform function takes a document as input and returns True to keep it or False to filter it out."; + case "code_reduce": + return "Aggregate multiple documents using Python code. The transform function takes a list of documents as input and returns a single aggregated result."; + } + }; + + return ( +
+
+ + + + + + + +

{getTooltipContent()}

+

+ Code operations allow you to use Python for: +

    +
  • Deterministic processing
  • +
  • Complex calculations
  • +
  • Integration with Python libraries
  • +
  • Structured data transformations
  • +
+

+
+
+
+
+
+ onChange(value || "")} + options={{ + minimap: { enabled: false }, + lineNumbers: "on", + scrollBeyondLastLine: false, + wordWrap: "on", + wrappingIndent: "indent", + automaticLayout: true, + tabSize: 4, + fontSize: 14, + fontFamily: "monospace", + suggest: { + showKeywords: true, + showSnippets: true, + }, + }} + /> +
+ {!validatePythonCode(code) && ( +
+ Code must define a transform function with a return statement +
+ )} +
+ ); + } +); + +CodeInput.displayName = "CodeInput"; diff --git a/website/src/components/operations/components.tsx b/website/src/components/operations/components.tsx index 666bc17a..64082b8d 100644 --- a/website/src/components/operations/components.tsx +++ b/website/src/components/operations/components.tsx @@ -1,5 +1,6 @@ +import React from "react"; import { Operation, SchemaItem } from "@/app/types"; -import { OutputSchema, PromptInput } from "./args"; +import { OutputSchema, PromptInput, CodeInput } from "./args"; import { useMemo } from "react"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; @@ -21,6 +22,67 @@ import { import { Checkbox } from "@/components/ui/checkbox"; import { Textarea } from "../ui/textarea"; +interface PromptConfig { + prompt: string; + output_keys?: string[]; + model?: string; +} + +interface MethodKwargs { + delimiter?: string; + num_tokens?: number; + stratify_key?: string; + embedding_keys?: string[]; + std?: number; + keep?: boolean; +} + +interface PeripheralChunkConfig { + content_key?: string; + count?: number; +} + +interface PeripheralChunksSection { + head?: PeripheralChunkConfig; + middle?: PeripheralChunkConfig; + tail?: PeripheralChunkConfig; +} + +interface PeripheralChunks { + previous?: PeripheralChunksSection; + next?: PeripheralChunksSection; +} + +interface OtherKwargs { + prompts?: PromptConfig[]; + method?: string; + method_kwargs?: MethodKwargs; + reduce_key?: string[]; + comparison_prompt?: string; + resolution_prompt?: string; + blocking_threshold?: number; + blocking_keys?: string[]; + split_key?: string; + unnest_key?: string; + recursive?: boolean; + depth?: number; + content_key?: string; + doc_id_key?: string; + order_key?: string; + peripheral_chunks?: PeripheralChunks; + samples?: string | number; + code?: string; +} + +interface Operation { + type: string; + prompt?: string; + output?: { + schema?: SchemaItem[]; + }; + otherKwargs?: OtherKwargs; +} + interface OperationComponentProps { operation: Operation; isSchemaExpanded: boolean; @@ -961,7 +1023,7 @@ export const ParallelMapOperationComponent: React.FC< return (
{(operation.otherKwargs?.prompts || []).map( - (prompt: any, index: number) => ( + (prompt: PromptConfig, index: number) => (
@@ -1059,7 +1121,7 @@ export const SampleOperationComponent: React.FC = ({ isSchemaExpanded, onToggleSchema, }) => { - const handleChange = (field: string, value: any) => { + const handleChange = (field: string, value: string | number | boolean) => { onUpdate({ ...operation, otherKwargs: { @@ -1069,7 +1131,10 @@ export const SampleOperationComponent: React.FC = ({ }); }; - const handleMethodKwargsChange = (field: string, value: any) => { + const handleMethodKwargsChange = ( + field: string, + value: string | number | boolean | string[] + ) => { onUpdate({ ...operation, otherKwargs: { @@ -1212,6 +1277,109 @@ export const SampleOperationComponent: React.FC = ({ ); }; +export const CodeOperationComponent: React.FC = ({ + operation, + onUpdate, +}) => { + const handleCodeChange = (newCode: string) => { + onUpdate({ + ...operation, + otherKwargs: { + ...operation.otherKwargs, + code: newCode, + }, + }); + }; + + const handleReduceKeysChange = (newReduceKeys: string[]) => { + onUpdate({ + ...operation, + otherKwargs: { + ...operation.otherKwargs, + reduce_key: newReduceKeys, + }, + }); + }; + + return ( +
+ {operation.type === "code_reduce" && ( +
+
+ +
+
+ {(operation.otherKwargs?.reduce_key || [""]).map( + (key: string, index: number) => ( +
+ { + const newKeys = [ + ...(operation.otherKwargs?.reduce_key || [""]), + ]; + newKeys[index] = e.target.value; + handleReduceKeysChange(newKeys); + }} + placeholder="Enter reduce key" + className="w-full pr-8" + /> + +
+ ) + )} + +
+
+
+
+ )} +
+ +
+
+ ); +}; + export default function createOperationComponent( operation: Operation, onUpdate: (updatedOperation: Operation) => void, @@ -1292,7 +1460,17 @@ export default function createOperationComponent( onToggleSchema={onToggleSchema} /> ); - + case "code_map": + case "code_reduce": + case "code_filter": + return ( + + ); default: console.warn(`Unsupported operation type: ${operation.type}`); return null;