Skip to content

Commit

Permalink
Merge branch 'main' into two-step-edits
Browse files Browse the repository at this point in the history
  • Loading branch information
biobootloader committed Mar 6, 2024
2 parents adcb7e9 + fe167fb commit b816358
Show file tree
Hide file tree
Showing 23 changed files with 536 additions and 204 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lint_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Setup node (for pyright)
uses: actions/setup-node@v4
with:
node-version: "16"
- name: Install universal-ctags (Ubuntu)
if: runner.os == 'Linux'
run: sudo apt update && sudo apt install universal-ctags
Expand Down
10 changes: 10 additions & 0 deletions docs/source/user/commands.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ Show information on available commands.

Add files to context.

/load [context file path]
------------

Load context from a file.

/save [context file path]
------------

Save context to a file.

/redo
-----

Expand Down
15 changes: 4 additions & 11 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
ChatCompletionSystemMessageParam,
)

from mentat.llm_api_handler import prompt_tokens
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no, collect_user_input
Expand Down Expand Up @@ -42,7 +41,7 @@ async def enable_agent_mode(self):
"Finding files to determine how to test changes...", style="info"
)
features = ctx.code_context.get_all_features(split_intervals=False)
messages: List[ChatCompletionMessageParam] = [
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
role="system", content=self.agent_file_selection_prompt
),
Expand Down Expand Up @@ -85,21 +84,15 @@ async def _determine_commands(self) -> List[str]:
ctx = SESSION_CONTEXT.get()

model = ctx.config.model
messages = [
system_prompt: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
role="system", content=self.agent_command_prompt
),
ChatCompletionSystemMessageParam(
role="system", content=self.agent_file_message
),
] + ctx.conversation.get_messages(include_system_prompt=False)
code_message = await ctx.code_context.get_code_message(
prompt_tokens=prompt_tokens(messages, model)
)
code_message = ChatCompletionSystemMessageParam(
role="system", content=code_message
)
messages.insert(1, code_message)
]
messages = await ctx.conversation.get_messages(system_prompt=system_prompt)

try:
# TODO: Should this even be a separate call or should we collect commands in the edit call?
Expand Down
127 changes: 59 additions & 68 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, TypedDict, Union

from openai.types.chat import ChatCompletionSystemMessageParam

from mentat.code_feature import (
CodeFeature,
get_code_message_from_features,
Expand All @@ -27,12 +25,7 @@
validate_and_format_path,
)
from mentat.interval import parse_intervals, split_intervals_from_path
from mentat.llm_api_handler import (
count_tokens,
get_max_tokens,
prompt_tokens,
raise_if_context_exceeds_max,
)
from mentat.llm_api_handler import count_tokens, get_max_tokens
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream

Expand All @@ -53,35 +46,28 @@ class CodeContext:
def __init__(
self,
stream: SessionStream,
git_root: Optional[Path] = None,
cwd: Path,
diff: Optional[str] = None,
pr_diff: Optional[str] = None,
ignore_patterns: Iterable[Path | str] = [],
):
self.git_root = git_root
self.diff = diff
self.pr_diff = pr_diff
self.ignore_patterns = set(Path(p) for p in ignore_patterns)

self.diff_context = None
if self.git_root:
self.diff_context = DiffContext(
stream, self.git_root, self.diff, self.pr_diff
)
self.diff_context = DiffContext(stream, cwd, self.diff, self.pr_diff)

self.include_files: Dict[Path, List[CodeFeature]] = {}
self.ignore_files: Set[Path] = set()
self.auto_features: List[CodeFeature] = []

def refresh_context_display(self):
async def refresh_context_display(self):
"""
Sends a message to the client with the code context. It is called in the main loop.
"""
ctx = SESSION_CONTEXT.get()

diff_context_display = None
if self.diff_context and self.diff_context.name:
diff_context_display = self.diff_context.get_display_context()
diff_context_display = self.diff_context.get_display_context()

features = get_consolidated_feature_refs(
[
Expand All @@ -91,30 +77,10 @@ def refresh_context_display(self):
]
)
auto_features = get_consolidated_feature_refs(self.auto_features)
if self.diff_context:
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]
else:
git_diff_paths = []
git_untracked_paths = []
messages = ctx.conversation.get_messages()
code_message = get_code_message_from_features(
[
feature
for file_features in self.include_files.values()
for feature in file_features
]
+ self.auto_features
)
total_tokens = prompt_tokens(
messages
+ [
ChatCompletionSystemMessageParam(
role="system", content="\n".join(code_message)
)
],
ctx.config.model,
)
git_diff_paths = [str(p) for p in self.diff_context.diff_files()]
git_untracked_paths = [str(p) for p in self.diff_context.untracked_files()]

total_tokens = await ctx.conversation.count_tokens(include_code_message=True)

total_cost = ctx.cost_tracker.total_cost

Expand All @@ -136,7 +102,6 @@ async def get_code_message(
prompt_tokens: int,
prompt: Optional[str] = None,
expected_edits: Optional[list[str]] = None, # for training/benchmarking
suppress_context_check: bool = False,
) -> str:
"""
Retrieves the current code message.
Expand All @@ -151,42 +116,42 @@ async def get_code_message(

# Setup code message metadata
code_message = list[str]()
if self.diff_context:
# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
f' "-" = {self.diff_context.name}',
' "+" = Active Changes',
"",
]

# Since there is no way of knowing when the git diff changes,
# we just refresh the cache every time get_code_message is called
self.diff_context.refresh()
if self.diff_context.diff_files():
code_message += [
"Diff References:",
f' "-" = {self.diff_context.name}',
' "+" = Active Changes',
"",
]

code_message += ["Code Files:\n"]
meta_tokens = count_tokens("\n".join(code_message), model, full_message=True)

# Calculate user included features token size
include_features = [
feature
for file_features in self.include_files.values()
for feature in file_features
]
include_files_message = get_code_message_from_features(include_features)
include_files_tokens = count_tokens(
"\n".join(include_files_message), model, full_message=False
)

tokens_used = prompt_tokens + meta_tokens + include_files_tokens
if not suppress_context_check:
raise_if_context_exceeds_max(tokens_used)
auto_tokens = min(
get_max_tokens() - tokens_used - config.token_buffer,
config.auto_context_tokens,
)

# Get auto included features
if config.auto_context_tokens > 0 and prompt:
meta_tokens = count_tokens(
"\n".join(code_message), model, full_message=True
)
include_files_message = get_code_message_from_features(include_features)
include_files_tokens = count_tokens(
"\n".join(include_files_message), model, full_message=False
)

tokens_used = prompt_tokens + meta_tokens + include_files_tokens
auto_tokens = min(
get_max_tokens() - tokens_used - config.token_buffer,
config.auto_context_tokens,
)
features = self.get_all_features()
feature_filter = DefaultFilter(
auto_tokens,
Expand Down Expand Up @@ -444,3 +409,29 @@ async def search(
return all_features_sorted
else:
return all_features_sorted[:max_results]

def to_simple_context_dict(self) -> dict[str, list[str]]:
"""Return a simple dictionary representation of the code context"""

simple_dict: dict[str, list[str]] = {}
for path, features in self.include_files.items():
simple_dict[str(path.absolute())] = [str(feature) for feature in features]
return simple_dict

def from_simple_context_dict(self, simple_dict: dict[str, list[str]]):
"""Load the code context from a simple dictionary representation"""

for path_str, features_str in simple_dict.items():
path = Path(path_str)
features_for_path: List[CodeFeature] = []

for feature_str in features_str:
feature_path = Path(feature_str)

# feature_path is already absolute, so cwd doesn't matter
current_features = get_code_features_for_path(
feature_path, cwd=Path("/")
)
features_for_path += list(current_features)

self.include_files[path] = features_for_path
5 changes: 1 addition & 4 deletions mentat/code_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def get_code_message(self, standalone: bool = True) -> list[str]:
if standalone:
code_message.append("")

if (
code_context.diff_context is not None
and self.path in code_context.diff_context.diff_files()
):
if self.path in code_context.diff_context.diff_files():
diff = get_diff_for_file(code_context.diff_context.target, self.path)
diff_annotations = parse_diff(diff)
if self.interval.whole_file():
Expand Down
2 changes: 2 additions & 0 deletions mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from .exclude import ExcludeCommand
from .help import HelpCommand
from .include import IncludeCommand
from .load import LoadCommand
from .redo import RedoCommand
from .run import RunCommand
from .sample import SampleCommand
from .save import SaveCommand
from .screenshot import ScreenshotCommand
from .search import SearchCommand
from .talk import TalkCommand
Expand Down
72 changes: 72 additions & 0 deletions mentat/command/commands/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
from pathlib import Path
from typing import List

from typing_extensions import override

from mentat.auto_completer import get_command_filename_completions
from mentat.command.command import Command, CommandArgument
from mentat.errors import PathValidationError
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import mentat_dir_path


class LoadCommand(Command, command_name="load"):
@override
async def apply(self, *args: str) -> None:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_context = session_context.code_context
context_file_path = mentat_dir_path / "context.json"

if len(args) > 1:
stream.send(
"Only one context file can be loaded at a time", style="warning"
)
return

if args:
path_arg = args[0]
try:
context_file_path = Path(path_arg).expanduser().resolve()
except RuntimeError as e:
raise PathValidationError(
f"Invalid context file path provided: {path_arg}: {e}"
)

try:
with open(context_file_path, "r") as file:
parsed_include_files = json.load(file)
except FileNotFoundError:
stream.send(f"Context file not found at {context_file_path}", style="error")
return
except json.JSONDecodeError as e:
stream.send(
f"Failed to parse context file at {context_file_path}: {e}",
style="error",
)
return

code_context.from_simple_context_dict(parsed_include_files)

stream.send(f"Context loaded from {context_file_path}", style="success")

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("optional", ["path"])]

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return get_command_filename_completions(arguments[-1])

@override
@classmethod
def help_message(cls) -> str:
return (
"Loads a context file. Loaded context adds to existing context, it does not"
" replace it."
)
Loading

0 comments on commit b816358

Please sign in to comment.