diff --git a/mle/agents/chat.py b/mle/agents/chat.py index 5d546f7..9f21571 100644 --- a/mle/agents/chat.py +++ b/mle/agents/chat.py @@ -1,14 +1,12 @@ -import sys -import json from rich.console import Console from mle.function import * -from mle.utils import get_config, print_in_box, WorkflowCache +from mle.utils import get_config, WorkflowCache class ChatAgent: - def __init__(self, model, working_dir='.', console=None): + def __init__(self, model, memory=None, working_dir='.', console=None): """ ChatAgent assists users with planning and debugging ML projects. @@ -18,7 +16,10 @@ def __init__(self, model, working_dir='.', console=None): config_data = get_config() self.model = model + self.memory = memory self.chat_history = [] + if working_dir == '.': + working_dir = os.getcwd() self.working_dir = working_dir self.cache = WorkflowCache(working_dir, 'baseline') @@ -56,7 +57,9 @@ def __init__(self, model, working_dir='.', console=None): schema_search_papers_with_code, schema_web_search, schema_execute_command, - schema_preview_csv_data + schema_preview_csv_data, + schema_unzip_data, + schema_preview_zip_structure ] if config_data.get('search_key'): @@ -69,9 +72,9 @@ def __init__(self, model, working_dir='.', console=None): advisor_report = self.cache.resume_variable("advisor_report") self.sys_prompt += f""" The overall project information: \n - {'Dataset: ' + dataset if dataset else ''} \n - {'Requirement: ' + ml_requirement if ml_requirement else ''} \n - {'Advisor: ' + advisor_report if advisor_report else ''} \n + {'Dataset: ' + str(dataset) if dataset else ''} \n + {'Requirement: ' + str(ml_requirement) if ml_requirement else ''} \n + {'Advisor: ' + str(advisor_report) if advisor_report else ''} \n """ self.chat_history.append({"role": 'system', "content": self.sys_prompt}) @@ -84,9 +87,8 @@ def greet(self): Returns: str: The generated greeting message. """ - system_prompt = """ - You are a Chatbot designed to collaborate with users on planning and debugging ML projects. - Your goal is to provide concise and friendly greetings within 50 words, including: + greet_prompt = """ + Can you provide concise and friendly greetings within 50 words, including: 1. Infer about the project's purpose or objective. 2. Summarize the previous conversations if it existed. 2. Offering a brief overview of the assistance and support you can provide to the user, such as: @@ -96,7 +98,7 @@ def greet(self): - Providing resources and references for further learning. Make sure your greeting is inviting and sets a positive tone for collaboration. """ - self.chat_history.append({"role": "system", "content": system_prompt}) + self.chat_history.append({"role": "user", "content": greet_prompt}) greets = self.model.query( self.chat_history, function_call='auto', @@ -116,7 +118,18 @@ def chat(self, user_prompt): user_prompt: the user prompt. """ text = '' + if self.memory: + table_name = 'mle_chat_' + self.working_dir.split('/')[-1] + query = self.memory.query([user_prompt], table_name=table_name, n_results=1) # TODO: adjust the n_results. + user_prompt += f""" + \nThese reference files and their snippets may be useful for the question:\n\n + """ + + for t in query[0]: + snippet, metadata = t.get('text'), t.get('metadata') + user_prompt += f"**File**: {metadata.get('file')}\n**Snippet**: {snippet}\n" self.chat_history.append({"role": "user", "content": user_prompt}) + for content in self.model.stream( self.chat_history, function_call='auto', diff --git a/mle/cli.py b/mle/cli.py index 8f640b3..32990b4 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -7,6 +7,7 @@ import questionary from pathlib import Path from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, TextColumn, BarColumn import mle from mle.server import app @@ -18,8 +19,11 @@ startup_web, print_in_box, ) +from mle.utils import LanceDBMemory, list_files, read_file +from mle.utils import CodeChunker console = Console() +memory = LanceDBMemory(os.getcwd()) @click.group() @@ -127,7 +131,7 @@ def report_local(ctx, path, email, start_date, end_date): ).ask() return workflow.report_local(os.getcwd(), path, email, start_date=start_date, end_date=end_date) - + @cli.command() @click.option('--model', default=None, help='The model to use for the chat.') @@ -187,14 +191,46 @@ def kaggle( @cli.command() @click.option('--model', default=None, help='The model to use for the chat.') -def chat(model): +@click.option('--build_mem', is_flag=True, help='Build and enable the local memory for the chat.') +def chat(model, build_mem): """ chat: start an interactive chat with LLM to work on your ML project. """ if not check_config(console): return - return workflow.chat(os.getcwd(), model) + if build_mem: + working_dir = os.getcwd() + table_name = 'mle_chat_' + working_dir.split('/')[-1] + source_files = list_files(working_dir, ['*.py']) # TODO: support more file types + + chunker = CodeChunker(os.path.join(working_dir, '.mle', 'cache'), 'py') + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console, + ) as progress: + process_task = progress.add_task("Processing files...", total=len(source_files)) + + for file_path in source_files: + raw_code = read_file(file_path) + progress.update( + process_task, + advance=1, + description=f"Adding {os.path.basename(file_path)} to memory..." + ) + + chunks = chunker.chunk(raw_code, token_limit=100) + memory.add( + texts=list(chunks.values()), + table_name=table_name, + metadata=[{'file': file_path, 'chunk_key': k} for k, _ in chunks.items()] + ) + + return workflow.chat(os.getcwd(), model=model, memory=memory) @cli.command() diff --git a/mle/utils/__init__.py b/mle/utils/__init__.py index 29d042e..e6a6d84 100644 --- a/mle/utils/__init__.py +++ b/mle/utils/__init__.py @@ -2,3 +2,4 @@ from .cache import * from .memory import * from .data import * +from .chunk import * diff --git a/mle/utils/chunk.py b/mle/utils/chunk.py new file mode 100644 index 0000000..d29a1ad --- /dev/null +++ b/mle/utils/chunk.py @@ -0,0 +1,130 @@ +# Source modified from https://github.com/CintraAI/code-chunker/blob/main/Chunker.py +import tiktoken +from .parser import CodeParser +from abc import ABC, abstractmethod + + +def count_tokens(string: str, encoding_name: str) -> int: + encoding = tiktoken.encoding_for_model(encoding_name) + num_tokens = len(encoding.encode(string)) + return num_tokens + + +class Chunker(ABC): + def __init__(self, encoding_name="gpt-4"): + self.encoding_name = encoding_name + + @abstractmethod + def chunk(self, content, token_limit): + pass + + @abstractmethod + def get_chunk(self, chunked_content, chunk_number): + pass + + @staticmethod + def print_chunks(chunks): + for chunk_number, chunk_code in chunks.items(): + print(f"Chunk {chunk_number}:") + print("=" * 40) + print(chunk_code) + print("=" * 40) + + @staticmethod + def consolidate_chunks_into_file(chunks): + return "\n".join(chunks.values()) + + @staticmethod + def count_lines(consolidated_chunks): + lines = consolidated_chunks.split("\n") + return len(lines) + + +class CodeChunker(Chunker): + def __init__(self, cache_dir, file_extension, encoding_name="gpt-4o-mini"): + super().__init__(encoding_name) + self.file_extension = file_extension + self.cache_dir = cache_dir + + def chunk(self, code, token_limit) -> dict: + code_parser = CodeParser(self.cache_dir, self.file_extension) + chunks = {} + token_count = 0 + lines = code.split("\n") + i = 0 + chunk_number = 1 + start_line = 0 + breakpoints = sorted(code_parser.get_lines_for_points_of_interest(code, self.file_extension)) + comments = sorted(code_parser.get_lines_for_comments(code, self.file_extension)) + adjusted_breakpoints = [] + for bp in breakpoints: + current_line = bp - 1 + highest_comment_line = None # Initialize with None to indicate no comment line has been found yet + while current_line in comments: + highest_comment_line = current_line # Update highest comment line found + current_line -= 1 # Move to the previous line + + if highest_comment_line: # If a highest comment line exists, add it + adjusted_breakpoints.append(highest_comment_line) + else: + adjusted_breakpoints.append( + bp) # If no comments were found before the breakpoint, add the original breakpoint + + breakpoints = sorted(set(adjusted_breakpoints)) # Ensure breakpoints are unique and sorted + + while i < len(lines): + line = lines[i] + new_token_count = count_tokens(line, self.encoding_name) + if token_count + new_token_count > token_limit: + + # Set the stop line to the last breakpoint before the current line + if i in breakpoints: + stop_line = i + else: + stop_line = max(max([x for x in breakpoints if x < i], default=start_line), start_line) + + # If the stop line is the same as the start line, it means we haven't reached a breakpoint yet, and we need to move to the next line to find one + if stop_line == start_line and i not in breakpoints: + token_count += new_token_count + i += 1 + + # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line + elif stop_line == start_line and i == stop_line: + token_count += new_token_count + i += 1 + + # If the stop line is the same as the start line and the current line is a breakpoint, it means we can create a chunk with just the current line + elif stop_line == start_line and i in breakpoints: + current_chunk = "\n".join(lines[start_line:stop_line]) + if current_chunk.strip(): # If the current chunk is not just whitespace + chunks[chunk_number] = current_chunk # Using chunk_number as key + chunk_number += 1 + + token_count = 0 + start_line = i + i += 1 + + # If the stop line is different from the start line, it means we're at the end of a block + else: + current_chunk = "\n".join(lines[start_line:stop_line]) + if current_chunk.strip(): + chunks[chunk_number] = current_chunk # Using chunk_number as key + chunk_number += 1 + + i = stop_line + token_count = 0 + start_line = stop_line + else: + # If the token count is still within the limit, add the line to the current chunk + token_count += new_token_count + i += 1 + + # Append remaining code, if any, ensuring it's not empty or whitespace + current_chunk_code = "\n".join(lines[start_line:]) + if current_chunk_code.strip(): # Checks if the chunk is not just whitespace + chunks[chunk_number] = current_chunk_code # Using chunk_number as key + + return chunks + + def get_chunk(self, chunked_codebase, chunk_number): + return chunked_codebase[chunk_number] diff --git a/mle/utils/data.py b/mle/utils/data.py index 0ea9c14..4cb60b9 100644 --- a/mle/utils/data.py +++ b/mle/utils/data.py @@ -1,6 +1,34 @@ import re import os import json +from typing import Dict, Any + + +def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None: + """ + Write a dictionary to a markdown file. + :param data: the dictionary to write. + :param file_path: the file path to write the dictionary to. + :return: + """ + + def write_item(k, v, indent_level=0): + if isinstance(v, dict): + md_file.write(f"{'##' * (indent_level + 1)} {k}\n") + for sub_key, sub_value in v.items(): + write_item(sub_key, sub_value, indent_level + 1) + elif isinstance(v, list): + md_file.write(f"{'##' * (indent_level + 1)} {k}\n") + for item in v: + md_file.write(f"{' ' * indent_level}- {item}\n") + else: + md_file.write(f"{'##' * (indent_level + 1)} {k}\n") + md_file.write(f"{' ' * indent_level}{v}\n") + + with open(file_path, 'w') as md_file: + for key, value in data.items(): + write_item(key, value) + md_file.write("\n") def is_markdown_file(file_path): diff --git a/mle/utils/memory.py b/mle/utils/memory.py index f1ae313..b536ba6 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -160,10 +160,7 @@ def reset(self): class LanceDBMemory: - def __init__( - self, - project_path: str, - ): + def __init__(self, project_path: str): """ Memory: A base class for memory and external knowledge management. Args: @@ -180,11 +177,11 @@ def __init__( raise NotImplementedError def add( - self, - texts: List[str], - metadata: Optional[List[Dict]] = None, - table_name: Optional[str] = None, - ids: Optional[List[str]] = None, + self, + texts: List[str], + metadata: Optional[List[Dict]] = None, + table_name: Optional[str] = None, + ids: Optional[List[str]] = None, ) -> List[str]: """ Adds a list of text items to the specified memory table in the database. @@ -200,12 +197,12 @@ def add( List[str]: A list of IDs associated with the added text items. """ if isinstance(texts, str): - texts = (texts, ) + texts = (texts,) if metadata is None: metadata = [None, ] * len(texts) elif isinstance(metadata, dict): - metadata = (metadata, ) + metadata = (metadata,) else: assert len(texts) == len(metadata) diff --git a/mle/utils/parser.py b/mle/utils/parser.py new file mode 100644 index 0000000..51df53e --- /dev/null +++ b/mle/utils/parser.py @@ -0,0 +1,351 @@ +# Source modified from https://github.com/CintraAI/code-chunker/blob/main/CodeParser.py +import os +import subprocess +from typing import Dict, Tuple, Union, List +from tree_sitter import Language, Parser, Node + +import warnings + +warnings.simplefilter(action='ignore', category=FutureWarning) + + +def return_simple_line_numbers_with_code(code: str) -> str: + code_lines = code.split('\n') + code_with_line_numbers = [f"Line {i + 1}: {line}" for i, line in enumerate(code_lines)] + joined_lines = "\n".join(code_with_line_numbers) + return joined_lines + + +class CodeParser: + def __init__(self, cache_dir: str, file_extensions: Union[None, List[str], str] = None): + """ + Initialize the code parser. + + :param file_extensions: + """ + if isinstance(file_extensions, str): + file_extensions = [file_extensions] + + if cache_dir is None: + cache_dir = os.path.join(os.getcwd(), '.mle', 'parsers') + + # make dir if not exists + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + self.cache_dir = cache_dir + self.language_extension_map = { + "py": "python", + "js": "javascript", + "jsx": "javascript", + "css": "css", + "ts": "typescript", + "tsx": "typescript", + "php": "php", + "rb": "ruby" + } + if file_extensions is None: + self.language_names = [] + else: + self.language_names = [self.language_extension_map.get(ext) for ext in file_extensions if + ext in self.language_extension_map] + self.languages = {} + self._install_parsers() + + def _install_parsers(self): + + try: + # Ensure cache directory exists + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + for language in self.language_names: + repo_path = os.path.join(self.cache_dir, f"tree-sitter-{language}") + + # Check if the repository exists and contains necessary files + if not os.path.exists(repo_path) or not self._is_repo_valid(repo_path, language): + try: + if os.path.exists(repo_path): + update_command = f"cd {repo_path} && git pull" + subprocess.run(update_command, shell=True, check=True) + else: + clone_command = f"git clone https://github.com/tree-sitter/tree-sitter-{language} {repo_path}" + subprocess.run(clone_command, shell=True, check=True) + except subprocess.CalledProcessError as e: + print(f"Failed to clone/update repository for {language}. Error: {e}") + continue + + try: + build_path = os.path.join(self.cache_dir, f"build/{language}.so") + + # Special handling for TypeScript + if language == 'typescript': + ts_dir = os.path.join(repo_path, 'typescript') + tsx_dir = os.path.join(repo_path, 'tsx') + if os.path.exists(ts_dir) and os.path.exists(tsx_dir): + Language.build_library(build_path, [ts_dir, tsx_dir]) + else: + raise FileNotFoundError(f"TypeScript or TSX directory not found in {repo_path}") + if language == 'php': + php_dir = os.path.join(repo_path, 'php') + Language.build_library(build_path, [php_dir]) + else: + Language.build_library(build_path, [repo_path]) + + self.languages[language] = Language(build_path, language) + # logging.info(f"Successfully built and loaded {language} parser") + except Exception as e: + print(f"Failed to build or load language {language}. Error: {str(e)}") + + except Exception as e: + print(f"An unexpected error occurred during parser installation: {str(e)}") + + def _is_repo_valid(self, repo_path: str, language: str) -> bool: + """Check if the repository contains necessary files.""" + if language == 'typescript': + return (os.path.exists(os.path.join(repo_path, 'typescript', 'src', 'parser.c')) and + os.path.exists(os.path.join(repo_path, 'tsx', 'src', 'parser.c'))) + elif language == 'php': + return os.path.exists(os.path.join(repo_path, 'php', 'src', 'parser.c')) + else: + return os.path.exists(os.path.join(repo_path, 'src', 'parser.c')) + + def parse_code(self, code: str, file_extension: str) -> Union[None, Node]: + language_name = self.language_extension_map.get(file_extension) + if language_name is None: + print(f"Unsupported file type: {file_extension}") + return None + + language = self.languages.get(language_name) + if language is None: + print("Language parser not found") + return None + + parser = Parser() + parser.set_language(language) + tree = parser.parse(bytes(code, "utf8")) + + if tree is None: + print("Failed to parse the code") + return None + + return tree.root_node + + def extract_points_of_interest(self, node: Node, file_extension: str) -> List[Tuple[Node, str]]: + node_types_of_interest = self._get_node_types_of_interest(file_extension) + + points_of_interest = [] + if node.type in node_types_of_interest.keys(): + points_of_interest.append((node, node_types_of_interest[node.type])) + + for child in node.children: + points_of_interest.extend(self.extract_points_of_interest(child, file_extension)) + + return points_of_interest + + def _get_node_types_of_interest(self, file_extension: str) -> Dict[str, str]: + node_types = { + 'py': { + 'import_statement': 'Import', + 'export_statement': 'Export', + 'class_definition': 'Class', + 'function_definition': 'Function', + }, + 'css': { + 'tag_name': 'Tag', + '@media': 'Media Query', + }, + 'js': { + 'import_statement': 'Import', + 'export_statement': 'Export', + 'class_declaration': 'Class', + 'function_declaration': 'Function', + 'arrow_function': 'Arrow Function', + 'statement_block': 'Block', + }, + 'ts': { + 'import_statement': 'Import', + 'export_statement': 'Export', + 'class_declaration': 'Class', + 'function_declaration': 'Function', + 'arrow_function': 'Arrow Function', + 'statement_block': 'Block', + 'interface_declaration': 'Interface', + 'type_alias_declaration': 'Type Alias', + }, + 'php': { + 'namespace_definition': 'Namespace', + 'class_declaration': 'Class', + 'method_declaration': 'Method', + 'function_definition': 'Function', + 'interface_declaration': 'Interface', + 'trait_declaration': 'Trait', + }, + 'rb': { + 'class': 'Class', + 'method': 'Method', + 'module': 'Module', + 'singleton_class': 'Singleton Class', + 'begin': 'Begin Block', + } + } + + if file_extension in node_types.keys(): + return node_types[file_extension] + elif file_extension == "jsx": + return node_types["js"] + elif file_extension == "tsx": + return node_types["ts"] + else: + raise ValueError("Unsupported file type") + + def _get_nodes_for_comments(self, file_extension: str) -> Dict[str, str]: + node_types = { + 'py': { + 'comment': 'Comment', + 'decorator': 'Decorator', # Broadened category + }, + 'css': { + 'comment': 'Comment' + }, + 'js': { + 'comment': 'Comment', + 'decorator': 'Decorator', # Broadened category + }, + 'ts': { + 'comment': 'Comment', + 'decorator': 'Decorator', + }, + 'php': { + 'comment': 'Comment', + 'attribute': 'Attribute', + }, + 'rb': { + 'comment': 'Comment', + } + } + + if file_extension in node_types.keys(): + return node_types[file_extension] + elif file_extension == "jsx": + return node_types["js"] + else: + raise ValueError("Unsupported file type") + + def extract_comments(self, node: Node, file_extension: str) -> List[Tuple[Node, str]]: + node_types_of_interest = self._get_nodes_for_comments(file_extension) + + comments = [] + if node.type in node_types_of_interest: + comments.append((node, node_types_of_interest[node.type])) + + for child in node.children: + comments.extend(self.extract_comments(child, file_extension)) + + return comments + + def get_lines_for_points_of_interest(self, code: str, file_extension: str) -> List[int]: + language_name = self.language_extension_map.get(file_extension) + if language_name is None: + raise ValueError("Unsupported file type") + + language = self.languages.get(language_name) + if language is None: + raise ValueError("Language parser not found") + + parser = Parser() + parser.set_language(language) + + tree = parser.parse(bytes(code, "utf8")) + + root_node = tree.root_node + points_of_interest = self.extract_points_of_interest(root_node, file_extension) + + line_numbers_with_type_of_interest = {} + + for node, type_of_interest in points_of_interest: + start_line = node.start_point[0] + if type_of_interest not in line_numbers_with_type_of_interest: + line_numbers_with_type_of_interest[type_of_interest] = [] + + if start_line not in line_numbers_with_type_of_interest[type_of_interest]: + line_numbers_with_type_of_interest[type_of_interest].append(start_line) + + lines_of_interest = [] + for _, line_numbers in line_numbers_with_type_of_interest.items(): + lines_of_interest.extend(line_numbers) + + return lines_of_interest + + def get_lines_for_comments(self, code: str, file_extension: str) -> List[int]: + language_name = self.language_extension_map.get(file_extension) + if language_name is None: + raise ValueError("Unsupported file type") + + language = self.languages.get(language_name) + if language is None: + raise ValueError("Language parser not found") + + parser = Parser() + parser.set_language(language) + + tree = parser.parse(bytes(code, "utf8")) + + root_node = tree.root_node + comments = self.extract_comments(root_node, file_extension) + + line_numbers_with_comments = {} + + for node, type_of_interest in comments: + start_line = node.start_point[0] + if type_of_interest not in line_numbers_with_comments: + line_numbers_with_comments[type_of_interest] = [] + + if start_line not in line_numbers_with_comments[type_of_interest]: + line_numbers_with_comments[type_of_interest].append(start_line) + + lines_of_interest = [] + for _, line_numbers in line_numbers_with_comments.items(): + lines_of_interest.extend(line_numbers) + + return lines_of_interest + + def print_all_line_types(self, code: str, file_extension: str): + language_name = self.language_extension_map.get(file_extension) + if language_name is None: + print(f"Unsupported file type: {file_extension}") + return + + language = self.languages.get(language_name) + if language is None: + print("Language parser not found") + return + + parser = Parser() + parser.set_language(language) + tree = parser.parse(bytes(code, "utf8")) + + root_node = tree.root_node + line_to_node_type = self.map_line_to_node_type(root_node) + + code_lines = code.split('\n') + + for line_num, node_types in line_to_node_type.items(): + line_content = code_lines[line_num - 1] # Adjusting index for zero-based indexing + print(f"line {line_num}: {', '.join(node_types)} | Code: {line_content}") + + def map_line_to_node_type(self, node, line_to_node_type=None, depth=0): + if line_to_node_type is None: + line_to_node_type = {} + + start_line = node.start_point[0] + 1 # Tree-sitter lines are 0-indexed; converting to 1-indexed + + # Only add the node type if it's the start line of the node + if start_line not in line_to_node_type: + line_to_node_type[start_line] = [] + line_to_node_type[start_line].append(node.type) + + for child in node.children: + self.map_line_to_node_type(child, line_to_node_type, depth + 1) + + return line_to_node_type diff --git a/mle/utils/system.py b/mle/utils/system.py index bbaf8dd..535c6c3 100644 --- a/mle/utils/system.py +++ b/mle/utils/system.py @@ -4,41 +4,15 @@ import yaml import base64 import shutil +import fnmatch import requests import platform import subprocess import importlib.util -from typing import Dict, Any, Optional, Callable from rich.panel import Panel from rich.prompt import Prompt from rich.console import Console - - -def dict_to_markdown(data: Dict[str, Any], file_path: str) -> None: - """ - Write a dictionary to a markdown file. - :param data: the dictionary to write. - :param file_path: the file path to write the dictionary to. - :return: - """ - - def write_item(k, v, indent_level=0): - if isinstance(v, dict): - md_file.write(f"{'##' * (indent_level + 1)} {k}\n") - for sub_key, sub_value in v.items(): - write_item(sub_key, sub_value, indent_level + 1) - elif isinstance(v, list): - md_file.write(f"{'##' * (indent_level + 1)} {k}\n") - for item in v: - md_file.write(f"{' ' * indent_level}- {item}\n") - else: - md_file.write(f"{'##' * (indent_level + 1)} {k}\n") - md_file.write(f"{' ' * indent_level}{v}\n") - - with open(file_path, 'w') as md_file: - for key, value in data.items(): - write_item(key, value) - md_file.write("\n") +from typing import Dict, Any, Optional, Callable, List, Union def print_in_box(text: str, console: Optional[Console] = None, title: str = "", color: str = "white") -> None: @@ -175,6 +149,85 @@ def extract_file_name(text: str) -> Optional[str]: return None +def read_file(filepath: str, limit: Optional[int] = None) -> Optional[str]: + """ + Read and return file contents as string, with optional length limit. + + Args: + filepath (str): Path to the file to read + limit (Optional[int]): Maximum number of characters to read + + Returns: + Optional[str]: File contents or None if file is invalid + """ + if not os.path.isfile(filepath): + return None + + try: + with open(filepath, 'r', encoding='utf-8') as f: + if limit: + return f.read(limit) + return f.read() + except (IOError, UnicodeDecodeError): + return None + + +def list_files( + directory: str, + patterns: Optional[Union[str, List[str]]] = None, + include_hidden: bool = False +) -> List[str]: + """ + List all files in a directory, optionally filtered by wildcard patterns and visibility. + Files in hidden directories are considered hidden. + + Args: + directory (str): Path to the directory to search + patterns (Optional[Union[str, List[str]]]): Single pattern or list of wildcard patterns + include_hidden (bool): Whether to include hidden files and files in hidden directories + + Returns: + List[str]: List of absolute file paths + + Example: + list_files("/path/to/dir", ["*.txt", "*.pdf"], include_hidden=False) + """ + if not os.path.isdir(directory): + raise ValueError(f"Directory not found: {directory}") + + if isinstance(patterns, str): + patterns = [patterns] + + def is_hidden_path(path: str) -> bool: + """Check if the path or any of its parents are hidden (start with '.')""" + parts = os.path.abspath(path).split(os.sep) + # Skip the first empty part for absolute paths and drive letter for Windows + start_idx = 1 if parts[0] == '' else 0 + return any(part.startswith('.') for part in parts[start_idx:]) + + result = [] + + for root, _, files in os.walk(directory): + # Skip this directory if it's hidden and we're not including hidden files + if not include_hidden and is_hidden_path(root): + continue + + for file in files: + filepath = os.path.abspath(os.path.join(root, file)) + + # Skip hidden files if include_hidden is False + if not include_hidden and (file.startswith('.') or is_hidden_path(filepath)): + continue + + if patterns: + if any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns): + result.append(filepath) + else: + result.append(filepath) + + return result + + def list_dir_structure(start_path: str) -> str: """ List all files and directories under the given path. @@ -272,8 +325,6 @@ def get_user_id(): Get the unique user id of the current machine. """ system = platform.system() - username = None - hostname = None if system == "Windows": username = os.getenv('USERNAME', 'root') @@ -302,16 +353,18 @@ def get_session_id(): def get_langfuse_observer( - secret_key: Optional[str] = None, - public_key: Optional[str] = None, - user_id: Optional[str] = None, - session_id: Optional[str] = None, - host: Optional[str] = None, + secret_key: Optional[str] = None, + public_key: Optional[str] = None, + user_id: Optional[str] = None, + session_id: Optional[str] = None, + host: Optional[str] = None, ): """ Get the Langfuse observer. :param secret_key: Langfuse secret key. :param public_key: Langfuse public key. + :param user_id: Optional user id, defaulting to the unique user id of the current machine. + :param session_id: Optional session id, defaulting to the session id of the current process. :param host: Optional host address, defaulting to 'https://us.cloud.langfuse.com'. """ spec = importlib.util.find_spec("langfuse") @@ -366,6 +419,7 @@ def query(*args, **kwargs): session_id=session_id, ) return _fn(*args, **kwargs) + return query return _observe diff --git a/mle/workflow/__init__.py b/mle/workflow/__init__.py index 478a0b0..07c672d 100644 --- a/mle/workflow/__init__.py +++ b/mle/workflow/__init__.py @@ -1,5 +1,4 @@ +from .chat import chat from .baseline import baseline -from .report import report -from .report_local import report_local +from .report import report, report_local from .kaggle import kaggle, auto_kaggle -from .chat import chat diff --git a/mle/workflow/chat.py b/mle/workflow/chat.py index fb78043..291dc4f 100644 --- a/mle/workflow/chat.py +++ b/mle/workflow/chat.py @@ -1,7 +1,6 @@ """ Chat Mode: the mode to have an interactive chat with LLM to work on ML project. """ -import os import questionary from rich.live import Live from rich.panel import Panel @@ -12,11 +11,11 @@ from mle.agents import ChatAgent -def chat(work_dir: str, model=None): +def chat(work_dir: str, memory=None, model=None): console = Console() cache = WorkflowCache(work_dir, 'chat') model = load_model(work_dir, model) - chatbot = ChatAgent(model) + chatbot = ChatAgent(model, memory=memory) if not cache.is_empty(): if questionary.confirm(f"Would you like to continue the previous conversation?\n").ask(): diff --git a/mle/workflow/report.py b/mle/workflow/report.py index 5dff74e..11c932d 100644 --- a/mle/workflow/report.py +++ b/mle/workflow/report.py @@ -5,9 +5,9 @@ import pickle from rich.console import Console from mle.model import load_model -from mle.agents import GitHubSummaryAgent, ReportAgent from mle.utils.system import get_config, write_config, check_config from mle.integration import GoogleCalendarIntegration, github_login +from mle.agents import GitHubSummaryAgent, ReportAgent, GitSummaryAgent def ask_data(data_str: str): @@ -70,3 +70,40 @@ def report( github_summary = summarizer.summarize() return reporter.gen_report(github_summary, events, okr=okr_str) + + +def report_local( + work_dir: str, + git_path: str, + email: str, + okr_str: str = None, + start_date: str = None, + end_date: str = None, + model=None +): + """ + The workflow of the baseline mode. + :param work_dir: the working directory. + :param git_path: the path to the local Git repository. + :param email: the email address. + :param okr_str: the OKR string. + :param start_date: the start date. + :param end_date: the end date. + :param model: the model to use. + :return: + """ + + console = Console() + model = load_model(work_dir, model) + + events = None + + summarizer = GitSummaryAgent( + model, + git_path=git_path, + git_email=email, + ) + reporter = ReportAgent(model, console) + + git_summary = summarizer.summarize(start_date=start_date, end_date=end_date) + return reporter.gen_report(git_summary, events, okr=okr_str) diff --git a/mle/workflow/report_local.py b/mle/workflow/report_local.py deleted file mode 100644 index 772308b..0000000 --- a/mle/workflow/report_local.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Local Report Mode: the mode to generate the AI report based on users' local github repo -""" -from rich.console import Console -from mle.model import load_model -from mle.agents import ReportAgent, GitSummaryAgent - - -def report_local( - work_dir: str, - git_path: str, - email: str, - okr_str: str = None, - start_date: str = None, - end_date: str = None, - model=None -): - """ - The workflow of the baseline mode. - :param work_dir: the working directory. - :param git_path: the path to the local Git repository. - :param email: the email address. - :param okr_str: the OKR string. - :param start_date: the start date. - :param end_date: the end date. - :param model: the model to use. - :return: - """ - - console = Console() - model = load_model(work_dir, model) - - events = None - - summarizer = GitSummaryAgent( - model, - git_path=git_path, - git_email=email, - ) - reporter = ReportAgent(model, console) - - git_summary = summarizer.summarize(start_date=start_date, end_date=end_date) - return reporter.gen_report(git_summary, events, okr=okr_str) diff --git a/requirements.txt b/requirements.txt index d0a5d1f..0a8bf0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ uvicorn requests chromadb GitPython +tree-sitter==0.21.3 onnxruntime questionary pandas~=2.2.2