diff --git a/moatless/benchmark/evaluation.py b/moatless/benchmark/evaluation.py index c13ba71d..1d0b9980 100644 --- a/moatless/benchmark/evaluation.py +++ b/moatless/benchmark/evaluation.py @@ -198,7 +198,8 @@ def run_single_instance( split="test", ) -> dict: instance = load_instance(instance_id, dataset, split) - return self._evaluate_instance(instance) + trajectory = self._evaluate_instance(instance) + return to_result(instance, trajectory, self.report) def _evaluate_instance(self, instance: dict, retry: bool = False) -> Trajectory: instance_id = instance["instance_id"] diff --git a/moatless/loop.py b/moatless/loop.py index ccc0d459..5d72753c 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -190,6 +190,9 @@ def run(self, message: Optional[str] = None) -> Response: self._initial_message = message self._trajectory._initial_message = message + if not isinstance(self._current_state, Pending): + self._trajectory.update_workspace_to_current_state() + while not self.is_finished(): self._execute_state_until_transition() @@ -353,17 +356,6 @@ def _set_current_state(self, state: AgenticState): self._current_state = state self._trajectory.set_current_state(state) - def revert_to_state(self, state_id: int) -> AgenticState: - state = self._trajectory.get_state(state_id) - if state: - self.log_info(f"Reverting to state {state_id}") - self._set_current_state(state.state) - self.workspace.restore_from_snapshot(state.snapshot) - return state.state - else: - logger.warning(f"Tried to revert to state {state_id} but it does not exist.") - raise ValueError(f"Could not revert to state {state_id} as it does not exist.") - def transition_to(self, new_state: AgenticState) -> AgenticState: self.log_info(f"Transitioning from {self.state.name} to {new_state.name}") diff --git a/moatless/trajectory.py b/moatless/trajectory.py index 8b5ac962..59ab437b 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -145,9 +145,6 @@ def load(cls, file_path: str): logger.info(f"Loaded trajectory {trajectory._name} with {len(trajectory._transitions)} transitions") - current_state = trajectory._transitions.get(trajectory._current_transition_id) - trajectory.restore_from_snapshot(current_state) - return trajectory @property @@ -167,7 +164,7 @@ def transition_rules(self) -> TransitionRules: return self._transition_rules @property - def workspace(self) -> dict[str, Any] | None: + def workspace(self) -> Workspace: return self._workspace @property @@ -181,6 +178,9 @@ def set_current_state(self, state: AgenticState): def get_current_state(self) -> AgenticState: return self._transitions.get(self._current_transition_id).state + def update_workspace_to_current_state(self): + self.restore_from_snapshot(self._transitions[self._current_transition_id]) + def restore_from_snapshot(self, state: TrajectoryState): if not state.snapshot: logger.info(f"restore_from_snapshot(state: {state.id}:{state.name}) No snapshot found") diff --git a/tests/benchmark/test_evaluation.py b/tests/benchmark/test_evaluation.py index 28dea29c..7c87d0f4 100644 --- a/tests/benchmark/test_evaluation.py +++ b/tests/benchmark/test_evaluation.py @@ -72,4 +72,11 @@ def test_run_single_evaluation_mcts(): detailed_report=True, ) - evaluation.run_single_instance("django__django-16379") \ No newline at end of file + result = evaluation.run_single_instance("django__django-16379") + + assert result["instance_id"] == "django__django-16379" + assert result["status"] == "edited" + assert result["edited"] + assert result["identified"] + assert result["found_in_search"] + assert result["file_identified"]