From adcb7e9538d9f9610d4ac6eaee788af9bd03e164 Mon Sep 17 00:00:00 2001 From: BioBootloader Date: Tue, 5 Mar 2024 22:35:50 -0800 Subject: [PATCH] fix --- mentat/code_file_manager.py | 2 +- mentat/parsers/parser.py | 3 +- mentat/stream_model_response.py | 62 +++++++++++++++++++-------------- 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index 5e591f029..282812a85 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -156,7 +156,7 @@ async def write_changes_to_files_two_step( applied_edits: list[tuple[str, str]] = [] for abs_path, new_file_str in rewritten_files: new_lines = new_file_str.splitlines() - self.write_to_file(abs_path, new_lines) + self.write_to_file(Path(abs_path), new_lines) applied_edits.append((abs_path, new_file_str)) return applied_edits diff --git a/mentat/parsers/parser.py b/mentat/parsers/parser.py index 9db5de9c7..38c4f019b 100644 --- a/mentat/parsers/parser.py +++ b/mentat/parsers/parser.py @@ -35,7 +35,7 @@ class ParsedLLMResponse: full_response: str = attr.field() conversation: str = attr.field() file_edits: list[FileEdit] = attr.field() - rewritten_files: list[tuple[str, str]] = attr.field(default=[]) + rewritten_files: list[tuple[str, str]] = attr.field(factory=list) interrupted: bool = attr.field(default=False) @@ -338,6 +338,7 @@ async def stream_and_parse_llm_response( message, conversation, [file_edit for file_edit in file_edits.values()], + [], interrupted, ) diff --git a/mentat/stream_model_response.py b/mentat/stream_model_response.py index d46b8eb09..bbb87e239 100644 --- a/mentat/stream_model_response.py +++ b/mentat/stream_model_response.py @@ -6,7 +6,7 @@ from difflib import ndiff from json import JSONDecodeError from pathlib import Path -from typing import AsyncIterator +from typing import Any, AsyncIterator from openai.types.chat import ( ChatCompletionChunk, @@ -104,7 +104,7 @@ def get_two_step_rewrite_file_prompt() -> str: return read_prompt(two_step_edit_prompt_rewrite_file_filename) -def print_colored_diff(str1, str2, stream): +def print_colored_diff(str1: str, str2: str, stream: Any): diff = ndiff(str1.splitlines(), str2.splitlines()) for line in diff: @@ -185,19 +185,25 @@ async def stream_model_response_two_step( response_format=ResponseFormat(type="json_object"), ) - try: - response_json = json.loads(list_files_response.choices[0].message.content) - except JSONDecodeError: - stream.send("Error processing model response: Invalid JSON", style="error") - # TODO: handle error + response_json = {} + if list_files_response.choices and list_files_response.choices[0].message.content: + try: + response_json: dict[str, list[str]] = json.loads( + list_files_response.choices[0].message.content + ) + except JSONDecodeError: + stream.send("Error processing model response: Invalid JSON", style="error") + # TODO: handle error stream.send(f"\n{response_json}\n") # TODO remove line numbers when running two step edit # TODO handle creating new files - including update prompt to know that's possible - rewritten_files = [] - for file_path in response_json["files"]: + rewritten_files: list[tuple[str, str]] = [] + + file_paths: list[str] = response_json.get("files", []) + for file_path in file_paths: full_path = (cwd / Path(file_path)).resolve() code_file_lines = code_file_manager.file_lines.get(full_path, []) code_file_string = "\n".join(code_file_lines) @@ -222,21 +228,26 @@ async def stream_model_response_two_step( model="gpt-3.5-turbo-0125", # TODO add config for secondary model stream=False, ) - rewrite_file_response = rewrite_file_response.choices[0].message.content - lines = rewrite_file_response.splitlines() - # TODO remove asserts - assert "```" in lines[0] - assert "```" in lines[-1] - lines = lines[1:-1] - rewrite_file_response = "\n".join(lines) - - rewritten_files.append((full_path, rewrite_file_response)) - - stream.send(f"\n### File Rewrite Response: {file_path} ###\n") - # stream.send(rewrite_file_response) - - # TODO stream colored diff, skipping unchanged lines (except some for context) - print_colored_diff(code_file_string, rewrite_file_response, stream) + if ( + rewrite_file_response + and rewrite_file_response.choices + and rewrite_file_response.choices[0].message.content + ): + rewrite_file_response = rewrite_file_response.choices[0].message.content + lines = rewrite_file_response.splitlines() + # TODO remove asserts + assert "```" in lines[0] + assert "```" in lines[-1] + lines = lines[1:-1] + rewrite_file_response = "\n".join(lines) + + rewritten_files.append((str(full_path), rewrite_file_response)) + + stream.send(f"\n### File Rewrite Response: {file_path} ###\n") + # stream.send(rewrite_file_response) + + # TODO stream colored diff, skipping unchanged lines (except some for context) + print_colored_diff(code_file_string, rewrite_file_response, stream) # async with parser.interrupt_catcher(): # parsed_llm_response = await parser.stream_and_parse_llm_response( @@ -297,8 +308,7 @@ async def stream_and_parse_llm_response_two_step( # Only finish printing if we don't quit from ctrl-c printer.wrap_it_up() - if printer_task is not None: - await printer_task + await printer_task logging.debug("LLM Response:") logging.debug(message)