Skip to content

Commit

Permalink
refactor(#2): collect issues in collector
Browse files Browse the repository at this point in the history
not in storage
  • Loading branch information
williamfzc committed Dec 18, 2023
1 parent 8df0588 commit a062f5e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 40 deletions.
25 changes: 24 additions & 1 deletion srctag/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import git
from git import Repo, Commit
from loguru import logger

from pydantic_settings import BaseSettings
from tqdm import tqdm

from srctag.model import FileContext, RuntimeContext, SrcTagException
from srctag.storage import MetadataConstant


class FileLevelEnum(str, Enum):
Expand Down Expand Up @@ -41,6 +41,10 @@ class CollectorConfig(BaseSettings):
# BFS: walk the commits and get each diff files
scan_rule: ScanRuleEnum = ScanRuleEnum.DFS

# issue regex for matching issue grammar
# by default, we use GitHub standard
issue_regex: str = r"(#\d+)"


class Collector(object):
def __init__(self, config: CollectorConfig = None):
Expand All @@ -62,9 +66,28 @@ def collect_metadata(self) -> RuntimeContext:
else:
self._collect_histories_globally(ctx)

# issue processing and network building
self._process_relations(ctx)

logger.info("metadata ready")
return ctx

def _process_relations(self, ctx: RuntimeContext):
regex = re.compile(self.config.issue_regex)

for each_file in ctx.files.values():
for each_commit in each_file.commits:
issue_id_list = regex.findall(each_commit.message)

ctx.relations.add_node(each_file, node_type=MetadataConstant.KEY_SOURCE)
for each_issue in issue_id_list:
ctx.relations.add_node(each_issue, node_type=MetadataConstant.KEY_ISSUE_ID)
ctx.relations.add_edge(each_issue, each_file)
# todo: missing the commit nodes
# END loop issue
# END loop commit
# END loop file

def _check_env(self) -> typing.Optional[BaseException]:
try:
repo = git.Repo(self.config.repo_root, search_parent_directories=True)
Expand Down
3 changes: 3 additions & 0 deletions srctag/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing

import networkx
from git import Commit


Expand All @@ -10,8 +11,10 @@ def __init__(self, name: str):


class RuntimeContext(object):
""" shared data between components """
def __init__(self):
self.files: typing.Dict[str, FileContext] = dict()
self.relations = networkx.Graph()


class SrcTagException(BaseException):
Expand Down
59 changes: 26 additions & 33 deletions srctag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class MetadataConstant(object):
KEY_ISSUE_ID = "issue_id"
KEY_TAG = "tag"

# use in chroma
DATA_TYPE_COMMIT_MSG = "commit_msg"
DATA_TYPE_ISSUE = "issue"

Expand All @@ -41,9 +42,6 @@ class StorageConfig(BaseSettings):
# Multi langs: paraphrase-multilingual-MiniLM-L12-v2
st_model_name: str = "paraphrase-MiniLM-L6-v2"

# issue regex for matching issue grammar
# by default, we use GitHub standard
issue_regex: str = r"(#\d+)"
# content mapping for avoiding too much I/O
# "#11" -> "content for #11"
issue_mapping: typing.Dict[str, str] = dict()
Expand All @@ -69,7 +67,7 @@ def __init__(self, config: StorageConfig = None):

self.chromadb: typing.Optional[API] = None
self.chromadb_collection: typing.Optional[Collection] = None
self.relation_graph: Graph = nx.Graph()
self.relations: Graph = nx.Graph()

def init_chroma(self):
if self.chromadb and self.chromadb_collection:
Expand All @@ -90,7 +88,7 @@ def init_chroma(self):
metadata={"hnsw:space": "l2"}
)

def process_commit_msg(self, file: FileContext, collection: Collection):
def process_commit_msg(self, file: FileContext, collection: Collection, _: RuntimeContext):
""" can be overwritten for custom processing """
targets = []
for each in file.commits:
Expand Down Expand Up @@ -118,31 +116,25 @@ def process_issue_id_to_title(self, issue_id: str) -> str:
# so we use issue_mapping, keep it simple
return self.config.issue_mapping.get(issue_id, "")

def process_issue(self, file: FileContext, collection: Collection):
regex = re.compile(self.config.issue_regex)
def process_issue(self, _: FileContext, collection: Collection, ctx: RuntimeContext):
issue_id_list = [x for x, y in ctx.relations.nodes(data=True) if
y["node_type"] == MetadataConstant.KEY_ISSUE_ID]

targets = []
for each in file.commits:
issue_id_list = regex.findall(each.message)
for each_issue_id in issue_id_list:
each_issue_content = self.process_issue_id_to_title(each_issue_id)
if not each_issue_content:
continue

item = StorageDoc(
document=each_issue_content,
metadata={
MetadataConstant.KEY_ISSUE_ID: each_issue_id,
MetadataConstant.KEY_DATA_TYPE: MetadataConstant.DATA_TYPE_ISSUE,
},
id=f"{MetadataConstant.DATA_TYPE_ISSUE}|{each_issue_id}"
)
targets.append(item)

# save to graph
self.relation_graph.add_node(each_issue_id, node_type=MetadataConstant.KEY_ISSUE_ID)
self.relation_graph.add_node(file.name, node_type=MetadataConstant.KEY_SOURCE)
self.relation_graph.add_edge(each_issue_id, file.name)
for each_issue_id in issue_id_list:
each_issue_content = self.process_issue_id_to_title(each_issue_id)
if not each_issue_content:
continue

item = StorageDoc(
document=each_issue_content,
metadata={
MetadataConstant.KEY_ISSUE_ID: each_issue_id,
MetadataConstant.KEY_DATA_TYPE: MetadataConstant.DATA_TYPE_ISSUE,
},
id=f"{MetadataConstant.DATA_TYPE_ISSUE}|{each_issue_id}"
)
targets.append(item)

# END issue loop
# END commit loop
Expand All @@ -154,27 +146,28 @@ def process_issue(self, file: FileContext, collection: Collection):
ids=[each.id],
)

def process_file_ctx(self, file: FileContext, collection: Collection):
def process_file_ctx(self, file: FileContext, collection: Collection, ctx: RuntimeContext):
process_dict = {
MetadataConstant.DATA_TYPE_ISSUE: self.process_issue,
MetadataConstant.DATA_TYPE_COMMIT_MSG: self.process_commit_msg
}
for each in self.config.data_types:
if each not in process_dict:
raise SrcTagException(f"invalid data type: {each}")
process_dict[each](file, collection)
process_dict[each](file, collection, ctx)

def embed_file(self, file: FileContext):
def embed_file(self, file: FileContext, ctx: RuntimeContext):
if not file.commits:
logger.warning(f"no related commits found: {file.name}")
return

self.init_chroma()
self.process_file_ctx(file, self.chromadb_collection)
self.process_file_ctx(file, self.chromadb_collection, ctx)

def embed_ctx(self, ctx: RuntimeContext):
self.init_chroma()
self.relations = ctx.relations
logger.info("start embedding source files")
for each_file in tqdm(ctx.files.values()):
self.embed_file(each_file)
self.embed_file(each_file, ctx)
logger.info("embedding finished")
12 changes: 6 additions & 6 deletions srctag/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def tag_with_commit(self, storage: Storage) -> TagResult:
n_results = int(doc_count * self.config.n_percent)

tag_results = []
relation_graph = storage.relation_graph.copy()
relation_graph = storage.relations.copy()
for each_tag in tqdm(self.config.tags):
query_result: QueryResult = storage.chromadb_collection.query(
query_texts=each_tag,
Expand Down Expand Up @@ -156,7 +156,7 @@ def tag_with_commit(self, storage: Storage) -> TagResult:

logger.info(f"tag finished")
# update relation graph in storage
storage.relation_graph = relation_graph
storage.relations = relation_graph

return TagResult(scores_df=scores_df)

Expand All @@ -165,7 +165,7 @@ def tag_with_issue(self, storage: Storage) -> TagResult:
n_results = int(doc_count * self.config.n_percent)

tag_results = []
relation_graph = storage.relation_graph.copy()
relation_graph = storage.relations.copy()
for each_tag in tqdm(self.config.tags):
query_result: QueryResult = storage.chromadb_collection.query(
query_texts=each_tag,
Expand All @@ -190,7 +190,7 @@ def tag_with_issue(self, storage: Storage) -> TagResult:

ret = dict()
for each_tag, each_issue_id, each_score in tag_results:
files = storage.relation_graph.neighbors(each_issue_id)
files = storage.relations.neighbors(each_issue_id)
for each_file in files:
if each_file not in ret:
# has not been touched by other tags
Expand Down Expand Up @@ -223,14 +223,14 @@ def tag_with_issue(self, storage: Storage) -> TagResult:

logger.info(f"tag finished")
# update relation graph in storage
storage.relation_graph = relation_graph
storage.relations = relation_graph
return TagResult(scores_df=scores_df)

def tag(self, storage: Storage) -> TagResult:
logger.info(f"start tagging source files ...")
storage.init_chroma()

if storage.relation_graph.number_of_nodes():
if storage.config.issue_mapping:
logger.info("tag with issue")
return self.tag_with_issue(storage)
else:
Expand Down

0 comments on commit a062f5e

Please sign in to comment.