Skip to content

Commit

Permalink
feat: add reduce operation lineage
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyashankar committed Oct 12, 2024
1 parent 72a33cf commit 6d6e0e4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
25 changes: 25 additions & 0 deletions docetl/operations/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs):
else self.config["reduce_key"]
)
self.intermediates = {}
self.lineage_keys = self.config.get("output", {}).get("lineage", [])

def syntax_check(self) -> None:
"""
Expand Down Expand Up @@ -258,6 +259,19 @@ def syntax_check(self) -> None:
f"'embedding_keys' is required when using embedding-based sampling in {self.config['name']}"
)

# Check if lineage is a list of strings
if "lineage" in self.config.get("output", {}):
if not isinstance(self.config["output"]["lineage"], list):
raise TypeError(
f"'lineage' in {self.config['name']} 'output' configuration must be a list"
)
if not all(
isinstance(key, str) for key in self.config["output"]["lineage"]
):
raise TypeError(
f"All elements in 'lineage' list in {self.config['name']} 'output' configuration must be strings"
)

self.gleaning_check()

def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]:
Expand Down Expand Up @@ -363,6 +377,17 @@ def process_group(
if k not in self.config["output"]["schema"] and k not in result:
result[k] = v

# Add lineage information
if result is not None and self.lineage_keys:
lineage = []
for item in group_elems:
lineage_item = {
k: item.get(k) for k in self.lineage_keys if k in item
}
if lineage_item:
lineage.append(lineage_item)
result[f"{self.config['name']}_lineage"] = lineage

return result, total_cost

with ThreadPoolExecutor(max_workers=self.max_threads) as executor:
Expand Down
21 changes: 21 additions & 0 deletions tests/basic/test_basic_reduce_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,24 @@ def test_resolve_operation_empty_input(resolve_config, max_threads, api_wrapper)

assert len(results) == 0
assert cost == 0


def test_reduce_operation_with_lineage(
reduce_config, max_threads, reduce_sample_data, api_wrapper
):
# Add lineage configuration to reduce_config
reduce_config["output"]["lineage"] = ["name", "email"]

operation = ReduceOperation(
api_wrapper, reduce_config, "text-embedding-3-small", max_threads
)
results, cost = operation.execute(reduce_sample_data)

# Check if lineage information is present in the results
assert all(f"{reduce_config['name']}_lineage" in result for result in results)

# Check if lineage contains the specified keys
for result in results:
lineage = result[f"{reduce_config['name']}_lineage"]
assert all(isinstance(item, dict) for item in lineage)
assert all("name" in item and "email" in item for item in lineage)

0 comments on commit 6d6e0e4

Please sign in to comment.