Skip to content

Commit

Permalink
Add all key to reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Sep 14, 2024
1 parent 44159fd commit 49699e4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
30 changes: 17 additions & 13 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,23 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
reduce_keys = [reduce_keys]
input_schema = self.config.get("input", {}).get("schema", {})

# Group the input data by the reduce key(s) while maintaining original order
def get_group_key(item):
return tuple(item[key] for key in reduce_keys)

grouped_data = {}
for item in input_data:
key = get_group_key(item)
if key not in grouped_data:
grouped_data[key] = []
grouped_data[key].append(item)

# Convert the grouped data to a list of tuples
grouped_data = list(grouped_data.items())
# Check if we need to group everything into one group
if reduce_keys == ["_all"] or reduce_keys == "_all":
grouped_data = [("_all", input_data)]
else:
# Group the input data by the reduce key(s) while maintaining original order
def get_group_key(item):
return tuple(item[key] for key in reduce_keys)

grouped_data = {}
for item in input_data:
key = get_group_key(item)
if key not in grouped_data:
grouped_data[key] = []
grouped_data[key].append(item)

# Convert the grouped data to a list of tuples
grouped_data = list(grouped_data.items())

def process_group(
key: Tuple, group_elems: List[Dict]
Expand Down
11 changes: 11 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,17 @@ def test_reduce_operation(
assert cost > 0


def test_reduce_operation_with_all_key(
reduce_config, default_model, max_threads, reduce_sample_data
):
reduce_config["reduce_key"] = "_all"
operation = ReduceOperation(reduce_config, default_model, max_threads)
results, cost = operation.execute(reduce_sample_data)

assert len(results) == 1
assert cost > 0


def test_reduce_operation_with_list_key(
reduce_config, default_model, max_threads, reduce_sample_data_with_list_key
):
Expand Down

0 comments on commit 49699e4

Please sign in to comment.