Skip to content

Commit

Permalink
Merge pull request #102 from ucbepic/lineage
Browse files Browse the repository at this point in the history
fix: change gleaning prompt to validation_prompt
  • Loading branch information
shreyashankar authored Oct 13, 2024
2 parents d279bb8 + 7b1e04d commit f65543e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
16 changes: 6 additions & 10 deletions docetl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,12 @@ def _validate_parsing(
return []

for tool in parsing_tools:
if (
not isinstance(tool, dict)
or "function" not in tool
):
if not isinstance(tool, dict) or "function" not in tool:
raise ValueError(
"Each parsing tool must be a dictionary with a 'function' key and any arguments required by that function"
)
if not isinstance(tool["function"], str):
raise ValueError(
"'function' in parsing tools must be a string"
)
raise ValueError("'function' in parsing tools must be a string")
if "function_kwargs" in tool and not isinstance(
tool["function_kwargs"], dict
):
Expand Down Expand Up @@ -212,7 +207,7 @@ def _process_item(
):
result = func(item, **function_kwargs)
return [item.copy() | res for res in result]

def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
"""
Apply parsing tools to the data.
Expand All @@ -233,7 +228,7 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
# with the existing yaml format...
if "function_kwargs" in function_kwargs:
function_kwargs.update(function_kwargs.pop("function_kwargs"))

try:
func = get_parser(tool["function"])
except KeyError:
Expand All @@ -243,7 +238,8 @@ def _apply_parsing_tools(self, data: List[Dict]) -> List[Dict]:
):
# Define the custom function in the current scope
exec(
self.user_defined_parsing_tool_map[
"from typing import List, Dict\n"
+ self.user_defined_parsing_tool_map[
tool["function"]
].function_code
)
Expand Down
15 changes: 8 additions & 7 deletions docetl/operations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,9 @@ def _cached_call_llm(
if gleaning_config:
# Retry gleaning prompt + regular LLM
num_gleaning_rounds = gleaning_config.get("num_rounds", 2)
validator_prompt_template = Template(gleaning_config["prompt"])
validator_prompt_template = Template(
gleaning_config["validation_prompt"]
)

parsed_output = self.parse_llm_response(
response, output_schema, tools
Expand All @@ -484,9 +486,7 @@ def _cached_call_llm(
}
]
+ messages
+ [
{"role": "assistant", "content": json.dumps(parsed_output)},
]
+ [{"role": "assistant", "content": json.dumps(parsed_output)}]
)

for rnd in range(num_gleaning_rounds):
Expand Down Expand Up @@ -551,9 +551,10 @@ def _cached_call_llm(
parsed_output = self.parse_llm_response(
response, output_schema, tools
)[0]
validator_messages[-1] = [
{"role": "assistant", "content": json.dumps(parsed_output)},
]
validator_messages[-1] = {
"role": "assistant",
"content": json.dumps(parsed_output),
}

total_cost += completion_cost(response)

Expand Down
3 changes: 2 additions & 1 deletion tests/basic/test_basic_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def test_map_operation_with_gleaning(simple_map_config, map_sample_data, api_wra
map_config_with_gleaning = {
**simple_map_config,
"gleaning": {
"num_rounds": 1,
"num_rounds": 2,
"validation_prompt": "Review the sentiment analysis. Is it accurate? If not, suggest improvements.",
},
"bypass_cache": True,
}

operation = MapOperation(api_wrapper, map_config_with_gleaning, "gpt-4o-mini", 4)
Expand Down

0 comments on commit f65543e

Please sign in to comment.