Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Aug 6, 2024
1 parent cc53c96 commit 3fcb800
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 62 deletions.
10 changes: 6 additions & 4 deletions moatless/find/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ def has_search_attributes(self):
]
)


@model_validator(mode='after')
def validate_search_requests(self):
if not self.has_search_attributes:
raise ValueError("A search request must have at least one attribute set.")
return self

class Search(ActionRequest):
"""Take action to search for code, identify found and finish up."""
Expand All @@ -287,9 +291,7 @@ class Search(ActionRequest):
def validate_search_requests(self):
if not self.complete:
if not self.search_requests:
raise ValidationError("If 'complete' is False, at least one search request must exist.")
if not any(request.has_search_attributes() for request in self.search_requests):
raise ValidationError("At least one search request must have at least one attribute set.")
raise ValueError("At least one search request must exist.")
return self


Expand Down
4 changes: 3 additions & 1 deletion moatless/repository/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dict(self):
"type": "git",
"repo_path": self._repo_path,
"git_repo_url": self._repo_url,
"commit": self._current_commit,
"commit": self._initial_commit,
}

def snapshot(self) -> dict:
Expand All @@ -96,6 +96,8 @@ def commit(self, file_path: str | None = None):
self._repo.index.commit(commit_message)
self._current_commit = self._repo.head.commit.hexsha

logger.info(f"Committed changes to git with message '{commit_message}' and commit hash '{self._current_commit}'")

def commit_message(self, file_path: str | None = None) -> str:
if file_path:
diff = self._repo.git.diff("HEAD", file_path)
Expand Down
45 changes: 26 additions & 19 deletions moatless/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,20 @@ class Trajectory:
def __init__(
self,
name: str,
workspace: Workspace,
initial_message: Optional[str] = None,
persist_path: Optional[str] = None,
workspace: Optional[Workspace] = None,
transition_rules: Optional[TransitionRules] = None,
):
self._name = name
self._persist_path = persist_path
self._initial_message = initial_message
self._workspace = workspace

# Workaround to set to keep the current initial workspace state when loading an existing trajectory.
# TODO: Remove this when we have a better way to handle this.
self._initial_workspace_state = self._workspace.dict()

self._transition_rules = transition_rules

self._current_transition_id = 0
Expand All @@ -78,12 +83,14 @@ def load(cls, file_path: str):
else:
transition_rules = None

workspace = Workspace.from_dict(data["workspace"])
trajectory = cls(
name=data["name"],
initial_message=data["initial_message"],
transition_rules=transition_rules,
workspace=workspace
)
trajectory._workspace = Workspace.from_dict(data["workspace"])

trajectory._transitions = {}
trajectory._current_transition_id = data.get("current_transition_id", 0)

Expand All @@ -93,13 +100,6 @@ def load(cls, file_path: str):
state_data["id"] = t["id"]
state = state_class.model_validate(state_data)

if t.get("snapshot"):
try:
trajectory.restore_from_snapshot(t["snapshot"])
except Exception as e:
logger.exception(f"Error restoring from snapshot for state {state.name}")
raise e

state._workspace = trajectory._workspace
state._initial_message = trajectory._initial_message
state._actions = []
Expand Down Expand Up @@ -133,15 +133,19 @@ def load(cls, file_path: str):
for t in data["transitions"]:
try:
current_state = trajectory._transitions[t["id"]].state
if t["previous_state_id"] is not None:
if t.get("previous_state_id") is not None:
current_state.previous_state = trajectory._transitions.get(t["previous_state_id"]).state
except KeyError as e:
logger.error(f"Missing key {e}, existing keys: {trajectory._transitions.keys()}")
logger.exception(f"Missing key {e}, existing keys: {trajectory._transitions.keys()}")
raise

trajectory._info = data.get("info", {})

logger.info(f"Loaded trajectory {trajectory._name} with {len(trajectory._transitions)} transitions")

current_state = trajectory._transitions.get(trajectory._current_transition_id)
trajectory.restore_from_snapshot(current_state)

return trajectory

@property
Expand Down Expand Up @@ -171,15 +175,18 @@ def set_current_state(self, state: AgenticState):
def get_current_state(self) -> AgenticState:
return self._transitions.get(self._current_transition_id).state

def restore_from_snapshot(self, snapshot: dict):
if not self._workspace:
logger.info("No workspace to restore from snapshot")
def restore_from_snapshot(self, state: TrajectoryState):
if not state.snapshot:
logger.info(f"restore_from_snapshot(state: {state.id}:{state.name}) No snapshot found")
return

logger.info(f"restore_from_snapshot(starte: {state.id}:{state.name}) Restoring from snapshot")

if "repository" in snapshot:
self._workspace.file_repo.restore_from_snapshot(snapshot["repository"])
if state.snapshot.get("repository"):
self._workspace.file_repo.restore_from_snapshot(state.snapshot["repository"])

if "file_context" in snapshot:
self._workspace.file_context.restore_from_snapshot(snapshot["file_context"])
if state.snapshot.get("file_context"):
self._workspace.file_context.restore_from_snapshot(state.snapshot["file_context"])

def save_state(self, state: AgenticState):
if state.id in self._transitions:
Expand Down Expand Up @@ -226,7 +233,7 @@ def to_dict(self):
)
if self._transition_rules
else None,
"workspace": self._workspace.dict() if self._workspace else None,
"workspace": self._initial_workspace_state,
"initial_message": self._initial_message,
"current_transition_id": self._current_transition_id,
"transitions": [t.model_dump(exclude_none=True) for t in self.transitions],
Expand Down
20 changes: 9 additions & 11 deletions tests/find/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from moatless.types import ActionResponse
from moatless.workspace import Workspace
from unittest.mock import Mock, MagicMock
from pydantic import ValidationError

class TestSearchCode:
@pytest.fixture
Expand Down Expand Up @@ -30,17 +31,14 @@ def test_execute_action_complete(self, search_code):
assert response.trigger == "finish"
assert response.output["message"] == "Search complete"

def test_execute_action_without_search_attributes(self, search_code):
action = Search(
scratch_pad="Invalid search",
search_requests=[SearchRequest()]
)

response = search_code._execute_action(action)

assert isinstance(response, ActionResponse)
assert response.trigger == "retry"
assert "You must provide at least one the search attributes" in response.retry_message
def test_validate_search_without_search_attributes(self):
with pytest.raises(ValidationError) as excinfo:
Search(
scratch_pad="Invalid search",
search_requests=[]
)

assert "At least one search request must exist." in str(excinfo.value)

def test_execute_action_with_search_results(self, search_code):
mock_code_index = MagicMock()
Expand Down
10 changes: 5 additions & 5 deletions tests/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def test_loop_run_until_finished(mock_workspace, test_transition_rules):
response = loop.run("initial message")

assert response.status == "finished"
assert len(loop._state_history) == 3, f"Expected 3 states, got {[state.name for state in loop._state_history.values()]}"
assert loop._state_history[1].initial_message == "initial message"
assert isinstance(loop._state_history[2], Finished)
assert loop.state_count() == 3, f"Expected 3 states, got {[state.state.name for state in loop._trajectory.transitions()]}"
assert loop._initial_message == "initial message"
assert isinstance(loop._trajectory.transitions[2].state, Finished)

def test_loop_run_until_rejected(mock_workspace, test_transition_rules):
loop = AgenticLoop(test_transition_rules, mock_workspace)
Expand All @@ -88,7 +88,7 @@ def mock_next_action() :
response = loop.run("initial message")

assert response.status == "rejected"
assert len(loop._state_history) == 3 # Pending -> TestState -> Rejected
assert loop.state_count() == 3 # Pending -> TestState -> Rejected

def test_loop_max_transitions(mock_workspace, test_transition_rules):
loop = AgenticLoop(test_transition_rules, mock_workspace, max_transitions=3)
Expand All @@ -98,7 +98,7 @@ def test_loop_max_transitions(mock_workspace, test_transition_rules):

assert response.status == "rejected"
assert response.message == "Max transitions exceeded."
assert len(loop._state_history) == 4, f"Expected 4 states, got {[state.name for state in loop._state_history.values()]}"
assert loop.state_count() == 4, f"Expected 4 states, got {[t.state.name for t in loop._trajectory.transitions]}"

@pytest.mark.api_keys_required
def test_rerun_save_and_load_trajectory():
Expand Down
10 changes: 4 additions & 6 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def test_agentic_state_create_file_context(test_state):


def test_agentic_state_model_dump(test_state):
dump = test_state.model_dump()
assert "name" in dump
assert dump["name"] == "ConcreteAgenticState"
dump = test_state.model_dump(exclude_none=True)
assert dump == {'id': 1, 'include_message_history': False, 'max_tokens': 1000, 'temperature': 0.0}


def test_agentic_state_equality_same_state():
state1 = ConcreteAgenticState(id=1, temperature=0.5, max_tokens=500)
Expand Down Expand Up @@ -160,6 +160,4 @@ def test_finished_state_creation_and_dump():

assert dumped_state["id"] == 1
assert dumped_state["message"] == message
assert dumped_state["output"] == output
assert dumped_state["name"] == "Finished"
assert dumped_state["previous_state_id"] is None
assert dumped_state["output"] == output
30 changes: 14 additions & 16 deletions tests/test_transition_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,19 @@ def test_transition_rules_serialization_deserialization():
# Check if the internal _source_trigger_index is rebuilt correctly
assert deserialized_rules._source_trigger_index == rules._source_trigger_index

json_data = json.dumps(
rules.model_dump(exclude_none=True, exclude_unset=True), indent=2
)
data = rules.model_dump(exclude_none=True, exclude_unset=True)

assert (
json_data
== """{
data
== {
"global_params": {
"model": "gpt-4o"
},
"state_params": {
"MockStateB": {
"model": "claude-3.5-sonnet"
}
},
"initial_state": "MockStateA",
"transition_rules": [
{
Expand All @@ -116,17 +123,8 @@ def test_transition_rules_serialization_deserialization():
"source": "MockStateB",
"dest": "Rejected"
}
],
"global_params": {
"model": "gpt-4o"
},
"state_params": {
"MockStateB": {
"model": "claude-3.5-sonnet"
}
}
}"""
)
]
})


def test_find_transition_rule():
Expand Down

0 comments on commit 3fcb800

Please sign in to comment.