Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Code RAG for Chatbot #265

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions mle/agents/chat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import sys
import json
from rich.console import Console

from mle.function import *
from mle.utils import get_config, print_in_box, WorkflowCache
from mle.utils import get_config, WorkflowCache


class ChatAgent:

def __init__(self, model, working_dir='.', console=None):
def __init__(self, model, memory=None, working_dir='.', console=None):
"""
ChatAgent assists users with planning and debugging ML projects.

Expand All @@ -18,7 +16,10 @@ def __init__(self, model, working_dir='.', console=None):
config_data = get_config()

self.model = model
self.memory = memory
self.chat_history = []
if working_dir == '.':
working_dir = os.getcwd()
self.working_dir = working_dir
self.cache = WorkflowCache(working_dir, 'baseline')

Expand Down Expand Up @@ -56,7 +57,9 @@ def __init__(self, model, working_dir='.', console=None):
schema_search_papers_with_code,
schema_web_search,
schema_execute_command,
schema_preview_csv_data
schema_preview_csv_data,
schema_unzip_data,
huangyz0918 marked this conversation as resolved.
Show resolved Hide resolved
schema_preview_zip_structure
]

if config_data.get('search_key'):
Expand All @@ -69,9 +72,9 @@ def __init__(self, model, working_dir='.', console=None):
advisor_report = self.cache.resume_variable("advisor_report")
self.sys_prompt += f"""
The overall project information: \n
{'Dataset: ' + dataset if dataset else ''} \n
{'Requirement: ' + ml_requirement if ml_requirement else ''} \n
{'Advisor: ' + advisor_report if advisor_report else ''} \n
{'Dataset: ' + str(dataset) if dataset else ''} \n
{'Requirement: ' + str(ml_requirement) if ml_requirement else ''} \n
{'Advisor: ' + str(advisor_report) if advisor_report else ''} \n
"""

self.chat_history.append({"role": 'system', "content": self.sys_prompt})
Expand All @@ -84,9 +87,8 @@ def greet(self):
Returns:
str: The generated greeting message.
"""
system_prompt = """
You are a Chatbot designed to collaborate with users on planning and debugging ML projects.
Your goal is to provide concise and friendly greetings within 50 words, including:
greet_prompt = """
Can you provide concise and friendly greetings within 50 words, including:
1. Infer about the project's purpose or objective.
2. Summarize the previous conversations if it existed.
2. Offering a brief overview of the assistance and support you can provide to the user, such as:
Expand All @@ -96,7 +98,7 @@ def greet(self):
- Providing resources and references for further learning.
Make sure your greeting is inviting and sets a positive tone for collaboration.
"""
self.chat_history.append({"role": "system", "content": system_prompt})
self.chat_history.append({"role": "user", "content": greet_prompt})
greets = self.model.query(
self.chat_history,
function_call='auto',
Expand All @@ -116,7 +118,18 @@ def chat(self, user_prompt):
user_prompt: the user prompt.
"""
text = ''
if self.memory:
table_name = 'mle_chat_' + self.working_dir.split('/')[-1]
query = self.memory.query([user_prompt], table_name=table_name, n_results=1) # TODO: adjust the n_results.
user_prompt += f"""
\nThese reference files and their snippets may be useful for the question:\n\n
"""

for t in query[0]:
snippet, metadata = t.get('text'), t.get('metadata')
user_prompt += f"**File**: {metadata.get('file')}\n**Snippet**: {snippet}\n"
self.chat_history.append({"role": "user", "content": user_prompt})

for content in self.model.stream(
self.chat_history,
function_call='auto',
Expand Down
38 changes: 35 additions & 3 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import questionary
from pathlib import Path
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn, BarColumn

import mle
from mle.server import app
Expand All @@ -18,8 +19,11 @@
startup_web,
print_in_box,
)
from mle.utils import LanceDBMemory, list_files, read_file
from mle.utils import CodeChunker

console = Console()
memory = LanceDBMemory(os.getcwd())


@click.group()
Expand Down Expand Up @@ -127,7 +131,7 @@ def report_local(ctx, path, email, start_date, end_date):
).ask()

return workflow.report_local(os.getcwd(), path, email, start_date=start_date, end_date=end_date)


@cli.command()
@click.option('--model', default=None, help='The model to use for the chat.')
Expand Down Expand Up @@ -187,14 +191,42 @@ def kaggle(

@cli.command()
@click.option('--model', default=None, help='The model to use for the chat.')
def chat(model):
@click.option('--build_mem', is_flag=True, help='Build and enable the local memory for the chat.')
def chat(model, build_mem):
"""
chat: start an interactive chat with LLM to work on your ML project.
"""
if not check_config(console):
return

return workflow.chat(os.getcwd(), model)
if build_mem:
working_dir = os.getcwd()
table_name = 'mle_chat_' + working_dir.split('/')[-1]
source_files = list_files(working_dir, ['*.py']) # TODO: support more file types

chunker = CodeChunker(os.path.join(working_dir, '.mle', 'cache'), 'py')
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
TimeElapsedColumn(),
console=console,
) as progress:
process_task = progress.add_task("Processing files...", total=len(source_files))
huangyz0918 marked this conversation as resolved.
Show resolved Hide resolved

for file_path in source_files:
raw_code = read_file(file_path)
progress.update(process_task, advance=1, description=f"Adding memory...")
huangyz0918 marked this conversation as resolved.
Show resolved Hide resolved

chunks = chunker.chunk(raw_code, token_limit=100)
memory.add(
texts=list(chunks.values()),
table_name=table_name,
metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()]
)

return workflow.chat(os.getcwd(), model=model, memory=memory)


@cli.command()
Expand Down
1 change: 1 addition & 0 deletions mle/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .cache import *
from .memory import *
from .data import *
from .chunk import *
130 changes: 130 additions & 0 deletions mle/utils/chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Source modified from https://github.com/CintraAI/code-chunker/blob/main/Chunker.py
import tiktoken
from .parser import CodeParser
from abc import ABC, abstractmethod


def count_tokens(string: str, encoding_name: str) -> int:
encoding = tiktoken.encoding_for_model(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens


class Chunker(ABC):
def __init__(self, encoding_name="gpt-4"):
self.encoding_name = encoding_name

@abstractmethod
def chunk(self, content, token_limit):
pass

@abstractmethod
def get_chunk(self, chunked_content, chunk_number):
pass

@staticmethod
def print_chunks(chunks):
for chunk_number, chunk_code in chunks.items():
print(f"Chunk {chunk_number}:")
print("=" * 40)
print(chunk_code)
print("=" * 40)

@staticmethod
def consolidate_chunks_into_file(chunks):
return "\n".join(chunks.values())

@staticmethod
def count_lines(consolidated_chunks):
lines = consolidated_chunks.split("\n")
return len(lines)


class CodeChunker(Chunker):
def __init__(self, cache_dir, file_extension, encoding_name="gpt-4o-mini"):
super().__init__(encoding_name)
self.file_extension = file_extension
self.cache_dir = cache_dir

def chunk(self, code, token_limit) -> dict:
code_parser = CodeParser(self.cache_dir, self.file_extension)
chunks = {}
token_count = 0
lines = code.split("\n")
i = 0
chunk_number = 1
start_line = 0
breakpoints = sorted(code_parser.get_lines_for_points_of_interest(code, self.file_extension))
comments = sorted(code_parser.get_lines_for_comments(code, self.file_extension))
adjusted_breakpoints = []
for bp in breakpoints:
current_line = bp - 1
highest_comment_line = None # Initialize with None to indicate no comment line has been found yet
while current_line in comments:
highest_comment_line = current_line # Update highest comment line found
current_line -= 1 # Move to the previous line

if highest_comment_line: # If a highest comment line exists, add it
adjusted_breakpoints.append(highest_comment_line)
else:
adjusted_breakpoints.append(
bp) # If no comments were found before the breakpoint, add the original breakpoint

breakpoints = sorted(set(adjusted_breakpoints)) # Ensure breakpoints are unique and sorted

while i < len(lines):
line = lines[i]
new_token_count = count_tokens(line, self.encoding_name)
if token_count + new_token_count > token_limit:

# Set the stop line to the last breakpoint before the current line
if i in breakpoints:
stop_line = i
else:
stop_line = max(max([x for x in breakpoints if x < i], default=start_line), start_line)

# If the stop line is the same as the start line, it means we haven't reached a breakpoint yet, and we need to move to the next line to find one
if stop_line == start_line and i not in breakpoints:
token_count += new_token_count
i += 1

# If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
elif stop_line == start_line and i == stop_line:
token_count += new_token_count
i += 1

# If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line
elif stop_line == start_line and i in breakpoints:
current_chunk = "\n".join(lines[start_line:stop_line])
if current_chunk.strip(): # If the current chunk is not just whitespace
chunks[chunk_number] = current_chunk # Using chunk_number as key
chunk_number += 1

token_count = 0
start_line = i
i += 1

# If the stop line is different from the start line, it means we're at the end of a block
else:
current_chunk = "\n".join(lines[start_line:stop_line])
if current_chunk.strip():
chunks[chunk_number] = current_chunk # Using chunk_number as key
chunk_number += 1

i = stop_line
token_count = 0
start_line = stop_line
else:
# If the token count is still within the limit, add the line to the current chunk
token_count += new_token_count
i += 1

# Append remaining code, if any, ensuring it's not empty or whitespace
current_chunk_code = "\n".join(lines[start_line:])
if current_chunk_code.strip(): # Checks if the chunk is not just whitespace
chunks[chunk_number] = current_chunk_code # Using chunk_number as key

return chunks

def get_chunk(self, chunked_codebase, chunk_number):
return chunked_codebase[chunk_number]
28 changes: 28 additions & 0 deletions mle/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import re
import os
import json
from typing import Dict, Any


def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None:
"""
Write a dictionary to a markdown file.
:param data: the dictionary to write.
:param file_path: the file path to write the dictionary to.
:return:
"""

def write_item(k, v, indent_level=0):
if isinstance(v, dict):
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
for sub_key, sub_value in v.items():
write_item(sub_key, sub_value, indent_level + 1)
elif isinstance(v, list):
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
for item in v:
md_file.write(f"{' ' * indent_level}- {item}\n")
else:
md_file.write(f"{'##' * (indent_level + 1)} {k}\n")
md_file.write(f"{' ' * indent_level}{v}\n")

with open(file_path, 'w') as md_file:
for key, value in data.items():
write_item(key, value)
md_file.write("\n")


def is_markdown_file(file_path):
Expand Down
19 changes: 8 additions & 11 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def reset(self):

class LanceDBMemory:

def __init__(
self,
project_path: str,
):
def __init__(self, project_path: str):
"""
Memory: A base class for memory and external knowledge management.
Args:
Expand All @@ -180,11 +177,11 @@ def __init__(
raise NotImplementedError

def add(
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""
Adds a list of text items to the specified memory table in the database.
Expand All @@ -200,12 +197,12 @@ def add(
List[str]: A list of IDs associated with the added text items.
"""
if isinstance(texts, str):
texts = (texts, )
texts = (texts,)

if metadata is None:
metadata = [None, ] * len(texts)
elif isinstance(metadata, dict):
metadata = (metadata, )
metadata = (metadata,)
else:
assert len(texts) == len(metadata)

Expand Down
Loading
Loading