From 3fcb80085490c380f035d678f91ac0de996d1137 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Tue, 6 Aug 2024 08:03:53 +0200 Subject: [PATCH] Fix tests --- moatless/find/search.py | 10 +++++--- moatless/repository/git.py | 4 ++- moatless/trajectory.py | 45 ++++++++++++++++++++-------------- tests/find/test_search.py | 20 +++++++-------- tests/test_loop.py | 10 ++++---- tests/test_state.py | 10 +++----- tests/test_transition_rules.py | 30 +++++++++++------------ 7 files changed, 67 insertions(+), 62 deletions(-) diff --git a/moatless/find/search.py b/moatless/find/search.py index f759ede5..49b019db 100644 --- a/moatless/find/search.py +++ b/moatless/find/search.py @@ -265,7 +265,11 @@ def has_search_attributes(self): ] ) - + @model_validator(mode='after') + def validate_search_requests(self): + if not self.has_search_attributes: + raise ValueError("A search request must have at least one attribute set.") + return self class Search(ActionRequest): """Take action to search for code, identify found and finish up.""" @@ -287,9 +291,7 @@ class Search(ActionRequest): def validate_search_requests(self): if not self.complete: if not self.search_requests: - raise ValidationError("If 'complete' is False, at least one search request must exist.") - if not any(request.has_search_attributes() for request in self.search_requests): - raise ValidationError("At least one search request must have at least one attribute set.") + raise ValueError("At least one search request must exist.") return self diff --git a/moatless/repository/git.py b/moatless/repository/git.py index 7f423ac9..96332cf0 100644 --- a/moatless/repository/git.py +++ b/moatless/repository/git.py @@ -70,7 +70,7 @@ def dict(self): "type": "git", "repo_path": self._repo_path, "git_repo_url": self._repo_url, - "commit": self._current_commit, + "commit": self._initial_commit, } def snapshot(self) -> dict: @@ -96,6 +96,8 @@ def commit(self, file_path: str | None = None): self._repo.index.commit(commit_message) self._current_commit = self._repo.head.commit.hexsha + logger.info(f"Committed changes to git with message '{commit_message}' and commit hash '{self._current_commit}'") + def commit_message(self, file_path: str | None = None) -> str: if file_path: diff = self._repo.git.diff("HEAD", file_path) diff --git a/moatless/trajectory.py b/moatless/trajectory.py index 005b2ad0..bdebe4fe 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -52,15 +52,20 @@ class Trajectory: def __init__( self, name: str, + workspace: Workspace, initial_message: Optional[str] = None, persist_path: Optional[str] = None, - workspace: Optional[Workspace] = None, transition_rules: Optional[TransitionRules] = None, ): self._name = name self._persist_path = persist_path self._initial_message = initial_message self._workspace = workspace + + # Workaround to set to keep the current initial workspace state when loading an existing trajectory. + # TODO: Remove this when we have a better way to handle this. + self._initial_workspace_state = self._workspace.dict() + self._transition_rules = transition_rules self._current_transition_id = 0 @@ -78,12 +83,14 @@ def load(cls, file_path: str): else: transition_rules = None + workspace = Workspace.from_dict(data["workspace"]) trajectory = cls( name=data["name"], initial_message=data["initial_message"], transition_rules=transition_rules, + workspace=workspace ) - trajectory._workspace = Workspace.from_dict(data["workspace"]) + trajectory._transitions = {} trajectory._current_transition_id = data.get("current_transition_id", 0) @@ -93,13 +100,6 @@ def load(cls, file_path: str): state_data["id"] = t["id"] state = state_class.model_validate(state_data) - if t.get("snapshot"): - try: - trajectory.restore_from_snapshot(t["snapshot"]) - except Exception as e: - logger.exception(f"Error restoring from snapshot for state {state.name}") - raise e - state._workspace = trajectory._workspace state._initial_message = trajectory._initial_message state._actions = [] @@ -133,15 +133,19 @@ def load(cls, file_path: str): for t in data["transitions"]: try: current_state = trajectory._transitions[t["id"]].state - if t["previous_state_id"] is not None: + if t.get("previous_state_id") is not None: current_state.previous_state = trajectory._transitions.get(t["previous_state_id"]).state except KeyError as e: - logger.error(f"Missing key {e}, existing keys: {trajectory._transitions.keys()}") + logger.exception(f"Missing key {e}, existing keys: {trajectory._transitions.keys()}") + raise trajectory._info = data.get("info", {}) logger.info(f"Loaded trajectory {trajectory._name} with {len(trajectory._transitions)} transitions") + current_state = trajectory._transitions.get(trajectory._current_transition_id) + trajectory.restore_from_snapshot(current_state) + return trajectory @property @@ -171,15 +175,18 @@ def set_current_state(self, state: AgenticState): def get_current_state(self) -> AgenticState: return self._transitions.get(self._current_transition_id).state - def restore_from_snapshot(self, snapshot: dict): - if not self._workspace: - logger.info("No workspace to restore from snapshot") + def restore_from_snapshot(self, state: TrajectoryState): + if not state.snapshot: + logger.info(f"restore_from_snapshot(state: {state.id}:{state.name}) No snapshot found") + return + + logger.info(f"restore_from_snapshot(starte: {state.id}:{state.name}) Restoring from snapshot") - if "repository" in snapshot: - self._workspace.file_repo.restore_from_snapshot(snapshot["repository"]) + if state.snapshot.get("repository"): + self._workspace.file_repo.restore_from_snapshot(state.snapshot["repository"]) - if "file_context" in snapshot: - self._workspace.file_context.restore_from_snapshot(snapshot["file_context"]) + if state.snapshot.get("file_context"): + self._workspace.file_context.restore_from_snapshot(state.snapshot["file_context"]) def save_state(self, state: AgenticState): if state.id in self._transitions: @@ -226,7 +233,7 @@ def to_dict(self): ) if self._transition_rules else None, - "workspace": self._workspace.dict() if self._workspace else None, + "workspace": self._initial_workspace_state, "initial_message": self._initial_message, "current_transition_id": self._current_transition_id, "transitions": [t.model_dump(exclude_none=True) for t in self.transitions], diff --git a/tests/find/test_search.py b/tests/find/test_search.py index 2ced48b9..e2ff1d18 100644 --- a/tests/find/test_search.py +++ b/tests/find/test_search.py @@ -3,6 +3,7 @@ from moatless.types import ActionResponse from moatless.workspace import Workspace from unittest.mock import Mock, MagicMock +from pydantic import ValidationError class TestSearchCode: @pytest.fixture @@ -30,17 +31,14 @@ def test_execute_action_complete(self, search_code): assert response.trigger == "finish" assert response.output["message"] == "Search complete" - def test_execute_action_without_search_attributes(self, search_code): - action = Search( - scratch_pad="Invalid search", - search_requests=[SearchRequest()] - ) - - response = search_code._execute_action(action) - - assert isinstance(response, ActionResponse) - assert response.trigger == "retry" - assert "You must provide at least one the search attributes" in response.retry_message + def test_validate_search_without_search_attributes(self): + with pytest.raises(ValidationError) as excinfo: + Search( + scratch_pad="Invalid search", + search_requests=[] + ) + + assert "At least one search request must exist." in str(excinfo.value) def test_execute_action_with_search_results(self, search_code): mock_code_index = MagicMock() diff --git a/tests/test_loop.py b/tests/test_loop.py index 20c362a6..0db0d9e7 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -74,9 +74,9 @@ def test_loop_run_until_finished(mock_workspace, test_transition_rules): response = loop.run("initial message") assert response.status == "finished" - assert len(loop._state_history) == 3, f"Expected 3 states, got {[state.name for state in loop._state_history.values()]}" - assert loop._state_history[1].initial_message == "initial message" - assert isinstance(loop._state_history[2], Finished) + assert loop.state_count() == 3, f"Expected 3 states, got {[state.state.name for state in loop._trajectory.transitions()]}" + assert loop._initial_message == "initial message" + assert isinstance(loop._trajectory.transitions[2].state, Finished) def test_loop_run_until_rejected(mock_workspace, test_transition_rules): loop = AgenticLoop(test_transition_rules, mock_workspace) @@ -88,7 +88,7 @@ def mock_next_action() : response = loop.run("initial message") assert response.status == "rejected" - assert len(loop._state_history) == 3 # Pending -> TestState -> Rejected + assert loop.state_count() == 3 # Pending -> TestState -> Rejected def test_loop_max_transitions(mock_workspace, test_transition_rules): loop = AgenticLoop(test_transition_rules, mock_workspace, max_transitions=3) @@ -98,7 +98,7 @@ def test_loop_max_transitions(mock_workspace, test_transition_rules): assert response.status == "rejected" assert response.message == "Max transitions exceeded." - assert len(loop._state_history) == 4, f"Expected 4 states, got {[state.name for state in loop._state_history.values()]}" + assert loop.state_count() == 4, f"Expected 4 states, got {[t.state.name for t in loop._trajectory.transitions]}" @pytest.mark.api_keys_required def test_rerun_save_and_load_trajectory(): diff --git a/tests/test_state.py b/tests/test_state.py index 397cc965..a32c0021 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -41,9 +41,9 @@ def test_agentic_state_create_file_context(test_state): def test_agentic_state_model_dump(test_state): - dump = test_state.model_dump() - assert "name" in dump - assert dump["name"] == "ConcreteAgenticState" + dump = test_state.model_dump(exclude_none=True) + assert dump == {'id': 1, 'include_message_history': False, 'max_tokens': 1000, 'temperature': 0.0} + def test_agentic_state_equality_same_state(): state1 = ConcreteAgenticState(id=1, temperature=0.5, max_tokens=500) @@ -160,6 +160,4 @@ def test_finished_state_creation_and_dump(): assert dumped_state["id"] == 1 assert dumped_state["message"] == message - assert dumped_state["output"] == output - assert dumped_state["name"] == "Finished" - assert dumped_state["previous_state_id"] is None \ No newline at end of file + assert dumped_state["output"] == output \ No newline at end of file diff --git a/tests/test_transition_rules.py b/tests/test_transition_rules.py index 8c32f011..b19e7c67 100644 --- a/tests/test_transition_rules.py +++ b/tests/test_transition_rules.py @@ -90,12 +90,19 @@ def test_transition_rules_serialization_deserialization(): # Check if the internal _source_trigger_index is rebuilt correctly assert deserialized_rules._source_trigger_index == rules._source_trigger_index - json_data = json.dumps( - rules.model_dump(exclude_none=True, exclude_unset=True), indent=2 - ) + data = rules.model_dump(exclude_none=True, exclude_unset=True) + assert ( - json_data - == """{ + data + == { + "global_params": { + "model": "gpt-4o" + }, + "state_params": { + "MockStateB": { + "model": "claude-3.5-sonnet" + } + }, "initial_state": "MockStateA", "transition_rules": [ { @@ -116,17 +123,8 @@ def test_transition_rules_serialization_deserialization(): "source": "MockStateB", "dest": "Rejected" } - ], - "global_params": { - "model": "gpt-4o" - }, - "state_params": { - "MockStateB": { - "model": "claude-3.5-sonnet" - } - } -}""" - ) + ] +}) def test_find_transition_rule():