Skip to content

Commit

Permalink
Merge pull request #14 from ucbepic/shreyashankar/blockingerr
Browse files Browse the repository at this point in the history
Optionally persist intermediates for reduce
  • Loading branch information
shreyashankar authored Sep 25, 2024
2 parents c17c6ae + 0466ad9 commit be8a8bb
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tests:

tests-basic:
poetry run pytest tests/test_basic.py
poetry run pytest tests/test_api.py

lint:
poetry run ruff check docetl/* --fix
Expand Down
54 changes: 54 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, *args, **kwargs):
if isinstance(self.config["reduce_key"], str)
else self.config["reduce_key"]
)
self.intermediates = {}

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -351,6 +352,14 @@ def process_group(
if output is not None:
results.append(output)

if self.config.get("persist_intermediates", False):
for result in results:
key = tuple(result[k] for k in self.config["reduce_key"])
if key in self.intermediates:
result[f"_{self.config['name']}_intermediates"] = (
self.intermediates[key]
)

return results, total_cost

def _get_embeddings(
Expand Down Expand Up @@ -464,6 +473,10 @@ def calculate_num_parallel_folds():
fold_results = []
remaining_items = group_list

if self.config.get("persist_intermediates", False):
self.intermediates[key] = []
iter_count = 0

# Parallel folding and merging
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
while remaining_items:
Expand All @@ -485,6 +498,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_fold_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": result["updated_scratchpad"],
}
)
iter_count += 1

# Update fold_results with new results
fold_results = new_fold_results + fold_results[len(new_fold_results) :]
Expand All @@ -507,6 +529,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": None,
}
)
iter_count += 1

fold_results = new_results

Expand Down Expand Up @@ -538,6 +569,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": None,
}
)
iter_count += 1

fold_results = new_results

Expand Down Expand Up @@ -567,6 +607,10 @@ def _incremental_reduce(
num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size

scratchpad = ""
if self.config.get("persist_intermediates", False):
self.intermediates[key] = []
iter_count = 0

for i in range(0, len(group_list), fold_batch_size):
# Log the current iteration and total number of folds
current_fold = i // fold_batch_size + 1
Expand All @@ -584,6 +628,16 @@ def _incremental_reduce(
if folded_output is None:
continue

if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": folded_output,
"scratchpad": folded_output["updated_scratchpad"],
}
)
iter_count += 1

# Pop off updated_scratchpad
if "updated_scratchpad" in folded_output:
scratchpad = folded_output["updated_scratchpad"]
Expand Down
1 change: 1 addition & 0 deletions docs/operators/reduce.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ This Reduce operation processes customer feedback grouped by department:
| `fold_batch_size` | Number of items to process in each fold operation | None |
| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None |
| `verbose` | If true, enables detailed logging of the reduce operation | false |
| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false |

## Advanced Features

Expand Down
68 changes: 68 additions & 0 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,71 @@ def test_reduce_operation_non_associative(default_model, max_threads):
assert combined_result.index("brave princess") < combined_result.index(
"dragon"
), "Princess should be mentioned before the dragon in the story"


def test_reduce_operation_persist_intermediates(default_model, max_threads):
# Define a config with persist_intermediates enabled
persist_intermediates_config = {
"name": "persist_intermediates_reduce",
"type": "reduce",
"reduce_key": "group",
"persist_intermediates": True,
"prompt": "Summarize the numbers in '{{ inputs }}'.",
"fold_prompt": "Combine summaries: Previous '{{ output }}', New '{{ inputs[0] }}'.",
"fold_batch_size": 2,
"output": {"schema": {"summary": "string"}},
}

# Sample data with more items than fold_batch_size
sample_data = [
{"group": "numbers", "value": "1, 2"},
{"group": "numbers", "value": "3, 4"},
{"group": "numbers", "value": "5, 6"},
{"group": "numbers", "value": "7, 8"},
{"group": "numbers", "value": "9, 10"},
]

operation = ReduceOperation(
persist_intermediates_config, default_model, max_threads
)
results, cost = operation.execute(sample_data)

assert len(results) == 1, "Should have one result for the 'numbers' group"
assert cost > 0, "Cost should be greater than 0"

result = results[0]
assert "summary" in result, "Result should have a 'summary' key"

# Check if intermediates were persisted
assert (
"_persist_intermediates_reduce_intermediates" in result
), "Result should have '_persist_intermediates_reduce_intermediates' key"
intermediates = result["_persist_intermediates_reduce_intermediates"]
assert isinstance(intermediates, list), "Intermediates should be a list"
assert len(intermediates) > 1, "Should have multiple intermediate results"

# Check the structure of intermediates
for intermediate in intermediates:
assert "iter" in intermediate, "Each intermediate should have an 'iter' key"
assert (
"intermediate" in intermediate
), "Each intermediate should have an 'intermediate' key"
assert (
"scratchpad" in intermediate
), "Each intermediate should have a 'scratchpad' key"

# Verify that the intermediate results are stored in the correct order
for i in range(1, len(intermediates)):
assert (
intermediates[i]["iter"] > intermediates[i - 1]["iter"]
), "Intermediate results should be in ascending order of iterations"

# Check if the intermediate results are accessible via the special key
for result in results:
assert (
f"_persist_intermediates_reduce_intermediates" in result
), "Result should contain the special intermediate key"
stored_intermediates = result[f"_persist_intermediates_reduce_intermediates"]
assert (
stored_intermediates == intermediates
), "Stored intermediates should match the operation's intermediates"

0 comments on commit be8a8bb

Please sign in to comment.