Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two Step Edits #530

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
44 changes: 44 additions & 0 deletions mentat/code_edit_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,50 @@ async def get_user_feedback_on_edits(
return edits_to_apply, need_user_request


async def get_user_feedback_on_edits_two_step(
rewritten_files: list[tuple[str, str]],
) -> tuple[list[tuple[str, str]], bool]:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
conversation = session_context.conversation

stream.send(
"Apply these changes? 'Y/n' or provide feedback.",
style="input",
)
user_response_message = await collect_user_input()
user_response = user_response_message.data

need_user_request = True
match user_response.lower():
case "y" | "":
rewritten_files_to_apply = rewritten_files
conversation.add_message(
ChatCompletionSystemMessageParam(
role="system", content="User chose to apply all your changes."
)
)
case "n":
rewritten_files_to_apply = []
conversation.add_message(
ChatCompletionSystemMessageParam(
role="system",
content="User chose not to apply any of your changes.",
)
)
case _:
need_user_request = False
rewritten_files_to_apply = []
conversation.add_message(
ChatCompletionSystemMessageParam(
role="system",
content="User chose not to apply any of your changes.",
)
)
conversation.add_user_message(user_response)
return rewritten_files_to_apply, need_user_request


async def _user_filter_changes(file_edits: list[FileEdit]) -> list[FileEdit]:
new_edits = list[FileEdit]()
for file_edit in file_edits:
Expand Down
12 changes: 12 additions & 0 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ async def write_changes_to_files(
self.history.push_edits()
return applied_edits

# TODO handle creation, deletion, rename, undo/redo, check if file was modified, etc.
async def write_changes_to_files_two_step(
self, rewritten_files: list[tuple[str, str]]
) -> list[tuple[str, str]]:
applied_edits: list[tuple[str, str]] = []
for abs_path, new_file_str in rewritten_files:
new_lines = new_file_str.splitlines()
self.write_to_file(Path(abs_path), new_lines)
applied_edits.append((abs_path, new_file_str))

return applied_edits

def get_file_checksum(self, path: Path, interval: Interval | None = None) -> str:
if path.is_dir():
return "" # TODO: Build and maintain a hash tree for git_root
Expand Down
10 changes: 10 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ class Config:
},
converter=converters.optional(converters.to_bool),
)
two_step_edits: bool = attr.field(
default=False,
metadata={
"description": (
"Experimental feature that uses multiple LLM calls to make and parse"
" edits"
),
"auto_completions": bool_autocomplete,
},
)
revisor: bool = attr.field(
default=False,
metadata={
Expand Down
114 changes: 32 additions & 82 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)

from mentat.llm_api_handler import (
TOKEN_COUNT_WARNING,
count_tokens,
get_max_tokens,
prompt_tokens,
Expand All @@ -25,8 +24,12 @@
from mentat.parsers.file_edit import FileEdit
from mentat.parsers.parser import ParsedLLMResponse
from mentat.session_context import SESSION_CONTEXT
from mentat.stream_model_response import (
get_two_step_system_prompt,
stream_model_response,
stream_model_response_two_step,
)
from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage
from mentat.utils import add_newline


class MentatAssistantMessageParam(ChatCompletionAssistantMessageParam):
Expand Down Expand Up @@ -181,92 +184,28 @@ async def get_messages(
if ctx.config.no_parser_prompt:
system_prompt = []
else:
parser = ctx.config.parser
system_prompt = [
ChatCompletionSystemMessageParam(
role="system",
content=parser.get_system_prompt(),
)
]
if ctx.config.two_step_edits:
system_prompt = [
ChatCompletionSystemMessageParam(
role="system",
content=get_two_step_system_prompt(),
)
]
else:
parser = ctx.config.parser
system_prompt = [
ChatCompletionSystemMessageParam(
role="system",
content=parser.get_system_prompt(),
)
]

return system_prompt + _messages

def clear_messages(self) -> None:
"""Clears the messages in the conversation"""
self._messages = list[ChatCompletionMessageParam]()

async def _stream_model_response(
self,
messages: list[ChatCompletionMessageParam],
) -> ParsedLLMResponse:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_file_manager = session_context.code_file_manager
config = session_context.config
parser = config.parser
llm_api_handler = session_context.llm_api_handler
cost_tracker = session_context.cost_tracker

stream.send(
None,
channel="loading",
)
response = await llm_api_handler.call_llm_api(
messages,
config.model,
stream=True,
response_format=parser.response_format(),
)
stream.send(
None,
channel="loading",
terminate=True,
)

num_prompt_tokens = prompt_tokens(messages, config.model)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
"Warning: LLM performance drops off rapidly at large context sizes. Use"
" /clear to clear context or use /exclude to exclude any uneccessary"
" files.",
style="warning",
)

stream.send("Streaming... use control-c to interrupt the model at any point\n")
async with parser.interrupt_catcher():
parsed_llm_response = await parser.stream_and_parse_llm_response(
add_newline(response)
)
# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(
file_edit.file_path, []
)
if not parsed_llm_response.interrupted:
cost_tracker.display_last_api_call()
else:
# Generator doesn't log the api call if we interrupt it
cost_tracker.log_api_call_stats(
num_prompt_tokens,
count_tokens(
parsed_llm_response.full_response, config.model, full_message=False
),
config.model,
display=True,
)

messages.append(
ChatCompletionAssistantMessageParam(
role="assistant", content=parsed_llm_response.full_response
)
)
self.add_model_message(
parsed_llm_response.full_response, messages, parsed_llm_response
)

return parsed_llm_response

async def get_model_response(self) -> ParsedLLMResponse:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_model_response method has been significantly simplified by moving the logic for streaming model responses into the stream_model_response.py file. This is a good refactor as it reduces the complexity of the Conversation class and adheres to the single responsibility principle. However, it's important to ensure that all functionalities related to model response handling are thoroughly tested in their new location to prevent any regressions.

session_context = SESSION_CONTEXT.get()
stream = session_context.stream
Expand All @@ -277,7 +216,10 @@ async def get_model_response(self) -> ParsedLLMResponse:
raise_if_context_exceeds_max(tokens_used)

try:
response = await self._stream_model_response(messages_snapshot)
if session_context.config.two_step_edits:
response = await stream_model_response_two_step(messages_snapshot)
else:
response = await stream_model_response(messages_snapshot)
except RateLimitError:
stream.send(
"Rate limit error received from OpenAI's servers using model"
Expand All @@ -286,6 +228,14 @@ async def get_model_response(self) -> ParsedLLMResponse:
style="error",
)
return ParsedLLMResponse("", "", list[FileEdit]())

messages_snapshot.append(
ChatCompletionAssistantMessageParam(
role="assistant", content=response.full_response
)
)
self.add_model_message(response.full_response, messages_snapshot, response)

return response

async def remaining_context(self) -> int | None:
Expand Down
4 changes: 3 additions & 1 deletion mentat/parsers/git_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def parse_llm_response(self, content: str) -> ParsedLLMResponse:
file_edits.append(file_edit)

return ParsedLLMResponse(
f"{conversation}\n\n{git_diff}", conversation, file_edits
f"{conversation}\n\n{git_diff}",
conversation,
file_edits,
)

def file_edit_to_git_diff(self, file_edit: FileEdit) -> str:
Expand Down
2 changes: 2 additions & 0 deletions mentat/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ParsedLLMResponse:
full_response: str = attr.field()
conversation: str = attr.field()
file_edits: list[FileEdit] = attr.field()
rewritten_files: list[tuple[str, str]] = attr.field(factory=list)
interrupted: bool = attr.field(default=False)


Expand Down Expand Up @@ -337,6 +338,7 @@ async def stream_and_parse_llm_response(
message,
conversation,
[file_edit for file_edit in file_edits.values()],
[],
interrupted,
)

Expand Down
14 changes: 14 additions & 0 deletions mentat/resources/prompts/two_step_edit_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
**You are now operating within an advanced AI coding system designed to assist with code modifications and enhancements.**

Upon receiving context, which may range from specific code snippets to entire repositories, you will be tasked with addressing coding requests or answering questions.

**For your responses:**

- **Directly address the request or question:** Provide concise instructions for any code modifications, clearly stating what changes need to be made.
- **Specify modifications without reiterating existing code:** Guide the user on where and how to make modifications, e.g., "insert the new code block above the last function in the file" or "replace the existing loop condition with the provided snippet." Ensure instructions are clear without displaying how the entire file looks post-modification.
- **Use the full file path at least once per file with edits:** When mentioning a file for the first time, use its full path. You can refer to it by a shorter name afterward if it remains clear which file you're discussing.
- **Avoid suggesting non-actionable edits:** Do not recommend commenting out or non-specific removals. Be explicit about what to delete or change, referring to code blocks or functions by name and avoiding extensive verbatim rewrites.
- **Minimize the inclusion of unchanged code:** Focus on the new or altered lines rather than embedding them within large blocks of unchanged code. Your guidance should be clear enough for an intelligent actor to implement with just the changes specified.
- **Emphasize brevity and clarity:** Once you've provided detailed instructions for the edits, there's no need for further elaboration. Avoid concluding with summaries of how the code will look after the edits.

**Your guidance should empower users to confidently implement the suggested changes with minimal and precise directions, fostering an efficient and clear modification process.**
10 changes: 10 additions & 0 deletions mentat/resources/prompts/two_step_edit_prompt_list_files.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
You are part of an expert AI coding system.

The next message will be an answer to a user's question or request. It may include suggested edits to code files. Your job is simply to extract the names of files that edits need to be made to, according to that message.

In your response:
- respond in json, with a single key "files" and a value that is an array of strings
- return empty array if no files have suggested edits, e.g. {"files":[]}
- the message may mention files without suggesting edits to them, do not include these. Only include files that have suggested edits
- if a file is meant to be created, include it in the list of files to edit

11 changes: 11 additions & 0 deletions mentat/resources/prompts/two_step_edit_prompt_rewrite_file.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
You are part of an expert AI coding system.

In the next message you will be given the contents of a code file. The user will then specify some edits to be made to the file.

Your response should:
- rewrite the entire file, including all the requested edits
- wrap your entire response in ```
- do not include anything else in your response other than the code
- do not make any other changes to the code other than the requested edits, even to standardize formatting
- even formatting changes should not be made unless explicitly requested by the user
- if a change is not fully specified, do your best to follow the spirit of what was asked
Loading
Loading