diff --git a/docetl/dataset.py b/docetl/dataset.py index 9d53753c..c90fda4e 100644 --- a/docetl/dataset.py +++ b/docetl/dataset.py @@ -146,17 +146,12 @@ def _validate_parsing( return [] for tool in parsing_tools: - if ( - not isinstance(tool, dict) - or "function" not in tool - ): + if not isinstance(tool, dict) or "function" not in tool: raise ValueError( "Each parsing tool must be a dictionary with a 'function' key and any arguments required by that function" ) if not isinstance(tool["function"], str): - raise ValueError( - "'function' in parsing tools must be a string" - ) + raise ValueError("'function' in parsing tools must be a string") if "function_kwargs" in tool and not isinstance( tool["function_kwargs"], dict ): @@ -212,7 +207,7 @@ def _process_item( ): result = func(item, **function_kwargs) return [item.copy() | res for res in result] - + def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]: """ Apply parsing tools to the data. @@ -233,7 +228,7 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]: # with the existing yaml format... if "function_kwargs" in function_kwargs: function_kwargs.update(function_kwargs.pop("function_kwargs")) - + try: func = get_parser(tool["function"]) except KeyError: @@ -243,7 +238,8 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]: ): # Define the custom function in the current scope exec( - self.user_defined_parsing_tool_map[ + "from typing import List, Dict\n" + + self.user_defined_parsing_tool_map[ tool["function"] ].function_code ) diff --git a/docetl/operations/utils.py b/docetl/operations/utils.py index f7313690..a5807168 100644 --- a/docetl/operations/utils.py +++ b/docetl/operations/utils.py @@ -470,7 +470,9 @@ def _cached_call_llm( if gleaning_config: # Retry gleaning prompt + regular LLM num_gleaning_rounds = gleaning_config.get("num_rounds", 2) - validator_prompt_template = Template(gleaning_config["prompt"]) + validator_prompt_template = Template( + gleaning_config["validation_prompt"] + ) parsed_output = self.parse_llm_response( response, output_schema, tools @@ -484,9 +486,7 @@ def _cached_call_llm( } ] + messages - + [ - {"role": "assistant", "content": json.dumps(parsed_output)}, - ] + + [{"role": "assistant", "content": json.dumps(parsed_output)}] ) for rnd in range(num_gleaning_rounds): @@ -551,9 +551,10 @@ def _cached_call_llm( parsed_output = self.parse_llm_response( response, output_schema, tools )[0] - validator_messages[-1] = [ - {"role": "assistant", "content": json.dumps(parsed_output)}, - ] + validator_messages[-1] = { + "role": "assistant", + "content": json.dumps(parsed_output), + } total_cost += completion_cost(response) diff --git a/tests/basic/test_basic_map.py b/tests/basic/test_basic_map.py index 21a21d95..357a90b0 100644 --- a/tests/basic/test_basic_map.py +++ b/tests/basic/test_basic_map.py @@ -191,9 +191,10 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra map_config_with_gleaning = { **simple_map_config, "gleaning": { - "num_rounds": 1, + "num_rounds": 2, "validation_prompt": "Review the sentiment analysis. Is it accurate? If not, suggest improvements.", }, + "bypass_cache": True, } operation = MapOperation(api_wrapper, map_config_with_gleaning, "gpt-4o-mini", 4)