From 0a4bcc729bc543c839d784881af2e469a90c30f3 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Fri, 22 Nov 2024 10:22:38 -0800 Subject: [PATCH] feat: add basic llm call observability to the UI (#209) * feat: add basic llm call observability to the UI * test all observability for ops * rebase with main * edit readme --- README.md | 25 +-- docetl/operations/map.py | 14 +- docetl/operations/reduce.py | 72 +++++---- docetl/operations/resolve.py | 50 +++++- docetl/optimizers/join_optimizer.py | 2 +- website/src/app/api/utils.ts | 1 + website/src/components/Output.tsx | 1 + website/src/components/ResizableDataTable.tsx | 143 ++++++++++++++---- 8 files changed, 224 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index c6845d4d..97c3ea32 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,19 @@ DocETL is the ideal choice when you're looking to maximize correctness and outpu - You're working with long documents that don't fit into a single prompt - You have validation criteria and want tasks to automatically retry when validation fails +### Community Projects + +- [Conversation Generator](https://github.com/PassionFruits-net/docetl-conversation) +- [Text-to-speech](https://github.com/PassionFruits-net/docetl-speaker) +- [YouTube Transcript Topic Extraction](https://github.com/rajib76/docetl_examples) + +### Educational Resources + +- [UI/UX Thoughts](https://x.com/sh_reya/status/1846235904664273201) +- [Using Gleaning to Improve Output Quality](https://x.com/sh_reya/status/1843354256335876262) +- [Deep Dive on Resolve Operator](https://x.com/sh_reya/status/1840796824636121288) + + ## Getting Started There are two main ways to use DocETL: @@ -161,15 +174,3 @@ make tests-basic # Runs basic test suite (costs < $0.01 with OpenAI) ``` For detailed documentation and tutorials, visit our [documentation](https://ucbepic.github.io/docetl). - -## Community Projects - -- [Conversation Generator](https://github.com/PassionFruits-net/docetl-conversation) -- [Text-to-speech](https://github.com/PassionFruits-net/docetl-speaker) -- [YouTube Transcript Topic Extraction](https://github.com/rajib76/docetl_examples) - -## Educational Resources - -- [UI/UX Thoughts](https://x.com/sh_reya/status/1846235904664273201) -- [Using Gleaning to Improve Output Quality](https://x.com/sh_reya/status/1843354256335876262) -- [Deep Dive on Resolve Operator](https://x.com/sh_reya/status/1840796824636121288) diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 7b697c08..e4995fec 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -46,6 +46,7 @@ class schema(BaseOperation.schema): gleaning: Optional[Dict[str, Any]] = None drop_keys: Optional[List[str]] = None timeout: Optional[int] = None + enable_observability: bool = False batch_size: Optional[int] = None clustering_method: Optional[str] = None batch_prompt: Optional[str] = None @@ -228,9 +229,12 @@ def validation_fn(response: Union[Dict[str, Any], ModelResponse]): )[0] else: output = llm_result.response + # Augment the output with the original item output = {**item, **output} + if self.config.get("enable_observability", False): + output[f"_observability_{self.config['name']}"] = {"prompt": prompt} return output, llm_result.total_cost return None, llm_result.total_cost @@ -317,6 +321,7 @@ class schema(BaseOperation.schema): type: str = "parallel_map" prompts: List[Dict[str, Any]] output: Dict[str, Any] + enable_observability: bool = False def __init__( self, @@ -471,7 +476,7 @@ def process_prompt(item, prompt_config): tools=prompt_config.get("tools", None), manually_fix_errors=self.manually_fix_errors, )[0] - return output, response.total_cost + return output, prompt, response.total_cost with ThreadPoolExecutor(max_workers=self.max_threads) as executor: if "prompts" in self.config: @@ -488,7 +493,7 @@ def process_prompt(item, prompt_config): desc="Processing parallel map items", ): future = all_futures[i] - output, cost = future.result() + output, prompt, cost = future.result() total_cost += cost # Determine which item this future corresponds to @@ -503,6 +508,11 @@ def process_prompt(item, prompt_config): # Fetch the item_result item_result = results[item_index] + if self.config.get("enable_observability", False): + if f"_observability_{self.config['name']}" not in item_result: + item_result[f"_observability_{self.config['name']}"] = {} + item_result[f"_observability_{self.config['name']}"].update({f"prompt_{prompt_index}": prompt}) + # Update the item_result with the output item_result.update(output) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index d326d302..254323a6 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -53,6 +53,7 @@ class schema(BaseOperation.schema): verbose: Optional[bool] = None timeout: Optional[int] = None litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) + enable_observability: bool = False def __init__(self, *args, **kwargs): """ @@ -386,24 +387,32 @@ def process_group( # Only execute merge-based plans if associative = True if "merge_prompt" in self.config and self.config.get("associative", True): - result, cost = self._parallel_fold_and_merge(key, group_list) + result, prompts, cost = self._parallel_fold_and_merge(key, group_list) elif ( self.config.get("fold_batch_size", None) and self.config.get("fold_batch_size") >= len(group_list) ): # If the fold batch size is greater than or equal to the number of items in the group, # we can just run a single fold operation - result, cost = self._batch_reduce(key, group_list) + result, prompt, cost = self._batch_reduce(key, group_list) + prompts = [prompt] elif "fold_prompt" in self.config: - result, cost = self._incremental_reduce(key, group_list) + result, prompts, cost = self._incremental_reduce(key, group_list) else: - result, cost = self._batch_reduce(key, group_list) + result, prompt, cost = self._batch_reduce(key, group_list) + prompts = [prompt] total_cost += cost # Add the counts of items in the group to the result result[f"_counts_prereduce_{self.config['name']}"] = len(group_elems) + if self.config.get("enable_observability", False): + # Add the _observability_{self.config['name']} key to the result + result[f"_observability_{self.config['name']}"] = { + "prompts": prompts + } + # Apply pass-through at the group level if ( result is not None @@ -548,7 +557,7 @@ def _parallel_fold_and_merge( fold_batch_size = self.config["fold_batch_size"] merge_batch_size = self.config["merge_batch_size"] total_cost = 0 - + prompts = [] def calculate_num_parallel_folds(): fold_time, fold_default = self.get_fold_time() merge_time, merge_default = self.get_merge_time() @@ -589,8 +598,9 @@ def calculate_num_parallel_folds(): new_fold_results = [] for future in as_completed(fold_futures): - result, cost = future.result() + result, prompt, cost = future.result() total_cost += cost + prompts.append(prompt) if result is not None: new_fold_results.append(result) if self.config.get("persist_intermediates", False): @@ -620,8 +630,9 @@ def calculate_num_parallel_folds(): new_results = [] for future in as_completed(merge_futures): - result, cost = future.result() + result, prompt, cost = future.result() total_cost += cost + prompts.append(prompt) if result is not None: new_results.append(result) if self.config.get("persist_intermediates", False): @@ -660,8 +671,9 @@ def calculate_num_parallel_folds(): new_results = [] for future in as_completed(merge_futures): - result, cost = future.result() + result, prompt, cost = future.result() total_cost += cost + prompts.append(prompt) if result is not None: new_results.append(result) if self.config.get("persist_intermediates", False): @@ -676,11 +688,11 @@ def calculate_num_parallel_folds(): fold_results = new_results - return (fold_results[0], total_cost) if fold_results else (None, total_cost) + return (fold_results[0], prompts, total_cost) if fold_results else (None, prompts, total_cost) def _incremental_reduce( self, key: Tuple, group_list: List[Dict] - ) -> Tuple[Optional[Dict], float]: + ) -> Tuple[Optional[Dict], List[str], float]: """ Perform an incremental reduce operation on a group of items. @@ -691,12 +703,13 @@ def _incremental_reduce( group_list (List[Dict]): The list of items in the group to be processed. Returns: - Tuple[Optional[Dict], float]: A tuple containing the final reduced result (or None if processing failed) - and the total cost of the operation. + Tuple[Optional[Dict], List[str], float]: A tuple containing the final reduced result (or None if processing failed), + the list of prompts used, and the total cost of the operation. """ fold_batch_size = self.config["fold_batch_size"] total_cost = 0 current_output = None + prompts = [] # Calculate and log the number of folds to be performed num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size @@ -715,10 +728,11 @@ def _incremental_reduce( ) batch = group_list[i : i + fold_batch_size] - folded_output, fold_cost = self._increment_fold( + folded_output, prompt, fold_cost = self._increment_fold( key, batch, current_output, scratchpad ) total_cost += fold_cost + prompts.append(prompt) if folded_output is None: continue @@ -744,7 +758,7 @@ def _incremental_reduce( current_output = folded_output - return current_output, total_cost + return current_output, prompts, total_cost def validation_fn(self, response: Dict[str, Any]): output = self.runner.api.parse_llm_response( @@ -761,7 +775,7 @@ def _increment_fold( batch: List[Dict], current_output: Optional[Dict], scratchpad: Optional[str] = None, - ) -> Tuple[Optional[Dict], float]: + ) -> Tuple[Optional[Dict], str, float]: """ Perform an incremental fold operation on a batch of items. @@ -773,8 +787,8 @@ def _increment_fold( current_output (Optional[Dict]): The current accumulated output, if any. scratchpad (Optional[str]): The scratchpad to use for the fold operation. Returns: - Tuple[Optional[Dict], float]: A tuple containing the folded output (or None if processing failed) - and the cost of the fold operation. + Tuple[Optional[Dict], str, float]: A tuple containing the folded output (or None if processing failed), + the prompt used, and the cost of the fold operation. """ if current_output is None: return self._batch_reduce(key, batch, scratchpad) @@ -822,13 +836,13 @@ def _increment_fold( folded_output.update(dict(zip(self.config["reduce_key"], key))) fold_cost = response.total_cost - return folded_output, fold_cost + return folded_output, fold_prompt, fold_cost - return None, fold_cost + return None, fold_prompt, fold_cost def _merge_results( self, key: Tuple, outputs: List[Dict] - ) -> Tuple[Optional[Dict], float]: + ) -> Tuple[Optional[Dict], str, float]: """ Merge multiple outputs into a single result. @@ -839,8 +853,8 @@ def _merge_results( outputs (List[Dict]): The list of outputs to be merged. Returns: - Tuple[Optional[Dict], float]: A tuple containing the merged output (or None if processing failed) - and the cost of the merge operation. + Tuple[Optional[Dict], str, float]: A tuple containing the merged output (or None if processing failed), + the prompt used, and the cost of the merge operation. """ start_time = time.time() merge_prompt_template = Template(self.config["merge_prompt"]) @@ -879,9 +893,9 @@ def _merge_results( )[0] merged_output.update(dict(zip(self.config["reduce_key"], key))) merge_cost = response.total_cost - return merged_output, merge_cost + return merged_output, merge_prompt, merge_cost - return None, merge_cost + return None, merge_prompt, merge_cost def get_fold_time(self) -> Tuple[float, bool]: """ @@ -935,7 +949,7 @@ def _update_merge_time(self, time: float) -> None: def _batch_reduce( self, key: Tuple, group_list: List[Dict], scratchpad: Optional[str] = None - ) -> Tuple[Optional[Dict], float]: + ) -> Tuple[Optional[Dict], str, float]: """ Perform a batch reduce operation on a group of items. @@ -946,8 +960,8 @@ def _batch_reduce( group_list (List[Dict]): The list of items to be reduced. scratchpad (Optional[str]): The scratchpad to use for the reduce operation. Returns: - Tuple[Optional[Dict], float]: A tuple containing the reduced output (or None if processing failed) - and the cost of the reduce operation. + Tuple[Optional[Dict], str, float]: A tuple containing the reduced output (or None if processing failed), + the prompt used, and the cost of the reduce operation. """ prompt_template = Template(self.config["prompt"]) prompt = prompt_template.render( @@ -988,5 +1002,5 @@ def _batch_reduce( )[0] output.update(dict(zip(self.config["reduce_key"], key))) - return output, item_cost - return None, item_cost + return output, prompt, item_cost + return None, prompt, item_cost diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 88dfbbfd..bcfec266 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -5,7 +5,9 @@ import random import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional, Union +import json +from datetime import datetime import jinja2 from jinja2 import Template @@ -45,6 +47,7 @@ class schema(BaseOperation.schema): optimize: Optional[bool] = None timeout: Optional[int] = None litellm_completion_kwargs: Dict[str, Any] = Field(default_factory=dict) + enable_observability: bool = False def compare_pair( self, @@ -55,7 +58,7 @@ def compare_pair( blocking_keys: List[str] = [], timeout_seconds: int = 120, max_retries_per_timeout: int = 2, - ) -> Tuple[bool, float]: + ) -> Tuple[bool, float, str]: """ Compares two items using an LLM model to determine if they match. @@ -66,7 +69,7 @@ def compare_pair( item2 (Dict): The second item to compare. Returns: - Tuple[bool, float]: A tuple containing a boolean indicating whether the items match and the cost of the comparison. + Tuple[bool, float, str]: A tuple containing a boolean indicating whether the items match, the cost of the comparison, and the prompt. """ if blocking_keys: if all( @@ -75,7 +78,7 @@ def compare_pair( and str(item1[key]).lower() == str(item2[key]).lower() for key in blocking_keys ): - return True, 0 + return True, 0, "" prompt_template = Template(comparison_prompt) prompt = prompt_template.render(input1=item1, input2=item2) @@ -93,7 +96,8 @@ def compare_pair( response.response, {"is_match": "bool"}, )[0] - return output["is_match"], response.total_cost + + return output["is_match"], response.total_cost, prompt def syntax_check(self) -> None: """ @@ -232,6 +236,16 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: if len(input_data) == 0: return [], 0 + # Initialize observability data for all items at the start + if self.config.get("enable_observability", False): + observability_key = f"_observability_{self.config['name']}" + for item in input_data: + if observability_key not in item: + item[observability_key] = { + "comparison_prompts": [], + "resolution_prompt": None + } + blocking_keys = self.config.get("blocking_keys", []) blocking_threshold = self.config.get("blocking_threshold") blocking_conditions = self.config.get("blocking_conditions", []) @@ -336,7 +350,7 @@ def meets_blocking_conditions(pair: Tuple[int, int]) -> bool: is_match(input_data[i], input_data[j]) if blocking_conditions else False ) - blocked_pairs = list(filter(meets_blocking_conditions, comparison_pairs)) + blocked_pairs = list(filter(meets_blocking_conditions, comparison_pairs)) if blocking_conditions else comparison_pairs # Apply limit_comparisons to blocked pairs if limit_comparisons is not None and len(blocked_pairs) > limit_comparisons: @@ -502,10 +516,20 @@ def auto_batch() -> int: for future in as_completed(future_to_pair): pair = future_to_pair[future] - is_match_result, cost = future.result() + is_match_result, cost, prompt = future.result() pair_costs += cost if is_match_result: merge_clusters(pair[0], pair[1]) + + if self.config.get("enable_observability", False): + observability_key = f"_observability_{self.config['name']}" + for idx in (pair[0], pair[1]): + if observability_key not in input_data[idx]: + input_data[idx][observability_key] = { + "comparison_prompts": [], + "resolution_prompt": None + } + input_data[idx][observability_key]["comparison_prompts"].append(prompt) pbar.update(last_processed//batch_size) total_cost += pair_costs @@ -549,6 +573,16 @@ def process_cluster(cluster): ) reduction_cost = reduction_response.total_cost + if self.config.get("enable_observability", False): + for item in [input_data[i] for i in cluster]: + observability_key = f"_observability_{self.config['name']}" + if observability_key not in item: + item[observability_key] = { + "comparison_prompts": [], + "resolution_prompt": None + } + item[observability_key]["resolution_prompt"] = resolution_prompt + if reduction_response.validated: reduction_output = self.runner.api.parse_llm_response( reduction_response.response, @@ -617,6 +651,8 @@ def process_cluster(cluster): for output_key, compare_key in key_mapping.items(): if compare_key in input_data[list(cluster)[0]]: result[output_key] = input_data[list(cluster)[0]][compare_key] + elif output_key in input_data[list(cluster)[0]]: + result[output_key] = input_data[list(cluster)[0]][output_key] else: result[output_key] = None # or some default value diff --git a/docetl/optimizers/join_optimizer.py b/docetl/optimizers/join_optimizer.py index cb29404c..42d51871 100644 --- a/docetl/optimizers/join_optimizer.py +++ b/docetl/optimizers/join_optimizer.py @@ -1119,7 +1119,7 @@ def _perform_comparisons_resolve( for i, j in pairs ] for future, (i, j) in zip(futures, pairs): - is_match, cost = future.result() + is_match, cost, _ = future.result() comparisons.append((i, j, is_match)) total_cost += cost diff --git a/website/src/app/api/utils.ts b/website/src/app/api/utils.ts index 150d05e8..b3274cc0 100644 --- a/website/src/app/api/utils.ts +++ b/website/src/app/api/utils.ts @@ -124,6 +124,7 @@ export function generatePipelineConfig( return { ...newOp, + enable_observability: true, output: { schema: op.output.schema.reduce( (acc: Record, item: SchemaItem) => { diff --git a/website/src/components/Output.tsx b/website/src/components/Output.tsx index 1cb1f87b..5cbc4edf 100644 --- a/website/src/components/Output.tsx +++ b/website/src/components/Output.tsx @@ -310,6 +310,7 @@ export const Output: React.FC = () => { : [] } startingRowHeight={180} + currentOperation={opName} /> ) : ( diff --git a/website/src/components/ResizableDataTable.tsx b/website/src/components/ResizableDataTable.tsx index 26e50c73..59d263a4 100644 --- a/website/src/components/ResizableDataTable.tsx +++ b/website/src/components/ResizableDataTable.tsx @@ -31,7 +31,13 @@ import { TableRow, } from "@/components/ui/table"; import { Button } from "@/components/ui/button"; -import { ChevronLeft, ChevronRight, ChevronDown, Search } from "lucide-react"; +import { + ChevronLeft, + ChevronRight, + ChevronDown, + Search, + Eye, +} from "lucide-react"; import { DropdownMenu, DropdownMenuCheckboxItem, @@ -43,6 +49,11 @@ import ReactMarkdown from "react-markdown"; import debounce from "lodash/debounce"; import { BarChart, Bar, XAxis, Tooltip, ResponsiveContainer } from "recharts"; import { Input } from "@/components/ui/input"; +import { + HoverCard, + HoverCardContent, + HoverCardTrigger, +} from "@/components/ui/hover-card"; export type DataType = Record; export type ColumnType = ColumnDef & { @@ -446,6 +457,8 @@ const ColumnHeader = React.memo( ? " items" : stats.type === "string-words" ? " words" + : stats.type === "string-chars" + ? " chars" : ""} avg: {Math.round(stats.avg)} @@ -455,6 +468,8 @@ const ColumnHeader = React.memo( ? " items" : stats.type === "string-words" ? " words" + : stats.type === "string-chars" + ? " chars" : ""} @@ -537,6 +552,7 @@ interface ResizableDataTableProps { columns: ColumnType[]; boldedColumns: string[]; startingRowHeight?: number; + currentOperation: string; } interface MarkdownCellProps { @@ -786,11 +802,61 @@ const SearchableCell = React.memo( ); SearchableCell.displayName = "SearchableCell"; +interface ObservabilityIndicatorProps { + row: Record; + currentOperation: string; +} + +const ObservabilityIndicator = React.memo( + ({ row, currentOperation }: ObservabilityIndicatorProps) => { + // Only show observability data for the current operation + const observabilityEntries = Object.entries(row).filter( + ([key]) => key === `_observability_${currentOperation}` + ); + + if (observabilityEntries.length === 0) return null; + + return ( + + +
+ +
+
+ +
+

+ LLM Call(s) for {currentOperation} +

+
+ {observabilityEntries.map(([key, value]) => ( +
+
+ {typeof value === "object" + ? JSON.stringify(value, null, 2) + : String(value)} +
+
+ ))} +
+
+
+
+ ); + } +); +ObservabilityIndicator.displayName = "ObservabilityIndicator"; + function ResizableDataTable({ data, columns, boldedColumns, startingRowHeight = 60, + currentOperation, }: ResizableDataTableProps) { const [columnSizing, setColumnSizing] = useState(() => { const savedSettings = localStorage.getItem(TABLE_SETTINGS_KEY); @@ -892,34 +958,39 @@ function ResizableDataTable({ const table = useReactTable({ data, - columns: sortedColumns.map((col) => ({ - ...col, - enableSorting: true, - filterFn: fuzzyFilter, - sortingFn: (rowA: Row, rowB: Row) => { - const accessor = col.accessorKey; - if (!accessor) return 0; - - const a = rowA.getValue(accessor); - const b = rowB.getValue(accessor); - - // Handle null/undefined values - if (a == null) return -1; - if (b == null) return 1; - - // Sort based on type - if (typeof a === "number" && typeof b === "number") { - return a - b; - } - - if (Array.isArray(a) && Array.isArray(b)) { - return a.length - b.length; - } - - // For strings, do alphabetical comparison - return String(a).localeCompare(String(b)); - }, - })), + columns: sortedColumns + .filter((col) => { + const columnId = col.accessorKey || col.id; + return !columnId?.startsWith("_observability_"); + }) + .map((col) => ({ + ...col, + enableSorting: true, + filterFn: fuzzyFilter, + sortingFn: (rowA: Row, rowB: Row) => { + const accessor = col.accessorKey; + if (!accessor) return 0; + + const a = rowA.getValue(accessor); + const b = rowB.getValue(accessor); + + // Handle null/undefined values + if (a == null) return -1; + if (b == null) return 1; + + // Sort based on type + if (typeof a === "number" && typeof b === "number") { + return a - b; + } + + if (Array.isArray(a) && Array.isArray(b)) { + return a.length - b.length; + } + + // For strings, do alphabetical comparison + return String(a).localeCompare(String(b)); + }, + })), columnResizeMode: "onChange" as ColumnResizeMode, getCoreRowModel: getCoreRowModel(), getPaginationRowModel: getPaginationRowModel(), @@ -1065,7 +1136,7 @@ function ResizableDataTable({
@@ -1126,9 +1197,15 @@ function ResizableDataTable({ textAlign: "center", }} > - - {row.index + 1} - +
+ + {row.index + 1} + + +
{row.getVisibleCells().map((cell) => (