diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index cf3afae3..a5086606 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -253,19 +253,23 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: reduce_keys = [reduce_keys] input_schema = self.config.get("input", {}).get("schema", {}) - # Group the input data by the reduce key(s) while maintaining original order - def get_group_key(item): - return tuple(item[key] for key in reduce_keys) - - grouped_data = {} - for item in input_data: - key = get_group_key(item) - if key not in grouped_data: - grouped_data[key] = [] - grouped_data[key].append(item) - - # Convert the grouped data to a list of tuples - grouped_data = list(grouped_data.items()) + # Check if we need to group everything into one group + if reduce_keys == ["_all"] or reduce_keys == "_all": + grouped_data = [("_all", input_data)] + else: + # Group the input data by the reduce key(s) while maintaining original order + def get_group_key(item): + return tuple(item[key] for key in reduce_keys) + + grouped_data = {} + for item in input_data: + key = get_group_key(item) + if key not in grouped_data: + grouped_data[key] = [] + grouped_data[key].append(item) + + # Convert the grouped data to a list of tuples + grouped_data = list(grouped_data.items()) def process_group( key: Tuple, group_elems: List[Dict] diff --git a/tests/test_basic.py b/tests/test_basic.py index ed5b216b..7d1bf6b4 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -488,6 +488,17 @@ def test_reduce_operation( assert cost > 0 +def test_reduce_operation_with_all_key( + reduce_config, default_model, max_threads, reduce_sample_data +): + reduce_config["reduce_key"] = "_all" + operation = ReduceOperation(reduce_config, default_model, max_threads) + results, cost = operation.execute(reduce_sample_data) + + assert len(results) == 1 + assert cost > 0 + + def test_reduce_operation_with_list_key( reduce_config, default_model, max_threads, reduce_sample_data_with_list_key ):