Skip to content

Commit

Permalink
refactor: each commit -> each doc
Browse files Browse the repository at this point in the history
  • Loading branch information
williamfzc committed Nov 26, 2023
1 parent d20ed95 commit 194b0a5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 21 deletions.
47 changes: 35 additions & 12 deletions srctag/storage.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
import os
import typing

import chromadb
from chromadb import API
from chromadb.api.models.Collection import Collection
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from loguru import logger
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from tqdm import tqdm
from loguru import logger

from srctag.model import FileContext, RuntimeContext


class StorageDoc(BaseModel):
document: str
metadata: typing.Dict[str, str]
id: str


class MetadataConstant(object):
KEY_SOURCE = "source"
KEY_COMMIT_TIME = "commit_time"


class StorageConfig(BaseSettings):
db_path: str = ""
collection_name: str = "default_collection"
Expand Down Expand Up @@ -47,24 +58,36 @@ def init_chroma(self):
),
)

def process_file_ctx_to_doc(self, file: FileContext) -> str:
def process_file_ctx(self, file: FileContext, collection: Collection):
""" can be overwritten for custom processing """
sentences = [each.message.split(os.linesep)[0] for each in file.commits]
doc = os.linesep.join(sentences)
return doc

targets = []
for each in file.commits:
# keep enough data in metadata for calc the final score
item = StorageDoc(
document=each.message,
metadata={
MetadataConstant.KEY_SOURCE: file.name,
MetadataConstant.KEY_COMMIT_TIME: str(int(each.committed_datetime.timestamp())),
},
id=f"{file.name}|{each.hexsha}"
)
targets.append(item)

for each in targets:
collection.add(
documents=[each.document],
metadatas=[each.metadata],
ids=[each.id],
)

def embed_file(self, file: FileContext):
if not file.commits:
logger.warning(f"no related commits found: {file.name}")
return

doc = self.process_file_ctx_to_doc(file)
self.init_chroma()
self.chromadb_collection.add(
documents=[doc],
metadatas=[{"source": file.name}],
ids=[file.name],
)
self.process_file_ctx(file, self.chromadb_collection)

def embed_ctx(self, ctx: RuntimeContext):
self.init_chroma()
Expand Down
35 changes: 26 additions & 9 deletions srctag/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm
from loguru import logger

from srctag.storage import Storage
from srctag.storage import Storage, MetadataConstant


class TagResult(object):
Expand Down Expand Up @@ -50,11 +50,12 @@ def __init__(self, config: TaggerConfig = None):

def tag(self, storage: Storage) -> TagResult:
storage.init_chroma()
file_count = storage.chromadb_collection.count()
n_results = int(file_count * self.config.n_percent)
doc_count = storage.chromadb_collection.count()
n_results = int(doc_count * self.config.n_percent)

logger.info(f"start tagging source files ...")
ret = dict()

tag_results = []
for each_tag in tqdm(self.config.tags):
query_result: QueryResult = storage.chromadb_collection.query(
query_texts=each_tag,
Expand All @@ -76,13 +77,29 @@ def tag(self, storage: Storage) -> TagResult:
]

for each_metadata, each_score in zip(metadatas, normalized_scores):
each_file_name = each_metadata["source"]
if each_file_name not in ret:
ret[each_file_name] = OrderedDict()
ret[each_file_name][each_tag] = each_score
each_file_name = each_metadata[MetadataConstant.KEY_SOURCE]
tag_results.append((each_tag, each_file_name, each_score))
# END file loop
# END tag loop

logger.info(f"tag finished")
ret = dict()
for each_tag, each_file_name, each_score in tag_results:
if each_file_name not in ret:
# has not been touched by other tags
# the score order is decreasing
ret[each_file_name] = OrderedDict()
each_file_tag_result = ret[each_file_name]

if each_tag not in each_file_tag_result:
each_file_tag_result[each_tag] = each_score
else:
# has been touched by other commits
# merge these scores
each_file_tag_result[each_tag] += each_score
# END tag_results

scores_df = pd.DataFrame.from_dict(ret, orient="index")
# tag level normalization after merge
scores_df = scores_df.apply(lambda x: (x - x.min()) / (x.max() - x.min()), axis=0)
logger.info(f"tag finished")
return TagResult(scores_df=scores_df)

0 comments on commit 194b0a5

Please sign in to comment.