Skip to content

Commit

Permalink
Verify transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Jul 31, 2024
1 parent b750329 commit 2ee9cc9
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 705 deletions.
4 changes: 2 additions & 2 deletions moatless/benchmark/claude_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from moatless.edit.plan import PlanToCode
from moatless.find.decide import DecideRelevance
from moatless.find.identify import IdentifyCode
from moatless.find.search_v2 import SearchCode
from moatless.loop import TransitionRule
from moatless.find.search import SearchCode
from moatless.transition_rules import TransitionRule
from moatless.state import Finished, Rejected
from moatless.transitions import (
search_and_code_transitions,
Expand Down
2 changes: 1 addition & 1 deletion moatless/find/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from moatless.find.search_v2 import SearchCode
from moatless.find.search import SearchCode
from moatless.find.identify import IdentifyCode
from moatless.find.decide import DecideRelevance
190 changes: 67 additions & 123 deletions moatless/find/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from typing import Optional

import instructor
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field

from moatless.file_context import FileContext, RankedFileSpan
from moatless.file_context import RankedFileSpan
from moatless.index.types import SearchCodeHit
from moatless.state import ActionResponse, AgenticState
from moatless.types import (
ActionRequest,
Expand All @@ -32,12 +33,16 @@
3. Consider the Necessary Search Parameters:
Determine if specific file types, directories, function or class names or code patterns are mentioned in the issue.
If you can you should always try to specify the search parameters as accurately as possible.
You can do more than one search request at the same time so you can try different search parameters to cover all possible relevant code.
4. Ensure At Least One Search Parameter:
Make sure that at least one of query, code_snippet, class_name, or function_name is provided in each search request.
Make sure that at least one of query, code_snippet, class_name, or function_name is provided.
5. Formulate the Search function:
Set at least one of the search paramaters `query`, `code_snippet`, `class_name` or `function_name`.
"""


Expand All @@ -63,15 +68,15 @@
AI Assistant:
functions.Search({
class_name: "PaymentProcessor"
class_names: ["PaymentProcessor"]
)
User:
The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue.
AI Assistant:
functions.Search({
function_name: "generate_report",
function_names: ["generate_report"],
file_pattern: "**/reports/**/*.py"
)
Expand All @@ -80,8 +85,8 @@
AI Assistant:
functions.Search({
class_name: "HTMLParser",
function_name: "extract_data"
class_names: ["HTMLParser"],
function_names: ["extract_data"]
)
User:
Expand Down Expand Up @@ -120,23 +125,23 @@
There's a bug in the PaymentProcessor class where transactions sometimes fail to log correctly, resulting in missing transaction records.
Search parameters:
class_name: "PaymentProcessor"
class_names: ["PaymentProcessor"]
User:
The generate_report function sometimes produces incomplete reports under certain conditions. This function is part of the reporting module. Locate the generate_report function in the reports directory to debug and fix the issue.
Search parameters:
function_name: "generate_report",
function_names: ["generate_report"]
file_pattern: "**/reports/**/*.py"
User:
The extract_data function in HTMLParser throws an "AttributeError: 'NoneType' object has no attribute 'find'" error when parsing certain HTML pages.
Search parameters:
class_name: "HTMLParser",
function_name: "extract_data"
class_names: ["HTMLParser"]
function_names: ["extract_data"]
User:
Expand Down Expand Up @@ -225,13 +230,7 @@
)


class Search(ActionRequest):
"""Take action to search for code, identify found and finish up."""

scratch_pad: str = Field(
description="Your thoughts on what search parameters to set."
)

class SearchRequest(BaseModel):
file_pattern: Optional[str] = Field(
default=None,
description="A glob pattern to filter search results to specific file types or directories. ",
Expand All @@ -255,10 +254,6 @@ class Search(ActionRequest):
default=[], description="Specific function names to include in the search."
)

complete: Optional[bool] = Field(
default=False, description="Set to true when the search is complete."
)

def has_search_attributes(self):
return any(
[
Expand All @@ -270,15 +265,27 @@ def has_search_attributes(self):
)


class ActionCallWithContext(BaseModel):
action: ActionRequest
file_context: FileContext
message: Optional[str] = None
class Search(ActionRequest):
"""Take action to search for code, identify found and finish up."""

scratch_pad: str = Field(
description="Scratch pad for the search. Use this to write down your thoughts on how to approach the search."
)

search_requests: list[SearchRequest] = Field(
default=[],
description="List of search requests.",
)

complete: Optional[bool] = Field(
default=False, description="Set to true when the search is complete."
)

model_config = ConfigDict(arbitrary_types_allowed=True)
def has_search_attributes(self):
return all([search.has_search_attributes() for search in self.search_requests])


class LegacySearchCode(AgenticState):
class SearchCode(AgenticState):
message: Optional[str] = Field(
None,
description="Message to the search",
Expand Down Expand Up @@ -306,11 +313,11 @@ def __init__(
message: Optional[str] = None,
max_search_results: int = 25,
max_retries_with_any_file_context: int = 3,
include_message_history: bool = True,
provide_initial_context: bool = True,
initial_context_tokens: int = 4000,
initial_search_results: int = 50,
initial_context_spans_per_file: int = 5,
include_message_history=True,
**data,
):
super().__init__(
Expand All @@ -334,49 +341,38 @@ def handle_action(self, action: Search) -> ActionResponse:
},
)

if not action.has_search_attributes():
return self._retry(
"You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true."
)

dup_error = self._duplicate_search(action)
if dup_error:
message = dup_error

if action.file_pattern:
message += f"\n* **File Pattern:** `{action.file_pattern}`"
if action.query:
message += f"\n* **Query:** `{action.query}`"
if action.code_snippet:
message += f"\n* **Code Snippet:** `{action.code_snippet}`"
if action.class_names:
message += f"\n* **Class Name:** `{action.class_names}`"
if action.function_names:
message += f"\n* **Function Name:** `{action.function_names}`"

message += "\n\nPlease provide a new search parameters."
return self._retry(message)
if isinstance(action, Search):
if not action.has_search_attributes():
return self._retry(
"You must provide at least one the search attributes query, code_snippet, class_name or function_name to search. If you're finished, set finished to true."
)

if (
not self.support_test_files
and action.file_pattern
and is_test_pattern(action.file_pattern)
):
return self._retry("It's not possible to search for test files.")

search_result = self.workspace.code_index.search(
file_pattern=action.file_pattern,
query=action.query,
code_snippet=action.code_snippet,
class_names=action.class_names,
function_names=action.function_names,
max_results=self.max_search_results,
)
for request in action.search_requests:
if (
not self.support_test_files
and request.file_pattern
and is_test_pattern(request.file_pattern)
):
return self._retry("It's not possible to search for test files.")

message = ""
search_result: list[SearchCodeHit] = []
for search_request in action.search_requests:
search_response = self.workspace.code_index.search(
file_pattern=search_request.file_pattern,
query=search_request.query,
code_snippet=search_request.code_snippet,
class_names=search_request.class_names,
function_names=search_request.function_names,
max_results=int(self.max_search_results / len(action.search_requests)),
)
search_result.extend(search_response.hits)
message += "\n" + search_response.message

logger.info(f"Found {len(search_result.hits)} hits.")
logger.info(f"Found {len(search_result)} hits.")

ranked_spans = []
for hit in search_result.hits:
for hit in search_result:
for span in hit.spans:
ranked_spans.append(
RankedFileSpan(
Expand All @@ -389,38 +385,17 @@ def handle_action(self, action: Search) -> ActionResponse:

if len(ranked_spans) == 0:
logger.info("No search results found. Will retry.")

message = "I searched using the following parameters:\n"

if action.file_pattern:
message += f"\n* **File Pattern:** `{action.file_pattern}`"
if action.query:
message += f"\n* **Query:** `{action.query}`"
if action.code_snippet:
message += f"\n* **Code Snippet:** `{action.code_snippet}`"
if action.class_names:
message += f"\n* **Class Names:** `{','.join(action.class_names)}`"
if action.function_names:
message += (
f"\n* **Function Names:** `{','.join(action.function_names)}`"
)

message += "\n\nUnfortunately, I didn’t find any relevant results."
message += search_result.message

message = "\n\nUnfortunately, I didn’t find any relevant results."
return self._retry(message)

output = {"ranked_spans": ranked_spans}
output.update(action.dict(exclude={"scratch_pad"}))

return ActionResponse.transition(
trigger="did_search",
output=output,
output={"ranked_spans": ranked_spans},
)

def _retry(self, message: str) -> ActionResponse:
if (
self.retries() >= self.max_retries_with_any_file_context
self.retries() > self.max_retries_with_any_file_context
and self.file_context.files
):
logger.info(
Expand All @@ -430,37 +405,6 @@ def _retry(self, message: str) -> ActionResponse:
else:
return ActionResponse.retry(message)

def _duplicate_search(self, action: Search) -> Optional[str]:
previous_transitions = self.loop.get_previous_transitions(self)
for transition in previous_transitions:
for previous_action in transition.actions:
if isinstance(previous_action.action, Search):
err_message = ""
exclude = {"scratch_pad"}
if action.function_names or action.class_names:
exclude.add("query")

err_message = ""
if (
action.function_names
== previous_action.action.function_names
):
err_message += f"You already searched for the function name: {action.function_names}"
if action.class_names == previous_action.action.class_names:
err_message += f"You already searched for the class name: {action.class_names}"

previous = previous_action.action.model_dump(
exclude={"scratch_pad"}
)
current = action.model_dump(exclude={"scratch_pad"})
if previous == current:
return (
"You already did a search with the same parameters. "
+ err_message
)

return None

def action_type(self) -> type[BaseModel] | None:
return Search

Expand Down
Loading

0 comments on commit 2ee9cc9

Please sign in to comment.