Skip to content

Commit

Permalink
feat(#1): issue processor
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Dec 13, 2023
1 parent 48b5db9 commit 7821d31
Showing 1 changed file with 67 additions and 4 deletions.
71 changes: 67 additions & 4 deletions srctag/storage.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import re
import typing

import chromadb
Expand All @@ -9,7 +11,7 @@
from pydantic_settings import BaseSettings
from tqdm import tqdm

from srctag.model import FileContext, RuntimeContext
from srctag.model import FileContext, RuntimeContext, SrcTagException


class StorageDoc(BaseModel):
Expand All @@ -21,6 +23,10 @@ class StorageDoc(BaseModel):
class MetadataConstant(object):
KEY_SOURCE = "source"
KEY_COMMIT_SHA = "commit_sha"
KEY_DATA_TYPE = "data_type"

DATA_TYPE_COMMIT_MSG = "commit_msg"
DATA_TYPE_ISSUE = "issue"


class StorageConfig(BaseSettings):
Expand All @@ -31,6 +37,13 @@ class StorageConfig(BaseSettings):
# Multi langs: paraphrase-multilingual-MiniLM-L12-v2
st_model_name: str = "paraphrase-MiniLM-L6-v2"

# issue regex for matching issue grammar
# by default, we use GitHub standard
issue_regex: str = r"(#\d+)"
issue_mapping: typing.Dict[str, str] = dict()

data_types: typing.Set[str] = {MetadataConstant.DATA_TYPE_COMMIT_MSG, MetadataConstant.DATA_TYPE_ISSUE}


class Storage(object):
def __init__(self, config: StorageConfig = None):
Expand Down Expand Up @@ -60,9 +73,8 @@ def init_chroma(self):
metadata={"hnsw:space": "l2"}
)

def process_file_ctx(self, file: FileContext, collection: Collection):
def process_commit_msg(self, file: FileContext, collection: Collection):
""" can be overwritten for custom processing """

targets = []
for each in file.commits:
# keep enough data in metadata for calc the final score
Expand All @@ -71,10 +83,51 @@ def process_file_ctx(self, file: FileContext, collection: Collection):
metadata={
MetadataConstant.KEY_SOURCE: file.name,
MetadataConstant.KEY_COMMIT_SHA: str(each.hexsha),
MetadataConstant.KEY_DATA_TYPE: MetadataConstant.DATA_TYPE_COMMIT_MSG,
},
id=f"{file.name}|{each.hexsha}|{MetadataConstant.DATA_TYPE_COMMIT_MSG}"
)
targets.append(item)

for each in targets:
collection.add(
documents=[each.document],
metadatas=[each.metadata],
ids=[each.id],
)

def process_issue_id_to_title(self, issue_id: str) -> str:
# easily reach the API limit if using server API here,
# so we use issue_mapping, keep it simple
return self.config.issue_mapping.get(issue_id, "")

def process_issue(self, file: FileContext, collection: Collection):
regex = re.compile(self.config.issue_regex)

targets = []
for each in file.commits:
issue_id_list = regex.findall(each.message)
issue_contents = []
for each_issue in issue_id_list:
each_issue_content = self.process_issue_id_to_title(each_issue)
if not each_issue_content:
continue
issue_contents.append(each_issue_content)
# END issue loop

if not issue_contents:
continue
item = StorageDoc(
document=os.sep.join(issue_contents),
metadata={
MetadataConstant.KEY_SOURCE: file.name,
MetadataConstant.KEY_COMMIT_SHA: str(each.hexsha),
MetadataConstant.KEY_DATA_TYPE: MetadataConstant.DATA_TYPE_ISSUE,
},
id=f"{file.name}|{each.hexsha}"
id=f"{file.name}|{each.hexsha}|{MetadataConstant.DATA_TYPE_ISSUE}"
)
targets.append(item)
# END commit loop

for each in targets:
collection.add(
Expand All @@ -83,6 +136,16 @@ def process_file_ctx(self, file: FileContext, collection: Collection):
ids=[each.id],
)

def process_file_ctx(self, file: FileContext, collection: Collection):
process_dict = {
MetadataConstant.DATA_TYPE_ISSUE: self.process_issue,
MetadataConstant.DATA_TYPE_COMMIT_MSG: self.process_commit_msg
}
for each in self.config.data_types:
if each not in process_dict:
raise SrcTagException(f"invalid data type: {each}")
process_dict[each](file, collection)

def embed_file(self, file: FileContext):
if not file.commits:
logger.warning(f"no related commits found: {file.name}")
Expand Down

0 comments on commit 7821d31

Please sign in to comment.