Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader committed Mar 6, 2024
1 parent a90bdd6 commit adcb7e9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 28 deletions.
2 changes: 1 addition & 1 deletion mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def write_changes_to_files_two_step(
applied_edits: list[tuple[str, str]] = []
for abs_path, new_file_str in rewritten_files:
new_lines = new_file_str.splitlines()
self.write_to_file(abs_path, new_lines)
self.write_to_file(Path(abs_path), new_lines)
applied_edits.append((abs_path, new_file_str))

return applied_edits
Expand Down
3 changes: 2 additions & 1 deletion mentat/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ParsedLLMResponse:
full_response: str = attr.field()
conversation: str = attr.field()
file_edits: list[FileEdit] = attr.field()
rewritten_files: list[tuple[str, str]] = attr.field(default=[])
rewritten_files: list[tuple[str, str]] = attr.field(factory=list)
interrupted: bool = attr.field(default=False)


Expand Down Expand Up @@ -338,6 +338,7 @@ async def stream_and_parse_llm_response(
message,
conversation,
[file_edit for file_edit in file_edits.values()],
[],
interrupted,
)

Expand Down
62 changes: 36 additions & 26 deletions mentat/stream_model_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from difflib import ndiff
from json import JSONDecodeError
from pathlib import Path
from typing import AsyncIterator
from typing import Any, AsyncIterator

from openai.types.chat import (
ChatCompletionChunk,
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_two_step_rewrite_file_prompt() -> str:
return read_prompt(two_step_edit_prompt_rewrite_file_filename)


def print_colored_diff(str1, str2, stream):
def print_colored_diff(str1: str, str2: str, stream: Any):
diff = ndiff(str1.splitlines(), str2.splitlines())

for line in diff:
Expand Down Expand Up @@ -185,19 +185,25 @@ async def stream_model_response_two_step(
response_format=ResponseFormat(type="json_object"),
)

try:
response_json = json.loads(list_files_response.choices[0].message.content)
except JSONDecodeError:
stream.send("Error processing model response: Invalid JSON", style="error")
# TODO: handle error
response_json = {}
if list_files_response.choices and list_files_response.choices[0].message.content:
try:
response_json: dict[str, list[str]] = json.loads(
list_files_response.choices[0].message.content
)
except JSONDecodeError:
stream.send("Error processing model response: Invalid JSON", style="error")
# TODO: handle error

stream.send(f"\n{response_json}\n")

# TODO remove line numbers when running two step edit
# TODO handle creating new files - including update prompt to know that's possible

rewritten_files = []
for file_path in response_json["files"]:
rewritten_files: list[tuple[str, str]] = []

file_paths: list[str] = response_json.get("files", [])
for file_path in file_paths:
full_path = (cwd / Path(file_path)).resolve()
code_file_lines = code_file_manager.file_lines.get(full_path, [])
code_file_string = "\n".join(code_file_lines)
Expand All @@ -222,21 +228,26 @@ async def stream_model_response_two_step(
model="gpt-3.5-turbo-0125", # TODO add config for secondary model
stream=False,
)
rewrite_file_response = rewrite_file_response.choices[0].message.content
lines = rewrite_file_response.splitlines()
# TODO remove asserts
assert "```" in lines[0]
assert "```" in lines[-1]
lines = lines[1:-1]
rewrite_file_response = "\n".join(lines)

rewritten_files.append((full_path, rewrite_file_response))

stream.send(f"\n### File Rewrite Response: {file_path} ###\n")
# stream.send(rewrite_file_response)

# TODO stream colored diff, skipping unchanged lines (except some for context)
print_colored_diff(code_file_string, rewrite_file_response, stream)
if (
rewrite_file_response
and rewrite_file_response.choices
and rewrite_file_response.choices[0].message.content
):
rewrite_file_response = rewrite_file_response.choices[0].message.content
lines = rewrite_file_response.splitlines()
# TODO remove asserts
assert "```" in lines[0]
assert "```" in lines[-1]
lines = lines[1:-1]
rewrite_file_response = "\n".join(lines)

rewritten_files.append((str(full_path), rewrite_file_response))

stream.send(f"\n### File Rewrite Response: {file_path} ###\n")
# stream.send(rewrite_file_response)

# TODO stream colored diff, skipping unchanged lines (except some for context)
print_colored_diff(code_file_string, rewrite_file_response, stream)

# async with parser.interrupt_catcher():
# parsed_llm_response = await parser.stream_and_parse_llm_response(
Expand Down Expand Up @@ -297,8 +308,7 @@ async def stream_and_parse_llm_response_two_step(

# Only finish printing if we don't quit from ctrl-c
printer.wrap_it_up()
if printer_task is not None:
await printer_task
await printer_task

logging.debug("LLM Response:")
logging.debug(message)
Expand Down

0 comments on commit adcb7e9

Please sign in to comment.