Skip to content

Commit

Permalink
Set workspace from snapshot only when running the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Aug 6, 2024
1 parent cf85f9d commit 52773ed
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
3 changes: 2 additions & 1 deletion moatless/benchmark/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def run_single_instance(
split="test",
) -> dict:
instance = load_instance(instance_id, dataset, split)
return self._evaluate_instance(instance)
trajectory = self._evaluate_instance(instance)
return to_result(instance, trajectory, self.report)

def _evaluate_instance(self, instance: dict, retry: bool = False) -> Trajectory:
instance_id = instance["instance_id"]
Expand Down
14 changes: 3 additions & 11 deletions moatless/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def run(self, message: Optional[str] = None) -> Response:
self._initial_message = message
self._trajectory._initial_message = message

if not isinstance(self._current_state, Pending):
self._trajectory.update_workspace_to_current_state()

while not self.is_finished():
self._execute_state_until_transition()

Expand Down Expand Up @@ -353,17 +356,6 @@ def _set_current_state(self, state: AgenticState):
self._current_state = state
self._trajectory.set_current_state(state)

def revert_to_state(self, state_id: int) -> AgenticState:
state = self._trajectory.get_state(state_id)
if state:
self.log_info(f"Reverting to state {state_id}")
self._set_current_state(state.state)
self.workspace.restore_from_snapshot(state.snapshot)
return state.state
else:
logger.warning(f"Tried to revert to state {state_id} but it does not exist.")
raise ValueError(f"Could not revert to state {state_id} as it does not exist.")

def transition_to(self, new_state: AgenticState) -> AgenticState:
self.log_info(f"Transitioning from {self.state.name} to {new_state.name}")

Expand Down
8 changes: 4 additions & 4 deletions moatless/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def load(cls, file_path: str):

logger.info(f"Loaded trajectory {trajectory._name} with {len(trajectory._transitions)} transitions")

current_state = trajectory._transitions.get(trajectory._current_transition_id)
trajectory.restore_from_snapshot(current_state)

return trajectory

@property
Expand All @@ -167,7 +164,7 @@ def transition_rules(self) -> TransitionRules:
return self._transition_rules

@property
def workspace(self) -> dict[str, Any] | None:
def workspace(self) -> Workspace:
return self._workspace

@property
Expand All @@ -181,6 +178,9 @@ def set_current_state(self, state: AgenticState):
def get_current_state(self) -> AgenticState:
return self._transitions.get(self._current_transition_id).state

def update_workspace_to_current_state(self):
self.restore_from_snapshot(self._transitions[self._current_transition_id])

def restore_from_snapshot(self, state: TrajectoryState):
if not state.snapshot:
logger.info(f"restore_from_snapshot(state: {state.id}:{state.name}) No snapshot found")
Expand Down
9 changes: 8 additions & 1 deletion tests/benchmark/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,11 @@ def test_run_single_evaluation_mcts():
detailed_report=True,
)

evaluation.run_single_instance("django__django-16379")
result = evaluation.run_single_instance("django__django-16379")

assert result["instance_id"] == "django__django-16379"
assert result["status"] == "edited"
assert result["edited"]
assert result["identified"]
assert result["found_in_search"]
assert result["file_identified"]

0 comments on commit 52773ed

Please sign in to comment.