diff --git a/moatless/benchmark/claude_evaluation.py b/moatless/benchmark/claude_evaluation.py index 6e674c51..45c99b1c 100644 --- a/moatless/benchmark/claude_evaluation.py +++ b/moatless/benchmark/claude_evaluation.py @@ -10,8 +10,8 @@ from moatless.edit.plan import PlanToCode from moatless.find.decide import DecideRelevance from moatless.find.identify import IdentifyCode -from moatless.find.search_v2 import SearchCode -from moatless.loop import TransitionRule +from moatless.find.search import SearchCode +from moatless.transition_rules import TransitionRule from moatless.state import Finished, Rejected from moatless.transitions import ( search_and_code_transitions, diff --git a/moatless/find/__init__.py b/moatless/find/__init__.py index 5597d5a0..e6b45a9e 100644 --- a/moatless/find/__init__.py +++ b/moatless/find/__init__.py @@ -1,3 +1,3 @@ -from moatless.find.search_v2 import SearchCode +from moatless.find.search import SearchCode from moatless.find.identify import IdentifyCode from moatless.find.decide import DecideRelevance diff --git a/moatless/find/search.py b/moatless/find/search.py index b53e77be..6c169560 100644 --- a/moatless/find/search.py +++ b/moatless/find/search.py @@ -3,9 +3,10 @@ from typing import Optional import instructor -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field -from moatless.file_context import FileContext, RankedFileSpan +from moatless.file_context import RankedFileSpan +from moatless.index.types import SearchCodeHit from moatless.state import ActionResponse, AgenticState from moatless.types import ( ActionRequest, @@ -32,12 +33,16 @@ 3. Consider the Necessary Search Parameters: Determine if specific file types, directories, function or class names or code patterns are mentioned in the issue. If you can you should always try to specify the search parameters as accurately as possible. +You can do more than one search request at the same time so you can try different search parameters to cover all possible relevant code. 4. Ensure At Least One Search Parameter: -Make sure that at least one of query, code_snippet, class_name, or function_name is provided in each search request. +Make sure that at least one of query, code_snippet, class_name, or function_name is provided. 5. Formulate the Search function: Set at least one of the search paramaters `query`, `code_snippet`, `class_name` or `function_name`. + + + """ @@ -63,7 +68,7 @@ AI Assistant: functions.Search({ - class_name: "PaymentProcessor" + class_names: ["PaymentProcessor"] ) User: @@ -71,7 +76,7 @@ AI Assistant: functions.Search({ - function_name: "generate_report", + function_names: ["generate_report"], file_pattern: "**/reports/**/*.py" ) @@ -80,8 +85,8 @@ AI Assistant: functions.Search({ - class_name: "HTMLParser", - function_name: "extract_data" + class_names: ["HTMLParser"], + function_names: ["extract_data"] ) User: @@ -120,14 +125,14 @@ There's a bug in the PaymentProcessor class where transactions sometimes fail to log correctly, resulting in missing transaction records. Search parameters: - class_name: "PaymentProcessor" + class_names: ["PaymentProcessor"] User: The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue. Search parameters: - function_name: "generate_report", + function_names: ["generate_report"] file_pattern: "**/reports/**/*.py" @@ -135,8 +140,8 @@ The extract_data function in HTMLParser throws an "AttributeError: 'NoneType' object has no attribute 'find'" error when parsing certain HTML pages. Search parameters: - class_name: "HTMLParser", - function_name: "extract_data" + class_names: ["HTMLParser"] + function_names: ["extract_data"] User: @@ -225,13 +230,7 @@ ) -class Search(ActionRequest): - """Take action to search for code, identify found and finish up.""" - - scratch_pad: str = Field( - description="Your thoughts on what search parameters to set." - ) - +class SearchRequest(BaseModel): file_pattern: Optional[str] = Field( default=None, description="A glob pattern to filter search results to specific file types or directories. ", @@ -255,10 +254,6 @@ class Search(ActionRequest): default=[], description="Specific function names to include in the search." ) - complete: Optional[bool] = Field( - default=False, description="Set to true when the search is complete." - ) - def has_search_attributes(self): return any( [ @@ -270,15 +265,27 @@ def has_search_attributes(self): ) -class ActionCallWithContext(BaseModel): - action: ActionRequest - file_context: FileContext - message: Optional[str] = None +class Search(ActionRequest): + """Take action to search for code, identify found and finish up.""" + + scratch_pad: str = Field( + description="Scratch pad for the search. Use this to write down your thoughts on how to approach the search." + ) + + search_requests: list[SearchRequest] = Field( + default=[], + description="List of search requests.", + ) + + complete: Optional[bool] = Field( + default=False, description="Set to true when the search is complete." + ) - model_config = ConfigDict(arbitrary_types_allowed=True) + def has_search_attributes(self): + return all([search.has_search_attributes() for search in self.search_requests]) -class LegacySearchCode(AgenticState): +class SearchCode(AgenticState): message: Optional[str] = Field( None, description="Message to the search", @@ -306,11 +313,11 @@ def __init__( message: Optional[str] = None, max_search_results: int = 25, max_retries_with_any_file_context: int = 3, + include_message_history: bool = True, provide_initial_context: bool = True, initial_context_tokens: int = 4000, initial_search_results: int = 50, initial_context_spans_per_file: int = 5, - include_message_history=True, **data, ): super().__init__( @@ -334,49 +341,38 @@ def handle_action(self, action: Search) -> ActionResponse: }, ) - if not action.has_search_attributes(): - return self._retry( - "You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true." - ) - - dup_error = self._duplicate_search(action) - if dup_error: - message = dup_error - - if action.file_pattern: - message += f"\n* **File Pattern:** `{action.file_pattern}`" - if action.query: - message += f"\n* **Query:** `{action.query}`" - if action.code_snippet: - message += f"\n* **Code Snippet:** `{action.code_snippet}`" - if action.class_names: - message += f"\n* **Class Name:** `{action.class_names}`" - if action.function_names: - message += f"\n* **Function Name:** `{action.function_names}`" - - message += "\n\nPlease provide a new search parameters." - return self._retry(message) + if isinstance(action, Search): + if not action.has_search_attributes(): + return self._retry( + "You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true." + ) - if ( - not self.support_test_files - and action.file_pattern - and is_test_pattern(action.file_pattern) - ): - return self._retry("It's not possible to search for test files.") - - search_result = self.workspace.code_index.search( - file_pattern=action.file_pattern, - query=action.query, - code_snippet=action.code_snippet, - class_names=action.class_names, - function_names=action.function_names, - max_results=self.max_search_results, - ) + for request in action.search_requests: + if ( + not self.support_test_files + and request.file_pattern + and is_test_pattern(request.file_pattern) + ): + return self._retry("It's not possible to search for test files.") + + message = "" + search_result: list[SearchCodeHit] = [] + for search_request in action.search_requests: + search_response = self.workspace.code_index.search( + file_pattern=search_request.file_pattern, + query=search_request.query, + code_snippet=search_request.code_snippet, + class_names=search_request.class_names, + function_names=search_request.function_names, + max_results=int(self.max_search_results / len(action.search_requests)), + ) + search_result.extend(search_response.hits) + message += "\n" + search_response.message - logger.info(f"Found {len(search_result.hits)} hits.") + logger.info(f"Found {len(search_result)} hits.") ranked_spans = [] - for hit in search_result.hits: + for hit in search_result: for span in hit.spans: ranked_spans.append( RankedFileSpan( @@ -389,38 +385,17 @@ def handle_action(self, action: Search) -> ActionResponse: if len(ranked_spans) == 0: logger.info("No search results found. Will retry.") - - message = "I searched using the following parameters:\n" - - if action.file_pattern: - message += f"\n* **File Pattern:** `{action.file_pattern}`" - if action.query: - message += f"\n* **Query:** `{action.query}`" - if action.code_snippet: - message += f"\n* **Code Snippet:** `{action.code_snippet}`" - if action.class_names: - message += f"\n* **Class Names:** `{','.join(action.class_names)}`" - if action.function_names: - message += ( - f"\n* **Function Names:** `{','.join(action.function_names)}`" - ) - - message += "\n\nUnfortunately, I didn’t find any relevant results." - message += search_result.message - + message = "\n\nUnfortunately, I didn’t find any relevant results." return self._retry(message) - output = {"ranked_spans": ranked_spans} - output.update(action.dict(exclude={"scratch_pad"})) - return ActionResponse.transition( trigger="did_search", - output=output, + output={"ranked_spans": ranked_spans}, ) def _retry(self, message: str) -> ActionResponse: if ( - self.retries() >= self.max_retries_with_any_file_context + self.retries() > self.max_retries_with_any_file_context and self.file_context.files ): logger.info( @@ -430,37 +405,6 @@ def _retry(self, message: str) -> ActionResponse: else: return ActionResponse.retry(message) - def _duplicate_search(self, action: Search) -> Optional[str]: - previous_transitions = self.loop.get_previous_transitions(self) - for transition in previous_transitions: - for previous_action in transition.actions: - if isinstance(previous_action.action, Search): - err_message = "" - exclude = {"scratch_pad"} - if action.function_names or action.class_names: - exclude.add("query") - - err_message = "" - if ( - action.function_names - == previous_action.action.function_names - ): - err_message += f"You already searched for the function name: {action.function_names}" - if action.class_names == previous_action.action.class_names: - err_message += f"You already searched for the class name: {action.class_names}" - - previous = previous_action.action.model_dump( - exclude={"scratch_pad"} - ) - current = action.model_dump(exclude={"scratch_pad"}) - if previous == current: - return ( - "You already did a search with the same parameters. " - + err_message - ) - - return None - def action_type(self) -> type[BaseModel] | None: return Search diff --git a/moatless/find/search_v2.py b/moatless/find/search_v2.py deleted file mode 100644 index 9ca771ef..00000000 --- a/moatless/find/search_v2.py +++ /dev/null @@ -1,500 +0,0 @@ -import fnmatch -import logging -from typing import Optional - -import instructor -from pydantic import BaseModel, Field - -from moatless.file_context import RankedFileSpan -from moatless.index.types import SearchCodeHit -from moatless.state import ActionResponse, AgenticState -from moatless.types import ( - ActionRequest, - AssistantMessage, - Message, - UserMessage, -) - -logger = logging.getLogger(__name__) - - -SEARCH_SYSTEM_PROMPT = """You are an autonomous AI assistant. -Your task is to locate the code relevant to an issue. - -# Instructions: - -1. Understand The Issue: -Read the tag to understand the issue. - -2. Review Current File Context: -Examine the tag to see which files and code spans have already been identified. -If you believe that all relevant files have been identified, you can finish the search by setting complete to true. - -3. Consider the Necessary Search Parameters: -Determine if specific file types, directories, function or class names or code patterns are mentioned in the issue. -If you can you should always try to specify the search parameters as accurately as possible. -You can do more than one search request at the same time so you can try different search parameters to cover all possible relevant code. - -4. Ensure At Least One Search Parameter: -Make sure that at least one of query, code_snippet, class_name, or function_name is provided. - -5. Formulate the Search function: -Set at least one of the search paramaters `query`, `code_snippet`, `class_name` or `function_name`. - - - -""" - - -SEARCH_FUNCTIONS_FEW_SHOT_OPENAI_FUNC = """ -6. Execute the Search function: -Use the Search function with the search parameters and your thoughts on how to approach this task. - -Think step by step and write out your thoughts in the thoughts field. - -Examples: - -User: -The file uploader intermittently fails with "TypeError: cannot unpack non-iterable NoneType object". This issue appears sporadically during high load conditions.. - -AI Assistant: -functions.Search({ - query: "File upload process to fix intermittent 'TypeError: cannot unpack non-iterable NoneType object'", - file_pattern: "**/uploader/**/*.py" -) - -User: -There's a bug in the PaymentProcessor class where transactions sometimes fail to log correctly, resulting in missing transaction records. - -AI Assistant: -functions.Search({ - class_names: ["PaymentProcessor"] -) - -User: -The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue. - -AI Assistant: -functions.Search({ - function_names: ["generate_report"], - file_pattern: "**/reports/**/*.py" -) - -User: -The extract_data function in HTMLParser throws an "AttributeError: 'NoneType' object has no attribute 'find'" error when parsing certain HTML pages. - -AI Assistant: -functions.Search({ - class_names: ["HTMLParser"], - function_names: ["extract_data"] -) - -User: -The database connection setup is missing SSL configuration, causing insecure connections. - -Here’s the stack trace of the error: - -File "/opt/app/db_config/database.py", line 45, in setup_connection - engine = create_engine(DATABASE_URL) -File "/opt/app/db_config/database.py", line 50, in - connection = setup_connection() - -AI Assistant: -functions.Search({ - code_snippet: "engine = create_engine(DATABASE_URL)", - file_pattern: "db_config/database.py" -) -""" - -SEARCH_FUNCTIONS_FEW_SHOT = """6. Execute the Search function: -Use the Search function with the search parameters and your thoughts on how to approach this task. - -Think step by step and write out your thoughts in the scratch_pad field. - -Examples: - -User: -The file uploader intermittently fails with "TypeError: cannot unpack non-iterable NoneType object". This issue appears sporadically during high load conditions.. - -Search parameters: - query: "File upload process to fix intermittent 'TypeError: cannot unpack non-iterable NoneType object'", - file_pattern: "**/uploader/**/*.py" - - -User: -There's a bug in the PaymentProcessor class where transactions sometimes fail to log correctly, resulting in missing transaction records. - -Search parameters: - class_names: ["PaymentProcessor"] - - -User: -The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue. - -Search parameters: - function_names: ["generate_report"] - file_pattern: "**/reports/**/*.py" - - -User: -The extract_data function in HTMLParser throws an "AttributeError: 'NoneType' object has no attribute 'find'" error when parsing certain HTML pages. - -Search parameters: - class_names: ["HTMLParser"] - function_names: ["extract_data"] - - -User: -The database connection setup is missing SSL configuration, causing insecure connections. - -Here’s the stack trace of the error: - -File "/opt/app/db_config/database.py", line 45, in setup_connection - engine = create_engine(DATABASE_URL) -File "/opt/app/db_config/database.py", line 50, in - connection = setup_connection() - -Search parameters: - code_snippet: "engine = create_engine(DATABASE_URL)", - file_pattern: "db_config/database.py" - -""" - -SEARCH_JSON_FEW_SHOT = """6. Execute the Search: -Execute the search by providing the search parameters and your thoughts on how to approach this task in a JSON object. - -Think step by step and write out your thoughts in the scratch_pad field. - -Examples: - -User: -The file uploader intermittently fails with "TypeError: cannot unpack non-iterable NoneType object". This issue appears sporadically during high load conditions.. - -Assistant: -{ - "scratch_pad": "The error indicates that a variable expected to be iterable is None, which might be happening due to race conditions or missing checks under high load. Investigate the file upload logic to ensure all necessary checks are in place and improve concurrency handling.", - "file_pattern": "**/uploader/**/*.py", - "query": "TypeError: cannot unpack non-iterable NoneType object" -} - -User: -There's a bug in the PaymentProcessor class where transactions sometimes fail to log correctly, resulting in missing transaction records. - -Assistant: -{ - "scratch_pad": "Missing transaction logs can cause significant issues in tracking payments. The problem may be related to how the logging mechanism handles transaction states or errors. Investigate the PaymentProcessor class, focusing on the transaction logging part.", - "class_name": "PaymentProcessor", - "query": "transactions fail to log correctly" -} - -User: -The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue. - -Assistant: -{ - "scratch_pad": "Incomplete reports suggest that the function might be encountering edge cases or unhandled exceptions that disrupt the report generation. Reviewing the function's logic and error handling in the reporting module is necessary.", - "function_name": "generate_report", - "file_pattern": "**/reports/**/*.py", -} - -User: -The extract_data function in HTMLParser throws an "AttributeError: 'NoneType' object has no attribute 'find'" error when parsing certain HTML pages. - -Assistant: -{ - "scratch_pad": "The error occurs when 'find' is called on a NoneType object, suggesting that the HTML structure might not match expected patterns. ", - "class_name": "HTMLParser", - "function_name": "extract_data", -} - - -User: -The database connection setup is missing SSL configuration, causing insecure connections. - -Here’s the stack trace of the error: - -File "/opt/app/db_config/database.py", line 45, in setup_connection - engine = create_engine(DATABASE_URL) -File "/opt/app/db_config/database.py", line 50, in - connection = setup_connection() - -Assistant: -{ - "scratch_pad": "The missing SSL configuration poses a security risk by allowing unencrypted connections. Find the code snippet `engine = create_engine(DATABASE_URL)` provided in the issue.", - "code_snippet": "engine = create_engine(DATABASE_URL)", -} -""" - -IGNORE_TEST_PROMPT = ( - "Test files are not in the search scope. Ignore requests to search for tests. " -) - - -class SearchRequest(BaseModel): - file_pattern: Optional[str] = Field( - default=None, - description="A glob pattern to filter search results to specific file types or directories. ", - ) - - query: Optional[str] = Field( - default=None, - description="A semantic similarity search query. Use natural language to describe what you are looking for.", - ) - - code_snippet: Optional[str] = Field( - default=None, - description="Specific code snippet to that should be exactly matched.", - ) - - class_names: list[str] = Field( - default=[], description="Specific class names to include in the search." - ) - - function_names: list[str] = Field( - default=[], description="Specific function names to include in the search." - ) - - def has_search_attributes(self): - return any( - [ - self.query, - self.code_snippet, - self.class_names, - self.function_names, - ] - ) - - -class Search(ActionRequest): - """Take action to search for code, identify found and finish up.""" - - scratch_pad: str = Field( - description="Scratch pad for the search. Use this to write down your thoughts on how to approach the search." - ) - - search_requests: list[SearchRequest] = Field( - default=[], - description="List of search requests.", - ) - - complete: Optional[bool] = Field( - default=False, description="Set to true when the search is complete." - ) - - def has_search_attributes(self): - return all([search.has_search_attributes() for search in self.search_requests]) - - -class SearchCode(AgenticState): - message: Optional[str] = Field( - None, - description="Message to the search", - ) - - max_search_results: int = Field( - 25, - description="The maximum number of search results.", - ) - - max_retries_with_any_file_context: int = Field( - 3, - description="The maximum number of retries when there are identified files in file context.", - ) - - provide_initial_context: bool = True - initial_context_tokens: int = 4000 - initial_search_results: int = 50 - initial_context_spans_per_file: int = 5 - - support_test_files: bool = False - - def __init__( - self, - message: Optional[str] = None, - max_search_results: int = 25, - max_retries_with_any_file_context: int = 3, - include_message_history: bool = True, - provide_initial_context: bool = True, - initial_context_tokens: int = 4000, - initial_search_results: int = 50, - initial_context_spans_per_file: int = 5, - **data, - ): - super().__init__( - message=message, - include_message_history=include_message_history, - provide_initial_context=provide_initial_context, - max_search_results=max_search_results, - max_retries_with_any_file_context=max_retries_with_any_file_context, - initial_context_tokens=initial_context_tokens, - initial_search_results=initial_search_results, - initial_context_spans_per_file=initial_context_spans_per_file, - **data, - ) - - def handle_action(self, action: Search) -> ActionResponse: - if action.complete: - return ActionResponse.transition( - "finish", - output={ - "message": action.scratch_pad, - }, - ) - - if isinstance(action, Search): - if not action.has_search_attributes(): - return self._retry( - "You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true." - ) - - for request in action.search_requests: - if ( - not self.support_test_files - and request.file_pattern - and is_test_pattern(request.file_pattern) - ): - return self._retry("It's not possible to search for test files.") - - message = "" - search_result: list[SearchCodeHit] = [] - for search_request in action.search_requests: - search_response = self.workspace.code_index.search( - file_pattern=search_request.file_pattern, - query=search_request.query, - code_snippet=search_request.code_snippet, - class_names=search_request.class_names, - function_names=search_request.function_names, - max_results=int(self.max_search_results / len(action.search_requests)), - ) - search_result.extend(search_response.hits) - message += "\n" + search_response.message - - logger.info(f"Found {len(search_result)} hits.") - - ranked_spans = [] - for hit in search_result: - for span in hit.spans: - ranked_spans.append( - RankedFileSpan( - file_path=hit.file_path, - span_id=span.span_id, - rank=span.rank, - tokens=span.tokens, - ) - ) - - if len(ranked_spans) == 0: - logger.info("No search results found. Will retry.") - message = "\n\nUnfortunately, I didn’t find any relevant results." - return self._retry(message) - - output = {"ranked_spans": ranked_spans} - output.update(action.dict(exclude={"scratch_pad"})) - - return ActionResponse.transition( - trigger="did_search", - output=output, - ) - - def _retry(self, message: str) -> ActionResponse: - if ( - self.retries() > self.max_retries_with_any_file_context - and self.file_context.files - ): - logger.info( - "Exceeded max retries, will finish as there are identified files in the file context. Transitioning to finish." - ) - return ActionResponse.transition("finish") - else: - return ActionResponse.retry(message) - - def action_type(self) -> type[BaseModel] | None: - return Search - - def system_prompt(self) -> str: - system_prompt = SEARCH_SYSTEM_PROMPT - - if self.loop.instructor_mode == instructor.Mode.JSON: - system_prompt += SEARCH_JSON_FEW_SHOT - elif self.model.startswith("openai"): - system_prompt += SEARCH_FUNCTIONS_FEW_SHOT_OPENAI_FUNC - else: - system_prompt += SEARCH_FUNCTIONS_FEW_SHOT - - if not self.support_test_files: - system_prompt += IGNORE_TEST_PROMPT - return system_prompt - - def messages(self) -> list[Message]: - messages: list[Message] = [] - - content = f"\n{self.loop.trajectory.initial_message}\n" - - if self.provide_initial_context: - result = self.workspace.code_index.semantic_search( - query=self.loop.trajectory.initial_message, - exact_match_if_possible=False, - max_spans_per_file=5, - max_results=50, - ) - - file_context = self.create_file_context(max_tokens=4000) - - for hit in result.hits: - for span in hit.spans: - file_context.add_span_to_context( - hit.file_path, span.span_id, tokens=1 - ) - - content += "\n\nHere's some files that might be relevant when formulating the search.\n" - content += file_context.create_prompt( - show_span_ids=False, - show_line_numbers=False, - exclude_comments=True, - show_outcommented_code=False, - ) - - previous_transitions = self.loop.get_previous_transitions(self) - for transition in previous_transitions: - if transition.state.message: - content += transition.state.message - messages.append(UserMessage(content=content)) - messages.append( - AssistantMessage( - action=transition.actions[-1].action, - ) - ) - content = "" - - if self.message: - content += f"\n\n{self.message}\n" - - if self.file_context.files: - file_context_str = self.file_context.create_prompt( - exclude_comments=True, - show_outcommented_code=True, - outcomment_code_comment="... rest of the code", - ) - else: - file_context_str = "No files found yet." - - content += f"\n\n\n{file_context_str}\n" - - messages.append(UserMessage(content=content)) - messages.extend(self.retry_messages()) - - return messages - - -def is_test_pattern(file_pattern: str): - test_patterns = ["test_*.py", "/tests/"] - for pattern in test_patterns: - if pattern in file_pattern: - return True - - if file_pattern.startswith("test"): - return True - - test_patterns = ["test_*.py"] - - return any(fnmatch.filter([file_pattern], pattern) for pattern in test_patterns) diff --git a/moatless/loop.py b/moatless/loop.py index 6190810c..7c9dff95 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -45,6 +45,7 @@ def __init__( workspace: Workspace, trajectory: Trajectory | None = None, mocked_actions: list[dict] | None = None, + expected_states: list[Type[AgenticState]] | None = None, reset_mocks_at_state: Optional[str] = None, verify_state_func: Optional[Callable] = None, max_cost: float = 0.25, @@ -79,9 +80,22 @@ def __init__( self._prompt_log_dir = prompt_log_dir self._mocked_actions = mocked_actions - self._reset_mocks_at_state = reset_mocks_at_state + + if expected_states and not verify_state_func: + def verify_state_func(state: AgenticState): + nonlocal expected_states + if not expected_states: + raise ValueError(f"No more expected states, but got {state.__class__}") + expected_state = expected_states.pop(0) + if not (state.name == expected_state or isinstance(state, expected_state)): + raise ValueError(f"Expected state {expected_state} but got {state.__class__.__name__}") + + self.log_info(f"Verified expected next state {expected_state}") + self._verify_state_func = verify_state_func + self._reset_mocks_at_state = reset_mocks_at_state + self._max_cost = max_cost self._max_message_tokens = max_message_tokens self._max_transitions = max_transitions @@ -489,8 +503,7 @@ def _run(self): self._current_transition.actions.append( TrajectoryAction( action=action, - output=response.output, - retry_message=response.retry_message, + trigger=response.trigger, completion_cost=cost, input_tokens=input_tokens, output_tokens=output_tokens, @@ -500,10 +513,12 @@ def _run(self): if not response.trigger: self.log_info( - f"{self.state.name}: No transition found. Staying in the same state." + f"{self.state.name}: No trigger in action response. Staying in the same state." ) return + self.log_info(f"Received response with trigger {response.trigger}") + if response.trigger == "retry": self.log_info(f"Retry requested. {response.retry_message}") return @@ -516,7 +531,7 @@ def _run(self): ) except Exception: logger.exception( - f"Failed to initiate next state with trigger {response.trigger} and output {response.output}" + f"{self.transition_name}: Failed to initiate next state with trigger {response.trigger} and output {response.output}" ) raise @@ -533,7 +548,6 @@ def _run(self): else: self._rejections = 0 - self.log_info(f"Transitioning to {next_state.name}") self.transition_to(next_state) @property diff --git a/moatless/state.py b/moatless/state.py index 29a27b83..e3b4d79b 100644 --- a/moatless/state.py +++ b/moatless/state.py @@ -175,6 +175,7 @@ def get_state_class(name: str) -> type[AgenticState]: ] for module_name in possible_modules: + try: module = importlib.import_module(module_name) if hasattr(module, name): diff --git a/moatless/trajectory.py b/moatless/trajectory.py index 580c7a02..1d0d6e77 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -16,8 +16,7 @@ class TrajectoryAction(BaseModel): action: ActionRequest - retry_message: Optional[str] = None - output: Optional[dict[str, Any]] = None + trigger: Optional[str] completion_cost: Optional[float] = None input_tokens: Optional[int] = None output_tokens: Optional[int] = None @@ -139,6 +138,8 @@ def save_info(self, info: dict): def get_mocked_actions(self) -> list[dict]: """ Return a list of actions that can be used to mock the trajectory. + + TODO: Provide the end transition and support parent() """ actions = [] for transition in self._transitions: @@ -146,6 +147,19 @@ def get_mocked_actions(self) -> list[dict]: actions.append(action["action"]) return actions + def get_expected_states(self) -> list[str]: + """ + Return a list of expected states in the trajectory to use for verifcation when rerunning the trajectory. + + TODO: Provide the end transition and support parent() + """ + + states = [] + for transition in self._transitions: + states.append(transition["state"]["name"]) + return states + + def to_dict(self): return { "name": self._name, diff --git a/moatless/transitions.py b/moatless/transitions.py index 650b32bd..b76bbaf8 100644 --- a/moatless/transitions.py +++ b/moatless/transitions.py @@ -7,7 +7,7 @@ from moatless.edit.plan_lines import PlanToCodeWithLines from moatless.find.decide import DecideRelevance from moatless.find.identify import IdentifyCode -from moatless.find.search_v2 import SearchCode +from moatless.find.search import SearchCode from moatless.transition_rules import TransitionRule, TransitionRules from moatless.state import Finished, Rejected diff --git a/tests/loop/test_loop.py b/tests/loop/test_loop.py index c4c61949..dab7c33f 100644 --- a/tests/loop/test_loop.py +++ b/tests/loop/test_loop.py @@ -16,9 +16,10 @@ def test_rerun_save_and_load_trajectory(): workspace = create_workspace(instance) assert isinstance(workspace.file_repo, GitRepository) mocked_actions = trajectory.get_mocked_actions() + expected_states = trajectory.get_expected_states() loop = AgenticLoop( - trajectory.transition_rules, workspace=workspace, mocked_actions=mocked_actions + trajectory.transition_rules, workspace=workspace, mocked_actions=mocked_actions, expected_states=expected_states ) response = loop.run(message=trajectory.initial_message) diff --git a/tests/test_rerun_trajectories.py b/tests/test_rerun_trajectories.py deleted file mode 100644 index 8bc5a061..00000000 --- a/tests/test_rerun_trajectories.py +++ /dev/null @@ -1,69 +0,0 @@ -import json - -import pytest - -from moatless.benchmark.swebench import load_instance, create_workspace -from moatless.benchmark.utils import get_file_spans_from_patch -from moatless.edit.edit import EditCode -from moatless.loop import AgenticLoop -from moatless.state import AgenticState -from moatless.transitions import code_transitions - - -def read_trajectory(path): - with open(path, "r") as f: - return json.load(f) - - -def get_actions(trajectory: dict): - actions = [] - for transition in trajectory["transitions"]: - for action in transition["actions"]: - actions.append(action["action"]) - return actions - - -@pytest.mark.skip -def test_two_edits(): - path = "trajectories/two_edits.json" - trajectory = read_trajectory(path) - actions = get_actions(trajectory) - - instance = load_instance(trajectory["info"]["instance_id"]) - workspace = create_workspace( - instance, - repo_base_dir="/tmp/repos", - index_store_dir="/home/albert/20240522-voyage-code-2", - ) - - spans = get_file_spans_from_patch(workspace.file_repo, instance["patch"]) - for file_path, span_ids in spans.items(): - workspace.file_context.add_spans_to_context( - file_path=file_path, span_ids=span_ids - ) - - iteration = 0 - - def _verify_state_func(state: AgenticState): - nonlocal iteration - iteration += 1 - if isinstance(state, EditCode): - assert state.file_path == "src/_pytest/mark/evaluate.py" - if iteration == 2: - # Remove function - assert state.span_id == "cached_eval" - assert state.start_line == 21 - assert state.end_line == 31 - elif iteration == 4: - # Update the other function and expect lines to have been updated - assert state.span_id == "MarkEvaluator._istrue" - assert state.start_line == 71 - assert state.end_line == 110 - - loop = AgenticLoop( - code_transitions(), - workspace=workspace, - mocked_actions=actions, - verify_state_func=_verify_state_func, - ) - loop.run(message=trajectory["initial_message"])