From 173b99647a34832642c3034d48ce1bf861ad6e08 Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Fri, 20 Sep 2024 11:23:45 -0700 Subject: [PATCH 1/2] feat: add drop_keys parameter --- docetl/operations/map.py | 304 ++++++++++++++++++++++----------------- docs/operators/map.md | 27 ++-- tests/test_basic.py | 80 +++++++++++ 3 files changed, 269 insertions(+), 142 deletions(-) diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 4a37e257..676e60dc 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -1,5 +1,5 @@ """ -The `MapOperation` and `ParallelMapOperation` classes are subclasses of `BaseOperation` that perform mapping operations on input data. They use LLM-based processing to transform input items into output items based on specified prompts and schemas. +The `MapOperation` and `ParallelMapOperation` classes are subclasses of `BaseOperation` that perform mapping operations on input data. They use LLM-based processing to transform input items into output items based on specified prompts and schemas, and can also perform key dropping operations. """ from concurrent.futures import ThreadPoolExecutor @@ -25,66 +25,73 @@ def syntax_check(self) -> None: Checks the configuration of the MapOperation for required keys and valid structure. Raises: - ValueError: If required keys ('prompt' or 'output') are missing in the configuration. - ValueError: If 'schema' is missing in the 'output' configuration. - ValueError: If 'schema' in the 'output' configuration is empty. - ValueError: If the 'prompt' is not a valid Jinja2 template. - TypeError: If 'schema' in the 'output' configuration is not a dictionary. - TypeError: If 'model' is present in the configuration but is not a string. - ValueError: If any gleaning-related configuration is invalid (raised by self.gleaning_check()). + ValueError: If required keys are missing or invalid in the configuration. + TypeError: If configuration values have incorrect types. """ - required_keys = ["prompt", "output"] - for key in required_keys: - if key not in self.config: - raise ValueError( - f"Missing required key '{key}' in MapOperation configuration" + if "drop_keys" in self.config: + if not isinstance(self.config["drop_keys"], list): + raise TypeError( + "'drop_keys' in configuration must be a list of strings" ) + for key in self.config["drop_keys"]: + if not isinstance(key, str): + raise TypeError("All items in 'drop_keys' must be strings") - if "schema" not in self.config["output"]: - raise ValueError("Missing 'schema' in 'output' configuration") - - if not isinstance(self.config["output"]["schema"], dict): - raise TypeError("'schema' in 'output' configuration must be a dictionary") + if "prompt" in self.config or "output" in self.config: + required_keys = ["prompt", "output"] + for key in required_keys: + if key not in self.config: + raise ValueError( + f"Missing required key '{key}' in MapOperation configuration" + ) - if not self.config["output"]["schema"]: - raise ValueError("'schema' in 'output' configuration cannot be empty") + if "schema" not in self.config["output"]: + raise ValueError("Missing 'schema' in 'output' configuration") - # Check if the prompt is a valid Jinja2 template - try: - Template(self.config["prompt"]) - except Exception as e: - raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}") + if not isinstance(self.config["output"]["schema"], dict): + raise TypeError( + "'schema' in 'output' configuration must be a dictionary" + ) - # Check if the model is specified (optional) - if "model" in self.config and not isinstance(self.config["model"], str): - raise TypeError("'model' in configuration must be a string") + if not self.config["output"]["schema"]: + raise ValueError("'schema' in 'output' configuration cannot be empty") - # Check if tools are specified and validate their structure - if "tools" in self.config: - if not isinstance(self.config["tools"], list): - raise TypeError("'tools' in configuration must be a list") + # Check if the prompt is a valid Jinja2 template + try: + Template(self.config["prompt"]) + except Exception as e: + raise ValueError(f"Invalid Jinja2 template in 'prompt': {str(e)}") - for i, tool in enumerate(self.config["tools"]): - if not isinstance(tool, dict): - raise TypeError(f"Tool {i} in 'tools' must be a dictionary") + # Check if the model is specified (optional) + if "model" in self.config and not isinstance(self.config["model"], str): + raise TypeError("'model' in configuration must be a string") - if "code" not in tool or "function" not in tool: - raise ValueError( - f"Tool {i} is missing required 'code' or 'function' key" - ) + # Check if tools are specified and validate their structure + if "tools" in self.config: + if not isinstance(self.config["tools"], list): + raise TypeError("'tools' in configuration must be a list") - function = tool.get("function", {}) - if not isinstance(function, dict): - raise TypeError(f"'function' in tool {i} must be a dictionary") + for i, tool in enumerate(self.config["tools"]): + if not isinstance(tool, dict): + raise TypeError(f"Tool {i} in 'tools' must be a dictionary") - required_function_keys = ["name", "description", "parameters"] - for key in required_function_keys: - if key not in function: + if "code" not in tool or "function" not in tool: raise ValueError( - f"Tool {i} is missing required '{key}' in 'function'" + f"Tool {i} is missing required 'code' or 'function' key" ) - self.gleaning_check() + function = tool.get("function", {}) + if not isinstance(function, dict): + raise TypeError(f"'function' in tool {i} must be a dictionary") + + required_function_keys = ["name", "description", "parameters"] + for key in required_function_keys: + if key not in function: + raise ValueError( + f"Tool {i} is missing required '{key}' in 'function'" + ) + + self.gleaning_check() def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ @@ -97,15 +104,19 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation. This method performs the following steps: - 1. Processes each input item using the specified prompt and LLM model + 1. If a prompt is specified, it processes each input item using the specified prompt and LLM model 2. Applies gleaning if configured 3. Validates the output - 4. Aggregates results and calculates total cost + 4. If drop_keys is specified, it drops the specified keys from each document + 5. Aggregates results and calculates total cost The method uses parallel processing to improve performance. """ def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]: + if "prompt" not in self.config: + return item, 0.0 + prompt_template = Template(self.config["prompt"]) prompt = prompt_template.render(input=item) @@ -171,6 +182,12 @@ def validation_fn(response: Dict[str, Any]): for i in pbar: result, item_cost = futures[i].result() if result is not None: + if "drop_keys" in self.config: + result = { + k: v + for k, v in result.items() + if k not in self.config["drop_keys"] + } results.append(result) total_cost += item_cost pbar.update(i) @@ -204,67 +221,83 @@ def syntax_check(self) -> None: ValueError: If required keys are missing or if the configuration structure is invalid. TypeError: If the configuration values have incorrect types. """ - if "prompts" not in self.config or not isinstance(self.config["prompts"], list): - raise ValueError( - "ParallelMapOperation requires a 'prompts' list in the configuration" - ) + if "drop_keys" in self.config: + if not isinstance(self.config["drop_keys"], list): + raise TypeError( + "'drop_keys' in configuration must be a list of strings" + ) + for key in self.config["drop_keys"]: + if not isinstance(key, str): + raise TypeError("All items in 'drop_keys' must be strings") + + if "prompts" in self.config: + if not isinstance(self.config["prompts"], list): + raise ValueError( + "ParallelMapOperation requires a 'prompts' list in the configuration" + ) - if not self.config["prompts"]: - raise ValueError("The 'prompts' list cannot be empty") + if not self.config["prompts"]: + raise ValueError("The 'prompts' list cannot be empty") - for i, prompt_config in enumerate(self.config["prompts"]): - if not isinstance(prompt_config, dict): - raise TypeError(f"Prompt configuration {i} must be a dictionary") + for i, prompt_config in enumerate(self.config["prompts"]): + if not isinstance(prompt_config, dict): + raise TypeError(f"Prompt configuration {i} must be a dictionary") - required_keys = ["name", "prompt", "output_keys"] - for key in required_keys: - if key not in prompt_config: - raise ValueError( - f"Missing required key '{key}' in prompt configuration {i}" + required_keys = ["name", "prompt", "output_keys"] + for key in required_keys: + if key not in prompt_config: + raise ValueError( + f"Missing required key '{key}' in prompt configuration {i}" + ) + + if not isinstance(prompt_config["name"], str): + raise TypeError( + f"'name' in prompt configuration {i} must be a string" ) - if not isinstance(prompt_config["name"], str): - raise TypeError(f"'name' in prompt configuration {i} must be a string") + if not isinstance(prompt_config["prompt"], str): + raise TypeError( + f"'prompt' in prompt configuration {i} must be a string" + ) - if not isinstance(prompt_config["prompt"], str): - raise TypeError( - f"'prompt' in prompt configuration {i} must be a string" - ) + if not isinstance(prompt_config["output_keys"], list): + raise TypeError( + f"'output_keys' in prompt configuration {i} must be a list" + ) - if not isinstance(prompt_config["output_keys"], list): - raise TypeError( - f"'output_keys' in prompt configuration {i} must be a list" - ) + if not prompt_config["output_keys"]: + raise ValueError( + f"'output_keys' list in prompt configuration {i} cannot be empty" + ) - if not prompt_config["output_keys"]: - raise ValueError( - f"'output_keys' list in prompt configuration {i} cannot be empty" - ) + # Check if the prompt is a valid Jinja2 template + try: + Template(prompt_config["prompt"]) + except Exception as e: + raise ValueError( + f"Invalid Jinja2 template in prompt configuration {i}: {str(e)}" + ) - # Check if the prompt is a valid Jinja2 template - try: - Template(prompt_config["prompt"]) - except Exception as e: + # Check if the model is specified (optional) + if "model" in prompt_config and not isinstance( + prompt_config["model"], str + ): + raise TypeError( + f"'model' in prompt configuration {i} must be a string" + ) + + # Check if all output schema keys are covered by the prompts + output_schema = self.config["output"]["schema"] + output_keys_covered = set() + for prompt_config in self.config["prompts"]: + output_keys_covered.update(prompt_config["output_keys"]) + + missing_keys = set(output_schema.keys()) - output_keys_covered + if missing_keys: raise ValueError( - f"Invalid Jinja2 template in prompt configuration {i}: {str(e)}" + f"The following output schema keys are not covered by any prompt: {missing_keys}" ) - # Check if the model is specified (optional) - if "model" in prompt_config and not isinstance(prompt_config["model"], str): - raise TypeError(f"'model' in prompt configuration {i} must be a string") - - # Check if all output schema keys are covered by the prompts - output_schema = self.config["output"]["schema"] - output_keys_covered = set() - for prompt_config in self.config["prompts"]: - output_keys_covered.update(prompt_config["output_keys"]) - - missing_keys = set(output_schema.keys()) - output_keys_covered - if missing_keys: - raise ValueError( - f"The following output schema keys are not covered by any prompt: {missing_keys}" - ) - def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: """ Executes the parallel map operation on the provided input data. @@ -276,14 +309,15 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: Tuple[List[Dict], float]: A tuple containing the processed results and the total cost of the operation. This method performs the following steps: - 1. Processes each input item using multiple prompts in parallel + 1. If prompts are specified, it processes each input item using multiple prompts in parallel 2. Aggregates results from different prompts for each input item 3. Validates the combined output for each item - 4. Calculates total cost of the operation + 4. If drop_keys is specified, it drops the specified keys from each document + 5. Calculates total cost of the operation """ results = {} total_cost = 0 - output_schema = self.config["output"]["schema"] + output_schema = self.config.get("output", {}).get("schema", {}) def process_prompt(item, prompt_config): prompt_template = Template(prompt_config["prompt"]) @@ -307,40 +341,50 @@ def process_prompt(item, prompt_config): return output, completion_cost(response) with ThreadPoolExecutor(max_workers=self.max_threads) as executor: - # Create all futures at once - all_futures = [ - executor.submit(process_prompt, item, prompt_config) - for item in input_data - for prompt_config in self.config["prompts"] - ] - - # Process results in order - pbar = RichLoopBar( - range(len(all_futures)), - desc="Processing parallel map items", - console=self.console, - ) - for i in pbar: - future = all_futures[i] - output, cost = future.result() - total_cost += cost + if "prompts" in self.config: + # Create all futures at once + all_futures = [ + executor.submit(process_prompt, item, prompt_config) + for item in input_data + for prompt_config in self.config["prompts"] + ] + + # Process results in order + pbar = RichLoopBar( + range(len(all_futures)), + desc="Processing parallel map items", + console=self.console, + ) + for i in pbar: + future = all_futures[i] + output, cost = future.result() + total_cost += cost - # Determine which item this future corresponds to - item_index = i // len(self.config["prompts"]) - prompt_index = i % len(self.config["prompts"]) + # Determine which item this future corresponds to + item_index = i // len(self.config["prompts"]) + prompt_index = i % len(self.config["prompts"]) - # Initialize or update the item_result - if prompt_index == 0: - item_result = input_data[item_index].copy() - results[item_index] = item_result + # Initialize or update the item_result + if prompt_index == 0: + item_result = input_data[item_index].copy() + results[item_index] = item_result - # Fetch the item_result - item_result = results[item_index] + # Fetch the item_result + item_result = results[item_index] - # Update the item_result with the output - item_result.update(output) + # Update the item_result with the output + item_result.update(output) - pbar.update(i) + pbar.update(i) + else: + results = {i: item.copy() for i, item in enumerate(input_data)} + + # Apply drop_keys if specified + if "drop_keys" in self.config: + drop_keys = self.config["drop_keys"] + for item in results.values(): + for key in drop_keys: + item.pop(key, None) # Return the results in order return [results[i] for i in range(len(input_data)) if i in results], total_cost diff --git a/docs/operators/map.md b/docs/operators/map.md index bb197dcb..90afcc4f 100644 --- a/docs/operators/map.md +++ b/docs/operators/map.md @@ -126,21 +126,24 @@ This example demonstrates how the Map operation can transform long, unstructured - `name`: A unique name for the operation. - `type`: Must be set to "map". -- `prompt`: The prompt template to use for the transformation. Access input variables with `input.keyname`. -- `output`: Schema definition for the output from the LLM. ## Optional Parameters -| Parameter | Description | Default | -| --------------------------------- | -------------------------------------------------------------- | ----------------------------- | -| `model` | The language model to use | Falls back to `default_model` | -| `optimize` | Flag to enable operation optimization | `True` | -| `recursively_optimize` | Flag to enable recursive optimization | `false` | -| `sample_size` | Number of samples to use for the operation | Processes all data | -| `tools` | List of tool definitions for LLM use | None | -| `validate` | List of Python expressions to validate the output | None | -| `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | -| `gleaning` | Configuration for advanced validation and LLM-based refinement | None | +| Parameter | Description | Default | +| --------------------------------- | ----------------------------------------------------------------------------------------------- | ----------------------------- | +| `prompt` | The prompt template to use for the transformation. Access input variables with `input.keyname`. | None | +| `output` | Schema definition for the output from the LLM. | None | +| `model` | The language model to use | Falls back to `default_model` | +| `optimize` | Flag to enable operation optimization | `True` | +| `recursively_optimize` | Flag to enable recursive optimization of operators synthesized as part of rewrite rules | `false` | +| `sample_size` | Number of samples to use for the operation | Processes all data | +| `tools` | List of tool definitions for LLM use | None | +| `validate` | List of Python expressions to validate the output | None | +| `num_retries_on_validate_failure` | Number of retry attempts on validation failure | 0 | +| `gleaning` | Configuration for advanced validation and LLM-based refinement | None | +| `drop_keys` | List of keys to drop from the input before processing | None | + +Note: If `drop_keys` is specified, `prompt` and `output` become optional parameters. !!! info "Validation and Gleaning" diff --git a/tests/test_basic.py b/tests/test_basic.py index 44c359d8..026a231b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -63,6 +63,86 @@ def test_map_operation_empty_input(map_config, default_model, max_threads): assert cost == 0 +@pytest.fixture +def map_config_with_drop_keys(): + return { + "name": "sentiment_analysis_with_drop", + "type": "map", + "prompt": "Analyze the sentiment of the following text: '{{ input.text }}'. Classify it as either positive, negative, or neutral.", + "output": {"schema": {"sentiment": "string"}}, + "model": "gpt-4o-mini", + "drop_keys": ["original_sentiment"], + } + + +@pytest.fixture +def map_config_with_drop_keys_no_prompt(): + return { + "name": "drop_keys_only", + "type": "map", + "drop_keys": ["to_be_dropped"], + "model": "gpt-4o-mini", + } + + +@pytest.fixture +def map_sample_data_with_extra_keys(): + return [ + { + "text": "This is a positive sentence.", + "original_sentiment": "positive", + "to_be_dropped": "extra", + }, + { + "text": "This is a negative sentence.", + "original_sentiment": "negative", + "to_be_dropped": "extra", + }, + { + "text": "This is a neutral sentence.", + "original_sentiment": "neutral", + "to_be_dropped": "extra", + }, + ] + + +def test_map_operation_with_drop_keys( + map_config_with_drop_keys, + default_model, + max_threads, + map_sample_data_with_extra_keys, +): + operation = MapOperation(map_config_with_drop_keys, default_model, max_threads) + results, cost = operation.execute(map_sample_data_with_extra_keys) + + assert len(results) == len(map_sample_data_with_extra_keys) + assert all("sentiment" in result for result in results) + assert all("original_sentiment" not in result for result in results) + assert all("to_be_dropped" in result for result in results) + assert all( + result["sentiment"] in ["positive", "negative", "neutral"] for result in results + ) + assert cost > 0 + + +def test_map_operation_with_drop_keys_no_prompt( + map_config_with_drop_keys_no_prompt, + default_model, + max_threads, + map_sample_data_with_extra_keys, +): + operation = MapOperation( + map_config_with_drop_keys_no_prompt, default_model, max_threads + ) + results, cost = operation.execute(map_sample_data_with_extra_keys) + + assert len(results) == len(map_sample_data_with_extra_keys) + assert all("to_be_dropped" not in result for result in results) + assert all("text" in result for result in results) + assert all("original_sentiment" in result for result in results) + assert cost == 0 # No LLM calls should be made + + # Parallel Map Operation Tests @pytest.fixture def parallel_map_config(): From 9ea55c27448bd9f9845a65bb45a223d3ff2cd08c Mon Sep 17 00:00:00 2001 From: Shreya Shankar Date: Fri, 20 Sep 2024 11:31:02 -0700 Subject: [PATCH 2/2] feat: add drop_keys parameter --- docetl/operations/map.py | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 676e60dc..a4afd7f8 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -36,6 +36,11 @@ def syntax_check(self) -> None: for key in self.config["drop_keys"]: if not isinstance(key, str): raise TypeError("All items in 'drop_keys' must be strings") + else: + if "prompt" not in self.config or "output" not in self.config: + raise ValueError( + "If 'drop_keys' is not specified, both 'prompt' and 'output' must be present in the configuration" + ) if "prompt" in self.config or "output" in self.config: required_keys = ["prompt", "output"] @@ -112,11 +117,18 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: The method uses parallel processing to improve performance. """ + # Check if there's no prompt and only drop_keys + if "prompt" not in self.config and "drop_keys" in self.config: + # If only drop_keys is specified, simply drop the keys and return + dropped_results = [] + for item in input_data: + new_item = { + k: v for k, v in item.items() if k not in self.config["drop_keys"] + } + dropped_results.append(new_item) + return dropped_results, 0.0 # Return the modified data with no cost def _process_map_item(item: Dict) -> Tuple[Optional[Dict], float]: - if "prompt" not in self.config: - return item, 0.0 - prompt_template = Template(self.config["prompt"]) prompt = prompt_template.render(input=item) @@ -229,6 +241,11 @@ def syntax_check(self) -> None: for key in self.config["drop_keys"]: if not isinstance(key, str): raise TypeError("All items in 'drop_keys' must be strings") + else: + if "prompts" not in self.config: + raise ValueError( + "If 'drop_keys' is not specified, 'prompts' must be present in the configuration" + ) if "prompts" in self.config: if not isinstance(self.config["prompts"], list): @@ -319,6 +336,17 @@ def execute(self, input_data: List[Dict]) -> Tuple[List[Dict], float]: total_cost = 0 output_schema = self.config.get("output", {}).get("schema", {}) + # Check if there's no prompt and only drop_keys + if "prompts" not in self.config and "drop_keys" in self.config: + # If only drop_keys is specified, simply drop the keys and return + dropped_results = [] + for item in input_data: + new_item = { + k: v for k, v in item.items() if k not in self.config["drop_keys"] + } + dropped_results.append(new_item) + return dropped_results, 0.0 # Return the modified data with no cost + def process_prompt(item, prompt_config): prompt_template = Template(prompt_config["prompt"]) prompt = prompt_template.render(input=item)