Skip to content

Commit

Permalink
Support save and read trajectories in a tree structure
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Jul 30, 2024
1 parent 580b6f8 commit bd95dc8
Show file tree
Hide file tree
Showing 32 changed files with 771 additions and 408 deletions.
2 changes: 1 addition & 1 deletion moatless/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from moatless.loop import AgenticLoop, Transitions
from moatless.loop import AgenticLoop, TransitionRules
from moatless.repository import FileRepository
from moatless.workspace import Workspace
22 changes: 11 additions & 11 deletions moatless/benchmark/claude_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from moatless.find.decide import DecideRelevance
from moatless.find.identify import IdentifyCode
from moatless.find.search_v2 import SearchCode
from moatless.loop import Transition
from moatless.loop import TransitionRule
from moatless.state import Finished, Rejected
from moatless.transitions import (
search_and_code_transitions,
Expand Down Expand Up @@ -177,8 +177,8 @@ def evaluate_search():
},
initial_state=SearchCode,
transitions=[
Transition(source=SearchCode, dest=Finished, trigger="did_search"),
Transition(source=SearchCode, dest=Finished, trigger="finish"),
TransitionRule(source=SearchCode, dest=Finished, trigger="did_search"),
TransitionRule(source=SearchCode, dest=Finished, trigger="finish"),
],
)

Expand Down Expand Up @@ -298,19 +298,19 @@ def evaluate_plan(previous_trajectory_dir: Optional[str] = None):
},
initial_state=SearchCode,
transitions=[
Transition(source=SearchCode, dest=IdentifyCode, trigger="did_search"),
Transition(source=IdentifyCode, dest=SearchCode, trigger="search"),
Transition(source=IdentifyCode, dest=DecideRelevance, trigger="finish"),
Transition(source=DecideRelevance, dest=SearchCode, trigger="search"),
Transition(
TransitionRule(source=SearchCode, dest=IdentifyCode, trigger="did_search"),
TransitionRule(source=IdentifyCode, dest=SearchCode, trigger="search"),
TransitionRule(source=IdentifyCode, dest=DecideRelevance, trigger="finish"),
TransitionRule(source=DecideRelevance, dest=SearchCode, trigger="search"),
TransitionRule(
source=DecideRelevance,
dest=PlanToCode,
trigger="finish",
exclude_fields={"message"},
),
Transition(source=PlanToCode, dest=Finished, trigger="edit_code"),
Transition(source=PlanToCode, dest=Rejected, trigger="finish"),
Transition(source=PlanToCode, dest=Rejected, trigger="reject"),
TransitionRule(source=PlanToCode, dest=Finished, trigger="edit_code"),
TransitionRule(source=PlanToCode, dest=Rejected, trigger="finish"),
TransitionRule(source=PlanToCode, dest=Rejected, trigger="reject"),
],
)

Expand Down
6 changes: 3 additions & 3 deletions moatless/benchmark/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict:
repo_dir = setup_swebench_repo(instance)
persist_dir = os.path.join(self.index_store_dir, get_repo_dir_name(instance_id))
workspace = Workspace.from_dirs(
repo_dir=repo_dir, index_dir=persist_dir, max_file_context_tokens=16000
repo_path=repo_dir, index_dir=persist_dir, max_file_context_tokens=16000
)

problem_statement = instance["problem_statement"]
Expand All @@ -226,7 +226,7 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict:
)

loop = AgenticLoop(
transitions=self.transitions,
transition_rules=self.transitions,
workspace=workspace,
metadata=metadata,
mocked_actions=previous_actions,
Expand All @@ -252,7 +252,7 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict:
logging.exception(f"Error in evaluation of {instance['instance_id']} ")

info["duration"] = time.time() - start_time
info["total_cost"] = loop.trajectory.total_cost()
info["total_cost"] = loop.total_cost()

workspace.save()

Expand Down
7 changes: 3 additions & 4 deletions moatless/benchmark/swebench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,8 @@ def create_workspace(
index_store_dir, get_repo_dir_name(instance["instance_id"])
)
return Workspace.from_dirs(
# TODO: Enable this to use GitRepository
# git_repo_url=repo_url,
# commit=instance["base_commit"],
repo_dir=repo_dir,
git_repo_url=repo_url,
commit=instance["base_commit"],
repo_path=repo_dir,
index_dir=persist_dir,
)
3 changes: 2 additions & 1 deletion moatless/edit/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ def __init__(
max_iterations: int = 8,
show_file_context: bool = True,
verify: bool = True,
include_message_history=True,
chain_of_thought: bool = False,
max_prompt_file_tokens: int = 4000,
**data,
):
assert "model" in data
super().__init__(
include_message_history=True,
include_message_history=include_message_history,
show_initial_message=show_initial_message,
max_iterations=max_iterations,
show_file_context=show_file_context,
Expand Down
7 changes: 4 additions & 3 deletions moatless/edit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
message: Optional[str] = None,
diff: Optional[str] = None,
lint_messages: list[VerificationError] | None = None,
include_message_history=True,
max_prompt_file_tokens: int = 4000,
max_tokens_in_edit_prompt: int = 500,
max_iterations: int = 8,
Expand All @@ -116,7 +117,7 @@ def __init__(
message=message,
diff=diff,
lint_messages=lint_messages,
include_message_history=True,
include_message_history=include_message_history,
max_prompt_file_tokens=max_prompt_file_tokens,
max_tokens_in_edit_prompt=max_tokens_in_edit_prompt,
max_iterations=max_iterations,
Expand All @@ -131,7 +132,7 @@ def init(self):

if (
self.expand_context_with_related_spans
and len(self.loop.trajectory.get_transitions(self.name)) == 0
and len(self.loop.get_transitions(self.name)) == 0
):
self.file_context.expand_context_with_related_spans(
max_tokens=self.max_prompt_file_tokens
Expand Down Expand Up @@ -314,7 +315,7 @@ def messages(self) -> list[Message]:
else:
content = ""

previous_transitions = self.loop.trajectory.get_transitions(str(self))
previous_transitions = self.loop.get_transitions(str(self))

for transition in previous_transitions:
new_message = transition.state.to_message()
Expand Down
7 changes: 4 additions & 3 deletions moatless/edit/plan_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,14 @@ def __init__(
diff: Optional[str] = None,
lint_messages: list[VerificationError] | None = None,
max_iterations: int = 5,
include_message_history=True,
**data,
):
super().__init__(
message=message,
diff=diff,
lint_messages=lint_messages,
include_message_history=True,
include_message_history=include_message_history,
max_iterations=max_iterations,
**data,
)
Expand All @@ -113,7 +114,7 @@ def init(self):

if (
self.expand_context_with_related_spans
and len(self.loop.trajectory.get_transitions(self.name)) == 0
and len(self.loop.get_transitions(self.name)) == 0
):
self.file_context.expand_context_with_related_spans(max_tokens=4000)

Expand Down Expand Up @@ -261,7 +262,7 @@ def messages(self) -> list[Message]:

content = self.loop.trajectory.initial_message or ""

previous_transitions = self.loop.trajectory.get_transitions(str(self))
previous_transitions = self.loop.get_transitions(str(self))

for transition in previous_transitions:
new_message = transition.state.to_message()
Expand Down
11 changes: 4 additions & 7 deletions moatless/edit/review.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
CODER_SYSTEM_PROMPT,
SELECT_SPAN_SYSTEM_PROMPT,
CODER_FINAL_SYSTEM_PROMPT,
INCLUDE_SPAN_SYSTEM_PROMPT,
)
from moatless.state import AgenticState
from moatless.types import (
Expand Down Expand Up @@ -142,12 +141,13 @@ def __init__(
max_prompt_file_tokens: int = 4000,
max_tokens_in_edit_prompt: int = 500,
max_iterations: int = 8,
include_message_history=True,
**data,
):
super().__init__(
message=message,
diff=diff,
include_message_history=True,
include_message_history=include_message_history,
max_prompt_file_tokens=max_prompt_file_tokens,
max_tokens_in_edit_prompt=max_tokens_in_edit_prompt,
max_iterations=max_iterations,
Expand Down Expand Up @@ -373,10 +373,7 @@ def _request_for_change(self, rfc: ApplyChange) -> ActionResponse:

def system_prompt(self) -> str:
return (
CODER_SYSTEM_PROMPT
+ SELECT_SPAN_SYSTEM_PROMPT
+ INCLUDE_SPAN_SYSTEM_PROMPT
+ CODER_FINAL_SYSTEM_PROMPT
CODER_SYSTEM_PROMPT + SELECT_SPAN_SYSTEM_PROMPT + CODER_FINAL_SYSTEM_PROMPT
)

def to_message(self) -> str:
Expand Down Expand Up @@ -407,7 +404,7 @@ def messages(self) -> list[Message]:
else:
content = ""

previous_transitions = self.loop.trajectory.get_transitions(str(self))
previous_transitions = self.loop.get_transitions(str(self))

for transition in previous_transitions:
new_message = transition.state.to_message()
Expand Down
16 changes: 11 additions & 5 deletions moatless/file_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,18 +434,20 @@ def from_json(cls, repo_dir: str, json_data: str):
def from_dict(cls, repo_dir: str, data: Dict):
repo = FileRepository(repo_dir)
instance = cls(max_tokens=data.get("max_tokens", 4000), repo=repo)
for file_data in data.get("files", []):
instance.load_files_from_dict(data.get("files", []))
return instance

def load_files_from_dict(self, files: list[dict]):
for file_data in files:
file_path = file_data["file_path"]
show_all_spans = file_data.get("show_all_spans", False)
spans = [ContextSpan(**span) for span in file_data.get("spans", [])]
instance._file_context[file_path] = ContextFile(
file=instance._repo.get_file(file_path),
self._file_context[file_path] = ContextFile(
file=self._repo.get_file(file_path),
spans=spans,
show_all_spans=show_all_spans,
)

return instance

def model_dump(self, **kwargs):
if "exclude_none" not in kwargs:
kwargs["exclude_none"] = True
Expand All @@ -461,6 +463,10 @@ def snapshot(self):
del dict["max_tokens"]
return dict

def restore_from_snapshot(self, snapshot: dict):
self._file_context = {}
self.load_files_from_dict(snapshot.get("files", []))

def to_files_with_spans(self) -> List[FileWithSpans]:
return [
FileWithSpans(file_path=file_path, span_ids=list(file.span_ids))
Expand Down
7 changes: 4 additions & 3 deletions moatless/find/decide.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class DecideRelevance(AgenticState):
def __init__(
self,
expand_context: bool = True,
include_message_history=False,
finish_after_relevant_count: int = 2,
max_prompt_file_tokens: int = 4000,
**data,
Expand All @@ -86,7 +87,7 @@ def __init__(
expand_context=expand_context,
finish_after_relevant_count=finish_after_relevant_count,
max_prompt_file_tokens=max_prompt_file_tokens,
include_message_history=False,
include_message_history=include_message_history,
**data,
)

Expand All @@ -107,7 +108,7 @@ def handle_action(self, action: Decision) -> ActionResponse:

def _relevant_count(self) -> int:
relevant_count = 0
previous_transitions = self.loop.trajectory.get_transitions(str(self))
previous_transitions = self.loop.get_transitions(str(self))
for transition in previous_transitions:
for previous_action in transition.actions:
if (
Expand All @@ -124,7 +125,7 @@ def system_prompt(self) -> str:
return MAYBE_FINISH_SYSTEM_PROMPT

def _last_scratch_pad(self):
previous_searches = self.loop.trajectory.get_transitions("SearchCode")
previous_searches = self.loop.get_transitions("SearchCode")
logger.info(f"Previous searches: {len(previous_searches)}")
if previous_searches and previous_searches[-1].actions:
last_search = previous_searches[-1].actions[-1].action
Expand Down
48 changes: 16 additions & 32 deletions moatless/find/identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,42 +65,39 @@ class Identify(ActionRequest):


class IdentifyCode(AgenticState):
file_pattern: Optional[str]
query: Optional[str]
code_snippet: Optional[str]
class_name: Optional[str]
function_name: Optional[str]
ranked_spans: Optional[list[RankedFileSpan]] = Field(
default=None, description="Ranked file spans from the search results."
)

ranked_spans: Optional[list[RankedFileSpan]]
expand_context: bool = Field(
default=False,
description="Whether to expand the search result with relevant code spans .",
)

expand_context: bool
max_prompt_file_tokens: int = 4000
max_prompt_file_tokens: int = Field(
default=4000,
description="The maximum number of tokens to include in the prompt.",
)

def __init__(
self,
ranked_spans: list[RankedFileSpan],
file_pattern: Optional[str] = None,
query: Optional[str] = None,
code_snippet: Optional[str] = None,
class_name: Optional[str] = None,
function_name: Optional[str] = None,
expand_context: bool = True,
include_message_history: bool = False,
max_prompt_file_tokens: int = 4000,
**data,
):
super().__init__(
file_pattern=file_pattern,
query=query,
code_snippet=code_snippet,
class_name=class_name,
function_name=function_name,
ranked_spans=ranked_spans,
include_message_history=False,
include_message_history=include_message_history,
expand_context=expand_context,
max_prompt_file_tokens=max_prompt_file_tokens,
**data,
)

def model_dump(self, **kwargs):
return super().model_dump(**kwargs)

def handle_action(self, action: Identify) -> ActionResponse:
if action.identified_spans:
self.file_context.add_files_with_spans(action.identified_spans)
Expand All @@ -114,19 +111,6 @@ def handle_action(self, action: Identify) -> ActionResponse:
else:
logger.info("No spans identified.")

message = "I searched using the following parameters:\n"

if self.file_pattern:
message += f"\n* **File Pattern:** `{self.file_pattern}`"
if self.query:
message += f"\n* **Query:** `{self.query}`"
if self.code_snippet:
message += f"\n* **Code Snippet:** `{self.code_snippet}`"
if self.class_name:
message += f"\n* **Class Name:** `{self.class_name}`"
if self.function_name:
message += f"\n* **Function Name:** `{self.function_name}`"

message = f"The search returned {len(self.ranked_spans)} results. But unfortunately, I didn’t find any of the search results relevant to the query."

message += "\n\n"
Expand Down
Loading

0 comments on commit bd95dc8

Please sign in to comment.