From a50e3ef9da4e73e916e71294649f42038b9df47b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Tue, 6 Aug 2024 09:21:54 +0200 Subject: [PATCH] Refactor state (#31) * Refactor and add more tests * Adjust report_v2 to new trajectory format --- .gitignore | 4 +- moatless/benchmark/claude_evaluation.py | 6 +- moatless/benchmark/evaluation.py | 60 +- moatless/benchmark/report_v2.py | 324 ++++---- moatless/edit/clarify.py | 18 +- moatless/edit/edit.py | 60 +- moatless/edit/plan.py | 75 +- moatless/edit/plan_lines.py | 33 +- moatless/edit/review.py | 37 +- moatless/find/decide.py | 57 +- moatless/find/identify.py | 22 +- moatless/find/search.py | 75 +- moatless/loop.py | 762 ++++++++----------- moatless/repository/git.py | 36 +- moatless/state.py | 167 ++-- moatless/trajectory.py | 213 ++++-- moatless/transition_rules.py | 59 +- moatless/transitions.py | 4 +- moatless/types.py | 65 +- moatless/utils/llm_utils.py | 18 + moatless/workspace.py | 6 +- notebooks/swebench/01_evaluate_search.ipynb | 111 +-- poetry.lock | 86 ++- pyproject.toml | 1 + tests/benchmark/test_evaluation.py | 82 ++ tests/benchmark/test_report_v2.py | 39 + tests/edit/test_clarify.py | 236 +++--- tests/edit/test_edit.py | 173 +++-- tests/edit/test_plan.py | 124 ++- tests/find/test_decide.py | 115 +++ tests/find/test_identify.py | 118 ++- tests/find/test_search.py | 68 ++ tests/loop/test_loop.py | 69 -- tests/test_loop.py | 137 ++++ tests/test_state.py | 215 +++--- tests/test_trajectory.py | 52 ++ tests/test_transition_rules.py | 101 +-- tests/trajectories/django__django_16379.json | 512 +++++++------ 38 files changed, 2552 insertions(+), 1788 deletions(-) create mode 100644 moatless/utils/llm_utils.py create mode 100644 tests/benchmark/test_evaluation.py create mode 100644 tests/benchmark/test_report_v2.py create mode 100644 tests/find/test_decide.py create mode 100644 tests/find/test_search.py delete mode 100644 tests/loop/test_loop.py create mode 100644 tests/test_loop.py create mode 100644 tests/test_trajectory.py diff --git a/.gitignore b/.gitignore index 99142214..d2803afb 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,6 @@ cython_debug/ notebooks/.ipynb_checkpoints/ notebooks/local_experiments.ipynb -playground \ No newline at end of file +playground +logs +Pipfile diff --git a/moatless/benchmark/claude_evaluation.py b/moatless/benchmark/claude_evaluation.py index 45c99b1c..6b44915f 100644 --- a/moatless/benchmark/claude_evaluation.py +++ b/moatless/benchmark/claude_evaluation.py @@ -4,7 +4,7 @@ import instructor -from moatless import Transitions +from moatless.transition_rules import TransitionRules from moatless.benchmark.evaluation import create_evaluation_name, Evaluation from moatless.edit.edit import EditCode from moatless.edit.plan import PlanToCode @@ -170,7 +170,7 @@ def run_evaluation(): def evaluate_search(): - transitions = Transitions( + transitions = TransitionRules( global_params=global_params, state_params={ SearchCode: {"max_search_results": 50, "provide_initial_context": True}, @@ -280,7 +280,7 @@ def evaluate_coding(): def evaluate_plan(previous_trajectory_dir: Optional[str] = None): - transitions = Transitions( + transitions = TransitionRules( global_params=global_params, state_params={ SearchCode: { diff --git a/moatless/benchmark/evaluation.py b/moatless/benchmark/evaluation.py index 78a73a18..1d0b9980 100644 --- a/moatless/benchmark/evaluation.py +++ b/moatless/benchmark/evaluation.py @@ -7,7 +7,7 @@ import traceback from collections import defaultdict from datetime import datetime, timezone -from typing import Optional +from typing import Optional, Tuple import instructor import litellm @@ -15,6 +15,7 @@ from tqdm.auto import tqdm from moatless.benchmark.report_v2 import to_result, generate_md_report +from moatless.trajectory import Trajectory from moatless.transition_rules import TransitionRules from moatless.benchmark.swebench import ( found_in_alternative_spans, @@ -82,6 +83,7 @@ def __init__( max_transitions: int = 25, max_expansions: int = 2, max_file_context_tokens: int = 16000, + markdown_report: bool = False, litellm_callback: Optional[str] = None, previous_trajectory_dir: Optional[str] = None, retry_state: Optional[str] = None, @@ -93,6 +95,7 @@ def __init__( self.evaluations_dir = evaluations_dir self.num_workers = num_workers self.detailed_report = detailed_report + self.markdown_report = markdown_report self.evaluation_name = evaluation_name self.max_file_context_tokens = max_file_context_tokens @@ -193,11 +196,12 @@ def run_single_instance( instance_id: str, dataset: str = "princeton-nlp/SWE-bench_Lite", split="test", - ): + ) -> dict: instance = load_instance(instance_id, dataset, split) - return self._evaluate_instance(instance) + trajectory = self._evaluate_instance(instance) + return to_result(instance, trajectory, self.report) - def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict: + def _evaluate_instance(self, instance: dict, retry: bool = False) -> Trajectory: instance_id = instance["instance_id"] trajectory_path = os.path.join(self.trajectory_dir, f"{instance_id}.json") prompt_log_dir = os.path.join(self.logs_dir, f"{instance_id}") @@ -205,10 +209,8 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict: os.makedirs(prompt_log_dir) if os.path.exists(trajectory_path) and not retry: - with open(trajectory_path) as file: - trajectory = json.load(file) - if trajectory["info"].get("status") or trajectory["info"].get("error"): - return trajectory + # TODO: Retry when failed or not finished? + return Trajectory.load(trajectory_path) repo_dir = setup_swebench_repo(instance) persist_dir = os.path.join(self.index_store_dir, get_repo_dir_name(instance_id)) @@ -284,31 +286,30 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict: info["submission"] = diff loop.trajectory.save_info(info) - return loop.trajectory.to_dict() + return loop.trajectory - def _process_instance(self, instance): + def _process_instance(self, instance) -> Tuple[dict, str]: trajectory = self._evaluate_instance(instance) - if not trajectory: - return None, None, None - result, transition_result = to_result(instance, trajectory, self.report) - submission = trajectory["info"].get("submission", "") + result = to_result(instance, trajectory, self.report) + submission = trajectory.info.get("submission", "") - try: - md_report = generate_md_report(trajectory, instance) - if not os.path.exists(f"{self.evaluation_dir}/reports"): - os.makedirs(f"{self.evaluation_dir}/reports") - with open( - f"{self.evaluation_dir}/reports/{instance['instance_id']}.md", - "w", - ) as file: - file.write(md_report) - except Exception: - logging.exception( - f"Error in generating report for {instance['instance_id']} " - ) + if self.markdown_report: + try: + md_report = generate_md_report(trajectory, instance) + if not os.path.exists(f"{self.evaluation_dir}/reports"): + os.makedirs(f"{self.evaluation_dir}/reports") + with open( + f"{self.evaluation_dir}/reports/{instance['instance_id']}.md", + "w", + ) as file: + file.write(md_report) + except Exception: + logging.exception( + f"Error in generating report for {instance['instance_id']} " + ) - return result, transition_result, submission + return result, submission def _process_repo_group(self, repo, instances): results = [] @@ -322,9 +323,8 @@ def _process_repo_group(self, repo, instances): if not trajectory: return None, None - result, transition_result = to_result(instance, trajectory, report=self.report) + result = to_result(instance, trajectory, report=self.report) results.append(result) - transition_results.extend(transition_result) try: md_report = generate_md_report(trajectory, instance) diff --git a/moatless/benchmark/report_v2.py b/moatless/benchmark/report_v2.py index a98ce537..b8430cce 100644 --- a/moatless/benchmark/report_v2.py +++ b/moatless/benchmark/report_v2.py @@ -3,13 +3,37 @@ from moatless import FileRepository from moatless.benchmark.swebench import found_in_expected_spans, found_in_alternative_spans, setup_swebench_repo from moatless.benchmark.utils import get_missing_files +from moatless.edit.plan import ApplyChange from moatless.file_context import FileContext +from moatless.find.search import SearchRequest logger = logging.getLogger(__name__) +import logging + +from moatless import FileRepository +from moatless.benchmark.swebench import found_in_expected_spans, found_in_alternative_spans, setup_swebench_repo +from moatless.benchmark.utils import get_missing_files +from moatless.file_context import FileContext + +logger = logging.getLogger(__name__) + +import logging +from typing import Dict, List, Tuple, Optional + +from moatless import FileRepository +from moatless.benchmark.swebench import found_in_expected_spans, found_in_alternative_spans, setup_swebench_repo +from moatless.benchmark.utils import get_missing_files +from moatless.file_context import FileContext +from moatless.trajectory import Trajectory +from moatless.types import ActionTransaction, Usage, Content +from moatless.state import AgenticState + +logger = logging.getLogger(__name__) -def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[dict, list]: - info = trajectory["info"] + +def to_result(instance: Dict, trajectory: Trajectory, report: Optional[Dict] = None) -> Dict: + info = trajectory._info if report and "resolved_ids" in report and instance["instance_id"] in report["resolved_ids"]: result_status = "resolved" @@ -19,7 +43,6 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di resolved = result_status == "resolved" try: - transitions = [] result = { "instance_id": instance["instance_id"], "duration": info.get("duration", 0), @@ -27,7 +50,7 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di "resolved_by": (len(instance.get("resolved_by", []))), "status": None, "result_status": result_status, - "transitions": len(trajectory["transitions"]), + "transitions": len(trajectory.transitions), "edited": False, "planned": False, "identified": None, @@ -49,34 +72,29 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di } lint_codes = set() - search_results_spans = {} - identified_spans = {} - planned_spans = {} - edited_spans = {} + search_results_spans: Dict[str, List[str]] = {} + identified_spans: Dict[str, List[str]] = {} + planned_spans: Dict[str, List[str]] = {} + edited_spans: Dict[str, List[str]] = {} id_iterations = 0 search_iterations = 0 selected_transition_ids = [] - if "current_transition_id" in trajectory: - transitions_map = {t["id"]: t for t in trajectory["transitions"]} - - transition = transitions_map.get(trajectory["current_transition_id"]) - while transition: - selected_transition_ids.append(transition["id"]) - if "parent_id" in transition: - transition = transitions_map.get(transition["parent_id"]) - else: - break + current_state = trajectory.get_current_state() + while current_state: + selected_transition_ids.append(current_state.id) + current_state = current_state.previous_state logger.info(f"Selected transitions: {selected_transition_ids}") if instance.get("expected_spans"): - for transition in trajectory["transitions"]: - if selected_transition_ids and transition["id"] not in selected_transition_ids: + for transition in trajectory.transitions: + if selected_transition_ids and transition.id not in selected_transition_ids: continue - state_name = transition["state"]["name"] + state: AgenticState = transition.state + state_name = state.name if state_name not in result: result[state_name] = 0 @@ -88,76 +106,42 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di for file_path, span_ids in instance["expected_spans"].items(): expected_span_str += f"{file_path}: {span_ids} " - transition_result = { - "instance_id": instance["instance_id"], - "resolved": resolved, - "name": state_name, - "cost": 0, - "expected_spans": expected_span_str, - "actual_spans": "", - } - - if not transition["actions"]: + if not state._actions: continue - for traj_action in transition["actions"]: - result[f"{state_name}_cost"] += traj_action.get( - "completion_cost", 0 - ) - transition_result["cost"] += traj_action.get( - "completion_cost", 0 - ) + for action in state._actions: + result[f"{state_name}_cost"] += action.usage.completion_cost if action.usage else 0 if state_name == "SearchCode": search_iterations += 1 - action = transition["actions"][-1] + action = state._actions[-1] - if "search_requests" in action["action"]: - for search_request in action["action"]["search_requests"]: - if search_request.get("query"): + if isinstance(action.request, SearchRequest): + for search_request in action.request.search_requests: + if search_request.query: result["p_query"] += 1 - - if search_request.get("file_pattern"): + if search_request.file_pattern: result["p_file"] += 1 - - if search_request.get("code_snippet"): + if search_request.code_snippet: result["p_code"] += 1 - - if search_request.get( - "class_name" - ) or search_request.get("class_names"): + if search_request.class_name or search_request.class_names: result["p_class"] += 1 - - if search_request.get( - "function_name" - ) or search_request.get("function_names"): + if search_request.function_name or search_request.function_names: result["p_function"] += 1 if state_name == "IdentifyCode": id_iterations += 1 - state = transition["state"] - if state.get("ranked_spans"): - for ranked_span in state["ranked_spans"]: - if ( - ranked_span["file_path"] - not in search_results_spans - ): - search_results_spans[ - ranked_span["file_path"] - ] = [] - search_results_spans[ - ranked_span["file_path"] - ].append(ranked_span["span_id"]) + if state.ranked_spans: + for ranked_span in state.ranked_spans: + if ranked_span.file_path not in search_results_spans: + search_results_spans[ranked_span.file_path] = [] + search_results_spans[ranked_span.file_path].append(ranked_span.span_id) if not result["found_in_search"] and ( - found_in_expected_spans( - instance, search_results_spans - ) - or found_in_alternative_spans( - instance, search_results_spans - ) + found_in_expected_spans(instance, search_results_spans) + or found_in_alternative_spans(instance, search_results_spans) ): result["found_in_search"] = search_iterations @@ -169,24 +153,17 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di if not missing_files: result["file_in_search"] = search_iterations - action = transition["actions"][-1] - if action.get("action"): + if state._actions: + action = state._actions[-1] identified_str = "" - if action["action"].get("identified_spans"): - for span in action["action"]["identified_spans"]: - identified_str += ( - f"{span['file_path']}: {span['span_ids']} " - ) - if span["file_path"] not in identified_spans: - identified_spans[span["file_path"]] = [] - - transition_result["actual_spans"] += ( - f"{span['file_path']}: {','.join(span['span_ids'])} " - ) - for span_id in span["span_ids"]: - identified_spans[span["file_path"]].append( - span_id - ) + if action.request.identified_spans: + for span in action.request.identified_spans: + identified_str += f"{span.file_path}: {span.span_ids} " + if span.file_path not in identified_spans: + identified_spans[span.file_path] = [] + + for span_id in span.span_ids: + identified_spans[span.file_path].append(span_id) result["identified_spans"] = identified_str if not result["file_identified"]: @@ -197,92 +174,62 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di if not missing_files: result["file_identified"] = id_iterations - if result[ - "expected_identified" - ] is None and found_in_expected_spans( - instance, identified_spans - ): + if result["expected_identified"] is None and found_in_expected_spans(instance, identified_spans): result["expected_identified"] = id_iterations - if result[ - "alt_identified" - ] is None and found_in_alternative_spans( - instance, identified_spans - ): + if result["alt_identified"] is None and found_in_alternative_spans(instance, identified_spans): result["alt_identified"] = id_iterations - if result.get("alt_identified") or result.get( - "expected_identified" - ): + if result.get("alt_identified") or result.get("expected_identified"): result["identified"] = min( result.get("alt_identified") or 1000, result.get("expected_identified") or 1000, ) if state_name == "PlanToCode": - action = transition["actions"][-1]["action"] - if action.get("action") == "review": + action = state._actions[-1] + + if action.request.action == "review": result["review"] = True - if "file_path" in action: - if "span_id" not in action: - logger.warning( - f"Span id missing in planning action in {instance['instance_id']}" - ) - else: - file_path = action["file_path"] - if file_path not in planned_spans: - planned_spans[file_path] = [] - planned_spans[file_path].append(action["span_id"]) - transition_result["actual_spans"] = ( - f"{file_path}: {action['span_id']} " - ) + if action.request.file_path: + file_path = action.request.file_path + if file_path not in planned_spans: + planned_spans[file_path] = [] + planned_spans[file_path].append(action.request.span_id) if not result.get("planned") and ( - found_in_expected_spans( - instance, - planned_spans, - ) + found_in_expected_spans(instance, planned_spans) or found_in_alternative_spans(instance, planned_spans) ): result["planned"] = True if state_name == "EditCode": - result["edit_retries"] = len(transition["actions"]) - 1 + result["edit_retries"] = len(state._actions) - 1 - action = transition["actions"][-1] - edited = action.get("trigger") == "finish" + action = state._actions[-1] + edited = action.response and action.response.trigger == "finish" - if edited and "file_path" in transition["state"]: - file_path = transition["state"]["file_path"] + if edited and hasattr(state, 'file_path'): + file_path = state.file_path if file_path not in edited_spans: edited_spans[file_path] = [] - edited_spans[file_path].append( - transition["state"]["span_id"] - ) - transition_result["actual_spans"] = ( - f"{file_path}: {transition['state']['span_id']} " - ) + edited_spans[file_path].append(state.span_id) if not result.get("edited") and ( - found_in_expected_spans( - instance, - edited_spans, - ) + found_in_expected_spans(instance, edited_spans) or found_in_alternative_spans(instance, edited_spans) ): result["edited"] = True - output = action.get("output", {}) - if output: + if action.response and action.response.output: + output = action.response.output if edited: result["has_diff"] = True for lint in output.get("verification_errors", []): lint_codes.add(lint["code"]) - transitions.append(transition_result) - if result.get("alt_identified") or result.get("expected_identified"): result["identified"] = min( result.get("alt_identified") or 1000, @@ -291,9 +238,7 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di result["expected_files"] = list(instance["expected_spans"].keys()) result["edited_files"] = list(edited_spans.keys()) - result["identified_spans"] = sum( - [len(v) for v in identified_spans.values()] - ) + result["identified_spans"] = sum(len(v) for v in identified_spans.values()) result["lints"] = ",".join(lint_codes) @@ -316,7 +261,96 @@ def to_result(instance: dict, trajectory: dict, report: dict | None) -> tuple[di except Exception as e: raise e - return result, transitions + return result + +def generate_md_report(trajectory: Trajectory, instance: Dict) -> str: + info = trajectory._info + markdown = f"# {instance['instance_id']}\n" + + markdown += "\n## Problem statement\n" + markdown += f"```\n{instance['problem_statement']}\n```\n" + + if "error" in trajectory._info: + markdown += "\n## Error\n" + markdown += f"```\n{trajectory._info['error']}\n```\n" + else: + markdown += "\n## Prediction\n" + markdown += f"```diff\n{info['submission']}\n```\n" + + markdown += "\n## Golden patch\n" + markdown += f"```diff\n{instance['golden_patch']}\n```\n" + + markdown += "\n## Trajectory\n" + + repo_dir = setup_swebench_repo(instance) + file_repo = FileRepository(repo_dir) + + for j, transition in enumerate(trajectory.transitions): + state = transition.state + for i, action in enumerate(state._actions): + markdown += f"### {j+1} {state.name} ({i+1})\n\n" + + if state.name == "PlanToCode": + if action.request.file_path: + if action.request.instructions: + markdown += f"\n\n * {action.request.instructions}" + markdown += f"\n * {action.request.file_path}" + markdown += f"\n * {action.request.span_id}" + + markdown += "\n\n#### File context \n\n" + try: + file_context = FileContext(file_repo) + file_context.add_span_to_context( + action.request.file_path, + action.request.span_id, + ) + markdown += file_context.create_prompt( + show_outcommented_code=True + ) + except Exception as e: + logger.error(e) + + if state.name == "EditCode": + markdown += "#### LLM Response\n\n" + markdown += f"```\n{action.request.content if isinstance(action.request, Content) else ''}\n```\n" + + if action.response and action.response.output: + output = action.response.output + if output.get("diff"): + markdown += "#### Diff\n\n" + markdown += f"```diff\n{output['diff']}\n```\n" + + if output.get("errors"): + markdown += "#### Errors\n\n" + markdown += f"{output['errors']}\n\n" + + if output.get("message"): + markdown += "#### Message\n\n" + markdown += f"{output['message']}\n\n" + + if state.name == "ClarifyCodeChange": + + if action.request.scratch_pad: + markdown += f"*{action.request.scratch_pad}*" + + if action.response and action.response.output: + output = action.response.output + if output.get("start_line"): + markdown += f"\n* Start Line: {output['start_line']}\n" + markdown += f"\n* End Line: {output['end_line']}\n" + + if state.name == "Finished": + markdown += f"*{action.request.thoughts}*\n" + + if state.name == "Rejected": + markdown += f"*{action.request.thoughts}*\n" + + markdown += "## Alternative patches\n" + for alternative in instance["resolved_by"]: + markdown += f"### {alternative['name']}\n" + markdown += f"```diff\n{alternative['patch']}\n```\n" + + return markdown def generate_md_report(trajectory: dict, instance: dict): info = trajectory["info"] markdown = f"# {instance['instance_id']}\n" diff --git a/moatless/edit/clarify.py b/moatless/edit/clarify.py index 5614b190..9cc3398a 100644 --- a/moatless/edit/clarify.py +++ b/moatless/edit/clarify.py @@ -31,12 +31,12 @@ class LineNumberClarification(ActionRequest): class ClarifyCodeChange(AgenticState): - instructions: str - file_path: str - span_id: str + instructions: str = Field(..., description="The instructions for the code change.") + file_path: str = Field(..., description="The path to the file to be updated.") + span_id: str = Field(..., description="The ID of the span to be updated.") - start_line: Optional[int] = None - end_line: Optional[int] = None + start_line: Optional[int] = Field(None, description="The start line of the code to be updated.") + end_line: Optional[int] = Field(None, description="The end line of the code to be updated.") max_tokens_in_edit_prompt: int = Field( 500, @@ -47,11 +47,6 @@ class ClarifyCodeChange(AgenticState): _span: BlockSpan | None = PrivateAttr(None) _file_context_str: Optional[str] = PrivateAttr(None) - def __init__(self, instructions: str, file_path: str, span_id: str, **data): - super().__init__( - instructions=instructions, file_path=file_path, span_id=span_id, **data - ) - def init(self): self._file = self.file_repo.get_file(self.file_path) self._span = self._file.module.find_span_by_id(self.span_id) @@ -188,6 +183,9 @@ def system_prompt(self) -> str: return CLARIFY_CHANGE_SYSTEM_PROMPT def messages(self) -> list[Message]: + if not self._file_context_str: + self.init() + messages = [ Message( role="user", diff --git a/moatless/edit/edit.py b/moatless/edit/edit.py index bff74ac5..776eb3e6 100644 --- a/moatless/edit/edit.py +++ b/moatless/edit/edit.py @@ -18,7 +18,7 @@ ROLE_PROMPT = "You are autonomous AI assisistant with superior programming skills." -MAIN_OBJECTIVE_PROMPT = "The main objective is to solve a bigger task specfied by the user, this is wrapped in a tag." +MAIN_OBJECTIVE_PROMPT = "The main objective is to solve a bigger task specified by the user, this is wrapped in a tag." SEARCH_REPLACE_PROMPT = """Your task is to solve a smaller task within the main objective. This task is wrapped in a tag. @@ -86,16 +86,16 @@ class CodeChange(ActionRequest): class EditCode(AgenticState): - instructions: str - file_path: str - span_id: Optional[str] = None - start_line: int - end_line: int + instructions: str = Field(..., description="The instructions for the code change.") + file_path: str = Field(..., description="The path to the file to be updated.") + span_id: Optional[str] = Field(None, description="The ID of the span to be updated.") + start_line: int = Field(..., description="The start line of the code to be updated.") + end_line: int = Field(..., description="The end line of the code to be updated.") - show_initial_message: bool = True - show_file_context: bool = True - verify: bool = True - chain_of_thought: bool = False + show_initial_message: bool = Field(True, description="Whether to show the initial message.") + show_file_context: bool = Field(True, description="Whether to show the file context.") + verify: bool = Field(True, description="Whether to verify the code change.") + chain_of_thought: bool = Field(False, description="Whether to use chain of thought reasoning.") max_prompt_file_tokens: int = Field( 4000, @@ -106,39 +106,6 @@ class EditCode(AgenticState): _retry: int = PrivateAttr(default=0) _messages: list[Message] = PrivateAttr(default_factory=list) - def __init__( - self, - instructions: str, - file_path: str, - span_id: Optional[str] = None, - start_line: Optional[int] = None, - end_line: Optional[int] = None, - show_initial_message: bool = True, - max_iterations: int = 8, - show_file_context: bool = True, - verify: bool = True, - include_message_history=True, - chain_of_thought: bool = False, - max_prompt_file_tokens: int = 4000, - **data, - ): - assert "model" in data - super().__init__( - include_message_history=include_message_history, - show_initial_message=show_initial_message, - max_iterations=max_iterations, - show_file_context=show_file_context, - max_prompt_file_tokens=max_prompt_file_tokens, - verify=verify, - chain_of_thought=chain_of_thought, - instructions=instructions, - file_path=file_path, - span_id=span_id, - start_line=start_line, - end_line=end_line, - **data, - ) - def init(self): file = self.file_context.get_file(self.file_path) if not file: @@ -319,9 +286,12 @@ def system_prompt(self) -> str: return system_prompt def messages(self) -> list[Message]: + if not self._code_to_replace: + self.init() + content = "" if self.show_initial_message: - content = f"\n{self.loop.trajectory.initial_message}\n\n\n" + content = f"\n{self.initial_message}\n\n\n" content += f"\n{self.instructions}\n\n" @@ -366,4 +336,4 @@ def action_type(self) -> type[BaseModel] | None: return None def stop_words(self): - return [""] + return [""] \ No newline at end of file diff --git a/moatless/edit/plan.py b/moatless/edit/plan.py index c300f4a6..93f84e5c 100644 --- a/moatless/edit/plan.py +++ b/moatless/edit/plan.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, PrivateAttr from moatless.codeblocks import CodeBlockType from moatless.edit.clarify import _get_post_end_line_index, _get_pre_start_line @@ -99,45 +99,26 @@ class PlanToCode(AgenticState): False, description="Whether to finish the task if a review is requested." ) - def __init__( - self, - message: Optional[str] = None, - diff: Optional[str] = None, - lint_messages: list[VerificationError] | None = None, - include_message_history=True, - max_prompt_file_tokens: int = 4000, - max_tokens_in_edit_prompt: int = 500, - max_iterations: int = 8, - allow_hallucinated_spans: bool = False, - expand_context_with_related_spans: bool = True, - finish_on_review: bool = False, - **data, - ): - super().__init__( - message=message, - diff=diff, - lint_messages=lint_messages, - include_message_history=include_message_history, - max_prompt_file_tokens=max_prompt_file_tokens, - max_tokens_in_edit_prompt=max_tokens_in_edit_prompt, - max_iterations=max_iterations, - allow_hallucinated_spans=allow_hallucinated_spans, - expand_context_with_related_spans=expand_context_with_related_spans, - finish_on_review=finish_on_review, - **data, - ) + include_message_history: bool = Field( + True, + description="Whether to include the message history in the prompt.", + ) + + _expanded_context: bool = PrivateAttr(False) def init(self): - self.file_context.expand_context_with_init_spans() - - if ( - self.expand_context_with_related_spans - and self.loop.transition_count(self) == 0 - ): - self.file_context.expand_context_with_related_spans( - max_tokens=self.max_prompt_file_tokens - ) - self.file_context.expand_small_classes(max_tokens=1000) + if not self._expanded_context: + self.file_context.expand_context_with_init_spans() + + if ( + self.expand_context_with_related_spans + and len(self.get_previous_states(self)) == 0 + ): + self.file_context.expand_context_with_related_spans( + max_tokens=self.max_prompt_file_tokens + ) + self.file_context.expand_small_classes(max_tokens=1000) + self._expanded_context = True def _execute_action(self, action: ApplyChange) -> ActionResponse: if action.action == "review": @@ -151,9 +132,7 @@ def _execute_action(self, action: ApplyChange) -> ActionResponse: "Review isn't possible. If the change is done you can finish or reject the task." ) - if action.finish: - self.file_context.save() - + if action.action == "finish": return ActionResponse.transition( trigger="finish", output={"message": action.finish} ) @@ -313,17 +292,19 @@ def to_message(self) -> str: return response_msg def messages(self) -> list[Message]: + self.init() + messages: list[Message] = [] - if self.loop.trajectory.initial_message: - content = f"\n{self.loop.trajectory.initial_message}\n\n" + if self.initial_message: + content = f"\n{self.initial_message}\n\n" else: content = "" - previous_transitions = self.loop.get_previous_transitions(self) + previous_states = self.get_previous_states(self) - for transition in previous_transitions: - new_message = transition.state.to_message() + for previous_state in previous_states: + new_message = previous_state.to_message() if new_message and not content: content = new_message elif new_message: @@ -332,7 +313,7 @@ def messages(self) -> list[Message]: messages.append(UserMessage(content=content)) messages.append( AssistantMessage( - action=transition.actions[-1].action, + action=previous_state.last_action.request, ) ) content = "" diff --git a/moatless/edit/plan_lines.py b/moatless/edit/plan_lines.py index 5a3c4271..42e43bb0 100644 --- a/moatless/edit/plan_lines.py +++ b/moatless/edit/plan_lines.py @@ -84,23 +84,10 @@ class PlanToCodeWithLines(AgenticState): description="Whether to expand the context with related spans.", ) - def __init__( - self, - message: Optional[str] = None, - diff: Optional[str] = None, - lint_messages: list[VerificationError] | None = None, - max_iterations: int = 5, - include_message_history=True, - **data, - ): - super().__init__( - message=message, - diff=diff, - lint_messages=lint_messages, - include_message_history=include_message_history, - max_iterations=max_iterations, - **data, - ) + include_message_history: bool = Field( + True, + description="Whether to include the message history in the prompt.", + ) def init(self): # TODO: Make addition to context customizable?? @@ -114,7 +101,7 @@ def init(self): if ( self.expand_context_with_related_spans - and self.loop.transition_count(self) == 0 + and len(self.get_previous_states(self)) == 0 ): self.file_context.expand_context_with_related_spans(max_tokens=4000) @@ -260,12 +247,12 @@ def to_message(self) -> str: def messages(self) -> list[Message]: messages: list[Message] = [] - content = self.loop.trajectory.initial_message or "" + content = self.initial_message or "" - previous_transitions = self.loop.get_previous_transitions(self) + previous_states = self.get_previous_states(self) - for transition in previous_transitions: - new_message = transition.state.to_message() + for previous_state in previous_states: + new_message = previous_state.to_message() if new_message and not content: content = new_message elif new_message: @@ -274,7 +261,7 @@ def messages(self) -> list[Message]: messages.append(UserMessage(content=content)) messages.append( AssistantMessage( - action=transition.actions[-1].action, + action=previous_state.last_action.request, ) ) content = "" diff --git a/moatless/edit/review.py b/moatless/edit/review.py index cdf5a9ca..73b14091 100644 --- a/moatless/edit/review.py +++ b/moatless/edit/review.py @@ -132,27 +132,12 @@ class ReviewCode(AgenticState): description="Whether to finish the task if no verification errors are found.", ) - _verification_errors: List[VerificationError] = PrivateAttr(default_factory=list) + include_message_history: bool = Field( + True, + description="Whether to include the message history in the prompt.", + ) - def __init__( - self, - message: Optional[str] = None, - diff: Optional[str] = None, - max_prompt_file_tokens: int = 4000, - max_tokens_in_edit_prompt: int = 500, - max_iterations: int = 8, - include_message_history=True, - **data, - ): - super().__init__( - message=message, - diff=diff, - include_message_history=include_message_history, - max_prompt_file_tokens=max_prompt_file_tokens, - max_tokens_in_edit_prompt=max_tokens_in_edit_prompt, - max_iterations=max_iterations, - **data, - ) + _verification_errors: List[VerificationError] = PrivateAttr(default_factory=list) def init(self) -> Optional[ActionResponse]: self._verification_errors = self.workspace.verify() @@ -399,15 +384,15 @@ def to_message(self) -> str: def messages(self) -> list[Message]: messages: list[Message] = [] - if self.loop.trajectory.initial_message: - content = f"\n{self.loop.trajectory.initial_message}\n" + if self.initial_message: + content = f"\n{self.initial_message}\n" else: content = "" - previous_transitions = self.loop.get_previous_transitions(self) + previous_states = self.get_previous_states(self) - for transition in previous_transitions: - new_message = transition.state.to_message() + for previous_state in previous_states: + new_message = previous_state.to_message() if new_message and not content: content = new_message elif new_message: @@ -416,7 +401,7 @@ def messages(self) -> list[Message]: messages.append(UserMessage(content=content)) messages.append( AssistantMessage( - action=transition.actions[-1].action, + action=previous_state.last_action.request, ) ) content = "" diff --git a/moatless/find/decide.py b/moatless/find/decide.py index 61bf0fbd..bd7c1a4a 100644 --- a/moatless/find/decide.py +++ b/moatless/find/decide.py @@ -69,28 +69,18 @@ class Decision(ActionRequest): class DecideRelevance(AgenticState): - expand_context: bool + expand_context: bool = Field( + False, + description="If true, the file context will be expanded with additional context.", + ) finish_after_relevant_count: int = Field( 2, description="Finish the task after this many relevant decisions have been made but not complete.", ) - max_prompt_file_tokens: int = 4000 - - def __init__( - self, - expand_context: bool = True, - include_message_history=False, - finish_after_relevant_count: int = 2, - max_prompt_file_tokens: int = 4000, - **data, - ): - super().__init__( - expand_context=expand_context, - finish_after_relevant_count=finish_after_relevant_count, - max_prompt_file_tokens=max_prompt_file_tokens, - include_message_history=include_message_history, - **data, - ) + max_prompt_file_tokens: int = Field( + 4000, + description="The maximum number of tokens to include in the file context prompt.", + ) def _execute_action(self, action: Decision) -> ActionResponse: if action.complete and action.relevant: @@ -108,15 +98,17 @@ def _execute_action(self, action: Decision) -> ActionResponse: ) def _relevant_count(self) -> int: + """ + Count the number of times a decision was made that the file context was relevant. + """ relevant_count = 0 - previous_transitions = self.loop.get_previous_transitions(self) - for transition in previous_transitions: - for previous_action in transition.actions: - if ( - isinstance(previous_action.action, Decision) - and previous_action.action.relevant - ): - relevant_count += 1 + previous_states = self.get_previous_states(self) + for previous_state in previous_states: + if ( + previous_state.last_action + and previous_state.last_action.request.relevant + ): + relevant_count += 1 return relevant_count def action_type(self) -> type[BaseModel] | None: @@ -126,11 +118,10 @@ def system_prompt(self) -> str: return MAYBE_FINISH_SYSTEM_PROMPT def _last_scratch_pad(self): - previous_searches = self.loop.get_previous_transitions(SearchCode) - logger.info(f"Previous searches: {len(previous_searches)}") - if previous_searches and previous_searches[-1].actions: - last_search = previous_searches[-1].actions[-1].action - return last_search.scratch_pad + previous_states = self.get_previous_states() + if previous_states and previous_states[-1].last_action: + last_action = previous_states[-1].last_action + return last_action.request.scratch_pad else: return None @@ -155,7 +146,7 @@ def messages(self) -> list[Message]: ) content = f""" -{self.loop.trajectory.initial_message} +{self.initial_message} """ @@ -172,4 +163,4 @@ def messages(self) -> list[Message]: """ messages.append(UserMessage(content=content)) - return messages + return messages \ No newline at end of file diff --git a/moatless/find/identify.py b/moatless/find/identify.py index df3dff15..479dbd21 100644 --- a/moatless/find/identify.py +++ b/moatless/find/identify.py @@ -71,7 +71,7 @@ class IdentifyCode(AgenticState): expand_context: bool = Field( default=False, - description="Whether to expand the search result with relevant code spans .", + description="Whether to expand the search result with relevant code spans.", ) max_prompt_file_tokens: int = Field( @@ -79,22 +79,6 @@ class IdentifyCode(AgenticState): description="The maximum number of tokens to include in the prompt.", ) - def __init__( - self, - ranked_spans: list[RankedFileSpan], - expand_context: bool = True, - include_message_history: bool = False, - max_prompt_file_tokens: int = 4000, - **data, - ): - super().__init__( - ranked_spans=ranked_spans, - include_message_history=include_message_history, - expand_context=expand_context, - max_prompt_file_tokens=max_prompt_file_tokens, - **data, - ) - def model_dump(self, **kwargs): return super().model_dump(**kwargs) @@ -111,7 +95,7 @@ def _execute_action(self, action: Identify) -> ActionResponse: else: logger.info("No spans identified.") - message = f"The search returned {len(self.ranked_spans)} results. But unfortunately, I didn’t find any of the search results relevant to the query." + message = f"The search returned {len(self.ranked_spans)} results. But unfortunately, I didn't find any of the search results relevant to the query." message += "\n\n" message += action.scratch_pad @@ -166,7 +150,7 @@ def messages(self) -> list[Message]: file_context_str = "No relevant code identified yet." content = f""" -{self.loop.trajectory.initial_message} +{self.initial_message} diff --git a/moatless/find/search.py b/moatless/find/search.py index fdc54c5d..49b019db 100644 --- a/moatless/find/search.py +++ b/moatless/find/search.py @@ -3,7 +3,7 @@ from typing import Optional import instructor -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator, ValidationError from moatless.file_context import RankedFileSpan from moatless.index.types import SearchCodeHit @@ -14,6 +14,7 @@ Message, UserMessage, ) +from moatless.utils.llm_utils import instructor_mode_by_model logger = logging.getLogger(__name__) @@ -92,7 +93,7 @@ User: The database connection setup is missing SSL configuration, causing insecure connections. -Here’s the stack trace of the error: +Here's the stack trace of the error: File "/opt/app/db_config/database.py", line 45, in setup_connection engine = create_engine(DATABASE_URL) @@ -147,7 +148,7 @@ User: The database connection setup is missing SSL configuration, causing insecure connections. -Here’s the stack trace of the error: +Here's the stack trace of the error: File "/opt/app/db_config/database.py", line 45, in setup_connection engine = create_engine(DATABASE_URL) @@ -211,7 +212,7 @@ User: The database connection setup is missing SSL configuration, causing insecure connections. -Here’s the stack trace of the error: +Here's the stack trace of the error: File "/opt/app/db_config/database.py", line 45, in setup_connection engine = create_engine(DATABASE_URL) @@ -264,6 +265,11 @@ def has_search_attributes(self): ] ) + @model_validator(mode='after') + def validate_search_requests(self): + if not self.has_search_attributes: + raise ValueError("A search request must have at least one attribute set.") + return self class Search(ActionRequest): """Take action to search for code, identify found and finish up.""" @@ -281,8 +287,12 @@ class Search(ActionRequest): default=False, description="Set to true when the search is complete." ) - def has_search_attributes(self): - return all([search.has_search_attributes() for search in self.search_requests]) + @model_validator(mode='after') + def validate_search_requests(self): + if not self.complete: + if not self.search_requests: + raise ValueError("At least one search request must exist.") + return self class SearchCode(AgenticState): @@ -301,6 +311,11 @@ class SearchCode(AgenticState): description="The maximum number of retries when there are identified files in file context.", ) + include_message_history: bool = Field( + True, + description="Include message history from previous iterations", + ) + provide_initial_context: bool = True initial_context_tokens: int = 4000 initial_search_results: int = 50 @@ -308,30 +323,6 @@ class SearchCode(AgenticState): support_test_files: bool = False - def __init__( - self, - message: Optional[str] = None, - max_search_results: int = 25, - max_retries_with_any_file_context: int = 3, - include_message_history: bool = True, - provide_initial_context: bool = True, - initial_context_tokens: int = 4000, - initial_search_results: int = 50, - initial_context_spans_per_file: int = 5, - **data, - ): - super().__init__( - message=message, - include_message_history=include_message_history, - provide_initial_context=provide_initial_context, - max_search_results=max_search_results, - max_retries_with_any_file_context=max_retries_with_any_file_context, - initial_context_tokens=initial_context_tokens, - initial_search_results=initial_search_results, - initial_context_spans_per_file=initial_context_spans_per_file, - **data, - ) - def _execute_action(self, action: Search) -> ActionResponse: if action.complete: return ActionResponse.transition( @@ -342,11 +333,6 @@ def _execute_action(self, action: Search) -> ActionResponse: ) if isinstance(action, Search): - if not action.has_search_attributes(): - return self._retry( - "You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true." - ) - for request in action.search_requests: if ( not self.support_test_files @@ -385,7 +371,7 @@ def _execute_action(self, action: Search) -> ActionResponse: if len(ranked_spans) == 0: logger.info("No search results found. Will retry.") - message = "\n\nUnfortunately, I didn’t find any relevant results." + message = "\n\nUnfortunately, I didn't find any relevant results." return self._retry(message) return ActionResponse.transition( @@ -411,7 +397,8 @@ def action_type(self) -> type[BaseModel] | None: def system_prompt(self) -> str: system_prompt = SEARCH_SYSTEM_PROMPT - if self.loop.instructor_mode == instructor.Mode.JSON: + instructor_mode = instructor_mode_by_model(self.model) + if instructor_mode == instructor.Mode.JSON: system_prompt += SEARCH_JSON_FEW_SHOT elif self.model.startswith("openai"): system_prompt += SEARCH_FUNCTIONS_FEW_SHOT_OPENAI_FUNC @@ -425,12 +412,12 @@ def system_prompt(self) -> str: def messages(self) -> list[Message]: messages: list[Message] = [] - content = f"\n{self.loop.trajectory.initial_message}\n" + content = f"\n{self.initial_message}\n" if self.provide_initial_context: logger.info("Search for initial context to provide in the prompt") result = self.workspace.code_index.semantic_search( - query=self.loop.trajectory.initial_message, + query=self.initial_message, exact_match_if_possible=False, max_spans_per_file=5, max_results=100, @@ -452,14 +439,14 @@ def messages(self) -> list[Message]: show_outcommented_code=False, ) - previous_transitions = self.loop.get_previous_transitions(self) - for transition in previous_transitions: - if transition.state.message: - content += transition.state.message + previous_states = self.get_previous_states(self) + for previous_state in previous_states: + if previous_state.message: + content += previous_state.message messages.append(UserMessage(content=content)) messages.append( AssistantMessage( - action=transition.actions[-1].action, + action=previous_state.last_action.request, ) ) content = "" diff --git a/moatless/loop.py b/moatless/loop.py index a6880d07..5d72753c 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -14,7 +14,7 @@ import litellm from anthropic import Anthropic from litellm import completion_cost, cost_per_token, token_counter -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, Field, PrivateAttr, ConfigDict from moatless.repository import GitRepository from moatless.state import ( @@ -25,16 +25,18 @@ Rejected, get_state_class, ) -from moatless.trajectory import Trajectory, TrajectoryTransition, TrajectoryAction -from moatless.transition_rules import TransitionRules +from moatless.trajectory import Trajectory +from moatless.transition_rules import TransitionRule, TransitionRules from moatless.types import ( ActionRequest, AssistantMessage, Content, Message, Response, + Usage, UserMessage, ) +from moatless.utils.llm_utils import instructor_mode_by_model from moatless.workspace import Workspace logger = logging.getLogger("Loop") @@ -46,6 +48,7 @@ def __init__( transition_rules: TransitionRules, workspace: Workspace, input_data: dict[str, Any] | None = None, + initial_message: str | None = None, trajectory: Trajectory | None = None, mocked_actions: list[dict] | None = None, expected_states: list[Type[AgenticState]] | None = None, @@ -69,7 +72,6 @@ def __init__( Args: """ - self._trajectory = trajectory self._workspace = workspace @@ -81,12 +83,27 @@ def __init__( os.makedirs(parent_dir) self._trajectory_path = trajectory_path + if not trajectory: + self._trajectory = Trajectory( + "MoatlessTools", + initial_message=initial_message, + persist_path=self._trajectory_path, + workspace=self._workspace, + transition_rules=transition_rules, + ) + pending_state = Pending() + self._trajectory.save_state(pending_state) + self._set_current_state(pending_state) + else: + self._trajectory = trajectory + self._current_state = trajectory.get_current_state() + + self._initial_message = initial_message + if prompt_log_dir and not os.path.exists(prompt_log_dir): os.makedirs(prompt_log_dir) self._prompt_log_dir = prompt_log_dir - self._mocked_actions = mocked_actions - if expected_states and not verify_state_func: def verify_state_func(state: AgenticState): @@ -96,9 +113,12 @@ def verify_state_func(state: AgenticState): f"No more expected states, but got {state.__class__}" ) expected_state = expected_states.pop(0) - if not ( - state.name == expected_state or isinstance(state, expected_state) - ): + if isinstance(expected_state, str): + if state.name != expected_state: + raise ValueError( + f"Expected state {expected_state} but got {state.__class__.__name__}" + ) + elif isinstance(expected_state, AgenticState) and not isinstance(state, expected_state): raise ValueError( f"Expected state {expected_state} but got {state.__class__.__name__}" ) @@ -106,7 +126,7 @@ def verify_state_func(state: AgenticState): self.log_info(f"Verified expected next state {expected_state}") self._verify_state_func = verify_state_func - + self._mocked_actions = mocked_actions self._reset_mocks_at_state = reset_mocks_at_state self._max_cost = max_cost @@ -121,91 +141,60 @@ def verify_state_func(state: AgenticState): self._rejections = 0 self._transition_rules = transition_rules - - self._initial_message = "" - self._transitions: dict[int, TrajectoryTransition] = {} - self._current_transition: TrajectoryTransition | None = None - self._metadata = metadata - self._type = "standard" - - for k, v in kwargs.items(): - setattr(self, k, v) - @classmethod def from_trajectory_file(cls, trajectory_path: str, **kwargs): trajectory = Trajectory.load(trajectory_path) - transitions = trajectory.transitions - workspace = Workspace.from_dict(trajectory.workspace) - return cls( - transition_rules=transitions, + transition_rules=trajectory.transitions, trajectory=trajectory, - workspace=workspace, + workspace=trajectory.workspace, **kwargs, ) def persist(self, trajectory_path: str): self.trajectory.persist(trajectory_path) - def retry_from_transition( - self, - transition_id: int, - state_params: dict[Type[AgenticState], Any] = None, - ): - self.clone_transition(transition_id) - # TODO: I'm using only state params as an easy way test out changes. Need to think about a better way to do this. - self._transition_rules.state_params.update(state_params) - - while not self.is_finished(): - self.run_until_transition() - - if isinstance(self.state, Finished): - return Response(status="finished", message=self.state.message or "") - elif isinstance(self.state, Rejected): - return Response(status="rejected", message=self.state.message or "") - - raise RuntimeError(f"Loop exited with unknown state {self.state.name}.") + def run(self, message: Optional[str] = None) -> Response: + """ + Executes the entire loop until completion or termination. - def initialize_or_load_trajectory(self, message: Optional[str] = None) -> None: - if not self._trajectory: - self._trajectory = Trajectory( - "MoatlessTools", - initial_message=message, - persist_path=self._trajectory_path, - workspace=self._workspace, - transition_rules=self._transition_rules, - ) - pending_transition = self._create_transition( - state=Pending(), - snapshot=self._workspace.snapshot() - ) - self._set_current_transition(pending_transition) - else: - for transition in self._trajectory.transitions: - self.set_current_transition_from_dict(transition) - self.workspace.restore_from_snapshot(transition.get("snapshot")) + This method initializes the loop if it hasn't started, and then repeatedly + calls run_until_transition() until the loop is finished. It handles the + overall flow of the loop, including initialization and final state processing. - for transition_data in self._trajectory.transitions: - transition = self._transitions[transition_data["id"]] - if transition_data.get("parent_id"): - parent = self._transitions[transition_data["parent_id"]] - transition.parent = parent - parent.children.append(transition) + Args: + message (Optional[str]): An optional initial message to start the loop with. - def run(self, message: Optional[str] = None) -> Response: - """ - Run the loop and handle exceptions and cost checking. + Returns: + Response: An object containing the final status and message of the loop. + The status will be either "finished" or "rejected". + + Raises: + RuntimeError: If an unexpected state or condition occurs during execution. + This includes cases where the loop is already running, exits with an + unknown state, or encounters other unexpected runtime conditions. + + Note: + This method will continue running until a Finished or Rejected state is reached, + or until an exception occurs. It's designed to be the main entry point for + executing the entire loop process. """ - if self.is_running(): - raise Exception("Loop is already running.") + raise RuntimeError("Loop is already running.") + + # TODO: Move to always set this when the Loop is created instead + if message: + logger.warning("Setting initial message in run is deprecated. Set in contructor.") + self._initial_message = message + self._trajectory._initial_message = message - self.initialize_or_load_trajectory(message) + if not isinstance(self._current_state, Pending): + self._trajectory.update_workspace_to_current_state() while not self.is_finished(): - self.run_until_transition() + self._execute_state_until_transition() if isinstance(self.state, Finished): return Response(status="finished", message=self.state.message or "") @@ -214,48 +203,147 @@ def run(self, message: Optional[str] = None) -> Response: raise RuntimeError(f"Loop exited with unknown state {self.state.name}.") - def run_until_transition(self) -> TrajectoryTransition: - while not self.is_finished(): + def _execute_state_until_transition(self) -> AgenticState | None: + """ + Executes the state until a transition to a new state occurs. + + This method executes the state, processing actions and handling + state changes until one of the following conditions is met: + 1. A transition to a new state occurs + 2. Maximum cost, retries, or transitions are exceeded + + Returns: + AgenticState: The new state after a transition occurs + + Raises: + RuntimeError: If the loop exits without a transition or if the maximum cost is exceeded + ValueError: If the maximum number of retries is reached + """ + while not self.state.executed: total_cost = self.total_cost() if total_cost > self._max_cost: - logger.warning( - f"{self.transition_name}: Max cost reached ({total_cost} > {self._max_cost}). Exiting." - ) + self.log_info(f"Max cost reached ({total_cost} > {self._max_cost}). Exiting.") self.trajectory.save_info({"error": "Max cost reached."}) - raise RuntimeError( - "The loop was aborted because the cost exceeded the limit.", - ) - else: - self.log_info( - f"Running transition {len(self._transitions)}. Current total cost: {total_cost}" - ) + raise RuntimeError("The loop was aborted because the cost exceeded the limit.") + + self.log_info(f"Running transition {len(self._trajectory.states)}. Current total cost: {total_cost}") try: - transition = self._run() - if transition: - return transition + state = self._execute_state() + if state: + return state except Exception as e: - logger.warning( - f"{self.transition_name}: Failed to run loop. Error: {e}" - ) + self.log_info(f"Failed to run loop. Error: {e}") raise - if self.retries() > self._max_retries: - logger.warning( - f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting." - ) + if self.state.retries() > self._max_retries: + self.log_info(f"Max retries reached ({self._max_retries}). Exiting.") self.trajectory.save_info({"error": "Max retries reached."}) return self.transition_to(Rejected(message="Max retries reached.")) raise RuntimeError("Loop exited without a transition.") + def _execute_state(self) -> AgenticState | None: + """ + Execute one iteration of the current state and handle potential transitions. + + Processes the next action, updates the trajectory, and determines if a state + transition should occur based on the action's response. + + Returns: + AgenticState | None: The next state if transitioning, or None if remaining in the current state. + + Raises: + ValueError: + """ + if self.state.executed: + raise ValueError("Tried to execute already executed state.") + + if isinstance(self.state, Pending): + logger.info("Initializing first state.") + trigger = "init" + output = {} + + else: + action, usage = self._next_action() + + self.log_info(f"Received new action {action.action_name}.") + response = self.state.handle_action(action, usage) + + if not response.trigger: + self.log_info( + f"{self.state.name}: No trigger in action response. Staying in the same state." + ) + return None + + self.log_info(f"Received response with trigger {response.trigger}") + + if response.trigger == "retry": + self.log_info(f"Retry requested. {response.retry_message}") + return None + + trigger = response.trigger + output = response.output + + transition_rule = self._transition_rules.get_next_rule( + self.state, + trigger, + output, + ) + if not transition_rule: + raise RuntimeError( + f"No transition rule found for {self.state.name} with trigger {response.trigger} and output {response.output}" + ) + + next_state = self._create_state(transition_rule, output) + return self.transition_to(next_state) + + def _create_state(self, transition_rule: TransitionRule, output: dict) -> AgenticState: + params = {} + params.update(self._transition_rules.params(transition_rule)) + + for k, v in output.items(): + if transition_rule.excluded_fields and k in transition_rule.excluded_fields: + continue + + params[k] = v + + params["id"] = self.state_count() + + next_state_type = transition_rule.dest + if next_state_type not in [Finished, Rejected]: + + if self.state_count() >= self._max_transitions: + self.log_info(f"Max transitions exceeded ({self._max_transitions}). Transitioning to Rejected.") + next_state_type = Rejected + params["message"] = "Max transitions exceeded." + if ( + params.get("max_iterations") + and self.state_count(next_state_type) >= params["max_iterations"] + ): + self.log_info(f"Max iterations exceeded ({params['max_iterations']}). Transitioning to Rejected.") + next_state_type = Rejected + params["message"] = f"Max iterations exceeded ({params['max_iterations']})." + + self.log_info(f"Creating state {next_state_type.__name__} with params {params}") + + try: + next_state = next_state_type.model_validate(params) + next_state.previous_state = self._current_state + next_state._workspace = self._workspace + next_state._initial_message = self._initial_message + except Exception as e: + logger.error(f"Failed to create state {next_state_type.__name__} with params {params}") + raise e + + self._trajectory.save_state(next_state) + self._current_state.next_states.append(next_state) + return next_state + def total_cost(self): total_cost = 0 - for step in self._transitions.values(): - for action in step.actions: - if action.completion_cost: - total_cost += action.completion_cost - + for state in self._trajectory.transitions: + total_cost += state.state.total_cost() return total_cost def is_running(self) -> bool: @@ -264,169 +352,156 @@ def is_running(self) -> bool: def is_finished(self) -> bool: return isinstance(self.state, (Finished, Rejected)) - def _set_state_loop(self, state: AgenticState): - state._set_loop(self) + def _set_current_state(self, state: AgenticState): + self._current_state = state + self._trajectory.set_current_state(state) - def retries(self) -> int: - retries = 0 - for action in reversed(self._current_transition.actions): - if action.trigger == "retry": - retries += 1 - else: - return retries + def transition_to(self, new_state: AgenticState) -> AgenticState: + self.log_info(f"Transitioning from {self.state.name} to {new_state.name}") - return retries + self._trajectory.save_state(new_state) + self._set_current_state(new_state) - def retry_messages(self, state: AgenticState) -> list[Message]: - messages: list[Message] = [] + return new_state - if self._current_transition.name != state.name: - return messages + def _next_action( + self, + ) -> Tuple[ActionRequest, Usage | None]: + messages = self._to_completion_messages() + self.log_info(f"Create completion with {len(messages)} messages") - for action in self._current_transition.actions: - if action.trigger == "retry": - if isinstance(action.action, Content): - messages.append( - AssistantMessage( - content=action.action.content, - ) - ) - else: - messages.append(AssistantMessage(action=action.action)) + if self._verify_state_func: + self._verify_state_func(self.state) + + mocked_action = self._next_mock_action() + if mocked_action: + return mocked_action, None + + metadata = {} + if self._metadata: + metadata.update(self._metadata) + metadata["generation_name"] = self.state.name + + tokens = token_counter(messages=messages[-1:]) + if self._max_message_tokens and tokens > self._max_message_tokens: + raise ValueError(f"Too many tokens in the new message: {tokens}") + + self.log_info(f"Do completion request to {self.state.model}") - messages.append( - UserMessage( - content=action.retry_message, + if self.state.model.startswith("claude") and self.state.action_type(): + try: + anthropic_client = instructor.from_anthropic( + Anthropic(), + mode=self.instructor_mode, + ) + + action_request, completion_response = ( + anthropic_client.chat.completions.create_with_completion( + model=self.state.model, + max_tokens=self.state.max_tokens, + temperature=self.state.temperature, + # stop=self.state.stop_words(), + response_model=self.state.action_type(), + messages=messages, ) ) - return messages + self.log_info( + f"Input tokens: {completion_response.usage.input_tokens}, Output tokens: {completion_response.usage.output_tokens}" + ) + ( + prompt_tokens_cost_usd_dollar, + completion_tokens_cost_usd_dollar, + ) = cost_per_token( + model=self.state.model, + prompt_tokens=completion_response.usage.input_tokens, + completion_tokens=completion_response.usage.output_tokens, + ) + _final_cost = ( + prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar + ) + except Exception as e: + self._log_prompt(messages, error=traceback.format_exc()) + raise e - def _set_current_transition(self, transition: TrajectoryTransition): - self._current_transition = transition - self._transitions[transition.id] = transition - self._trajectory.set_current_transition_id(transition.id) + + self._log_prompt(messages, completion_response.content) - def set_current_transition_from_dict(self, transition_data: dict): - state_data = transition_data.get("state", {}) - name = state_data.get("name") - try: - state_class = get_state_class(name) - state = state_class(**state_data) - - transition = TrajectoryTransition( - id=transition_data["id"], - state=state, - snapshot=transition_data.get("snapshot"), - actions=[ - TrajectoryAction(**action) for action in transition_data["actions"] - ], - timestamp=datetime.fromisoformat(transition_data["timestamp"]), + usage = Usage( + completion_cost=_final_cost, + completion_tokens=completion_response.usage.output_tokens, + prompt_tokens=completion_response.usage.input_tokens, ) - self._set_current_transition(transition) - self._set_state_loop(state) - state.init() - - except Exception as e: - logger.exception(f"Failed to load state {name}") - raise e - - def set_current_transition(self, transition: TrajectoryTransition): - self._set_current_transition(transition) + return action_request, usage - def revert_to_transition(self, transition_id: int) -> TrajectoryTransition: - transition = self._transitions.get(transition_id) - if transition: - self.log_info(f"Reverting to transition {transition_id}") - self._set_current_transition(transition) - self.workspace.restore_from_snapshot(transition.snapshot) - return transition - else: - logger.warning( - f"Tried to revert to transition {transition_id} but it does not exist. Existing transition ids: {self._transitions.keys()}" + if self.state.action_type() is None: + completion_response = litellm.completion( + model=self.state.model, + max_tokens=self.state.max_tokens, + temperature=self.state.temperature, + stop=self.state.stop_words(), + metadata=metadata, + messages=messages, ) - raise ValueError( - f"Could not revert to transition {transition_id} as it does not exist." + action_request = Content( + content=completion_response.choices[0].message.content ) - - def _create_transition( - self, - state: AgenticState, - snapshot: dict | None = None, - parent: TrajectoryTransition | None = None, - ): - transition = TrajectoryTransition( - id=len(self._transitions) + 1, state=state, snapshot=snapshot, parent=parent - ) - self.trajectory.create_transition(transition) - self._transitions[transition.id] = transition - return transition - - def clone_current_transition(self): - cloned_state = self.state.clone() - cloned_transition = self._create_transition( - state=cloned_state, - snapshot=self._current_transition.snapshot, - parent=self._current_transition.parent, + else: + client = instructor.from_litellm( + litellm.completion, mode=self.instructor_mode ) - self._set_current_transition(cloned_transition) - return cloned_transition - - def transition_to(self, new_state: AgenticState) -> TrajectoryTransition: - self.log_info(f"Transitioning from {self.state.name} to {new_state.name}") - if self.transition_count() > self._max_transitions: - new_state = Rejected(message="Max transitions exceeded.") + try: + action_request, completion_response = ( + client.chat.completions.create_with_completion( + model=self.state.model, + max_tokens=self.state.max_tokens, + temperature=self.state.temperature, + stop=self.state.stop_words(), + response_model=self.state.action_type(), + metadata=metadata, + messages=messages, + ) + ) + except Exception as e: + self._log_prompt(messages, error=traceback.format_exc()) + raise e - if ( - new_state.max_iterations - and self.transition_count(new_state) > new_state.max_iterations - ): - new_state = Rejected( - message=f"Max transitions exceeded for state {new_state.name}." + try: + cost = completion_cost( + completion_response=completion_response, + model=self.state.model, ) + except Exception as e: + self.log_info(f"Error calculating completion cost: {e}") + cost = 0 - transition = self._create_transition( - state=new_state, - snapshot=self.workspace.snapshot(), - parent=self._current_transition, + self._log_prompt( + messages, [completion_response.choices[0].message.model_dump()], error=None ) - - if self._current_transition: - self._current_transition.children.append(transition) - - self._set_current_transition(transition) - self._set_state_loop(new_state) - - return transition - - def transition_count(self, state: AgenticState | None = None) -> int: + prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) + completion_tokens = completion_response.get("usage", {}).get( + "completion_tokens", 0 + ) + usage = Usage( + completion_cost=cost, + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + ) + return action_request, usage + + def state_count(self, state: AgenticState | None = None) -> int: if not state: - return len(self._transitions) + return len(self._trajectory.transitions) return len( - [t for t in self._transitions.values() if t.state.name == state.name] + [s for s in self._trajectory.transitions if s.state.name == state.name] ) - def get_previous_transitions(self, state: AgenticState | None): - previous_transitions = [] - parent_transition = self._current_transition.parent - while parent_transition: - if not state or parent_transition.state.name == state.name: - previous_transitions.insert(0, parent_transition) - - parent_transition = parent_transition.parent - - self.log_info( - f"Found {len(previous_transitions)} previous transitions for {state.name if state else 'all states'}" - ) - - return previous_transitions - @property def state(self): - return self._current_transition.state if self._current_transition else Pending() + return self._current_state @property def workspace(self) -> Workspace: @@ -523,95 +598,12 @@ def _to_completion_messages(self) -> list[dict]: return messages - def _run(self) -> TrajectoryTransition | None: - """ - Run the loop for one iteration. - - Returns: - - """ - if self.is_finished(): - self.log_info("Loop already finished.") - return None - - if isinstance(self.state, Pending): - logger.info("Initializing first state.") - initial_state = self._transition_rules.create_initial_state( - **(self._input_data or {}) - ) - return self.transition_to(initial_state) - - action, cost, input_tokens, output_tokens = self._next_action() - - self.log_info(f"Received new action {action.action_name}.") - response = self.state.handle_action(action) - - self._current_transition.actions.append( - TrajectoryAction( - action=action, - trigger=response.trigger, - retry_message=response.retry_message, - completion_cost=cost, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - ) - self.trajectory.update_transition(self._current_transition) - - if not response.trigger: - self.log_info( - f"{self.state.name}: No trigger in action response. Staying in the same state." - ) - return None - - self.log_info(f"Received response with trigger {response.trigger}") - - if response.trigger == "retry": - self.log_info(f"Retry requested. {response.retry_message}") - return None - - try: - next_state = self._transition_rules.next_state( - source=self.state, - trigger=response.trigger, - data=response.output, - ) - except Exception: - logger.exception( - f"{self.transition_name}: Failed to initiate next state with trigger {response.trigger} and output {response.output}" - ) - raise - - if not next_state: - raise ValueError( - f"No transition found for {self.state.name} with trigger {response.trigger}" - ) - - if response.trigger == "rejected" and next_state.__class__ != Rejected: - self._rejections += 1 - next_state = Rejected( - message=f"Got {self._rejections} rejections, aborting." - ) - else: - self._rejections = 0 - - return self.transition_to(next_state) - @property def instructor_mode(self): if self._instructor_mode: return self._instructor_mode - if "gpt" in self.state.model: - return instructor.Mode.TOOLS - - if self.state.model.startswith("claude"): - return instructor.Mode.ANTHROPIC_TOOLS - - if self.state.model.startswith("openrouter/anthropic/claude"): - return instructor.Mode.TOOLS - - return instructor.Mode.JSON + return instructor_mode_by_model(self.state.model) def _next_mock_action( self, @@ -645,124 +637,6 @@ def _next_mock_action( else: raise ValueError(f"Mocked action {action} does not have 'content' field.") - def _next_action( - self, - ) -> tuple[ActionRequest, Optional[float], Optional[int], Optional[int]]: - messages = self._to_completion_messages() - self.log_info(f"Create completion with {len(messages)} messages") - - if self._verify_state_func: - self._verify_state_func(self.state) - - mocked_action = self._next_mock_action() - if mocked_action: - return mocked_action, None, None, None - - metadata = {} - if self._metadata: - metadata.update(self._metadata) - metadata["generation_name"] = self.state.name - - tokens = token_counter(messages=messages[-1:]) - if self._max_message_tokens and tokens > self._max_message_tokens: - raise ValueError(f"Too many tokens in the new message: {tokens}") - - self.log_info(f"Do completion request to {self.state.model}") - - if self.state.model.startswith("claude") and self.state.action_type(): - try: - anthropic_client = instructor.from_anthropic( - Anthropic(), - mode=self.instructor_mode, - ) - - action_request, completion_response = ( - anthropic_client.chat.completions.create_with_completion( - model=self.state.model, - max_tokens=self.state.max_tokens, - temperature=self.state.temperature, - # stop=self.state.stop_words(), - response_model=self.state.action_type(), - messages=messages, - ) - ) - - self.log_info( - f"Input tokens: {completion_response.usage.input_tokens}, Output tokens: {completion_response.usage.output_tokens}" - ) - ( - prompt_tokens_cost_usd_dollar, - completion_tokens_cost_usd_dollar, - ) = cost_per_token( - model=self.state.model, - prompt_tokens=completion_response.usage.input_tokens, - completion_tokens=completion_response.usage.output_tokens, - ) - _final_cost = ( - prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar - ) - except Exception as e: - self._log_prompt(messages, error=traceback.format_exc()) - raise e - - self._log_prompt(messages, completion_response.content) - return ( - action_request, - _final_cost, - completion_response.usage.input_tokens, - completion_response.usage.output_tokens, - ) - - if self.state.action_type() is None: - completion_response = litellm.completion( - model=self.state.model, - max_tokens=self.state.max_tokens, - temperature=self.state.temperature, - stop=self.state.stop_words(), - metadata=metadata, - messages=messages, - ) - action_request = Content( - content=completion_response.choices[0].message.content - ) - else: - client = instructor.from_litellm( - litellm.completion, mode=self.instructor_mode - ) - - try: - action_request, completion_response = ( - client.chat.completions.create_with_completion( - model=self.state.model, - max_tokens=self.state.max_tokens, - temperature=self.state.temperature, - stop=self.state.stop_words(), - response_model=self.state.action_type(), - metadata=metadata, - messages=messages, - ) - ) - except Exception as e: - self._log_prompt(messages, error=traceback.format_exc()) - raise e - - try: - cost = completion_cost( - completion_response=completion_response, - model=self.state.model, - ) - except Exception as e: - self.log_info(f"Error calculating completion cost: {e}") - cost = 0 - - self._log_prompt( - messages, [completion_response.choices[0].message.model_dump()], error=None - ) - prompt_tokens = completion_response.get("usage", {}).get("prompt_tokens", 0) - completion_tokens = completion_response.get("usage", {}).get( - "completion_tokens", 0 - ) - return action_request, cost, prompt_tokens, completion_tokens def _log_prompt( self, @@ -776,10 +650,10 @@ def _log_prompt( time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") prompt_path = ( - f"{self._prompt_log_dir}/{self._current_transition.id}_{self.state.name}" + f"{self._prompt_log_dir}/{self._current_state.id}_{self._current_state.name}" ) - if self.retries() > 0: - prompt_path += f"_retry_{self.retries()}" + if self.state.retries() > 0: + prompt_path += f"_retry_{self.state.retries()}" prompt_path += f"_{time_str}.md" @@ -837,12 +711,10 @@ def log_info(self, message: str): @property def transition_name(self): - if self._current_transition: - return ( - f"{self._current_transition.state.name}:{self._current_transition.id}" - ) + if self._current_state: + return f"{self._current_state.name}:{self._current_state.id}" else: - return "No transition" + return "No state" def generate_call_id(): diff --git a/moatless/repository/git.py b/moatless/repository/git.py index 488f4512..96332cf0 100644 --- a/moatless/repository/git.py +++ b/moatless/repository/git.py @@ -70,7 +70,7 @@ def dict(self): "type": "git", "repo_path": self._repo_path, "git_repo_url": self._repo_url, - "commit": self._current_commit, + "commit": self._initial_commit, } def snapshot(self) -> dict: @@ -78,23 +78,40 @@ def snapshot(self) -> dict: "commit": self._current_commit, } + def save_file(self, file_path: str, updated_content: Optional[str] = None): + super().save_file(file_path, updated_content) + self.commit(file_path) + def save(self): super().save() - commit_message = self.commit_message() - self._repo.index.add("*") + self.commit() + + def commit(self, file_path: str | None = None): + commit_message = self.commit_message(file_path) + + if file_path: + self._repo.index.add(file_path) + else: + self._repo.index.add("*") self._repo.index.commit(commit_message) self._current_commit = self._repo.head.commit.hexsha - def diff(self): - return self._repo.git.diff(self._initial_commit, self._current_commit) + logger.info(f"Committed changes to git with message '{commit_message}' and commit hash '{self._current_commit}'") + + def commit_message(self, file_path: str | None = None) -> str: + if file_path: + diff = self._repo.git.diff("HEAD", file_path) + else: + diff = self._repo.git.diff("HEAD") - def commit_message(self) -> str: - diff = self._repo.git.diff(None) if not diff: return "No changes." if Settings.cheap_model: - prompt = f"Generate a concise commit message for the following git diff:\n\n{diff}\n\nCommit message:" + prompt = f"Generate a concise commit message for the following git diff" + if file_path: + prompt += f" of file {file_path}" + prompt += f":\n\n{diff}\n\nCommit message:" try: response = litellm.completion( @@ -107,3 +124,6 @@ def commit_message(self) -> str: logging.error(f"Error generating commit message: {e}") return "Automated commit by Moatless Tools" + + def diff(self): + return self._repo.git.diff(self._initial_commit, self._current_commit) \ No newline at end of file diff --git a/moatless/state.py b/moatless/state.py index 1d477e83..cffa6093 100644 --- a/moatless/state.py +++ b/moatless/state.py @@ -2,18 +2,20 @@ import sys import importlib from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any, Optional, List from copy import deepcopy -from pydantic import BaseModel, Field, PrivateAttr, ConfigDict +from pydantic import BaseModel, Field, PrivateAttr, ConfigDict, model_validator from moatless.file_context import FileContext from moatless.repository import FileRepository from moatless.types import ( ActionRequest, ActionResponse, + ActionTransaction, FileWithSpans, - Message, + Message, Content, AssistantMessage, + Usage, UserMessage, ) from moatless.workspace import Workspace @@ -21,6 +23,13 @@ class AgenticState(ABC, BaseModel): + id: int = Field(..., description="The unique identifier of the state") + previous_state: Optional["AgenticState"] = Field( + default=None, description="The state that led to this state" + ) + next_states: List["AgenticState"] = Field( + default_factory=list, description="The states this state transitioned to" + ) model: Optional[str] = Field( default=None, description="The model to use for completion" ) @@ -36,28 +45,31 @@ class AgenticState(ABC, BaseModel): None, description="The maximum number of transitions to this state." ) - _loop: Optional["AgenticLoop"] = PrivateAttr(None) # noqa: F821 + _workspace: Optional[Workspace] = PrivateAttr(None) + _initial_message: Optional[str] = PrivateAttr(None) _executed: bool = PrivateAttr(False) - _last_action: Optional[ActionRequest] = PrivateAttr(None) - _response: Optional[ActionResponse] = PrivateAttr(None) + _actions: List[ActionTransaction] = PrivateAttr(default_factory=list) - # model_config = ConfigDict(extra='allow') + model_config = ConfigDict( + arbitrary_types_allowed=True, + exclude={"previous_state", "next_states"} + ) def __init__(self, **data): super().__init__(**data) - self._loop = None + self._workspace = data.get('_workspace') + self._initial_message = data.get('_initial_message') - def handle_action(self, action: ActionRequest) -> ActionResponse: + def handle_action(self, action: ActionRequest, usage: Usage | None) -> ActionResponse: if self._executed: raise ValueError(f"State has already been executed") - self._last_action = action response = self._execute_action(action) + self._actions.append(ActionTransaction(request=action, response=response, usage=usage)) if response.trigger and response.trigger != "retry": self._executed = True - self._response = response return response @@ -65,10 +77,6 @@ def handle_action(self, action: ActionRequest) -> ActionResponse: def _execute_action(self, action: ActionRequest) -> ActionResponse: raise NotImplementedError - def _set_loop(self, loop: "AgenticLoop"): # noqa: F821 - self._loop = loop - self.init() - @property def name(self): return self.__class__.__name__ @@ -78,29 +86,28 @@ def executed(self): return self._executed @property - def last_action(self) -> Optional[ActionRequest]: - return self._last_action + def last_action(self) -> Optional[ActionTransaction]: + return self._actions[-1] if self._actions else None @property def response(self) -> Optional[ActionResponse]: - return self._response - - @property - def loop(self) -> "AgenticLoop": # noqa: F821 - assert self._loop is not None, "Loop has not been set" - return self._loop + return self._actions[-1].response if self._actions else None @property def workspace(self) -> Workspace: - return self.loop.workspace + return self._workspace @property def file_repo(self) -> FileRepository: - return self.workspace.file_repo + return self._workspace.file_repo @property def file_context(self) -> FileContext: - return self.workspace.file_context + return self._workspace.file_context + + @property + def initial_message(self) -> str: + return self._initial_message def create_file_context( self, files: list[FileWithSpans] = None, **kwargs @@ -113,9 +120,6 @@ def init(self): """Initialization logic for the state.""" pass - def transition_to(self, new_state: "AgenticState"): - self.loop.transition_to(new_state) - def finish(self, message: str): # TODO!! logger.info(message) @@ -127,18 +131,62 @@ def messages(self) -> list[Message]: def required_fields(cls) -> set[str]: return set() + def get_previous_states(self, state: Optional["AgenticState"] = None) -> list["AgenticState"]: + """ + Retrieves previous states of the same type as the given state. + If no state is provided, it returns all previous states. + + Args: + state (AgenticState | None): The state to filter by. If None, all previous states are returned. + + Returns: + list: A list of previous states, filtered by type if a state is provided. + """ + previous_states = [] + current_state = self + + while current_state and current_state.previous_state: + current_state = current_state.previous_state + if not state or isinstance(current_state, type(state)): + previous_states.insert(0, current_state) + + logger.debug( + f"Found {len(previous_states)} previous states of type {state.__class__.__name__ if state else 'all types'}" + ) + + return previous_states + def retries(self) -> int: retries = 0 - for action in reversed(self.loop._current_transition.actions): - if action.trigger == "retry": + for action in reversed(self._actions): + if action.response.trigger == "retry": retries += 1 else: return retries return retries - def retry_messages(self): - return self.loop.retry_messages(self) + def retry_messages(self) -> list[Message]: + messages: list[Message] = [] + + for action in self._actions: + if isinstance(action.request, Content): + messages.append( + AssistantMessage( + content=action.request.content, + ) + ) + else: + messages.append(AssistantMessage(action=action.request)) + + if action.response.retry_message: + messages.append( + UserMessage( + content=action.response.retry_message, + ) + ) + + return messages def system_prompt(self) -> str: return "" @@ -154,58 +202,61 @@ def stop_words(self) -> list[str] | None: return None def model_dump(self, **kwargs): + if 'exclude' not in kwargs: + kwargs['exclude'] = {"previous_state", "next_states"} + data = super().model_dump(**kwargs) - return {"name": self.name, **data} + return data + + @classmethod + @model_validator(mode="before") + def validate_previous_state(cls, values): + if isinstance(obj, dict) and "previous_state_id" in obj: + obj = obj.copy() + obj["previous_state"] = None + return super().model_validate(obj) def clone(self) -> "AgenticState": - data = self.model_dump(exclude={"_executed", "_last_action", "_response"}) - new_state = self.__class__(**data) - new_state._loop = self._loop + new_state = self.__class__(**self.model_dump()) + if hasattr(self, '_workspace'): + new_state._workspace = self._workspace return new_state + def total_cost(self): + total_cost = 0 + for action in self._actions: + if action.usage: + total_cost += action.usage.completion_cost + + return total_cost + def __eq__(self, other): if not isinstance(other, AgenticState): return NotImplemented - if self.model_dump() != other.model_dump(): return False - - if self._loop and other._loop: - self_context = self._loop.workspace.file_context - other_context = other._loop.workspace.file_context - - return self_context.model_dump() == other_context.model_dump() - return True class NoopState(AgenticState): - def __init__(self, **data): - super().__init__(**data) def _execute_action(self, action: ActionRequest): raise ValueError("NoopState cannot handle actions") class Finished(NoopState): - message: Optional[str] - + message: Optional[str] = None output: dict[str, Any] | None = None - def __init__(self, message: Optional[str] = None, **kwargs): - super().__init__(message=message) - self.output = kwargs - class Rejected(NoopState): - message: str - - def __init__(self, message: str, **kwargs): - super().__init__(message=message) + message: Optional[str] = None class Pending(NoopState): def __init__(self, **data): + if 'id' not in data: + data['id'] = 0 super().__init__(**data) @@ -243,4 +294,4 @@ def get_state_class(name: str) -> type[AgenticState]: if isinstance(cls, type) and issubclass(cls, AgenticState): return cls - raise ValueError(f"State {name} not found") + raise ValueError(f"State {name} not found") \ No newline at end of file diff --git a/moatless/trajectory.py b/moatless/trajectory.py index 971a3336..59ab437b 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -1,41 +1,25 @@ import json import logging from datetime import datetime -from typing import Any, Optional +from typing import Any, Optional, List from pydantic import BaseModel, Field from pydantic_core import to_jsonable_python from moatless.workspace import Workspace from moatless.transition_rules import TransitionRules -from moatless.state import AgenticState -from moatless.types import ActionRequest +from moatless.state import AgenticState, get_state_class +from moatless.types import ActionRequest, ActionTransaction, ActionResponse, Usage, Content logger = logging.getLogger(__name__) -class TrajectoryAction(BaseModel): - action: ActionRequest - trigger: Optional[str] = None - retry_message: Optional[str] = None - completion_cost: Optional[float] = None - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - def model_dump(self, **kwargs): - data = super().model_dump(**kwargs) - data["action"] = self.action.model_dump(**kwargs) - return data - - -class TrajectoryTransition(BaseModel): +class TrajectoryState(BaseModel): id: int - parent: Optional["TrajectoryTransition"] = None - children: list["TrajectoryTransition"] = Field(default_factory=list) - state: AgenticState - snapshot: Optional[dict] = None - actions: list[TrajectoryAction] = [] timestamp: datetime = Field(default_factory=datetime.now) + snapshot: Optional[dict] = None + state: AgenticState @property def name(self): @@ -44,38 +28,48 @@ def name(self): def model_dump(self, **kwargs): data = { "id": self.id, + "name": self.state.name, "timestamp": self.timestamp, - "parent_id": self.parent.id if self.parent else None, - "state": self.state.model_dump(**kwargs) if self.state else None, - "snapshot": self.snapshot, - "actions": [action.model_dump(**kwargs) for action in self.actions], } - if kwargs.get("exclude_none", False): - data = {k: v for k, v in data.items() if v is not None} + if self.snapshot: + data["snapshot"] = self.snapshot - return data + if self.state.previous_state: + data["previous_state_id"] = self.state.previous_state.id + + properties = self.state.model_dump(exclude={"previous_state", "next_states", "id"}, **kwargs) if self.state else None + if properties: + data["properties"] = properties + if self.state._actions: + data["actions"] = [a.model_dump(**kwargs) for a in self.state._actions] + + return data class Trajectory: def __init__( self, name: str, + workspace: Workspace, initial_message: Optional[str] = None, persist_path: Optional[str] = None, - workspace: Optional[Workspace] = None, transition_rules: Optional[TransitionRules] = None, ): self._name = name self._persist_path = persist_path self._initial_message = initial_message - self._workspace = workspace.dict() if workspace else None - self._transition_rules = transition_rules + self._workspace = workspace - self._transitions: list[dict[str, Any]] = [] + # Workaround to set to keep the current initial workspace state when loading an existing trajectory. + # TODO: Remove this when we have a better way to handle this. + self._initial_workspace_state = self._workspace.dict() + + self._transition_rules = transition_rules self._current_transition_id = 0 + self._transitions: dict[int, TrajectoryState] = {} self._info: dict[str, Any] = {} @@ -89,14 +83,67 @@ def load(cls, file_path: str): else: transition_rules = None + workspace = Workspace.from_dict(data["workspace"]) trajectory = cls( name=data["name"], initial_message=data["initial_message"], transition_rules=transition_rules, + workspace=workspace ) - trajectory._workspace = data["workspace"] - trajectory._transitions = data["transitions"] - trajectory._info = data["info"] + + trajectory._info = data.get("info", {}) + + trajectory._transitions = {} + trajectory._current_transition_id = data.get("current_transition_id", 0) + + for t in data["transitions"]: + state_class = get_state_class(t["name"]) + state_data = t["properties"] + state_data["id"] = t["id"] + state = state_class.model_validate(state_data) + + state._workspace = trajectory._workspace + state._initial_message = trajectory._initial_message + state._actions = [] + if "actions" in t: + for a in t["actions"]: + try: + if state.action_type() is None: + request = Content.model_validate(a["request"]) + else: + request = state.action_type().model_validate(a["request"]) + response = ActionResponse.model_validate(a.get("response")) + if a.get("usage"): + usage = Usage.model_validate(a.get("usage")) + else: + usage = None + state._actions.append(ActionTransaction(request=request, response=response, usage=usage)) + except Exception as e: + logger.exception(f"Error loading action for state {state.name}: {a}") + raise e + + trajectory_state = TrajectoryState( + id=t["id"], + timestamp=datetime.fromisoformat(t["timestamp"]), + snapshot=t.get("snapshot"), + state=state + ) + + trajectory._transitions[t["id"]] = trajectory_state + + # Set previous_state and next_states + for t in data["transitions"]: + try: + current_state = trajectory._transitions[t["id"]].state + if t.get("previous_state_id") is not None: + current_state.previous_state = trajectory._transitions.get(t["previous_state_id"]).state + except KeyError as e: + logger.exception(f"Missing key {e}, existing keys: {trajectory._transitions.keys()}") + raise + + trajectory._info = data.get("info", {}) + + logger.info(f"Loaded trajectory {trajectory._name} with {len(trajectory._transitions)} transitions") return trajectory @@ -105,79 +152,97 @@ def initial_message(self): return self._initial_message @property - def transitions(self) -> list[dict]: - return sorted(self._transitions, key=lambda x: x["timestamp"]) + def info(self): + return self._info + + @property + def states(self) -> List[dict]: + return [t.state.model_dump() for t in self.transitions] @property def transition_rules(self) -> TransitionRules: return self._transition_rules @property - def workspace(self) -> dict[str, Any] | None: + def workspace(self) -> Workspace: return self._workspace - def create_transition(self, transition: TrajectoryTransition): - self._transitions.append( - transition.model_dump(exclude_none=True, exclude_unset=True) - ) + @property + def transitions(self) -> List[TrajectoryState]: + return sorted(self._transitions.values(), key=lambda x: x.id) + + def set_current_state(self, state: AgenticState): + self._current_transition_id = state.id self._maybe_persist() - return transition - def update_transition(self, transition: TrajectoryTransition): - for i, t in enumerate(self._transitions): - if t["id"] == transition.id: - self._transitions[i] = transition.model_dump( - exclude_none=True, exclude_unset=True - ) - self._maybe_persist() - return + def get_current_state(self) -> AgenticState: + return self._transitions.get(self._current_transition_id).state + + def update_workspace_to_current_state(self): + self.restore_from_snapshot(self._transitions[self._current_transition_id]) + + def restore_from_snapshot(self, state: TrajectoryState): + if not state.snapshot: + logger.info(f"restore_from_snapshot(state: {state.id}:{state.name}) No snapshot found") + return + + logger.info(f"restore_from_snapshot(starte: {state.id}:{state.name}) Restoring from snapshot") - raise ValueError(f"Transition with id {transition.id} not found") + if state.snapshot.get("repository"): + self._workspace.file_repo.restore_from_snapshot(state.snapshot["repository"]) + + if state.snapshot.get("file_context"): + self._workspace.file_context.restore_from_snapshot(state.snapshot["file_context"]) + + def save_state(self, state: AgenticState): + if state.id in self._transitions: + self._transitions[state.id].state = state + else: + transition = TrajectoryState( + id=state.id, + state=state, + snapshot=state.workspace.snapshot() if state.workspace else None, + ) + self._transitions[state.id] = transition - def set_current_transition_id(self, transition_id: int): - self._current_transition_id = transition_id self._maybe_persist() + def get_state(self, state_id: int) -> TrajectoryState | None: + return self._transitions.get(state_id) + def save_info(self, info: dict): self._info = info self._maybe_persist() - def get_mocked_actions(self) -> list[dict]: + def get_mocked_actions(self) -> List[dict]: """ Return a list of actions that can be used to mock the trajectory. - - TODO: Provide the end transition and support parent() """ actions = [] - for transition in self._transitions: - for action in transition["actions"]: - actions.append(action["action"]) + + for transition in self.transitions: + for action in transition.state._actions: + actions.append(action.request.model_dump()) return actions - def get_expected_states(self) -> list[str]: + def get_expected_states(self) -> List[str]: """ - Return a list of expected states in the trajectory to use for verifcation when rerunning the trajectory. - - TODO: Provide the end transition and support parent() + Return a list of expected states in the trajectory to use for verification when rerunning the trajectory. """ - - states = [] - for transition in self._transitions: - states.append(transition["state"]["name"]) - return states + return [transition.state.name for transition in self.transitions[1:]] def to_dict(self): return { "name": self._name, "transition_rules": self._transition_rules.model_dump( - exclude_none=True, exclude_unset=True + exclude_none=True ) if self._transition_rules else None, - "workspace": self._workspace, + "workspace": self._initial_workspace_state, "initial_message": self._initial_message, "current_transition_id": self._current_transition_id, - "transitions": self._transitions, + "transitions": [t.model_dump(exclude_none=True) for t in self.transitions], "info": self._info, } @@ -193,4 +258,4 @@ def persist(self, file_path: str): indent=2, default=to_jsonable_python, ) - ) + ) \ No newline at end of file diff --git a/moatless/transition_rules.py b/moatless/transition_rules.py index 5b836f56..4317d9ab 100644 --- a/moatless/transition_rules.py +++ b/moatless/transition_rules.py @@ -5,6 +5,7 @@ from moatless.settings import Settings from moatless.state import AgenticState, get_state_class +from moatless.workspace import Workspace logger = logging.getLogger(__name__) @@ -54,12 +55,18 @@ def validate_state_classes(cls, data: Any) -> Any: data["source"] = get_state_class(data["source"]) if isinstance(data.get("dest"), str): data["dest"] = get_state_class(data["dest"]) + + if data["source"] == data["dest"]: + raise ValueError("Source and destination states cannot be the same.") + return data class TransitionRules(BaseModel): - initial_state: type[AgenticState] = Field( - ..., description="The initial state of the loop." + initial_state: type[AgenticState] | None = Field( + default=None, + description="The initial state for the loop.", + deprecated="Initial state should be set in transition_rules instead." ) transition_rules: list[TransitionRule] = Field( ..., description="The transition rules for the loop." @@ -80,15 +87,19 @@ def __init__(self, **data): self._build_source_trigger_index() def model_dump(self, **kwargs): - return { - "initial_state": self.initial_state.__name__, + data = { + "global_params": self.global_params, + "state_params": {k.__name__: v for k, v in self.state_params.items()}, "transition_rules": [ rule.model_dump(**kwargs) for rule in self.transition_rules ], - "global_params": self.global_params, - "state_params": {k.__name__: v for k, v in self.state_params.items()}, } + if self.initial_state: + data["initial_state"] = self.initial_state.__name__ + + return data + @model_validator(mode="before") @classmethod def validate_before_init(cls, data: Any) -> Any: @@ -123,17 +134,24 @@ def find_transition_rule_by_source_and_trigger( ) -> list[TransitionRule]: return self._source_trigger_index.get((source, trigger), []) - def create_initial_state(self, **data) -> AgenticState: + def params(self, rule: TransitionRule) -> dict[str, Any]: params = {} params.update(self.global_params) - params.update(self.state_params.get(self.initial_state, {})) - params.update(data) - print(f"initial_state,{params}") - return self.initial_state(**params) + params.update(self.state_params.get(rule.dest, {})) + return params - def next_state( + def get_next_rule( self, source: AgenticState, trigger: str, data: dict[str, Any] - ) -> AgenticState | None: + ) -> TransitionRule | None: + + if trigger == "init" and self.initial_state: + logger.warning("Using deprecated 'initial_state'. Set initial state in transition_rules instead.") + return TransitionRule( + trigger="init", + source=source.__class__, + dest=self.initial_state, + ) + transition_rules = self.find_transition_rule_by_source_and_trigger( source.__class__, trigger ) @@ -145,17 +163,6 @@ def next_state( logger.info(f"Missing required fields for transition {transition_rule}") continue - params = {} - params.update(self.global_params) - params.update(self.state_params.get(transition_rule.dest, {})) - - if transition_rule.excluded_fields: - data = { - k: v - for k, v in data.items() - if k not in transition_rule.excluded_fields - } + return transition_rule - params.update(data) - return transition_rule.dest(**params) - return None + return None \ No newline at end of file diff --git a/moatless/transitions.py b/moatless/transitions.py index b76bbaf8..3479ab66 100644 --- a/moatless/transitions.py +++ b/moatless/transitions.py @@ -9,7 +9,7 @@ from moatless.find.identify import IdentifyCode from moatless.find.search import SearchCode from moatless.transition_rules import TransitionRule, TransitionRules -from moatless.state import Finished, Rejected +from moatless.state import Finished, Rejected, Pending CODE_TRANSITIONS = [ TransitionRule( @@ -194,8 +194,8 @@ def search_and_code_transitions( return TransitionRules( global_params=global_params, state_params=state_params, - initial_state=SearchCode, transition_rules=[ + TransitionRule(source=Pending, dest=SearchCode, trigger="init"), TransitionRule(source=SearchCode, dest=IdentifyCode, trigger="did_search"), TransitionRule(source=SearchCode, dest=PlanToCode, trigger="finish"), TransitionRule(source=IdentifyCode, dest=SearchCode, trigger="search"), diff --git a/moatless/types.py b/moatless/types.py index aa0bd005..5546f140 100644 --- a/moatless/types.py +++ b/moatless/types.py @@ -20,7 +20,6 @@ def add_span_ids(self, span_ids: list[str]): for span_id in span_ids: self.add_span_id(span_id) - class ActionRequest(BaseModel): pass @@ -28,6 +27,51 @@ class ActionRequest(BaseModel): def action_name(self): return self.__class__.__name__ +class ActionResponse(BaseModel): + trigger: Optional[str] = Field( + default=None, + description="Trigger to transition to the next state. If None, no transition is made.", + ) + output: Optional[dict[str, Any]] = Field( + default=None, + description="Output data to be passed to the next state.", + ) + + retry_message: Optional[str] = Field( + default=None, + description="Message to use in retry." + ) + + @classmethod + def retry(cls, retry_message: str): + return cls(trigger="retry", retry_message=retry_message) + + @classmethod + def transition(cls, trigger: str, output: dict[str, Any] | None = None): + output = output or {} + return cls(trigger=trigger, output=output) + + @classmethod + def no_transition(cls, output: dict[str, Any]): + return cls(output=output) + +class Usage(BaseModel): + completion_cost: float + completion_tokens: int + prompt_tokens: int + + +class ActionTransaction(BaseModel): + request: ActionRequest + response: Optional[ActionResponse] = None + usage: Optional[Usage] = None + + def model_dump(self, **kwargs): + data = super().model_dump(**kwargs) + data["request"] = self.request.model_dump(**kwargs) + data["response"] = self.response.model_dump(**kwargs) if self.response else None + return data + class EmptyRequest(ActionRequest): pass @@ -62,25 +106,6 @@ class UserMessage(Message): content: Optional[str] = None -class ActionResponse(BaseModel): - trigger: Optional[str] = None - output: Optional[dict[str, Any]] = None - retry_message: Optional[str] = None - - @classmethod - def retry(cls, retry_message: str): - return cls(trigger="retry", retry_message=retry_message) - - @classmethod - def transition(cls, trigger: str, output: dict[str, Any] | None = None): - output = output or {} - return cls(trigger=trigger, output=output) - - @classmethod - def no_transition(cls, output: dict[str, Any]): - return cls(output=output) - - class Response(BaseModel): status: str message: str diff --git a/moatless/utils/llm_utils.py b/moatless/utils/llm_utils.py new file mode 100644 index 00000000..daff1ad1 --- /dev/null +++ b/moatless/utils/llm_utils.py @@ -0,0 +1,18 @@ + +import instructor + + +def instructor_mode_by_model(model: str) -> instructor.Mode | None: + if "gpt" in model: + return instructor.Mode.TOOLS + + if "claude" in model: + return instructor.Mode.TOOLS + + if model.startswith("claude"): + return instructor.Mode.ANTHROPIC_TOOLS + + if model.startswith("openrouter/anthropic/claude"): + return instructor.Mode.TOOLS + + return instructor.Mode.JSON diff --git a/moatless/workspace.py b/moatless/workspace.py index 0b64445b..0dbaae78 100644 --- a/moatless/workspace.py +++ b/moatless/workspace.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Any, Optional, Dict from moatless.codeblocks.parser.python import PythonParser from moatless.file_context import FileContext @@ -135,7 +135,7 @@ def dict(self): "code_index": self.code_index.dict() if self.code_index else None, } - def snapshot(self): + def snapshot(self) -> Dict[str, Any]: return { "repository": self.file_repo.snapshot(), "file_context": self.file_context.snapshot(), @@ -168,4 +168,4 @@ def verify(self, file: CodeFile | None = None) -> list[VerificationError]: return self.verifier.verify(file) logger.info("No verifier configured.") - return [] + return [] \ No newline at end of file diff --git a/notebooks/swebench/01_evaluate_search.ipynb b/notebooks/swebench/01_evaluate_search.ipynb index 7211d216..a3639df5 100644 --- a/notebooks/swebench/01_evaluate_search.ipynb +++ b/notebooks/swebench/01_evaluate_search.ipynb @@ -1,79 +1,83 @@ { "cells": [ { + "cell_type": "code", + "execution_count": 1, + "id": "a068fdc332c9a4d8", "metadata": { "ExecuteTime": { "end_time": "2024-06-15T08:35:43.806834Z", "start_time": "2024-06-15T08:35:43.010364Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "import datetime\n", - "import litellm\n", "import os\n", "\n", "index_store_dir = f\"/home/albert/index_store\"\n", "repo_base_dir = f\"/tmp/repos\"\n", "\n", - "# model = \"gpt-4o-2024-05-13\"\n", - "model = \"mistral/open-mixtral-8x22b\"\n", - "\n", - "date_str = datetime.datetime.now().strftime(\"%Y%m%d\")\n", - "model_file_name = f\"{model.replace('/', '_')}\"\n", - "\n", "evaluations_dir = \"/home/albert/repos/albert/moatless/evaluations\"\n", - "evaluation_name = f\"{date_str}_moatless_search_{model_file_name}\"\n", + "evaluation_name = f\"\"\n", "evaluation_dir = f\"{evaluations_dir}/{evaluation_name}\"\n", "trajectory_dir = f\"{evaluations_dir}/{evaluation_name}/trajs\"\n", "\n", "if not os.path.exists(trajectory_dir):\n", - " os.makedirs(trajectory_dir)\n", - " \n", - "litellm.success_callback = [\"langfuse\"]\n", - "litellm.failure_callback = [\"langfuse\"]" - ], - "id": "a068fdc332c9a4d8", - "execution_count": 1, - "outputs": [] + " os.makedirs(trajectory_dir)\n" + ] }, { + "cell_type": "code", + "execution_count": 2, + "id": "66d6eea0cb4abe8d", "metadata": { "ExecuteTime": { "end_time": "2024-06-15T08:35:45.468948Z", "start_time": "2024-06-15T08:35:43.808190Z" } }, - "cell_type": "code", + "outputs": [], "source": [ - "from moatless import Workspace\n", - "from moatless.benchmark.swebench import verify_search_trajectory\n", - "\n", - "def to_result(instance: dict, trajectory: dict, workspace: Workspace):\n", - " result = {\n", - " \"instance_id\": instance[\"instance_id\"],\n", - " \"duration\": trajectory[\"info\"][\"duration\"],\n", - " \"total_cost\": trajectory[\"info\"][\"total_cost\"],\n", - " \"resolved_by\": len(instance[\"resolved_by\"])\n", - " }\n", + "from moatless.find.identify import IdentifyCode\n", + "from moatless.find.search import SearchCode\n", + "from moatless.transitions import search_transitions\n", "\n", - " result.update(verify_search_trajectory(trajectory, instance, workspace))\n", - " \n", - " return result\n", - " " - ], - "id": "66d6eea0cb4abe8d", - "execution_count": 2, - "outputs": [] + "global_params = {\n", + " \"model\": \"gpt-4o-2024-05-13\",\n", + " \"temperature\": 0.2,\n", + " \"max_tokens\": 2000,\n", + " \"max_prompt_file_tokens\": 8000,\n", + "}\n", + "\n", + "state_params = {\n", + " SearchCode: {\n", + " \"provide_initial_context\": True,\n", + " \"max_search_results\": 75,\n", + " \"initial_context_tokens\": 6000,\n", + " \"initial_search_results\": 100,\n", + " \"initial_context_spans_per_file\": 5,\n", + " },\n", + " IdentifyCode: {\"expand_context\": True},\n", + "}\n", + "\n", + "transitions = search_transitions(\n", + " global_params=global_params,\n", + " state_params=state_params,\n", + ")" + ] }, { + "cell_type": "code", + "execution_count": 3, + "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2024-06-15T08:35:45.785579Z", "start_time": "2024-06-15T08:35:45.469888Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "import logging\n", "import traceback\n", @@ -129,19 +133,19 @@ " search_loop.trajectory.save_info(info)\n", " \n", " return to_result(instance, search_loop.trajectory.to_dict(), workspace)" - ], - "id": "initial_id", - "execution_count": 3, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": 4, + "id": "398309b4de44e80d", "metadata": { "ExecuteTime": { "end_time": "2024-06-15T08:35:56.745639Z", "start_time": "2024-06-15T08:35:45.786823Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "from pandas import DataFrame\n", "import pandas as pd\n", @@ -200,34 +204,31 @@ " return pd.DataFrame(results)\n", "\n", "df = run_evaluation(\"/home/albert/repos/albert/moatless/datasets/swebench_lite_all_evaluations.json\")" - ], - "id": "398309b4de44e80d", - "execution_count": 4, - "outputs": [] + ] }, { + "cell_type": "code", + "execution_count": null, + "id": "f6f979631e1caf4", "metadata": { "ExecuteTime": { "end_time": "2024-06-15T08:35:56.747100Z", "start_time": "2024-06-15T08:35:56.746889Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "df.to_csv(f\"{evaluation_dir}/result.csv\", index=False, sep=';', decimal=',')\n", "df" - ], - "id": "f6f979631e1caf4", - "execution_count": null, - "outputs": [] + ] }, { - "metadata": {}, "cell_type": "code", - "source": "", - "id": "f8e931eb46f23d03", "execution_count": null, - "outputs": [] + "id": "f8e931eb46f23d03", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/poetry.lock b/poetry.lock index 29ee6018..3dc68df9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -379,6 +379,90 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.6.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "dataclasses-json" version = "0.6.7" @@ -3780,4 +3864,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<=3.13" -content-hash = "f15aca2a3254e6744bc65a474c0fedf3f23a48867d3badcc6e2578cdc5d96a1a" +content-hash = "b9c679d3b471808285e7c116c35976c18c36f74ab5cd7fc60fc2f4abbb4c14c5" diff --git a/pyproject.toml b/pyproject.toml index f93af2be..908def18 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ gitpython = "3.1.43" pyarrow = "17.0.0" requests = "2.32.3" pandas = "2.2.2" +coverage = "^7.6.1" [tool.ruff.lint] select = [ "B", "DTZ", "E", "F", "I", "LOG", "N", "PLE", "SIM", "T20", "UP",] diff --git a/tests/benchmark/test_evaluation.py b/tests/benchmark/test_evaluation.py new file mode 100644 index 00000000..7c87d0f4 --- /dev/null +++ b/tests/benchmark/test_evaluation.py @@ -0,0 +1,82 @@ +import os +from datetime import datetime + +import pytest +from dotenv import load_dotenv + +from moatless.benchmark.evaluation import Evaluation +from moatless.edit import PlanToCode, EditCode +from moatless.find import SearchCode, IdentifyCode, DecideRelevance +from moatless.transitions import search_and_code_transitions + +load_dotenv() +moatless_dir = os.getenv("MOATLESS_DIR", "/tmp/moatless") +index_store_dir = os.getenv("INDEX_STORE_DIR", "/tmp/index_store") +repo_dir = os.getenv("REPO_DIR", "/tmp/repo") + +global_params = { + "model": "gpt-4o-mini-2024-07-18", # "azure/gpt-4o", + "temperature": 0.5, + "max_tokens": 2000, + "max_prompt_file_tokens": 8000, +} + +state_params = { + SearchCode: { + "provide_initial_context": True, + "max_search_results": 75, + "initial_context_tokens": 6000, + "initial_search_results": 100, + "initial_context_spans_per_file": 5, + }, + IdentifyCode: {"expand_context": True}, + DecideRelevance: { + "finish_after_relevant_count": 1, + }, + PlanToCode: { + "max_tokens_in_edit_prompt": 750, + "expand_context_with_related_spans": False, + "finish_on_review": True, + }, + EditCode: { + "chain_of_thought": False, + "show_file_context": False, + "max_prompt_file_tokens": 8000, + }, +} + +search_and_code = search_and_code_transitions( + global_params=global_params, state_params=state_params +) + +pytest.mark.llm_integration = pytest.mark.skipif( + "not config.getoption('--run-llm-integration')", + reason="need --run-llm-integration option to run tests that call LLMs", +) + + +@pytest.mark.llm_integration +def test_run_single_evaluation_mcts(): + datestr = datetime.now().strftime("%Y%m%d-%H%M%S") + dir = f"{moatless_dir}/eval_test" + evaluation_name = f"{datestr}_mcts" + + evaluation = Evaluation( + transitions=search_and_code, + evaluations_dir=dir, + evaluation_name=evaluation_name, + index_store_dir=index_store_dir, + repo_base_dir=repo_dir, + max_file_context_tokens=16000, + num_workers=1, + detailed_report=True, + ) + + result = evaluation.run_single_instance("django__django-16379") + + assert result["instance_id"] == "django__django-16379" + assert result["status"] == "edited" + assert result["edited"] + assert result["identified"] + assert result["found_in_search"] + assert result["file_identified"] diff --git a/tests/benchmark/test_report_v2.py b/tests/benchmark/test_report_v2.py new file mode 100644 index 00000000..e35c4771 --- /dev/null +++ b/tests/benchmark/test_report_v2.py @@ -0,0 +1,39 @@ +import json +from pathlib import Path + +import pytest + +from moatless.benchmark.report_v2 import to_result +from moatless.trajectory import Trajectory + + +@pytest.fixture +def django_trajectory(): + file_path = Path("tests/trajectories/django__django_16379.json") + return Trajectory.load(str(file_path)) + + +@pytest.fixture +def dataset(): + with open("moatless/benchmark/swebench_lite_all_evaluations.json") as f: + return json.load(f) + +@pytest.fixture +def django_instance(dataset): + for instance in dataset: + if instance["instance_id"] == "django__django-16379": + return instance + + return None + + +def test_to_result(django_trajectory, django_instance): + result = to_result(django_instance, django_trajectory) + + assert result["instance_id"] == "django__django-16379" + assert result["status"] == "edited" + assert result["transitions"] == len(django_trajectory.transitions) + assert result["edited"] + assert result["identified"] + assert result["found_in_search"] + assert result["file_identified"] diff --git a/tests/edit/test_clarify.py b/tests/edit/test_clarify.py index eaba3a74..8a14b308 100644 --- a/tests/edit/test_clarify.py +++ b/tests/edit/test_clarify.py @@ -1,105 +1,135 @@ import pytest - -from moatless.benchmark.swebench import load_instance, create_workspace +from unittest.mock import Mock, patch from moatless.edit.clarify import ClarifyCodeChange, LineNumberClarification -from moatless.loop import AgenticLoop -from moatless.types import ActionResponse - - -def create_clarify( - mocker, instance_id: str, file_path: str, span_id: str, instructions: str = "" -): - clarify = ClarifyCodeChange( - instructions=instructions, file_path=file_path, span_id=span_id - ) - mock_loop = mocker.create_autospec(AgenticLoop) - instance = load_instance(instance_id) - mock_loop.workspace = create_workspace(instance) - mock_loop.workspace.file_context.add_span_to_context( - file_path=file_path, span_id=span_id - ) - clarify._set_loop(mock_loop) - return clarify, mock_loop - - -@pytest.mark.skip -def test_line_span_in_end_of_class(mocker): - instance_id = "scikit-learn__scikit-learn-13439" - file_path = "sklearn/pipeline.py" - span_id = "Pipeline" - - coding, mock_loop = create_clarify( - mocker, instance_id=instance_id, file_path=file_path, span_id=span_id - ) - - request = LineNumberClarification(start_line=562, end_line=563, thoughts="") - - response = coding._execute_action(request) - assert response == ActionResponse( - trigger="edit_code", - output={ - "instructions": "", - "file_path": "sklearn/pipeline.py", - "span_id": "Pipeline", - "start_line": 559, - "end_line": 562, - }, - ) - - -@pytest.mark.skip -def test_impl_span(mocker): - instance_id = "django__django-10914" - file_path = "django/conf/global_settings.py" - span_id = "impl:105" - start_line = 307 - end_line = 307 - - coding, mock_loop = create_clarify( - mocker, instance_id=instance_id, file_path=file_path, span_id=span_id - ) - - request = LineNumberClarification( - start_line=start_line, end_line=end_line, thoughts="" - ) - - response = coding._execute_action(request) - assert response == ActionResponse( - trigger="edit_code", - output={ - "instructions": "", - "file_path": "django/conf/global_settings.py", - "span_id": "impl:105", - "start_line": 303, - "end_line": 311, - }, - ) - - -@pytest.mark.skip -def test_line_span_in_class(mocker): - instance_id = "sympy__sympy-13177" - file_path = "requests/models.py" - span_id = "Request" - start_line = 151 - end_line = 153 - - coding, mock_loop = create_clarify( - mocker, instance_id=instance_id, file_path=file_path, span_id=span_id - ) - - request = LineNumberClarification( - start_line=start_line, end_line=end_line, thoughts="" - ) - - response = coding._execute_action(request) - assert response == ActionResponse( - trigger="edit_code", - output={ - "instructions": "", - "file_path": "requests/models.py", - "span_id": "Request", - "start_line": 147, - "end_line": 157, - }, - ) +from moatless.types import ActionResponse, FileWithSpans +from moatless.workspace import Workspace +from moatless.file_context import FileContext +from moatless.repository import CodeFile +from moatless.codeblocks.codeblocks import BlockSpan, CodeBlock +from moatless.codeblocks import CodeBlockType + +class TestClarifyCodeChange: + @pytest.fixture + def clarify_code_change(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + mock_file_context = Mock(spec=FileContext) + + return ClarifyCodeChange( + id=2, + instructions="Update function", + file_path="test.py", + span_id="span1", + _workspace=mock_workspace, + file_context=mock_file_context + ) + + def test_action_type(self, clarify_code_change: ClarifyCodeChange): + assert clarify_code_change.action_type() == LineNumberClarification + + @patch('moatless.edit.clarify.ClarifyCodeChange._verify_line_numbers') + def test_execute_action_reject(self, mock_verify, clarify_code_change: ClarifyCodeChange): + action = LineNumberClarification( + scratch_pad="Cannot complete the task", + start_line=1, + end_line=5, + reject=True + ) + + response = clarify_code_change._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "reject" + assert response.output["message"] == "Cannot complete the task" + + @patch('moatless.edit.clarify.ClarifyCodeChange._verify_line_numbers') + @patch('moatless.edit.clarify.ClarifyCodeChange.get_line_span') + def test_execute_action_edit_code(self, mock_get_line_span, mock_verify, clarify_code_change: ClarifyCodeChange): + action = LineNumberClarification( + scratch_pad="Updating lines", + start_line=2, + end_line=4 + ) + + mock_verify.return_value = None + mock_get_line_span.return_value = (2, 4) + + response = clarify_code_change._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "edit_code" + assert response.output["instructions"] == "Update function\n\nUpdating lines" + assert response.output["file_path"] == "test.py" + assert response.output["span_id"] == "span1" + assert response.output["start_line"] == 2 + assert response.output["end_line"] == 4 + + @patch('moatless.edit.clarify.ClarifyCodeChange._verify_line_numbers') + def test_execute_action_retry(self, mock_verify, clarify_code_change: ClarifyCodeChange): + action = LineNumberClarification( + scratch_pad="Retry needed", + start_line=1, + end_line=10 + ) + + mock_verify.return_value = "Invalid line numbers" + + response = clarify_code_change._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "retry" + assert response.retry_message == "Invalid line numbers" + + def test_required_fields(self, clarify_code_change: ClarifyCodeChange): + assert clarify_code_change.required_fields() == {"instructions", "file_path", "span_id"} + + def test_messages(self, clarify_code_change: ClarifyCodeChange): + # TODO: Test init() properly + clarify_code_change._file_context_str = "Mock file context" + messages = clarify_code_change.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Update function" in messages[0].content + assert "" in messages[0].content + assert "Mock file context" in messages[0].content + + + @patch('moatless.repository.CodeFile') + @patch('moatless.codeblocks.codeblocks.BlockSpan') + def test_verify_line_numbers_valid(self, mock_span, mock_file, clarify_code_change: ClarifyCodeChange): + mock_file.content = "line1\nline2\nline3\nline4\nline5" + mock_span.start_line = 1 + mock_span.end_line = 5 + clarify_code_change._file = mock_file + clarify_code_change._span = mock_span + + action = LineNumberClarification( + scratch_pad="Valid lines", + start_line=2, + end_line=4 + ) + + result = clarify_code_change._verify_line_numbers(action) + + assert result is None + + @patch('moatless.repository.CodeFile') + @patch('moatless.codeblocks.codeblocks.BlockSpan') + def test_verify_line_numbers_invalid(self, mock_span, mock_file, clarify_code_change: ClarifyCodeChange): + mock_file.content = "line1\nline2\nline3\nline4\nline5" + mock_span.start_line = 1 + mock_span.end_line = 5 + clarify_code_change._file = mock_file + clarify_code_change._span = mock_span + + action = LineNumberClarification( + scratch_pad="Invalid lines", + start_line=1, + end_line=5 + ) + + result = clarify_code_change._verify_line_numbers(action) + + assert result is not None + assert "covers the whole code span" in result diff --git a/tests/edit/test_edit.py b/tests/edit/test_edit.py index ec22a2a5..e608dd41 100644 --- a/tests/edit/test_edit.py +++ b/tests/edit/test_edit.py @@ -1,59 +1,118 @@ import pytest - -from moatless.benchmark.swebench import create_workspace, load_instance +from unittest.mock import Mock, patch from moatless.edit.edit import EditCode -from moatless.loop import AgenticLoop - - -def create_clarify( - mocker, - instance_id: str, - file_path: str, - span_id: str, - start_line: int, - end_line: int, - instructions: str = "", -): - clarify = EditCode( - model="gpt-4o", - instructions=instructions, - file_path=file_path, - span_id=span_id, - start_line=start_line, - end_line=end_line, - ) - mock_loop = mocker.create_autospec(AgenticLoop) - instance = load_instance(instance_id) - mock_loop.workspace = create_workspace(instance) - mock_loop.workspace.file_context.add_span_to_context( - file_path=file_path, span_id=span_id - ) - clarify._set_loop(mock_loop) - return clarify, mock_loop - - -@pytest.mark.skip -def test_search_block(mocker): - instance_id = "scikit-learn__scikit-learn-10297" - file_path = "sklearn/linear_model/ridge.py" - span_id = "RidgeClassifierCV.fit" - start_line = 1342 - end_line = 1377 - - coding, mock_loop = create_clarify( - mocker, - instance_id=instance_id, - file_path=file_path, - span_id=span_id, - start_line=start_line, - end_line=end_line, - ) - - # Assert that the first line is correct in the search block - found_search = False - for line in coding.messages()[0].content.split("\n"): - if found_search: - assert line == " def fit(self, X, y, sample_weight=None):" - break - if "" in line: - found_search = True +from moatless.repository.file import UpdateResult +from moatless.types import ActionResponse, Content +from moatless.workspace import Workspace +from moatless.file_context import FileContext +from moatless.repository import CodeFile + +class TestEditCode: + + @pytest.fixture + def edit_code(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + + return EditCode( + id=1, + instructions="Update function", + file_path="test.py", + span_id="span1", + verify=False, + start_line=1, + end_line=5, + _workspace=mock_workspace, + model="gpt-3.5-turbo" + ) + + def test_required_fields(self, edit_code: EditCode): + assert edit_code.required_fields() == {"instructions", "file_path", "span_id", "start_line", "end_line"} + + @patch('moatless.edit.edit.EditCode.file_context') + def test_init(self, mock_file_context, edit_code: EditCode): + mock_file = Mock(spec=CodeFile) + mock_file.content = "line1\nline2\nline3\nline4\nline5" + mock_file_wrapper = Mock(file=mock_file) + mock_file_context.get_file.return_value = mock_file_wrapper + + edit_code.init() + + assert edit_code._code_to_replace == "line1\nline2\nline3\nline4\nline5" + + @patch('moatless.edit.edit.EditCode.file_context') + def test_execute_action_reject(self, mock_file_context, edit_code: EditCode): + content = Content(content="Cannot complete the task") + + response = edit_code._execute_action(content) + + assert isinstance(response, ActionResponse) + assert response.trigger == "reject" + assert response.output["message"] == "Cannot complete the task" + + @patch('moatless.edit.edit.EditCode.file_context') + def test_execute_action_edit_code(self, mock_file_context, edit_code: EditCode): + update_result = UpdateResult(diff="diff", updated=True, file_path="test.py") + + mock_file = Mock(spec=CodeFile) + mock_file.update_content_by_line_numbers.return_value = update_result + + mock_context_file = Mock() + mock_context_file.file = mock_file + mock_context_file.update_content_by_line_numbers.return_value = update_result + + mock_file_context.get_file.return_value = mock_context_file + + content = Content(content="updated code") + + response = edit_code._execute_action(content) + + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + assert "Applied the change to test.py." in response.output["message"] + assert response.output["diff"] == "diff" + + mock_context_file.update_content_by_line_numbers.assert_called_once() + + @patch('moatless.edit.edit.EditCode.file_context') + def test_execute_action_retry(self, mock_file_context, edit_code: EditCode): + mock_file = Mock(spec=CodeFile) + mock_file.update_content_by_line_numbers.return_value = Mock(diff=None, updated=False) + mock_context_file = Mock() + mock_context_file.file = mock_file + mock_context_file.update_content_by_line_numbers.return_value = Mock(diff=None, updated=False) + mock_file_context.get_file.return_value = mock_context_file + + content = Content(content="unchanged code") + + response = edit_code._execute_action(content) + + assert isinstance(response, ActionResponse) + assert response.trigger == "retry" + assert "The code in the replace tag is the same as in the search" in response.retry_message + + def test_system_prompt(self, edit_code: EditCode): + system_prompt = edit_code.system_prompt() + + assert "You are autonomous AI assisistant with superior programming skills." in system_prompt + + @patch('moatless.edit.edit.EditCode.file_context') + def test_messages(self, mock_file_context, edit_code: EditCode): + mock_file_context.create_prompt.return_value = "Mock file context" + edit_code._code_to_replace = "code to replace" + + messages = edit_code.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Update function" in messages[0].content + assert "" in messages[0].content + assert "Mock file context" in messages[0].content + assert "" in messages[0].content + assert "code to replace" in messages[0].content + + def test_action_type(self, edit_code: EditCode): + assert edit_code.action_type() is None + + def test_stop_words(self, edit_code: EditCode): + assert edit_code.stop_words() == [""] \ No newline at end of file diff --git a/tests/edit/test_plan.py b/tests/edit/test_plan.py index 96932cf7..a6d9fe54 100644 --- a/tests/edit/test_plan.py +++ b/tests/edit/test_plan.py @@ -1,19 +1,115 @@ -from moatless.edit.plan import ApplyChange +import pytest +from unittest.mock import Mock, patch +from moatless.edit.plan import PlanToCode, ApplyChange +from moatless.types import ActionResponse, ActionTransaction +from moatless.workspace import Workspace +from moatless.file_context import FileContext +class TestPlanToCode: + @pytest.fixture + def plan_to_code(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + mock_file_context = Mock(spec=FileContext) + + return PlanToCode( + id=1, + _workspace=mock_workspace, + _initial_message="Test initial message", + file_context=mock_file_context + ) -def test_deserialize_action(): - data = { - "scratch_pad": "To fix the race condition in the has_key method, we need to handle the case where the file might be deleted between the os.path.exists check and the open call. We can do this by wrapping the open call in a try-except block to catch the FileNotFoundError exception and return False if the file is not found.", - "action": "modify", - "instructions": "Wrap the open call in the has_key method in a try-except block to catch FileNotFoundError and return False if the file is not found.", - "file_path": "django/core/cache/backends/filebased.py", - "span_id": "FileBasedCache.has_key", - } + def test_action_type(self, plan_to_code): + assert plan_to_code.action_type() == ApplyChange - action = ApplyChange.model_validate(data) + def test_execute_action_finish(self, plan_to_code): + action = ApplyChange( + scratch_pad="Finished", + action="finish", + finish="Task completed successfully" + ) - assert action.scratch_pad == data["scratch_pad"] - assert action.action == data["action"] - assert action.instructions == data["instructions"] - assert action.file_path == data["file_path"] + response = plan_to_code._execute_action(action) + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + assert response.output["message"] == "Task completed successfully" + + def test_execute_action_reject(self, plan_to_code): + action = ApplyChange( + scratch_pad="Rejected", + action="reject", + reject="Cannot complete the task" + ) + + response = plan_to_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "reject" + assert response.output["message"] == "Cannot complete the task" + + def test_execute_action_review(self, plan_to_code): + action = ApplyChange( + scratch_pad="Review needed", + action="review" + ) + + response = plan_to_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "retry" + assert "Review isn't possible" in response.retry_message + + @patch('moatless.edit.plan.PlanToCode._request_for_change') + def test_execute_action_apply_change(self, mock_request_for_change, plan_to_code): + action = ApplyChange( + scratch_pad="Applying change", + action="modify", + file_path="test.py", + span_id="span1", + instructions="Update function" + ) + + mock_request_for_change.return_value = ActionResponse(trigger="edit_code") + + response = plan_to_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "edit_code" + mock_request_for_change.assert_called_once_with(action) + + @patch('moatless.file_context.FileContext.create_prompt') + def test_messages(self, mock_create_prompt, plan_to_code): + mock_create_prompt.return_value = "Mock file context" + + messages = plan_to_code.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Test initial message" in messages[0].content + assert "" in messages[0].content + assert "Mock file context" in messages[0].content + + mock_create_prompt.assert_called_once() + + @patch('moatless.file_context.FileContext.get_file') + @patch('moatless.file_context.FileContext.get_spans') + def test_request_for_change_file_not_found(self, mock_get_spans, mock_get_file, plan_to_code): + mock_get_file.return_value = None + mock_get_spans.return_value = [] + + action = ApplyChange( + scratch_pad="Change request", + action="modify", + file_path="nonexistent.py", + span_id="span1", + instructions="Update function" + ) + + response = plan_to_code._request_for_change(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "retry" + assert "File nonexistent.py is not found in the file context" in response.retry_message + + # Add more tests for other scenarios in _request_for_change method \ No newline at end of file diff --git a/tests/find/test_decide.py b/tests/find/test_decide.py new file mode 100644 index 00000000..d588fe43 --- /dev/null +++ b/tests/find/test_decide.py @@ -0,0 +1,115 @@ +import pytest +from moatless.find.decide import DecideRelevance, Decision +from moatless.find.identify import Identify, IdentifyCode +from moatless.types import ActionResponse, ActionTransaction +from moatless.workspace import Workspace +from moatless.file_context import FileContext +from unittest.mock import Mock, MagicMock, patch + +class TestDecideRelevance: + @pytest.fixture + def decide_relevance(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + mock_file_context = Mock(spec=FileContext) + + return DecideRelevance( + id=1, + _workspace=mock_workspace, + _initial_message="Test initial message", + expand_context=False, + file_context=mock_file_context + ) + + def test_action_type(self, decide_relevance): + assert decide_relevance.action_type() == Decision + + def test_execute_action_complete_and_relevant(self, decide_relevance): + action = Decision( + scratch_pad="Complete and relevant", + relevant=True, + complete=True + ) + + response = decide_relevance._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + + def test_execute_action_relevant_but_not_complete(self, decide_relevance): + decide_relevance.finish_after_relevant_count = 1 + decide_relevance._relevant_count = Mock(return_value=1) + action = Decision( + scratch_pad="Relevant but not complete", + relevant=True, + complete=False + ) + + response = decide_relevance._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + + def test_execute_action_not_relevant_not_complete(self, decide_relevance): + action = Decision( + scratch_pad="Not relevant, not complete", + relevant=False, + complete=False, + search_suggestions="Try searching for X" + ) + + response = decide_relevance._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "search" + assert response.output["message"] == "Try searching for X" + + def test_relevant_count(self, decide_relevance: DecideRelevance): + state3 = DecideRelevance(id=3, expand_context=False, file_context=Mock()) + state3._actions = [ActionTransaction(request=Decision(scratch_pad="Test", relevant=True), response=ActionResponse(trigger="finish"))] + state2 = DecideRelevance(id=2, expand_context=False, file_context=Mock()) + state2._actions = [ActionTransaction(request=Decision(scratch_pad="Test", relevant=False), response=ActionResponse(trigger="finish"))] + state2.previous_state = state3 + state1 = DecideRelevance(id=1, expand_context=False, file_context=Mock()) + state1._actions = [ActionTransaction(request=Decision(scratch_pad="Test", relevant=True), response=ActionResponse(trigger="finish"))] + state1.previous_state = state2 + + decide_relevance.previous_state = state1 + assert len(decide_relevance.get_previous_states(decide_relevance)) == 3 + assert decide_relevance._relevant_count() == 2 + + @patch('moatless.file_context.FileContext.create_prompt') + def test_messages(self, mock_create_prompt, decide_relevance): + mock_create_prompt.return_value = "Mock file context" + + messages = decide_relevance.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Test initial message" in messages[0].content + assert "" in messages[0].content + assert "Mock file context" in messages[0].content + + mock_create_prompt.assert_called_once() + + @patch('moatless.file_context.FileContext.create_prompt') + def test_messages_with_last_scratch_pad(self, mock_create_prompt, decide_relevance): + mock_create_prompt.return_value = "Mock file context" + + previous_state = IdentifyCode(id=3) + previous_state._actions = [ActionTransaction(request=Identify(scratch_pad="Previous scratch pad", relevant=True))] + decide_relevance.previous_state = previous_state + + messages = decide_relevance.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Previous scratch pad" in messages[0].content + + def test_system_prompt(self, decide_relevance): + system_prompt = decide_relevance.system_prompt() + + assert "You will be provided a reported issue and the file context" in system_prompt + assert "Analyze the Issue:" in system_prompt + assert "Analyze File Context:" in system_prompt + assert "Make a Decision:" in system_prompt \ No newline at end of file diff --git a/tests/find/test_identify.py b/tests/find/test_identify.py index a7bd5a76..10c2d998 100644 --- a/tests/find/test_identify.py +++ b/tests/find/test_identify.py @@ -1,30 +1,90 @@ +import pytest +from moatless.codeblocks.codeblocks import BlockSpan, SpanType +from moatless.find.identify import IdentifyCode, Identify, is_test_pattern from moatless.file_context import RankedFileSpan -from moatless.find import IdentifyCode - - -def test_model_dump(): - identify = IdentifyCode( - ranked_spans=[ - RankedFileSpan( - file_path="file1.py", - span_id="span1", - rank=1, - ), - RankedFileSpan(file_path="file2.py", span_id="span2", rank=2, tokens=50), - ] - ) - - assert identify.model_dump() == { - "include_message_history": False, - "model": None, - "temperature": 0.0, - "max_tokens": 1000, - "max_iterations": None, - "ranked_spans": [ - {"file_path": "file1.py", "span_id": "span1", "rank": 1, "tokens": 0}, - {"file_path": "file2.py", "span_id": "span2", "rank": 2, "tokens": 50}, - ], - "expand_context": True, - "max_prompt_file_tokens": 4000, - "name": "IdentifyCode", - } +from moatless.repository.file import CodeFile +from moatless.types import FileWithSpans, ActionResponse +from moatless.workspace import Workspace +from unittest.mock import Mock, MagicMock + +class TestIdentifyCode: + @pytest.fixture + def identify_code(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + + mock_module = Mock() + mock_module.find_span_by_id.side_effect = lambda span_id: BlockSpan( + span_type=SpanType.IMPLEMENTATION, + span_id=span_id, + start_line=0, + end_line=10, + ) + + mock_code_file = MagicMock(spec=CodeFile) + mock_code_file.content = "Mock file content" + mock_code_file.file_path = "test.py" + mock_code_file.module = mock_module + + mock_file_repo.get_file.side_effect = lambda path: mock_code_file + + return IdentifyCode(id=1, _workspace=mock_workspace, _initial_message="Test initial message") + + def test_action_type(self, identify_code): + assert identify_code.action_type() == Identify + + def test_system_prompt(self, identify_code): + assert isinstance(identify_code.system_prompt(), str) + assert "You are an autonomous AI assistant" in identify_code.system_prompt() + + def test_execute_action_with_identified_spans(self, identify_code): + action = Identify( + scratch_pad="Test scratch pad", + identified_spans=[ + FileWithSpans(file_path="test.py", span_ids=["span1", "span2"]) + ] + ) + + response = identify_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + + # Verify that the file was added to the file context + assert "test.py" in identify_code.file_context._file_context + context_file = identify_code.file_context._file_context["test.py"] + assert context_file.file_path == "test.py" + assert set(context_file.span_ids) == {"span1", "span2"} + + def test_execute_action_without_identified_spans(self, identify_code): + identify_code.ranked_spans = [RankedFileSpan(file_path="test.py", span_id="span1", rank=1)] + action = Identify(scratch_pad="No relevant spans found") + + response = identify_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "search" + assert "The search returned 1 results" in response.output["message"] + + def test_messages(self, identify_code): + messages = identify_code.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Test initial message" in messages[0].content + assert "" in messages[0].content + assert "" in messages[0].content + + def test_initial_message(self, identify_code): + assert identify_code.initial_message == "Test initial message" + + def test_workspace_initialization(self, identify_code): + assert identify_code._workspace is not None + assert isinstance(identify_code._workspace, Workspace) + +def test_is_test_pattern(): + assert is_test_pattern("test_file.py") == True + assert is_test_pattern("file_test.py") == False + assert is_test_pattern("/tests/some_file.py") == True + assert is_test_pattern("src/main.py") == False + assert is_test_pattern("test_utils/helper.py") == True \ No newline at end of file diff --git a/tests/find/test_search.py b/tests/find/test_search.py new file mode 100644 index 00000000..e2ff1d18 --- /dev/null +++ b/tests/find/test_search.py @@ -0,0 +1,68 @@ +import pytest +from moatless.find.search import SearchCode, Search, SearchRequest +from moatless.types import ActionResponse +from moatless.workspace import Workspace +from unittest.mock import Mock, MagicMock +from pydantic import ValidationError + +class TestSearchCode: + @pytest.fixture + def search_code(self): + mock_file_repo = Mock() + mock_workspace = Workspace(file_repo=mock_file_repo) + mock_code_index = MagicMock() + mock_workspace.code_index = mock_code_index + + return SearchCode(id=1, _workspace=mock_workspace, _initial_message="Test initial message") + + def test_action_type(self, search_code): + assert search_code.action_type() == Search + + def test_execute_action_complete(self, search_code): + action = Search( + scratch_pad="Search complete", + search_requests=[], + complete=True + ) + + response = search_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "finish" + assert response.output["message"] == "Search complete" + + def test_validate_search_without_search_attributes(self): + with pytest.raises(ValidationError) as excinfo: + Search( + scratch_pad="Invalid search", + search_requests=[] + ) + + assert "At least one search request must exist." in str(excinfo.value) + + def test_execute_action_with_search_results(self, search_code): + mock_code_index = MagicMock() + mock_code_index.search.return_value.hits = [ + MagicMock(file_path="test.py", spans=[MagicMock(span_id="span1", rank=1, tokens=10)]) + ] + search_code.workspace.code_index = mock_code_index + + action = Search( + scratch_pad="Valid search", + search_requests=[SearchRequest(query="test query")] + ) + + response = search_code._execute_action(action) + + assert isinstance(response, ActionResponse) + assert response.trigger == "did_search" + assert "ranked_spans" in response.output + assert len(response.output["ranked_spans"]) == 1 + + def test_messages(self, search_code): + messages = search_code.messages() + + assert len(messages) == 1 + assert "" in messages[0].content + assert "Test initial message" in messages[0].content + assert "" in messages[0].content \ No newline at end of file diff --git a/tests/loop/test_loop.py b/tests/loop/test_loop.py deleted file mode 100644 index beb871ac..00000000 --- a/tests/loop/test_loop.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import tempfile - -import pytest - -from moatless import AgenticLoop -from moatless.benchmark.swebench import create_workspace, load_instance -from moatless.repository import GitRepository -from moatless.settings import Settings -from moatless.trajectory import Trajectory - -pytest.mark.api_keys_required = pytest.mark.skipif( - "VOYAGE_API_KEY" not in os.environ or os.environ["VOYAGE_API_KEY"] == "", - reason="VOYAGE_API_KEY environment variable is required" -) - - -@pytest.mark.api_keys_required -def test_rerun_save_and_load_trajectory(): - trajectory = Trajectory.load("tests/trajectories/django__django_16379.json") - Settings.cheap_model = None # To not use an LLM when generating commit messages - - # Start by running the trajectory again with mocked action requests - instance = load_instance("django__django-16379") - workspace = create_workspace(instance) - assert isinstance(workspace.file_repo, GitRepository) - mocked_actions = trajectory.get_mocked_actions() - expected_states = trajectory.get_expected_states() - - loop = AgenticLoop( - trajectory.transition_rules, workspace=workspace, mocked_actions=mocked_actions, expected_states=expected_states - ) - response = loop.run(message=trajectory.initial_message) - - assert workspace.file_context.has_span( - "django/core/cache/backends/filebased.py", "FileBasedCache.has_key" - ) - assert loop.workspace.file_repo._initial_commit != loop.workspace.file_repo._current_commit - diff = loop.workspace.file_repo.diff() - assert diff == """diff --git a/django/core/cache/backends/filebased.py b/django/core/cache/backends/filebased.py -index 631da49444..f980d8d6ac 100644 ---- a/django/core/cache/backends/filebased.py -+++ b/django/core/cache/backends/filebased.py -@@ -91,8 +91,11 @@ class FileBasedCache(BaseCache): - def has_key(self, key, version=None): - fname = self._key_to_file(key, version) - if os.path.exists(fname): -- with open(fname, "rb") as f: -- return not self._is_expired(f) -+ try: -+ with open(fname, "rb") as f: -+ return not self._is_expired(f) -+ except FileNotFoundError: -+ return False - return False - - def _cull(self):""" - - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - loop.persist(tmp_file.name) - - # Verify that the loop can be iniatied from the saved trajectory - saved_loop = AgenticLoop.from_trajectory_file(tmp_file.name) - - saved_response = saved_loop.run(message=trajectory.initial_message) - assert saved_response == response - assert saved_loop.workspace.file_repo._initial_commit == loop.workspace.file_repo._initial_commit - assert saved_loop.workspace.file_repo._current_commit == loop.workspace.file_repo._current_commit - assert saved_loop.workspace.file_repo.diff() == loop.workspace.file_repo.diff() diff --git a/tests/test_loop.py b/tests/test_loop.py new file mode 100644 index 00000000..0db0d9e7 --- /dev/null +++ b/tests/test_loop.py @@ -0,0 +1,137 @@ +import os +import tempfile + +import pytest +from unittest.mock import MagicMock, patch +from moatless.loop import AgenticLoop +from moatless.state import AgenticState, Finished, Rejected, Pending +from moatless.transition_rules import TransitionRules, TransitionRule +from moatless.workspace import Workspace +from moatless.types import ActionRequest, ActionResponse, Content + +from moatless.benchmark.swebench import create_workspace, load_instance +from moatless.repository import GitRepository +from moatless.settings import Settings +from moatless.trajectory import Trajectory + +pytest.mark.api_keys_required = pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ or os.environ["VOYAGE_API_KEY"] == "", + reason="VOYAGE_API_KEY environment variable is required" +) + +class TestState(AgenticState): + def _execute_action(self, action: ActionRequest) -> ActionResponse: + if action.content == "reject": + return ActionResponse(trigger="reject", output={"message": "Rejected"}) + elif action.content == "continue": + return ActionResponse(trigger="continue", output={"message": "Continue"}) + return ActionResponse(trigger="finish", output={"message": "Finished"}) + +class TestState2(AgenticState): + def _execute_action(self, action: ActionRequest) -> ActionResponse: + if action.content == "reject": + return ActionResponse(trigger="reject", output={"message": "Rejected"}) + elif action.content == "continue": + return ActionResponse(trigger="continue", output={"message": "Continue"}) + return ActionResponse(trigger="finish", output={"message": "Finished"}) + +class TestTransitionRules(TransitionRules): + def __init__(self, rules): + super().__init__(transition_rules=rules) + + def get_next_rule(self, source, trigger, data): + for rule in self.transition_rules: + if rule.source == source.__class__ and rule.trigger == trigger: + return rule + return None + +@pytest.fixture +def mock_workspace(): + workspace = MagicMock(spec=Workspace) + workspace.snapshot.return_value = {} + return workspace + +@pytest.fixture +def test_transition_rules(): + rules = [ + TransitionRule(trigger="init", source=Pending, dest=TestState), + TransitionRule(trigger="finish", source=TestState, dest=Finished), + TransitionRule(trigger="continue", source=TestState, dest=TestState2), + TransitionRule(trigger="continue", source=TestState2, dest=TestState), + TransitionRule(trigger="reject", source=TestState, dest=Rejected), + ] + return TestTransitionRules(rules) + +def test_loop_initialization(mock_workspace, test_transition_rules): + loop = AgenticLoop(test_transition_rules, mock_workspace) + assert loop.workspace == mock_workspace + assert loop._transition_rules == test_transition_rules + +def test_loop_run_until_finished(mock_workspace, test_transition_rules): + loop = AgenticLoop(test_transition_rules, mock_workspace) + + with patch.object(AgenticLoop, '_next_action', return_value=(Content(content="test"), None)): + response = loop.run("initial message") + + assert response.status == "finished" + assert loop.state_count() == 3, f"Expected 3 states, got {[state.state.name for state in loop._trajectory.transitions()]}" + assert loop._initial_message == "initial message" + assert isinstance(loop._trajectory.transitions[2].state, Finished) + +def test_loop_run_until_rejected(mock_workspace, test_transition_rules): + loop = AgenticLoop(test_transition_rules, mock_workspace) + + def mock_next_action() : + return Content(content="reject"), None + + with patch.object(AgenticLoop, '_next_action', side_effect=mock_next_action): + response = loop.run("initial message") + + assert response.status == "rejected" + assert loop.state_count() == 3 # Pending -> TestState -> Rejected + +def test_loop_max_transitions(mock_workspace, test_transition_rules): + loop = AgenticLoop(test_transition_rules, mock_workspace, max_transitions=3) + + with patch.object(AgenticLoop, '_next_action', return_value=(Content(content="continue"), None)): + response = loop.run("initial message") + + assert response.status == "rejected" + assert response.message == "Max transitions exceeded." + assert loop.state_count() == 4, f"Expected 4 states, got {[t.state.name for t in loop._trajectory.transitions]}" + +@pytest.mark.api_keys_required +def test_rerun_save_and_load_trajectory(): + trajectory = Trajectory.load("tests/trajectories/django__django_16379.json") + Settings.cheap_model = None # To not use an LLM when generating commit messages + + # Start by running the trajectory again with mocked action requests + instance = load_instance("django__django-16379") + workspace = create_workspace(instance) + assert isinstance(workspace.file_repo, GitRepository) + mocked_actions = trajectory.get_mocked_actions() + expected_states = trajectory.get_expected_states() + + loop = AgenticLoop( + trajectory.transition_rules, workspace=workspace, mocked_actions=mocked_actions, expected_states=expected_states + ) + response = loop.run(message=trajectory.initial_message) + + assert workspace.file_context.has_span( + "django/core/cache/backends/filebased.py", "FileBasedCache.has_key" + ) + assert loop.workspace.file_repo._initial_commit != loop.workspace.file_repo._current_commit + diff = loop.workspace.file_repo.diff() + # TODO: assert diff + + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + loop.persist(tmp_file.name) + + # Verify that the loop can be iniatied from the saved trajectory + saved_loop = AgenticLoop.from_trajectory_file(tmp_file.name) + + saved_response = saved_loop.run(message=trajectory.initial_message) + assert saved_response == response + assert saved_loop.workspace.file_repo._initial_commit == loop.workspace.file_repo._initial_commit + assert saved_loop.workspace.file_repo._current_commit == loop.workspace.file_repo._current_commit + assert saved_loop.workspace.file_repo.diff() == loop.workspace.file_repo.diff() \ No newline at end of file diff --git a/tests/test_state.py b/tests/test_state.py index 62cd3d5e..ef45cedd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,23 +1,24 @@ import pytest from unittest.mock import MagicMock -from moatless.state import AgenticState, NoopState +from moatless.state import AgenticState, NoopState, Finished from moatless.workspace import Workspace from moatless.repository import FileRepository from moatless.file_context import FileContext -from moatless.types import ActionRequest, ActionResponse, FileWithSpans +from moatless.types import ActionRequest, ActionResponse, Content, FileWithSpans, Usage class ConcreteAgenticState(AgenticState): def _execute_action(self, action: ActionRequest) -> ActionResponse: - return ActionResponse(content="Test response") + return ActionResponse(output={"content": "Test response"}) @pytest.fixture def test_state(): - return ConcreteAgenticState() + return ConcreteAgenticState(id=1) def test_agentic_state_initialization(test_state): + assert test_state.id == 1 assert test_state.include_message_history == False assert test_state.model is None assert test_state.temperature == 0.0 @@ -29,35 +30,9 @@ def test_agentic_state_name(test_state): assert test_state.name == "ConcreteAgenticState" -def test_agentic_state_set_loop(test_state): - mock_loop = MagicMock() - test_state._set_loop(mock_loop) - assert test_state.loop == mock_loop - - -def test_agentic_state_workspace_properties(test_state): - mock_loop = MagicMock() - mock_workspace = MagicMock(spec=Workspace) - mock_file_repo = MagicMock(spec=FileRepository) - mock_file_context = MagicMock(spec=FileContext) - - mock_loop.workspace = mock_workspace - mock_workspace.file_repo = mock_file_repo - mock_workspace.file_context = mock_file_context - - test_state._set_loop(mock_loop) - - assert test_state.workspace == mock_workspace - assert test_state.file_repo == mock_file_repo - assert test_state.file_context == mock_file_context - - def test_agentic_state_create_file_context(test_state): - mock_loop = MagicMock() mock_workspace = MagicMock(spec=Workspace) - mock_loop.workspace = mock_workspace - - test_state._set_loop(mock_loop) + test_state._workspace = mock_workspace files = [FileWithSpans(file_path="test.py", content="print('hello')", spans=[])] test_state.create_file_context(files) @@ -65,98 +40,124 @@ def test_agentic_state_create_file_context(test_state): mock_workspace.create_file_context.assert_called_once_with(files) -def test_agentic_state_transition_to(test_state): - mock_loop = MagicMock() - test_state._set_loop(mock_loop) - - new_state = NoopState() - test_state.transition_to(new_state) - - mock_loop.transition_to.assert_called_once_with(new_state) - - def test_agentic_state_model_dump(test_state): - dump = test_state.model_dump() - assert "name" in dump - assert dump["name"] == "ConcreteAgenticState" + dump = test_state.model_dump(exclude_none=True) + assert dump == {'id': 1, 'include_message_history': False, 'max_tokens': 1000, 'temperature': 0.0} + def test_agentic_state_equality_same_state(): - state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) - state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + state1 = ConcreteAgenticState(id=1, temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(id=1, temperature=0.5, max_tokens=500) assert state1 == state2 def test_agentic_state_equality_different_state(): - state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) - state2 = ConcreteAgenticState(temperature=0.7, max_tokens=500) + state1 = ConcreteAgenticState(id=1, temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(id=2, temperature=0.7, max_tokens=500) assert state1 != state2 def test_agentic_state_equality_different_types(): - state1 = ConcreteAgenticState() - state2 = NoopState() + state1 = ConcreteAgenticState(id=1) + state2 = NoopState(id=2) assert state1 != state2 -def test_agentic_state_equality_with_file_context(): - mock_loop1 = MagicMock() - mock_workspace1 = MagicMock(spec=Workspace) - mock_file_context1 = MagicMock(spec=FileContext) - mock_loop1.workspace = mock_workspace1 - mock_workspace1.file_context = mock_file_context1 - mock_file_context1.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} - - mock_loop2 = MagicMock() - mock_workspace2 = MagicMock(spec=Workspace) - mock_file_context2 = MagicMock(spec=FileContext) - mock_loop2.workspace = mock_workspace2 - mock_workspace2.file_context = mock_file_context2 - mock_file_context2.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} - - state1 = ConcreteAgenticState() - state2 = ConcreteAgenticState() +def test_handle_action(test_state): + action = ActionRequest(content="Test action") + usage = Usage(prompt_tokens=10, completion_tokens=20, completion_cost=0.2) + response = test_state.handle_action(action, usage) - state1._set_loop(mock_loop1) - state2._set_loop(mock_loop2) - - assert state1 == state2 - -def test_agentic_state_inequality_with_different_file_context(): - mock_loop1 = MagicMock() - mock_workspace1 = MagicMock(spec=Workspace) - mock_file_context1 = MagicMock(spec=FileContext) - mock_loop1.workspace = mock_workspace1 - mock_workspace1.file_context = mock_file_context1 - mock_file_context1.model_dump.return_value = {"files": [{"file_path": "test1.py", "spans": ["foo"]}]} - - mock_loop2 = MagicMock() - mock_workspace2 = MagicMock(spec=Workspace) - mock_file_context2 = MagicMock(spec=FileContext) - mock_loop2.workspace = mock_workspace2 - mock_workspace2.file_context = mock_file_context2 - mock_file_context2.model_dump.return_value = {"files": [{"file_path": "test1.py", "spans": ["bar"]}]} - - state1 = ConcreteAgenticState() - state2 = ConcreteAgenticState() + assert isinstance(response, ActionResponse) + assert response.output == {"content": "Test response"} + assert len(test_state._actions) == 1 + assert test_state._actions[0].request == action + assert test_state._actions[0].response == response + assert test_state._actions[0].usage == usage + +def test_handle_action_executed_state(): + state = ConcreteAgenticState(id=1) + state._executed = True - state1._set_loop(mock_loop1) - state2._set_loop(mock_loop2) + with pytest.raises(ValueError, match="State has already been executed"): + state.handle_action(ActionRequest(content="Test"), None) - assert state1 != state2 +def test_last_action(test_state): + assert test_state.last_action is None + + action = Content(content="Test action") + test_state.handle_action(action, None) + + assert test_state.last_action is not None + assert test_state.last_action.request == action -def test_agentic_state_equality_without_loop(): - state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) - state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) - assert state1 == state2 +def test_response(test_state): + assert test_state.response is None + + action = Content(content="Test action") + response = test_state.handle_action(action, None) + + assert test_state.response == response -def test_agentic_state_equality_one_with_loop(): - mock_loop = MagicMock() - mock_workspace = MagicMock(spec=Workspace) - mock_file_context = MagicMock(spec=FileContext) - mock_loop.workspace = mock_workspace - mock_workspace.file_context = mock_file_context - mock_file_context.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} +def test_retries(test_state): + assert test_state.retries() == 0 + + test_state.handle_action(Content(content="Test 1"), None) + assert test_state.retries() == 0 + + test_state.handle_action(Content(content="Test 2"), None) + test_state._actions[-1].response.trigger = "retry" + assert test_state.retries() == 1 + + test_state.handle_action(Content(content="Test 3"), None) + test_state._actions[-1].response.trigger = "retry" + assert test_state.retries() == 2 - state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) - state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) +def test_retry_messages(test_state): + test_state.handle_action(Content(content="Test 1"), None) + test_state._actions[-1].response.retry_message = "Retry 1" + + test_state.handle_action(Content(content="Test 2"), None) + test_state._actions[-1].response.retry_message = "Retry 2" + + messages = test_state.retry_messages() + assert len(messages) == 4 + assert messages[0].role == "assistant" + assert messages[0].content == "Test 1" + assert messages[1].role == "user" + assert messages[1].content == "Retry 1" + assert messages[2].role == "assistant" + assert messages[2].content == "Test 2" + assert messages[3].role == "user" + assert messages[3].content == "Retry 2" + +def test_clone(test_state): + test_state.handle_action(Content(content="Test"), None) + cloned_state = test_state.clone() + + assert cloned_state.id == test_state.id + assert cloned_state.model_dump() == test_state.model_dump() + assert len(cloned_state._actions) == 0 # Actions should not be cloned + assert cloned_state._executed == False + +def test_total_cost(test_state): + usage1 = Usage(prompt_tokens=10, completion_tokens=20, completion_cost=0.2) + usage2 = Usage(prompt_tokens=15, completion_tokens=25, completion_cost=0.25) + + test_state.handle_action(Content(content="Test 1"), usage1) + test_state.handle_action(Content(content="Test 2"), usage2) - state1._set_loop(mock_loop) + assert test_state.total_cost() == pytest.approx(0.45) - assert state1 == state2 \ No newline at end of file +def test_finished_state_creation_and_dump(): + message = "Task completed successfully" + output = {"result": "success", "data": [1, 2, 3]} + + finished_state = Finished.model_validate({"id": 1, "message": message, "output": output}) + + assert finished_state.id == 1 + assert finished_state.message == message + assert finished_state.output == output + + dumped_state = finished_state.model_dump() + + assert dumped_state["id"] == 1 + assert dumped_state["message"] == message + assert dumped_state["output"] == output diff --git a/tests/test_trajectory.py b/tests/test_trajectory.py new file mode 100644 index 00000000..04ac3162 --- /dev/null +++ b/tests/test_trajectory.py @@ -0,0 +1,52 @@ +import os +import json +from datetime import datetime + +from moatless.repository import GitRepository +from moatless.trajectory import Trajectory + +def test_load_django_trajectory(): + file_path = "tests/trajectories/django__django_16379.json" + trajectory = Trajectory.load(file_path) + + with open(file_path, 'r') as f: + original_data = json.load(f) + + assert trajectory._name == original_data["name"] + assert trajectory._initial_message == original_data["initial_message"] + assert isinstance(trajectory._workspace.file_repo, GitRepository) + assert trajectory._workspace.file_repo.path == original_data["workspace"]["repository"]["repo_path"] + assert trajectory._workspace.file_context._max_tokens == original_data["workspace"]["file_context"]["max_tokens"] + + assert trajectory._current_transition_id == original_data["current_transition_id"] + + assert len(trajectory._transitions) == len(original_data["transitions"]) + + for loaded_transition, original_transition in zip(trajectory.transitions, original_data["transitions"]): + loaded_state = loaded_transition.state + assert loaded_state.id == original_transition["id"] + assert loaded_state.name == original_transition["name"] + assert loaded_transition.timestamp == datetime.fromisoformat(original_transition["timestamp"]) + assert loaded_transition.snapshot == original_transition.get("snapshot") + + original_properties = original_transition["properties"] + assert loaded_state.model_dump(exclude_none=True, exclude={"id", "previous_state", "next_states"}) == original_properties + + if "actions" in original_transition: + assert len(loaded_state._actions) == len(original_transition["actions"]) + for loaded_action, original_action in zip(loaded_state._actions, original_transition["actions"]): + assert loaded_action.request.__class__.__name__ == loaded_state.action_type().__name__ if loaded_state.action_type() else "Content" + if loaded_action.response: + assert loaded_action.response.trigger == original_action["response"]["trigger"] + if loaded_action.usage: + assert loaded_action.usage.completion_cost == original_action["usage"]["completion_cost"] + assert loaded_action.usage.completion_tokens == original_action["usage"]["completion_tokens"] + assert loaded_action.usage.prompt_tokens == original_action["usage"]["prompt_tokens"] + + for loaded_transition, original_transition in zip(trajectory.transitions, original_data["transitions"]): + if original_transition.get("previous_state_id") is not None: + assert loaded_transition.state.previous_state.id == original_transition["previous_state_id"] + else: + assert loaded_transition.state.previous_state is None + + assert trajectory._info == original_data.get("info", {}) diff --git a/tests/test_transition_rules.py b/tests/test_transition_rules.py index a8af1359..b19e7c67 100644 --- a/tests/test_transition_rules.py +++ b/tests/test_transition_rules.py @@ -90,12 +90,19 @@ def test_transition_rules_serialization_deserialization(): # Check if the internal _source_trigger_index is rebuilt correctly assert deserialized_rules._source_trigger_index == rules._source_trigger_index - json_data = json.dumps( - rules.model_dump(exclude_none=True, exclude_unset=True), indent=2 - ) + data = rules.model_dump(exclude_none=True, exclude_unset=True) + assert ( - json_data - == """{ + data + == { + "global_params": { + "model": "gpt-4o" + }, + "state_params": { + "MockStateB": { + "model": "claude-3.5-sonnet" + } + }, "initial_state": "MockStateA", "transition_rules": [ { @@ -116,17 +123,8 @@ def test_transition_rules_serialization_deserialization(): "source": "MockStateB", "dest": "Rejected" } - ], - "global_params": { - "model": "gpt-4o" - }, - "state_params": { - "MockStateB": { - "model": "claude-3.5-sonnet" - } - } -}""" - ) + ] +}) def test_find_transition_rule(): @@ -156,10 +154,14 @@ def test_find_transition_rule(): assert len(not_found_rules) == 0 -def test_next_state(): +def test_next_transition_rule(): rules = TransitionRules( - initial_state=MockStateA, transition_rules=[ + TransitionRule( + source=Pending, + dest=MockStateA, + trigger="init", + ), TransitionRule( source=MockStateA, dest=MockStateB, @@ -176,53 +178,34 @@ def test_next_state(): ) # Test successful transition - source_state = MockStateA(value=5) - action_response = source_state._execute_action("to_b") - next_state = rules.next_state(source_state, action_response.trigger, {"value": 5}) - assert isinstance(next_state, MockStateB) - assert next_state.name == "MockStateB" - assert next_state.model == "claude-3.5-sonnet" + source_state = MockStateA(id=1, value=5) + next_transition_rule = rules.get_next_rule(source_state, "to_b", {"value": 5}) + assert isinstance(next_transition_rule, TransitionRule) + assert next_transition_rule.source == MockStateA + assert next_transition_rule.dest == MockStateB + assert next_transition_rule.trigger == "to_b" + assert next_transition_rule.required_fields == {"value"} # Test transition with missing required fields - action_response = source_state._execute_action("to_b") - next_state = rules.next_state(source_state, action_response.trigger, {}) - assert next_state is None + next_transition_rule = rules.get_next_rule(source_state, "to_b", {}) + assert next_transition_rule is None # Test transition to Finished state - source_state = MockStateB(default_name="TestB") - action_response = source_state._execute_action("finish") - next_state = rules.next_state(source_state, action_response.trigger, {}) - assert isinstance(next_state, Finished) + source_state = MockStateB(id=2, default_name="TestB") + next_transition_rule = rules.get_next_rule(source_state, "finish", {}) + assert next_transition_rule is not None + assert next_transition_rule.source == MockStateB + assert next_transition_rule.dest == Finished + assert next_transition_rule.trigger == "finish" # Test transition to Rejected state - action_response = source_state._execute_action("reject") - next_state = rules.next_state( - source_state, action_response.trigger, {"message": "Custom rejection message"} - ) - assert isinstance(next_state, Rejected) - assert next_state.message == "Custom rejection message" - - -def test_initial_state_creation(): - rules = TransitionRules( - initial_state=MockStateA, - transition_rules=[], - global_params={"model": "gpt-4o"}, - state_params={ - MockStateB: {"model": "claude-3.5-sonnet"}, - }, - ) - - initial_state = rules.create_initial_state() - print(initial_state) - assert isinstance(initial_state, MockStateA) - assert initial_state.model == "gpt-4o" - - # Test overriding with custom data - custom_initial_state = rules.create_initial_state(value=20, model="custom-model") - assert isinstance(custom_initial_state, MockStateA) - assert custom_initial_state.model == "custom-model" - assert custom_initial_state.value == 20 + next_transition_rule = rules.get_next_rule( + source_state, "reject", {} + ) + assert next_transition_rule is not None + assert next_transition_rule.source == MockStateB + assert next_transition_rule.dest == Rejected + assert next_transition_rule.trigger == "reject" if __name__ == "__main__": diff --git a/tests/trajectories/django__django_16379.json b/tests/trajectories/django__django_16379.json index 2526c4b2..e8cf4cda 100644 --- a/tests/trajectories/django__django_16379.json +++ b/tests/trajectories/django__django_16379.json @@ -1,8 +1,23 @@ { "name": "MoatlessTools", "transition_rules": { - "initial_state": "SearchCode", + "global_params": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "max_prompt_file_tokens": 8000 + }, + "state_params": { + "PlanToCode": { + "max_tokens_in_edit_prompt": 500 + } + }, "transition_rules": [ + { + "trigger": "init", + "source": "Pending", + "dest": "SearchCode" + }, { "trigger": "did_search", "source": "SearchCode", @@ -39,10 +54,10 @@ "dest": "EditCode", "required_fields": [ "span_id", - "start_line", - "instructions", + "end_line", "file_path", - "end_line" + "instructions", + "start_line" ] }, { @@ -50,9 +65,9 @@ "source": "PlanToCode", "dest": "ClarifyCodeChange", "required_fields": [ + "file_path", "instructions", - "span_id", - "file_path" + "span_id" ] }, { @@ -71,10 +86,10 @@ "dest": "EditCode", "required_fields": [ "span_id", - "start_line", - "instructions", + "end_line", "file_path", - "end_line" + "instructions", + "start_line" ] }, { @@ -92,132 +107,145 @@ "source": "EditCode", "dest": "PlanToCode" } - ], - "global_params": { - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "max_prompt_file_tokens": 8000 - }, - "state_params": { - "PlanToCode": { - "max_tokens_in_edit_prompt": 500 - } - } + ] }, "workspace": { "repository": { "type": "git", "repo_path": "/tmp/repos/swe-bench_django__django", "git_repo_url": "https://github.com/swe-bench/django__django.git", - "branch": null, "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" }, "file_context": { "max_tokens": 4000, - "files": [] + "files": [ + { + "spans": [ + { + "span_id": "FileBasedCache.has_key" + }, + { + "span_id": "FileBasedCache._is_expired" + }, + { + "span_id": "FileBasedCache.__init__" + }, + { + "span_id": "FileBasedCache" + } + ], + "file_path": "django/core/cache/backends/filebased.py" + } + ] }, "code_index": { - "index_name": "django__django_16379" + "index_name": null } }, "initial_message": "FileBasedCache has_key is susceptible to race conditions\nDescription\n\t \n\t\t(last modified by Marti Raudsepp)\n\t \nI received the exception from Django's cache framework:\nFileNotFoundError: [Errno 2] No such file or directory: '/app/var/cache/d729e4cf4ba88cba5a0f48e0396ec48a.djcache'\n[...]\n File \"django/core/cache/backends/base.py\", line 229, in get_or_set\n\tself.add(key, default, timeout=timeout, version=version)\n File \"django/core/cache/backends/filebased.py\", line 26, in add\n\tif self.has_key(key, version):\n File \"django/core/cache/backends/filebased.py\", line 94, in has_key\n\twith open(fname, \"rb\") as f:\nThe code is:\n\tdef has_key(self, key, version=None):\n\t\tfname = self._key_to_file(key, version)\n\t\tif os.path.exists(fname):\n\t\t\twith open(fname, \"rb\") as f:\n\t\t\t\treturn not self._is_expired(f)\n\t\treturn False\nBetween the exists() check and open(), it's possible for the file to be deleted. In fact, the _is_expired() method itself deletes the file if it finds it to be expired. So if many threads race to read an expired cache at once, it's not that unlikely to hit this window.\n", + "current_transition_id": 7, "transitions": [ + { + "id": 0, + "name": "Pending", + "timestamp": "2024-08-05T17:51:25.951848", + "properties": { + "temperature": 0.0, + "max_tokens": 1000, + "include_message_history": false + } + }, { "id": 1, - "timestamp": "2024-07-31T10:35:01.498016", - "state": { - "name": "SearchCode", - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "include_message_history": true, - "max_search_results": 25, - "max_retries_with_any_file_context": 3, - "provide_initial_context": true, - "initial_context_tokens": 4000, - "initial_search_results": 50, - "initial_context_spans_per_file": 5 - }, + "name": "SearchCode", + "timestamp": "2024-08-05T17:51:25.952973", "snapshot": { "repository": { - "branch": null, "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" }, "file_context": { "files": [] } }, + "previous_state_id": 0, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": true, + "max_search_results": 25, + "max_retries_with_any_file_context": 3, + "provide_initial_context": true, + "initial_context_tokens": 4000, + "initial_search_results": 50, + "initial_context_spans_per_file": 5, + "support_test_files": false + }, "actions": [ { - "action": { - "scratch_pad": "The issue is with the `has_key` method in the `FileBasedCache` class, which is susceptible to race conditions. Specifically, the problem occurs between the `os.path.exists(fname)` check and the `open(fname, \"rb\")` statement. The `_is_expired()` method, which is called within `has_key`, can delete the file, leading to a `FileNotFoundError`.\n\nRelevant files to search for the `has_key` method include:\n- django/core/cache/backends/filebased.py\n- django/core/cache/backends/base.py\n\nI will search for the `has_key` method in these files to locate the relevant code.", + "request": { + "scratch_pad": "The issue is related to a race condition in the `has_key` method of the `FileBasedCache` class. Specifically, the problem arises between the `os.path.exists()` check and the subsequent `open()` call, where the file could be deleted by another thread. I need to search for the `has_key` method in the `filebased.py` file to understand its implementation and see if there are any potential fixes or related methods that might also contribute to the race condition. Additionally, I will also search for the `_is_expired` method to see how it might be interacting with the cache and potentially causing file deletions. \n\nI will start by searching for the `has_key` method and the `_is_expired` method in the `django/core/cache/backends/filebased.py` file. \n\n", "search_requests": [ { "file_pattern": "django/core/cache/backends/filebased.py", - "function_names": [ - "has_key" - ] - }, - { - "file_pattern": "django/core/cache/backends/base.py", - "function_names": [ - "has_key" - ] - } - ] - }, - "output": { - "ranked_spans": [ - { - "file_path": "django/core/cache/backends/filebased.py", - "span_id": "FileBasedCache.has_key", - "rank": 0, - "tokens": 47 - }, - { - "file_path": "django/core/cache/backends/base.py", - "span_id": "BaseCache.has_key", - "rank": 0, - "tokens": 53 - } - ], - "search_requests": [ - { - "file_pattern": "django/core/cache/backends/filebased.py", - "query": null, - "code_snippet": null, "class_names": [], "function_names": [ "has_key" ] }, { - "file_pattern": "django/core/cache/backends/base.py", - "query": null, - "code_snippet": null, + "file_pattern": "django/core/cache/backends/filebased.py", "class_names": [], "function_names": [ - "has_key" + "_is_expired" ] } ], "complete": false }, - "completion_cost": 0.008787, - "input_tokens": 2009, - "output_tokens": 184 + "response": { + "trigger": "did_search", + "output": { + "ranked_spans": [ + { + "file_path": "django/core/cache/backends/filebased.py", + "span_id": "FileBasedCache.has_key", + "rank": 0, + "tokens": 47 + }, + { + "file_path": "django/core/cache/backends/filebased.py", + "span_id": "FileBasedCache._is_expired", + "rank": 0, + "tokens": 86 + } + ] + } + }, + "usage": { + "completion_cost": 0.0005272499999999999, + "completion_tokens": 224, + "prompt_tokens": 2619 + } } ] }, { "id": 2, - "timestamp": "2024-07-31T10:35:07.200342", - "parent_id": 1, - "state": { - "name": "IdentifyCode", - "model": "no_llm", - "temperature": 0.2, + "name": "IdentifyCode", + "timestamp": "2024-08-05T17:51:31.792308", + "snapshot": { + "repository": { + "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" + }, + "file_context": { + "files": [] + } + }, + "previous_state_id": 1, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, "max_tokens": 2000, "include_message_history": false, "ranked_spans": [ @@ -228,61 +256,52 @@ "tokens": 47 }, { - "file_path": "django/core/cache/backends/base.py", - "span_id": "BaseCache.has_key", + "file_path": "django/core/cache/backends/filebased.py", + "span_id": "FileBasedCache._is_expired", "rank": 0, - "tokens": 53 + "tokens": 86 } ], - "expand_context": true, + "expand_context": false, "max_prompt_file_tokens": 8000 }, - "snapshot": { - "repository": { - "branch": null, - "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" - }, - "file_context": { - "files": [] - } - }, "actions": [ { - "action": { - "scratch_pad": "The issue reported is a race condition in the `FileBasedCache.has_key` method in the `django/core/cache/backends/filebased.py` file. The problem occurs between the `os.path.exists(fname)` check and the `open(fname, \"rb\")` operation, where the file might be deleted by another thread, leading to a `FileNotFoundError`.\n\nThe search results provided include the `FileBasedCache` class and its `has_key` method, which is directly relevant to the issue. The relevant code span for `has_key` in `django/core/cache/backends/filebased.py` is:\n\n```python\ndef has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n if os.path.exists(fname):\n with open(fname, \"rb\") as f:\n return not self._is_expired(f)\n return False\n```\n\nThis code span matches the code mentioned in the issue report and is the source of the race condition. Therefore, the search task is complete, and the identified code span is relevant to the reported issue.", + "request": { + "scratch_pad": "The reported issue indicates that the `has_key` method in the `FileBasedCache` class is susceptible to race conditions due to the possibility of a file being deleted between the `os.path.exists(fname)` check and the subsequent `open(fname, 'rb')` call. This can lead to a `FileNotFoundError` when multiple threads attempt to access expired cache entries simultaneously.\n\nThe `has_key` method is directly relevant as it is the source of the issue. The `_is_expired` method is also relevant because it can delete the file if it finds it to be expired, which exacerbates the race condition problem. \n\nBoth methods are in the same file, `django/core/cache/backends/filebased.py`, and are critical to understanding and addressing the race condition issue. \n\nThe identified spans are:\n1. `FileBasedCache.has_key` - This span contains the implementation of the `has_key` method, which is at the center of the issue.\n2. `FileBasedCache._is_expired` - This span contains the implementation of the `_is_expired` method, which is invoked within `has_key` and can lead to file deletion, thus contributing to the race condition.\n\nSince both spans are relevant and provide necessary context for addressing the issue, the search task can be considered complete.", "identified_spans": [ { "file_path": "django/core/cache/backends/filebased.py", "span_ids": [ "FileBasedCache.has_key" ] + }, + { + "file_path": "django/core/cache/backends/filebased.py", + "span_ids": [ + "FileBasedCache._is_expired" + ] } ] }, - "output": {}, - "completion_cost": 0.007863, - "input_tokens": 1276, - "output_tokens": 269 + "response": { + "trigger": "finish", + "output": {} + }, + "usage": { + "completion_cost": 0.00035715, + "completion_tokens": 332, + "prompt_tokens": 1053 + } } ] }, { "id": 3, - "timestamp": "2024-07-31T10:35:12.007797", - "parent_id": 2, - "state": { - "name": "DecideRelevance", - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "include_message_history": false, - "expand_context": true, - "finish_after_relevant_count": 2, - "max_prompt_file_tokens": 8000 - }, + "name": "DecideRelevance", + "timestamp": "2024-08-05T17:51:38.745281", "snapshot": { "repository": { - "branch": null, "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" }, "file_context": { @@ -291,6 +310,9 @@ "spans": [ { "span_id": "FileBasedCache.has_key" + }, + { + "span_id": "FileBasedCache._is_expired" } ], "show_all_spans": false, @@ -299,40 +321,41 @@ ] } }, + "previous_state_id": 2, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": false, + "expand_context": false, + "finish_after_relevant_count": 2, + "max_prompt_file_tokens": 8000 + }, "actions": [ { - "action": { - "scratch_pad": "The issue is about a race condition in the `has_key` method of the `FileBasedCache` class. The provided file context includes the `has_key` method and its implementation. This method is indeed the one that needs to be modified to fix the race condition issue.", + "request": { + "scratch_pad": "The reported issue indicates that the `has_key` method in the `FileBasedCache` class is susceptible to race conditions due to the possibility of a file being deleted between the `os.path.exists(fname)` check and the subsequent `open(fname, 'rb')` call. This can lead to a `FileNotFoundError` when multiple threads attempt to access expired cache entries simultaneously.\n\nThe `has_key` method is directly relevant as it is the source of the issue. The `_is_expired` method is also relevant because it can delete the file if it finds it to be expired, which exacerbates the race condition problem. \n\nBoth methods are in the same file, `django/core/cache/backends/filebased.py`, and are critical to understanding and addressing the race condition issue. \n\nThe identified spans are:\n1. `FileBasedCache.has_key` - This span contains the implementation of the `has_key` method, which is at the center of the issue.\n2. `FileBasedCache._is_expired` - This span contains the implementation of the `_is_expired` method, which is invoked within `has_key` and can lead to file deletion, thus contributing to the race condition.\n\nSince both spans are relevant and provide necessary context for addressing the issue, the search task can be considered complete.", "relevant": true, "complete": true }, - "output": {}, - "completion_cost": 0.00387, - "input_tokens": 945, - "output_tokens": 69 + "response": { + "trigger": "finish", + "output": {} + }, + "usage": { + "completion_cost": 0.00035894999999999997, + "completion_tokens": 289, + "prompt_tokens": 1237 + } } ] }, { "id": 4, - "timestamp": "2024-07-31T10:35:13.543776", - "parent_id": 3, - "state": { - "name": "PlanToCode", - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "include_message_history": true, - "max_iterations": 8, - "max_prompt_file_tokens": 8000, - "max_tokens_in_edit_prompt": 500, - "expand_context_with_related_spans": true, - "allow_hallucinated_spans": false, - "finish_on_review": false - }, + "name": "PlanToCode", + "timestamp": "2024-08-05T17:51:45.173389", "snapshot": { "repository": { - "branch": null, "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" }, "file_context": { @@ -343,10 +366,7 @@ "span_id": "FileBasedCache.has_key" }, { - "span_id": "FileBasedCache" - }, - { - "span_id": "FileBasedCache.__init__" + "span_id": "FileBasedCache._is_expired" } ], "show_all_spans": false, @@ -355,53 +375,51 @@ ] } }, + "previous_state_id": 3, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": true, + "max_prompt_file_tokens": 8000, + "max_tokens_in_edit_prompt": 500, + "expand_context_with_related_spans": true, + "allow_hallucinated_spans": false, + "finish_on_review": false + }, "actions": [ { - "action": { - "scratch_pad": "To fix the race condition in the has_key method, we need to handle the case where the file might be deleted between the os.path.exists check and the open call. We can do this by wrapping the open call in a try-except block to catch the FileNotFoundError exception and return False if the file is not found.", + "request": { + "scratch_pad": "To prevent race conditions in the `has_key` method, I will modify the code to use a try-except block while attempting to open the file after checking its existence. This will ensure that if the file is deleted between the existence check and the open call, it will handle the exception gracefully instead of raising a `FileNotFoundError`.", "action": "modify", - "instructions": "Wrap the open call in the has_key method in a try-except block to catch FileNotFoundError and return False if the file is not found.", + "instructions": "Update the `has_key` method to handle potential race conditions when checking for file existence and opening the file.", "file_path": "django/core/cache/backends/filebased.py", "span_id": "FileBasedCache.has_key" }, - "output": { - "instructions": "Wrap the open call in the has_key method in a try-except block to catch FileNotFoundError and return False if the file is not found.", - "file_path": "django/core/cache/backends/filebased.py", - "span_id": "FileBasedCache.has_key", - "start_line": 91, - "end_line": 96 + "response": { + "trigger": "edit_code", + "output": { + "instructions": "Update the `has_key` method to handle potential race conditions when checking for file existence and opening the file.", + "file_path": "django/core/cache/backends/filebased.py", + "span_id": "FileBasedCache.has_key", + "start_line": 91, + "end_line": 96 + } }, - "completion_cost": 0.005079, - "input_tokens": 1048, - "output_tokens": 129 + "usage": { + "completion_cost": 0.0002433, + "completion_tokens": 123, + "prompt_tokens": 1130 + } } ] }, { "id": 5, - "timestamp": "2024-07-31T10:35:16.400579", - "parent_id": 4, - "state": { - "name": "EditCode", - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "include_message_history": true, - "max_iterations": 8, - "instructions": "Wrap the open call in the has_key method in a try-except block to catch FileNotFoundError and return False if the file is not found.", - "file_path": "django/core/cache/backends/filebased.py", - "span_id": "FileBasedCache.has_key", - "start_line": 91, - "end_line": 96, - "show_initial_message": true, - "show_file_context": true, - "verify": true, - "chain_of_thought": false, - "max_prompt_file_tokens": 8000 - }, + "name": "EditCode", + "timestamp": "2024-08-05T17:51:47.828433", "snapshot": { "repository": { - "branch": null, "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" }, "file_context": { @@ -412,10 +430,13 @@ "span_id": "FileBasedCache.has_key" }, { - "span_id": "FileBasedCache" + "span_id": "FileBasedCache._is_expired" }, { "span_id": "FileBasedCache.__init__" + }, + { + "span_id": "FileBasedCache" } ], "show_all_spans": false, @@ -424,46 +445,51 @@ ] } }, + "previous_state_id": 4, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": false, + "instructions": "Update the `has_key` method to handle potential race conditions when checking for file existence and opening the file.", + "file_path": "django/core/cache/backends/filebased.py", + "span_id": "FileBasedCache.has_key", + "start_line": 91, + "end_line": 96, + "show_initial_message": true, + "show_file_context": true, + "verify": true, + "chain_of_thought": false, + "max_prompt_file_tokens": 8000 + }, "actions": [ { - "action": { - "content": "\n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n if os.path.exists(fname):\n try:\n with open(fname, \"rb\") as f:\n return not self._is_expired(f)\n except FileNotFoundError:\n return False\n return False\n" + "request": { + "content": "\n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n try:\n with open(fname, \"rb\") as f:\n return not self._is_expired(f)\n except FileNotFoundError:\n return False\n" }, - "output": { - "message": "Applied the change to django/core/cache/backends/filebased.py.", - "diff": "--- django/core/cache/backends/filebased.py\n+++ django/core/cache/backends/filebased.py\n@@ -91,8 +91,11 @@\n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n if os.path.exists(fname):\n- with open(fname, \"rb\") as f:\n- return not self._is_expired(f)\n+ try:\n+ with open(fname, \"rb\") as f:\n+ return not self._is_expired(f)\n+ except FileNotFoundError:\n+ return False\n return False\n \n def _cull(self):\n", - "verification_errors": [] + "response": { + "trigger": "finish", + "output": { + "message": "Applied the change to django/core/cache/backends/filebased.py.", + "diff": "--- django/core/cache/backends/filebased.py\n+++ django/core/cache/backends/filebased.py\n@@ -90,10 +90,11 @@\n \n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n- if os.path.exists(fname):\n+ try:\n with open(fname, \"rb\") as f:\n return not self._is_expired(f)\n- return False\n+ except FileNotFoundError:\n+ return False\n \n def _cull(self):\n \"\"\"\n", + "verification_errors": [] + } }, - "completion_cost": 0.0041730000000000005, - "input_tokens": 1031, - "output_tokens": 72 + "usage": { + "completion_cost": 0.00020625, + "completion_tokens": 61, + "prompt_tokens": 1131 + } } ] }, { "id": 6, - "timestamp": "2024-07-31T10:35:19.745656", - "parent_id": 5, - "state": { - "name": "PlanToCode", - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "include_message_history": true, - "max_iterations": 8, - "message": "Applied the change to django/core/cache/backends/filebased.py.", - "diff": "--- django/core/cache/backends/filebased.py\n+++ django/core/cache/backends/filebased.py\n@@ -91,8 +91,11 @@\n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n if os.path.exists(fname):\n- with open(fname, \"rb\") as f:\n- return not self._is_expired(f)\n+ try:\n+ with open(fname, \"rb\") as f:\n+ return not self._is_expired(f)\n+ except FileNotFoundError:\n+ return False\n return False\n \n def _cull(self):\n", - "verification_errors": [], - "max_prompt_file_tokens": 8000, - "max_tokens_in_edit_prompt": 500, - "expand_context_with_related_spans": true, - "allow_hallucinated_spans": false, - "finish_on_review": false - }, + "name": "PlanToCode", + "timestamp": "2024-08-05T17:51:53.177732", "snapshot": { "repository": { - "branch": null, - "commit": "1d0fa848e084cad62d0bb6bde3b51e4862558e57" + "commit": "8c1c528fea1644fc1cf597288d047cd110b3febd" }, "file_context": { "files": [ @@ -473,10 +499,13 @@ "span_id": "FileBasedCache.has_key" }, { - "span_id": "FileBasedCache" + "span_id": "FileBasedCache._is_expired" }, { "span_id": "FileBasedCache.__init__" + }, + { + "span_id": "FileBasedCache" } ], "show_all_spans": false, @@ -485,40 +514,49 @@ ] } }, + "previous_state_id": 5, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": true, + "message": "Applied the change to django/core/cache/backends/filebased.py.", + "diff": "--- django/core/cache/backends/filebased.py\n+++ django/core/cache/backends/filebased.py\n@@ -90,10 +90,11 @@\n \n def has_key(self, key, version=None):\n fname = self._key_to_file(key, version)\n- if os.path.exists(fname):\n+ try:\n with open(fname, \"rb\") as f:\n return not self._is_expired(f)\n- return False\n+ except FileNotFoundError:\n+ return False\n \n def _cull(self):\n \"\"\"\n", + "verification_errors": [], + "max_prompt_file_tokens": 8000, + "max_tokens_in_edit_prompt": 500, + "expand_context_with_related_spans": true, + "allow_hallucinated_spans": false, + "finish_on_review": false + }, "actions": [ { - "action": { - "scratch_pad": "The race condition in the has_key method has been fixed by wrapping the open call in a try-except block to catch FileNotFoundError and return False if the file is not found. This ensures that if the file is deleted between the os.path.exists check and the open call, the method will handle it gracefully.", + "request": { + "scratch_pad": "The change has been applied successfully to the `has_key` method. The method now handles the potential `FileNotFoundError` exception when trying to open the file, thus preventing race conditions. I will now finish the task as the issue is resolved.", "action": "finish", - "finish": "The race condition in the has_key method has been fixed successfully." + "finish": "The `has_key` method has been updated to handle race conditions by catching the `FileNotFoundError` when attempting to open a file that may have been deleted after the existence check." }, - "output": { - "message": "The race condition in the has_key method has been fixed successfully." + "response": { + "trigger": "finish", + "output": { + "message": "The `has_key` method has been updated to handle race conditions by catching the `FileNotFoundError` when attempting to open a file that may have been deleted after the existence check." + } }, - "completion_cost": 0.005406, - "input_tokens": 1362, - "output_tokens": 88 + "usage": { + "completion_cost": 0.00027105, + "completion_tokens": 100, + "prompt_tokens": 1407 + } } ] }, { "id": 7, - "timestamp": "2024-07-31T10:35:24.038496", - "parent_id": 6, - "state": { - "name": "Finished", - "message": "The race condition in the has_key method has been fixed successfully.", - "output": { - "model": "no_llm", - "temperature": 0.2, - "max_tokens": 2000, - "max_prompt_file_tokens": 8000 - } - }, + "name": "Finished", + "timestamp": "2024-08-05T17:51:55.454796", "snapshot": { "repository": { - "branch": null, - "commit": "b6ee6f08d54caff6fe24cf4f0bd9420db61faf92" + "commit": "8c1c528fea1644fc1cf597288d047cd110b3febd" }, "file_context": { "files": [ @@ -528,10 +566,13 @@ "span_id": "FileBasedCache.has_key" }, { - "span_id": "FileBasedCache" + "span_id": "FileBasedCache._is_expired" }, { "span_id": "FileBasedCache.__init__" + }, + { + "span_id": "FileBasedCache" } ], "show_all_spans": false, @@ -540,7 +581,14 @@ ] } }, - "actions": [] + "previous_state_id": 6, + "properties": { + "model": "gpt-4o-mini-2024-07-18", + "temperature": 0.5, + "max_tokens": 2000, + "include_message_history": false, + "message": "The `has_key` method has been updated to handle race conditions by catching the `FileNotFoundError` when attempting to open a file that may have been deleted after the existence check." + } } ], "info": {}