From 4494e50ba37fd6b8ede48a752c53132ec56fe646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Wed, 14 Aug 2024 16:27:20 +0200 Subject: [PATCH] Improve code index performance --- moatless/codeblocks/codeblocks.py | 174 +++++++++-------------- moatless/codeblocks/module.py | 30 ++-- moatless/codeblocks/parser/parser.py | 41 ++++-- moatless/index/code_index.py | 5 +- moatless/index/code_node.py | 11 +- moatless/index/embed_model.py | 6 +- moatless/index/epic_split.py | 39 ++--- moatless/index/retry_voyage_embedding.py | 36 +++++ moatless/repository/file.py | 34 +++-- 9 files changed, 213 insertions(+), 163 deletions(-) create mode 100644 moatless/index/retry_voyage_embedding.py diff --git a/moatless/codeblocks/codeblocks.py b/moatless/codeblocks/codeblocks.py index e7a6bacf..51267394 100644 --- a/moatless/codeblocks/codeblocks.py +++ b/moatless/codeblocks/codeblocks.py @@ -1,8 +1,8 @@ import re +from dataclasses import dataclass, field from enum import Enum -from typing import Optional +from typing import Dict, List, Optional, Set -from pydantic import BaseModel, ConfigDict, Field, model_validator, field_validator from typing_extensions import deprecated from moatless.codeblocks.parser.comment import get_comment_symbol @@ -118,9 +118,10 @@ def from_string(cls, tag: str) -> Optional["CodeBlockType"]: @deprecated("Use BlockSpans to define code block visibility instead") -class PathTree(BaseModel): - show: bool = Field(default=False, description="Show the block and all sub blocks.") - tree: dict[str, "PathTree"] = Field(default_factory=dict) +@dataclass +class PathTree: + show: bool = False + tree: dict[str, "PathTree"] = field(default_factory=dict) @staticmethod def from_block_paths(block_paths: list[BlockPath]) -> "PathTree": @@ -189,31 +190,18 @@ class RelationshipType(str, Enum): DEPENDENCY = "dependency" TYPE = "type" +@dataclass +class Relationship: + scope: ReferenceScope + external_path: list[str] = field(default_factory=list) + resolved_path: list[str] = field(default_factory=list) + path: list[str] = field(default_factory=list) + type: RelationshipType = RelationshipType.USES + identifier: Optional[str] = None -class Relationship(BaseModel): - scope: ReferenceScope = Field(description="The scope of the reference.") - identifier: Optional[str] = Field(default=None, description="ID") - type: RelationshipType = Field( - default=RelationshipType.USES, description="The type of the reference." - ) - external_path: list[str] = Field( - default=[], description="The path to the referenced parent code block." - ) - resolved_path: list[str] = Field( - default=[], description="The path to the file with the referenced code block." - ) - path: list[str] = Field( - default=[], description="The path to the referenced code block." - ) - - @classmethod - @model_validator(mode="before") - def validate_path(cls, values): - external_path = values.get("external_path") - path = values.get("path") - if not external_path and not path: + def __post_init__(self): + if not self.external_path and not self.path: raise ValueError("Cannot create Reference without external_path or path.") - return values def __hash__(self): return hash((self.scope, tuple(self.path))) @@ -237,10 +225,10 @@ def __str__(self): return f"({start_node})-[:{self.type.name} {{scope: {self.scope.value}}}]->({end_node})" - -class Parameter(BaseModel): - identifier: str = Field(description="The identifier of the parameter.") - type: Optional[str] = Field(description="The type of the parameter.") +@dataclass +class Parameter: + identifier: str + type: Optional[str] = None class SpanType(str, Enum): @@ -248,43 +236,23 @@ class SpanType(str, Enum): DOCUMENTATION = "docs" IMPLEMENTATION = "impl" - -class BlockSpan(BaseModel): - span_id: str = Field() - span_type: SpanType = Field(description="Type of span.") - start_line: int = Field(description="Start line of the span.") - end_line: int = Field(description="End line of the span.") - - initiating_block: "CodeBlock" = Field( - default=None, - description="The block that initiated the span.", - ) +@dataclass +class BlockSpan: + span_id: str + span_type: SpanType + start_line: int + end_line: int + block_paths: list[BlockPath] = field(default_factory=list) + initiating_block: Optional['CodeBlock'] = None + visible: bool = True + index: int = 0 + parent_block_path: Optional[BlockPath] = None + is_partial: bool = False + tokens: int = 0 @property def block_type(self): - return self.initiating_block.type - - # TODO: Remove - visible: bool = Field(default=True, description="If the span should be visible.") - - index: int = 0 - - parent_block_path: BlockPath = Field( - default=None, - description="Path to the parent block of the span.", - ) - - is_partial: bool = Field( - default=False, - description="If the span is covering a partial part of the parent block.", - ) - - block_paths: list[BlockPath] = Field( - default=[], - description="Block paths that should be shown when the span is shown.", - ) - - tokens: int = Field(default=0, description="Number of tokens in the span.") + return self.initiating_block.type if self.initiating_block else None def __str__(self): return f"{self.span_id} ({self.span_type.value}, {self.tokens} tokens)" @@ -296,64 +264,60 @@ def get_first_child_block_path(self): return block_path -class ValidationError(BaseModel): +@dataclass +class ValidationError: error: str -class CodeBlock(BaseModel): - content: str +@dataclass(eq=False, repr=False, slots=True) +class CodeBlock: type: CodeBlockType + content: str identifier: Optional[str] = None - parameters: list[Parameter] = [] # TODO: Move to Function sub class - relationships: list[Relationship] = [] - span_ids: set[str] = set() - belongs_to_span: BlockSpan | None = None - content_lines: list[str] = [] + parameters: List['Parameter'] = field(default_factory=list) + relationships: List['Relationship'] = field(default_factory=list) + span_ids: Set[str] = field(default_factory=set) + belongs_to_span: Optional['BlockSpan'] = None start_line: int = 0 end_line: int = 0 - properties: dict = {} + properties: Dict = field(default_factory=dict) pre_code: str = "" pre_lines: int = 0 indentation: str = "" tokens: int = 0 - children: list["CodeBlock"] = [] - validation_errors: list[ValidationError] = [] - parent: Optional["CodeBlock"] = None - previous: Optional["CodeBlock"] = None - next: Optional["CodeBlock"] = None + children: List['CodeBlock'] = field(default_factory=list) + validation_errors: List['ValidationError'] = field(default_factory=list) + parent: Optional['CodeBlock'] = None + previous: Optional['CodeBlock'] = None + next: Optional['CodeBlock'] = None - model_config = ConfigDict(arbitrary_types_allowed=True) + _content_lines: Optional[List[str]] = field(default=None, init=False) - @classmethod - @field_validator("type", mode="before") - def validate_type(cls, v): - if v is None: - raise ValueError("Cannot create CodeBlock without type.") - return v - - def __init__(self, **data): - super().__init__(**data) - for child in self.children: - child.parent = self + def __post_init__(self): + self._content_lines = None + + if self.children: + for child in self.children: + child.parent = self + + if self.pre_code and not self.indentation and not self.pre_lines: + pre_code_lines = self.pre_code.split("\n") + self.pre_lines = len(pre_code_lines) - 1 + self.indentation = pre_code_lines[-1] if self.pre_lines > 0 else self.pre_code + + @property + def content_lines(self): + if self._content_lines is None: + self._content_lines = self.content.split("\n") + return self._content_lines + def validate_pre_code(self): if self.pre_code and not re.match(r"^[ \n\\]*$", self.pre_code): raise ValueError( f"Failed to parse code block with type {self.type} and content `{self.content}`. " f"Expected pre_code to only contain spaces and line breaks. Got `{self.pre_code}`" ) - if self.pre_code and not self.indentation and not self.pre_lines: - pre_code_lines = self.pre_code.split("\n") - self.pre_lines = len(pre_code_lines) - 1 - if self.pre_lines > 0: - self.indentation = pre_code_lines[-1] - else: - self.indentation = self.pre_code - - self.content_lines = self.content.split("\n") - # if self.indentation and self.pre_lines: - # self.content_lines[1:] = [line[len(self.indentation):] for line in self.content_lines[1:]] - def last(self): if self.next: return self.next.last() @@ -894,7 +858,7 @@ def full_path(self): def module(self) -> "Module": # noqa: F821 if self.parent: return self.parent.module - return self + return None @deprecated("Use codeblock.module") def root(self) -> "Module": # noqa: F821 diff --git a/moatless/codeblocks/module.py b/moatless/codeblocks/module.py index 4aef5661..7e90d148 100644 --- a/moatless/codeblocks/module.py +++ b/moatless/codeblocks/module.py @@ -1,10 +1,8 @@ import logging -from typing import Optional +from dataclasses import field, dataclass +from typing import Optional, Dict from networkx import DiGraph -from pydantic import ( - ConfigDict, -) from moatless.codeblocks import CodeBlock, CodeBlockType from moatless.codeblocks.codeblocks import BlockSpan, SpanType @@ -12,20 +10,26 @@ logger = logging.getLogger(__name__) +@dataclass class Module(CodeBlock): - model_config = ConfigDict(arbitrary_types_allowed=True) - file_path: Optional[str] = None - content: str = None - spans_by_id: dict[str, BlockSpan] = {} + content: str = "" + spans_by_id: Dict[str, BlockSpan] = field(default_factory=dict) language: Optional[str] = None - parent: CodeBlock | None = None + code_block: CodeBlock = field(default_factory=lambda: CodeBlock(content="", type=CodeBlockType.MODULE)) + _graph: DiGraph = field(default_factory=DiGraph, init=False) # TODO: Move to central CodeGraph + + def __post_init__(self): + if not self.code_block.type == CodeBlockType.MODULE: + self.code_block.type = CodeBlockType.MODULE - _graph: DiGraph = None # TODO: Move to central CodeGraph + # Delegate other methods to self.code_block as needed + def __getattr__(self, name): + return getattr(self.code_block, name) - def __init__(self, **data): - data.setdefault("type", CodeBlockType.MODULE) - super().__init__(**data) + @property + def module(self) -> "Module": # noqa: F821 + return self def find_span_by_id(self, span_id: str) -> BlockSpan | None: return self.spans_by_id.get(span_id) diff --git a/moatless/codeblocks/parser/parser.py b/moatless/codeblocks/parser/parser.py index e9f519a3..ee935587 100644 --- a/moatless/codeblocks/parser/parser.py +++ b/moatless/codeblocks/parser/parser.py @@ -75,6 +75,8 @@ def __init__( encoding: str = "utf8", max_tokens_in_span: int = 500, min_tokens_for_docs_span: int = 100, + min_lines_to_parse_block: Optional[int] = None, # If this is set code will just be parsed if they have more line than this + enable_code_graph: bool = True, index_callback: Callable[[CodeBlock], None] | None = None, tokenizer: Callable[[str], list] | None = None, apply_gpt_tweaks: bool = False, @@ -101,11 +103,13 @@ def __init__( self._previous_block = None # TODO: Move this to CodeGraph + self._enable_code_graph = enable_code_graph self._graph = None self.tokenizer = tokenizer or get_tokenizer() self._max_tokens_in_span = max_tokens_in_span self._min_tokens_for_docs_span = min_tokens_for_docs_span + self._min_lines_to_parse_block = min_lines_to_parse_block @property def language(self): @@ -163,6 +167,13 @@ def parse_code( pre_code = content_bytes[start_byte : node.start_byte].decode(self.encoding) end_line = node.end_point[0] + # Skip parsing of non structure blocks if they have less lines than min_lines_to_parse_implementation + # But still parse classes and modules + if (node_match.first_child and self._min_lines_to_parse_block + and node_match.block_type not in [CodeBlockType.MODULE, CodeBlockType.CLASS, CodeBlockType.TEST_SUITE] + and (node.end_point[0] - node.start_point[0]) < self._min_lines_to_parse_block): + node_match.first_child = None + if node_match.first_child: end_byte = self.get_previous(node_match.first_child, node) else: @@ -177,10 +188,14 @@ def parse_code( else: identifier = None - relationships = self.create_references( - code, content_bytes, identifier, node_match - ) - parameters = self.create_parameters(content_bytes, node_match, relationships) + if self._enable_code_graph: + relationships = self.create_references( + code, content_bytes, identifier, node_match + ) + parameters = self.create_parameters(content_bytes, node_match, relationships) + else: + relationships = [] + parameters = [] if parent_block: code_block = CodeBlock( @@ -195,7 +210,6 @@ def parse_code( end_line=end_line + 1, pre_code=pre_code, content=code, - language=self.language, tokens=self._count_tokens(code), children=[], properties={ @@ -262,12 +276,13 @@ def parse_code( self.comments_with_no_span = [] - self._graph.add_node(code_block.path_string(), block=code_block) + if self._enable_code_graph: + self._graph.add_node(code_block.path_string(), block=code_block) - for relationship in relationships: - self._graph.add_edge( - code_block.path_string(), ".".join(relationship.path) - ) + for relationship in relationships: + self._graph.add_edge( + code_block.path_string(), ".".join(relationship.path) + ) else: current_span = None @@ -428,10 +443,13 @@ def find_match_with_gpt_tweaks(self, node: Node) -> NodeMatch | None: def find_match(self, node: Node) -> NodeMatch | None: self.debug_log(f"find_match() node type {node.type}") + + queries = 0 for label, node_type, query in self.queries: if node_type and node.type != node_type and node_type != "_": continue match = self._find_match(node, query, label) + queries += 1 if match: self.debug_log( f"find_match() Found match on node {node.type} with query {label}" @@ -683,7 +701,8 @@ def parse(self, content, file_path: Optional[str] = None) -> Module: self._span_counter = {} # TODO: Should me moved to a central CodeGraph - self._graph = nx.DiGraph() + if self._enable_code_graph: + self._graph = nx.DiGraph() tree = self.tree_parser.parse(content_in_bytes) module, _, _ = self.parse_code( diff --git a/moatless/index/code_index.py b/moatless/index/code_index.py index 63afdd60..cc20cc0e 100644 --- a/moatless/index/code_index.py +++ b/moatless/index/code_index.py @@ -553,6 +553,7 @@ def _vector_search( category: str = "implementation", file_pattern: Optional[str] = None, exact_content_match: Optional[str] = None, + top_k: int = 500 ): if file_pattern: query += f" file:{file_pattern}" @@ -578,7 +579,7 @@ def _vector_search( query_bundle = VectorStoreQuery( query_str=query, query_embedding=query_embedding, - similarity_top_k=500, # TODO: Fix paging? + similarity_top_k=top_k, # TODO: Fix paging? filters=filters, ) @@ -763,7 +764,7 @@ def index_callback(codeblock: CodeBlock): ] ) logger.info( - f"Prepared {len(prepared_nodes)} nodes and {prepared_tokens} tokens" + f"Run embed pipeline with {len(prepared_nodes)} nodes and {prepared_tokens} tokens" ) embedded_nodes = embed_pipeline.run( diff --git a/moatless/index/code_node.py b/moatless/index/code_node.py index 129d09b6..00f59796 100644 --- a/moatless/index/code_node.py +++ b/moatless/index/code_node.py @@ -1,4 +1,5 @@ from hashlib import sha256 +import re from llama_index.core.schema import TextNode @@ -10,5 +11,13 @@ def hash(self): metadata = self.metadata.copy() metadata.pop("start_line", None) metadata.pop("end_line", None) - doc_identity = str(self.text) + str(metadata) + metadata.pop("tokens", None) + cleaned_text = self._clean_text(self.text) + doc_identity = cleaned_text + str(metadata) return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) + + def _clean_text(self, text): + """ + Remove all whitespace and convert to lowercase to reduce the number of changes in hashes. + """ + return ''.join(text.split()).lower() \ No newline at end of file diff --git a/moatless/index/embed_model.py b/moatless/index/embed_model.py index 71783183..75417f4f 100644 --- a/moatless/index/embed_model.py +++ b/moatless/index/embed_model.py @@ -2,6 +2,8 @@ from llama_index.core.base.embeddings.base import BaseEmbedding +from moatless.index.retry_voyage_embedding import VoyageEmbeddingWithRetry + def get_embed_model(model_name: str) -> BaseEmbedding: if model_name.startswith("voyage"): @@ -17,11 +19,11 @@ def get_embed_model(model_name: str) -> BaseEmbedding: "VOYAGE_API_KEY environment variable is not set. Please set it to your Voyage API key." ) - return VoyageEmbedding( + return VoyageEmbeddingWithRetry( model_name=model_name, voyage_api_key=os.environ.get("VOYAGE_API_KEY"), truncation=True, - embed_batch_size=50, + embed_batch_size=80, ) else: # Assumes OpenAI otherwise diff --git a/moatless/index/epic_split.py b/moatless/index/epic_split.py index ce41d4c7..a912ed64 100644 --- a/moatless/index/epic_split.py +++ b/moatless/index/epic_split.py @@ -3,16 +3,15 @@ from collections.abc import Callable, Sequence from typing import Any, Optional -from llama_index.core.bridge.pydantic import Field +from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.callbacks import CallbackManager from llama_index.core.node_parser import NodeParser, TextSplitter, TokenTextSplitter from llama_index.core.node_parser.node_utils import logger from llama_index.core.schema import BaseNode, TextNode from llama_index.core.utils import get_tokenizer, get_tqdm_iterable -from moatless.codeblocks import create_parser +from moatless.codeblocks import create_parser, CodeParser from moatless.codeblocks.codeblocks import CodeBlock, CodeBlockType, PathTree -from moatless.codeblocks.parser.python import PythonParser from moatless.index.code_node import CodeNode from moatless.index.settings import CommentStrategy @@ -83,6 +82,7 @@ class EpicSplitter(NodeParser): default=None, description="Callback to call when indexing a code block." ) + _parser: CodeParser = PrivateAttr() # _fallback_code_splitter: Optional[TextSplitter] = PrivateAttr() TODO: Implement fallback when tree sitter fails def __init__( @@ -99,6 +99,7 @@ def __init__( index_callback: Optional[Callable[[CodeBlock], None]] = None, repo_path: Optional[str] = None, comment_strategy: CommentStrategy = CommentStrategy.ASSOCIATE, + min_lines_to_parse_block: int = 25, # fallback_code_splitter: Optional[TextSplitter] = None, include_non_code_files: bool = True, tokenizer: Optional[Callable] = None, @@ -109,9 +110,15 @@ def __init__( non_code_file_extensions = ["md", "txt"] callback_manager = callback_manager or CallbackManager([]) + self._parser = create_parser( + language=language, + index_callback=index_callback, + min_lines_to_parse_block=min_lines_to_parse_block, + enable_code_graph=False) # self._fallback_code_splitter = fallback_code_splitter super().__init__( + # _parser=parser, language=language, chunk_size=chunk_size, chunk_overlap=0, @@ -152,8 +159,7 @@ def _parse_nodes( starttime = time.time_ns() # TODO: Derive language from file extension - parser = create_parser(language=self.language, index_callback=self.index_callback) - codeblock = parser.parse(content, file_path=file_path) + codeblock = self._parser.parse(content, file_path=file_path) parse_time = time.time_ns() - starttime if parse_time > 1e9: @@ -252,14 +258,6 @@ def _chunk_block( comment_chunk.append(child) continue else: - if child.tokens > self.max_chunk_size: - start_content = child.content[:100] - logger.warning( - f"Skipping code block {child.path_string()} in {file_path} as it has {child.tokens} tokens which is" - f" more than chunk size {self.chunk_size}. Content: {start_content}..." - ) - continue - ignoring_comment = False if ( @@ -488,14 +486,23 @@ def _create_node( if block.belongs_to_span ] ) - metadata["span_ids"] = list(span_ids) + metadata["span_ids"] = list(sorted(span_ids)) node_id += f"_{chunk[0].path_string()}_{chunk[-1].path_string()}" content = content.strip("\n") - tokens = get_tokenizer()(content) - metadata["tokens"] = len(tokens) + tokens = count_chunk_tokens(chunk) + + # Truncate large chunks + if tokens > self.hard_token_limit: + content = content[:self.hard_token_limit] + logger.debug( + f"Truncating chunk {node_id} in {metadata['file_path']} as it has {tokens} tokens which is" + f" more than chunk size {self.chunk_size}." + ) + + metadata["tokens"] = tokens excluded_embed_metadata_keys = node.excluded_embed_metadata_keys.copy() excluded_embed_metadata_keys.extend(["start_line", "end_line", "tokens"]) diff --git a/moatless/index/retry_voyage_embedding.py b/moatless/index/retry_voyage_embedding.py new file mode 100644 index 00000000..2d5ca36f --- /dev/null +++ b/moatless/index/retry_voyage_embedding.py @@ -0,0 +1,36 @@ +import logging +from typing import List + +from llama_index.embeddings.voyageai import VoyageEmbedding +from tenacity import retry, wait_random_exponential, stop_after_attempt +from voyageai.error import InvalidRequestError + +logger = logging.getLogger(__name__) + +class VoyageEmbeddingWithRetry(VoyageEmbedding): + + @retry(wait=wait_random_exponential(multiplier=1, max=60), stop=stop_after_attempt(6)) + def _get_embedding(self, texts: List[str], input_type: str) -> List[List[float]]: + try: + return self._client.embed( + texts, + model=self.model_name, + input_type=input_type, + truncation=self.truncation, + ).embeddings + except InvalidRequestError as e: + if "Please lower the number of tokens in the batch" in str(e): + if len(texts) < 10: + raise # If batch size is already less than 10 we expect batchs to be abnormaly large and raise the error + + mid = len(texts) // 2 + first_half = texts[:mid] + second_half = texts[mid:] + + logger.info(f"Splitting batch of {len(texts)} texts into two halves of {len(first_half)} and {len(second_half)} texts.") + + embeddings_first = self._get_embedding(first_half, input_type) + embeddings_second = self._get_embedding(second_half, input_type) + + return embeddings_first + embeddings_second + raise diff --git a/moatless/repository/file.py b/moatless/repository/file.py index 5460281a..3db5a4b2 100644 --- a/moatless/repository/file.py +++ b/moatless/repository/file.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from moatless.codeblocks import get_parser_by_path from moatless.codeblocks.codeblocks import CodeBlockType, CodeBlockTypeGroup @@ -25,12 +25,10 @@ class UpdateResult: class CodeFile(BaseModel): - file_path: str - content: str - module: Module | None = None - dirty: bool = False - - model_config = ConfigDict(exclude={"module", "dirty"}) + file_path: str = Field(..., description="The path to the file") + content: str = Field(..., description="The content of the file") + _module: Module | None = PrivateAttr(None) + _dirty: bool = PrivateAttr(False) @classmethod def from_file(cls, repo_path: str, file_path: str): @@ -53,6 +51,16 @@ def from_content(cls, file_path: str, content: str): def supports_codeblocks(self): return self.module is not None + @property + def module(self) -> Module: + if not self._module: + return None + return self._module + + @property + def dirty(self) -> bool: + return self._dirty + def update_content_by_line_numbers( self, start_line_index: int, end_line_index: int, replacement_content: str ) -> UpdateResult: @@ -154,11 +162,11 @@ def update_content(self, updated_content: str) -> UpdateResult: logger.info( f"Updated content for {self.file_path} with {len(new_span_ids)} new span ids." ) - self.module = module + self._module = module else: new_span_ids = [] - self.dirty = True + self._dirty = True self.content = updated_content return UpdateResult( @@ -230,8 +238,8 @@ def get_file( self._files[file_path] = existing_file elif refresh or not from_origin: existing_file.content = found_file.content - existing_file.module = found_file.module - existing_file.dirty = False + existing_file._module = found_file.module + existing_file._dirty = False return existing_file @@ -242,11 +250,11 @@ def save_file(self, file_path: str, updated_content: Optional[str] = None): updated_content = updated_content or file.module.to_string() f.write(updated_content) - file.dirty = False + file._dirty = False def save(self): for file in self._files.values(): - if file.dirty: + if file._dirty: self.save_file(file.file_path, file.content) def matching_files(self, file_pattern: str):