From 6d6e0e4bfc2a973261f09434714607f0a5b75f63 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Sat, 12 Oct 2024 18:06:36 -0400 Subject: [PATCH] feat: add reduce operation lineage --- docetl/operations/reduce.py | 25 ++++++++++++++++++++++++ tests/basic/test_basic_reduce_resolve.py | 21 ++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index 682b5d70..cd3cee78 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -55,6 +55,7 @@ def __init__(self, *args, **kwargs): else self.config["reduce_key"] ) self.intermediates = {} + self.lineage_keys = self.config.get("output", {}).get("lineage", []) def syntax_check(self) -> None: """ @@ -258,6 +259,19 @@ def syntax_check(self) -> None: f"'embedding_keys' is required when using embedding-based sampling in {self.config['name']}" ) + # Check if lineage is a list of strings + if "lineage" in self.config.get("output", {}): + if not isinstance(self.config["output"]["lineage"], list): + raise TypeError( + f"'lineage' in {self.config['name']} 'output' configuration must be a list" + ) + if not all( + isinstance(key, str) for key in self.config["output"]["lineage"] + ): + raise TypeError( + f"All elements in 'lineage' list in {self.config['name']} 'output' configuration must be strings" + ) + self.gleaning_check() def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: @@ -363,6 +377,17 @@ def process_group( if k not in self.config["output"]["schema"] and k not in result: result[k] = v + # Add lineage information + if result is not None and self.lineage_keys: + lineage = [] + for item in group_elems: + lineage_item = { + k: item.get(k) for k in self.lineage_keys if k in item + } + if lineage_item: + lineage.append(lineage_item) + result[f"{self.config['name']}_lineage"] = lineage + return result, total_cost with ThreadPoolExecutor(max_workers=self.max_threads) as executor: diff --git a/tests/basic/test_basic_reduce_resolve.py b/tests/basic/test_basic_reduce_resolve.py index a0a93c2f..5326b5ab 100644 --- a/tests/basic/test_basic_reduce_resolve.py +++ b/tests/basic/test_basic_reduce_resolve.py @@ -143,3 +143,24 @@ def test_resolve_operation_empty_input(resolve_config, max_threads, api_wrapper) assert len(results) == 0 assert cost == 0 + + +def test_reduce_operation_with_lineage( + reduce_config, max_threads, reduce_sample_data, api_wrapper +): + # Add lineage configuration to reduce_config + reduce_config["output"]["lineage"] = ["name", "email"] + + operation = ReduceOperation( + api_wrapper, reduce_config, "text-embedding-3-small", max_threads + ) + results, cost = operation.execute(reduce_sample_data) + + # Check if lineage information is present in the results + assert all(f"{reduce_config['name']}_lineage" in result for result in results) + + # Check if lineage contains the specified keys + for result in results: + lineage = result[f"{reduce_config['name']}_lineage"] + assert all(isinstance(item, dict) for item in lineage) + assert all("name" in item and "email" in item for item in lineage)