Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optionally persist intermediates for reduce #14

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tests:

tests-basic:
poetry run pytest tests/test_basic.py
poetry run pytest tests/test_api.py

lint:
poetry run ruff check docetl/* --fix
Expand Down
54 changes: 54 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, *args, **kwargs):
if isinstance(self.config["reduce_key"], str)
else self.config["reduce_key"]
)
self.intermediates = {}

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -351,6 +352,14 @@ def process_group(
if output is not None:
results.append(output)

if self.config.get("persist_intermediates", False):
for result in results:
key = tuple(result[k] for k in self.config["reduce_key"])
if key in self.intermediates:
result[f"_{self.config['name']}_intermediates"] = (
self.intermediates[key]
)

return results, total_cost

def _get_embeddings(
Expand Down Expand Up @@ -464,6 +473,10 @@ def calculate_num_parallel_folds():
fold_results = []
remaining_items = group_list

if self.config.get("persist_intermediates", False):
self.intermediates[key] = []
iter_count = 0

# Parallel folding and merging
with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
while remaining_items:
Expand All @@ -485,6 +498,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_fold_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": result["updated_scratchpad"],
}
)
iter_count += 1

# Update fold_results with new results
fold_results = new_fold_results + fold_results[len(new_fold_results) :]
Expand All @@ -507,6 +529,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": None,
}
)
iter_count += 1

fold_results = new_results

Expand Down Expand Up @@ -538,6 +569,15 @@ def calculate_num_parallel_folds():
total_cost += cost
if result is not None:
new_results.append(result)
if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": result,
"scratchpad": None,
}
)
iter_count += 1

fold_results = new_results

Expand Down Expand Up @@ -567,6 +607,10 @@ def _incremental_reduce(
num_folds = (len(group_list) + fold_batch_size - 1) // fold_batch_size

scratchpad = ""
if self.config.get("persist_intermediates", False):
self.intermediates[key] = []
iter_count = 0

for i in range(0, len(group_list), fold_batch_size):
# Log the current iteration and total number of folds
current_fold = i // fold_batch_size + 1
Expand All @@ -584,6 +628,16 @@ def _incremental_reduce(
if folded_output is None:
continue

if self.config.get("persist_intermediates", False):
self.intermediates[key].append(
{
"iter": iter_count,
"intermediate": folded_output,
"scratchpad": folded_output["updated_scratchpad"],
}
)
iter_count += 1

# Pop off updated_scratchpad
if "updated_scratchpad" in folded_output:
scratchpad = folded_output["updated_scratchpad"]
Expand Down
1 change: 1 addition & 0 deletions docs/operators/reduce.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ This Reduce operation processes customer feedback grouped by department:
| `fold_batch_size` | Number of items to process in each fold operation | None |
| `value_sampling` | A dictionary specifying the sampling strategy for large groups | None |
| `verbose` | If true, enables detailed logging of the reduce operation | false |
| `persist_intermediates` | If true, persists the intermediate results for each group to the key `_{operation_name}_intermediates` | false |

## Advanced Features

Expand Down
68 changes: 68 additions & 0 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,71 @@ def test_reduce_operation_non_associative(default_model, max_threads):
assert combined_result.index("brave princess") < combined_result.index(
"dragon"
), "Princess should be mentioned before the dragon in the story"


def test_reduce_operation_persist_intermediates(default_model, max_threads):
# Define a config with persist_intermediates enabled
persist_intermediates_config = {
"name": "persist_intermediates_reduce",
"type": "reduce",
"reduce_key": "group",
"persist_intermediates": True,
"prompt": "Summarize the numbers in '{{ inputs }}'.",
"fold_prompt": "Combine summaries: Previous '{{ output }}', New '{{ inputs[0] }}'.",
"fold_batch_size": 2,
"output": {"schema": {"summary": "string"}},
}

# Sample data with more items than fold_batch_size
sample_data = [
{"group": "numbers", "value": "1, 2"},
{"group": "numbers", "value": "3, 4"},
{"group": "numbers", "value": "5, 6"},
{"group": "numbers", "value": "7, 8"},
{"group": "numbers", "value": "9, 10"},
]

operation = ReduceOperation(
persist_intermediates_config, default_model, max_threads
)
results, cost = operation.execute(sample_data)

assert len(results) == 1, "Should have one result for the 'numbers' group"
assert cost > 0, "Cost should be greater than 0"

result = results[0]
assert "summary" in result, "Result should have a 'summary' key"

# Check if intermediates were persisted
assert (
"_persist_intermediates_reduce_intermediates" in result
), "Result should have '_persist_intermediates_reduce_intermediates' key"
intermediates = result["_persist_intermediates_reduce_intermediates"]
assert isinstance(intermediates, list), "Intermediates should be a list"
assert len(intermediates) > 1, "Should have multiple intermediate results"

# Check the structure of intermediates
for intermediate in intermediates:
assert "iter" in intermediate, "Each intermediate should have an 'iter' key"
assert (
"intermediate" in intermediate
), "Each intermediate should have an 'intermediate' key"
assert (
"scratchpad" in intermediate
), "Each intermediate should have a 'scratchpad' key"

# Verify that the intermediate results are stored in the correct order
for i in range(1, len(intermediates)):
assert (
intermediates[i]["iter"] > intermediates[i - 1]["iter"]
), "Intermediate results should be in ascending order of iterations"

# Check if the intermediate results are accessible via the special key
for result in results:
assert (
f"_persist_intermediates_reduce_intermediates" in result
), "Result should contain the special intermediate key"
stored_intermediates = result[f"_persist_intermediates_reduce_intermediates"]
assert (
stored_intermediates == intermediates
), "Stored intermediates should match the operation's intermediates"
Loading