From 0466ad9f7b1411c21cb3df42f6d2b49ad3f67b23 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Tue, 24 Sep 2024 23:44:11 -0700 Subject: [PATCH] Optionally persist intermediates for reduce --- Makefile | 1 + docetl/operations/reduce.py | 54 +++++++++++++++++++++++++++++ docs/operators/reduce.md | 1 + tests/test_reduce.py | 68 +++++++++++++++++++++++++++++++++++++ 4 files changed, 124 insertions(+) diff --git a/Makefile b/Makefile index a8c5c631..a27707fa 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ tests: tests-basic: poetry run pytest tests/test_basic.py + poetry run pytest tests/test_api.py lint: poetry run ruff check docetl/* --fix diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 56ca6b4c..a9f89ddd 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -60,6 +60,7 @@ def __init__(self, *args, **kwargs): if isinstance(self.config["reduce_key"], str) else self.config["reduce_key"] ) + self.intermediates = {} def syntax_check(self) -> None: """ @@ -351,6 +352,14 @@ def process_group( if output is not None: results.append(output) + if self.config.get("persist_intermediates", False): + for result in results: + key = tuple(result[k] for k in self.config["reduce_key"]) + if key in self.intermediates: + result[f"_{self.config['name']}_intermediates"] = ( + self.intermediates[key] + ) + return results, total_cost def _get_embeddings( @@ -464,6 +473,10 @@ def calculate_num_parallel_folds(): fold_results = [] remaining_items = group_list + if self.config.get("persist_intermediates", False): + self.intermediates[key] = [] + iter_count = 0 + # Parallel folding and merging with ThreadPoolExecutor(max_workers=self.max_threads) as executor: while remaining_items: @@ -485,6 +498,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_fold_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": result["updated_scratchpad"], + } + ) + iter_count += 1 # Update fold_results with new results fold_results = new_fold_results + fold_results[len(new_fold_results) :] @@ -507,6 +529,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": None, + } + ) + iter_count += 1 fold_results = new_results @@ -538,6 +569,15 @@ def calculate_num_parallel_folds(): total_cost += cost if result is not None: new_results.append(result) + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": result, + "scratchpad": None, + } + ) + iter_count += 1 fold_results = new_results @@ -567,6 +607,10 @@ def _incremental_reduce( num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size scratchpad = "" + if self.config.get("persist_intermediates", False): + self.intermediates[key] = [] + iter_count = 0 + for i in range(0, len(group_list), fold_batch_size): # Log the current iteration and total number of folds current_fold = i // fold_batch_size + 1 @@ -584,6 +628,16 @@ def _incremental_reduce( if folded_output is None: continue + if self.config.get("persist_intermediates", False): + self.intermediates[key].append( + { + "iter": iter_count, + "intermediate": folded_output, + "scratchpad": folded_output["updated_scratchpad"], + } + ) + iter_count += 1 + # Pop off updated_scratchpad if "updated_scratchpad" in folded_output: scratchpad = folded_output["updated_scratchpad"] diff --git a/docs/operators/reduce.md b/docs/operators/reduce.md index 331deb34..0f782f43 100644 --- a/docs/operators/reduce.md +++ b/docs/operators/reduce.md @@ -60,6 +60,7 @@ This Reduce operation processes customer feedback grouped by department: | `fold_batch_size` | Number of items to process in each fold operation | None | | `value_sampling` | A dictionary specifying the sampling strategy for large groups | None | | `verbose` | If true, enables detailed logging of the reduce operation | false | +| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false | ## Advanced Features diff --git a/tests/test_reduce.py b/tests/test_reduce.py index bb528a5c..c8fdf65b 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -194,3 +194,71 @@ def test_reduce_operation_non_associative(default_model, max_threads): assert combined_result.index("brave princess") < combined_result.index( "dragon" ), "Princess should be mentioned before the dragon in the story" + + +def test_reduce_operation_persist_intermediates(default_model, max_threads): + # Define a config with persist_intermediates enabled + persist_intermediates_config = { + "name": "persist_intermediates_reduce", + "type": "reduce", + "reduce_key": "group", + "persist_intermediates": True, + "prompt": "Summarize the numbers in '{{ inputs }}'.", + "fold_prompt": "Combine summaries: Previous '{{ output }}', New '{{ inputs[0] }}'.", + "fold_batch_size": 2, + "output": {"schema": {"summary": "string"}}, + } + + # Sample data with more items than fold_batch_size + sample_data = [ + {"group": "numbers", "value": "1, 2"}, + {"group": "numbers", "value": "3, 4"}, + {"group": "numbers", "value": "5, 6"}, + {"group": "numbers", "value": "7, 8"}, + {"group": "numbers", "value": "9, 10"}, + ] + + operation = ReduceOperation( + persist_intermediates_config, default_model, max_threads + ) + results, cost = operation.execute(sample_data) + + assert len(results) == 1, "Should have one result for the 'numbers' group" + assert cost > 0, "Cost should be greater than 0" + + result = results[0] + assert "summary" in result, "Result should have a 'summary' key" + + # Check if intermediates were persisted + assert ( + "_persist_intermediates_reduce_intermediates" in result + ), "Result should have '_persist_intermediates_reduce_intermediates' key" + intermediates = result["_persist_intermediates_reduce_intermediates"] + assert isinstance(intermediates, list), "Intermediates should be a list" + assert len(intermediates) > 1, "Should have multiple intermediate results" + + # Check the structure of intermediates + for intermediate in intermediates: + assert "iter" in intermediate, "Each intermediate should have an 'iter' key" + assert ( + "intermediate" in intermediate + ), "Each intermediate should have an 'intermediate' key" + assert ( + "scratchpad" in intermediate + ), "Each intermediate should have a 'scratchpad' key" + + # Verify that the intermediate results are stored in the correct order + for i in range(1, len(intermediates)): + assert ( + intermediates[i]["iter"] > intermediates[i - 1]["iter"] + ), "Intermediate results should be in ascending order of iterations" + + # Check if the intermediate results are accessible via the special key + for result in results: + assert ( + f"_persist_intermediates_reduce_intermediates" in result + ), "Result should contain the special intermediate key" + stored_intermediates = result[f"_persist_intermediates_reduce_intermediates"] + assert ( + stored_intermediates == intermediates + ), "Stored intermediates should match the operation's intermediates"