Skip to content

Commit

Permalink
Add GeneralReports.generate_reports() function
Browse files Browse the repository at this point in the history
Signed-off-by: Konstantin Slavnov <[email protected]>
  • Loading branch information
zurk committed Mar 7, 2019
1 parent ab29f54 commit 07b2cc4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 129 deletions.
92 changes: 53 additions & 39 deletions lookout/style/format/benchmarks/general_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import json
import logging
import os
from typing import Any, Dict, Iterable, List, Mapping, NamedTuple, Optional, Sequence, Type, Union
from typing import Any, Dict, Iterable, List, Mapping, NamedTuple, Optional, Sequence, Tuple, \
Type, Union

from bblfsh import BblfshClient
from lookout.core.analyzer import ReferencePointer
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_data(self):

def analyze_files(analyzer_type: Type[FormatAnalyzer], config: dict, model_path: str,
language: str, bblfsh: str, input_pattern: str, log: logging.Logger,
) -> List[FileFix]:
) -> List[Comment]:
"""Run the model, record the fixes for each file and return them."""
class FakePointer:
def to_pb(self):
Expand Down Expand Up @@ -216,28 +217,21 @@ class ReportAnalyzer(FormatAnalyzerSpy):
default_config = merge_dicts(FormatAnalyzer.default_config,
{"aggregate": False})

def generate_train_report(self, fixes: Iterable[FileFix]) -> str:
@classmethod
def get_report_names(cls) -> Tuple[str, ...]:
"""
Generate report on the train dataset.
Get all available report names.
:param fixes: fixes with all required information for report generation.
:return: Report.
:return: List of report names.
"""
raise NotImplementedError()

def generate_model_report(self) -> str:
def generate_reports(self, fixes: Iterable[FileFix]) -> Dict[str, str]:
"""
Generate report about the trained model.
General function to generate reports.
:return: Report.
"""
return ""

def generate_test_report(self) -> str:
"""
Generate report on the test dataset.
:return: Report.
:param fixes: List of fixes per file or for all files if config["aggregate"] is True.
:return: Dictionary with report names as keys and report string as values.
"""
raise NotImplementedError()

Expand All @@ -254,29 +248,21 @@ def analyze(self, ptr_from: ReferencePointer, ptr_to: ReferencePointer,
:param data: Contains "files" - the list of changes in the pointed state.
:return: List of comments.
"""
def convert_fixes_to_report_comments(fixes: List[FileFix], filepath: str):
for report in self.generate_reports(fixes=fixes).values():
yield generate_comment(filename=filepath, line=0, confidence=100, text=report)

comments = []
fixes = []
for fix in self.run(ptr_from, data_service):
filepath = fix.head_file.path
if fix.error:
continue
if self.config["aggregate"]:
fixes.append(fix)
else:
report = self.generate_train_report(fixes=[fix])
comments.append(generate_comment(
filename=filepath, line=0, confidence=100, text=report))
if self.config["aggregate"]:
report = self.generate_train_report(fixes=fixes)
comments.append(generate_comment(
filename="", line=0, confidence=100, text=report))
comments.append(generate_comment(
filename="", line=0, confidence=100, text=self.generate_model_report()))
try:
comments.append(generate_comment(
filename="", line=0, confidence=100, text=self.generate_test_report()))
except ValueError:
pass
if not self.config["aggregate"]:
for fix in self.run(ptr_from, data_service):
filepath = fix.head_file.path
if fix.error:
continue
comments.extend(convert_fixes_to_report_comments([fix], filepath))
else:
comments.extend(
convert_fixes_to_report_comments(
[fix for fix in self.run(ptr_from, data_service) if not fix.error], ""))
return comments


Expand Down Expand Up @@ -316,6 +302,34 @@ class QualityReportAnalyzer(ReportAnalyzer):
"train": {"language_defaults": {"test_dataset_ratio": 0.2}},
})

@classmethod
def get_report_names(cls) -> Tuple[str, str, str]:
"""
Get all available report names.
:return: Tuple with report names.
"""
return "model", "train", "test"

def generate_reports(self, fixes: Iterable[FileFix]) -> Dict[str, str]:
"""
Generate model train and test reports.
Model report generated only if config["aggregate"] is True.
:param fixes: List of fixes per file or for all files if config["aggregate"] is True.
:return: Ordered dictionary with report names as keys and report string as values.
"""
reports = OrderedDict() # to keep reports order.
if self.config["aggregate"]:
reports["model"] = self.generate_model_report()
try:
reports["train"] = self.generate_train_report(fixes)
except ValueError as e:
self._log.warning("Train report generation failed. %s", e.args[0])
reports["test"] = self.generate_test_report()
return reports

def generate_model_report(self) -> str:
"""
Generate report about the trained model.
Expand Down
121 changes: 40 additions & 81 deletions lookout/style/format/benchmarks/quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import subprocess
import sys
import tempfile
from typing import Iterable, Iterator, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Dict, Iterable, Iterator, Mapping, NamedTuple, Optional, Sequence, Union

from dulwich import porcelain
from lookout.core.helpers.analyzer_context_manager import AnalyzerContextManager
Expand Down Expand Up @@ -55,32 +55,13 @@ def ensure_repo(repository: str, storage_dir: str) -> str:
return git_dir


class QualityReport:
"""
Storage for reports generated by QualityReportAnalyzer.
"""

def __init__(self, train_report: Optional[str] = None, model_report: Optional[str] = None,
test_report: Optional[str] = None) -> None:
"""
Init method.
:param train_report: Train report string generated by generate_train_report.
:param model_report: Model report string generated by generate_model_report.
:param test_report: Test report string generated by generate_test_report.
"""
self.train_report = train_report
self.model_report = model_report
self.test_report = test_report


class RestartReport(ValueError):
"""Exception raises if report collection should be restarted."""


def measure_quality(repository: str, from_commit: str, to_commit: str,
context: AnalyzerContextManager, config: dict, bblfsh: Optional[str],
vnodes_expected_number: Optional[int], restarts: int=3) -> QualityReport:
vnodes_expected_number: Optional[int], restarts: int=3) -> Dict[str, str]:
"""
Generate `QualityReport` for a repository. If it fails it returns empty reports.
Expand All @@ -94,9 +75,8 @@ def measure_quality(repository: str, from_commit: str, to_commit: str,
report collection will be restarted if number of extracted \
vnodes does not match.
:param restarts: Number of restarts if number of extracted vnodes does not match.
:return: Reports.
:return: Dictionary with all QualityReport reports.
"""
report = QualityReport()
log = logging.getLogger("QualityAnalyzer")

# This dirty hack should be removed as soon as
Expand All @@ -119,25 +99,23 @@ def _convert_files_to_xy(self, parsed_files):
log.info("VNodes number match to expected %d. ", vnodes_expected_number)
return _convert_files_to_xy_backup(self, parsed_files)

def capture_report(func, name):
reports = {}

def capture_reports(func):
@functools.wraps(func)
def wrapped_capture_quality_report(*args, **kwargs):
if getattr(report, name) is not None:
raise RuntimeError("%s should be called only one time." % func.__name__)
def wrapped_capture_quality_reports(*args, **kwargs):
nonlocal reports
if reports:
raise RuntimeError("generate_reports should be called only one time.")
result = func(*args, **kwargs)
setattr(report, name, result)
reports = result
return result
wrapped_capture_quality_report.original = func
return wrapped_capture_quality_report
reports = {
"model_report": "generate_model_report",
"train_report": "generate_train_report",
"test_report": "generate_test_report",
}
wrapped_capture_quality_reports.original = func
return wrapped_capture_quality_reports

try:
for name in reports:
setattr(QualityReportAnalyzer, reports[name],
capture_report(getattr(QualityReportAnalyzer, reports[name]), name))
QualityReportAnalyzer.generate_reports = \
capture_reports(QualityReportAnalyzer.generate_reports)
if vnodes_expected_number:
log.info("Vnodes expected number is equal to %d", vnodes_expected_number)
FeatureExtractor._convert_files_to_xy = _convert_files_to_xy
Expand All @@ -158,12 +136,10 @@ def wrapped_capture_quality_report(*args, **kwargs):
context.review(fr=from_commit, to=to_commit, git_dir=git_dir, log_level="warning",
bblfsh=bblfsh, config_json=config)
finally:
for name in reports:
setattr(QualityReportAnalyzer, reports[name],
getattr(QualityReportAnalyzer, reports[name]).original)
QualityReportAnalyzer.generate_reports = QualityReportAnalyzer.generate_reports.original
if vnodes_expected_number:
FeatureExtractor._convert_files_to_xy = _convert_files_to_xy_backup
return report
return reports


def calc_weighted_avg(arr: Union[Sequence[Sequence], numpy.ndarray], col: int,
Expand Down Expand Up @@ -250,21 +226,20 @@ def handle_input_arg(input_arg: str, log: Optional[logging.Logger] = None) -> It
yield line


def _generate_report_summary(reports: Iterable[Tuple[str, QualityReport]], report_name: str,
) -> str:
def _generate_report_summary(reports: Iterable[Mapping[str, str]], report_name: str) -> str:
# precision, recall, f1, support, n_rules, avg_len stats
additional_fields = ("Rules Number", "Average Rule Len")
table = []
fields2id = OrderedDict()
for repo, report in reports:
metrics = _get_metrics(getattr(report, report_name))
metrics = _get_metrics(report[report_name])
if not table:
table.append(("repo",) + metrics._fields + additional_fields)
for i, field in enumerate(table[0]):
fields2id[field] = i
n_rules, avg_len = _get_model_summary(report.model_report)
n_rules, avg_len = _get_model_summary(report["model"])
table.append((get_repo_name(repo),) + metrics + (n_rules, avg_len))
avgvals = tuple(calc_avg(table[1:], fields2id[field]) for field in metrics._fields)
avgvals = tuple(calc_avg(table[1:], fields2id[field]) for field in Metrics._fields)
average = tuple(("%" + FLOAT_PRECISION) % v for v in avgvals[:-2])
average += tuple("%d" % v for v in avgvals[-2:]) # support, full_support
average += tuple(("%d", "%.1f")[i] % calc_avg(table[1:], fields2id[field])
Expand Down Expand Up @@ -346,53 +321,37 @@ def generate_quality_report(input: str, output: str, force: bool, bblfsh: str, c
now + left if left is not None else None, " " * 11,
"=" * 80,
)
report_loc = os.path.join(output, get_repo_name(row["url"]))
train_rep_loc = report_loc + ".train_report.md"
model_rep_loc = report_loc + ".model_report.md"
test_rep_loc = report_loc + ".test_report.md"
# generate or read report
path_tmpl = os.path.join(output, get_repo_name(row["url"])) + ".%s_report.md"
try:
if force or not os.path.exists(train_rep_loc) or \
not os.path.exists(model_rep_loc):
# Skip this step if report was already generated
if force or not any(os.path.exists(path_tmpl % name)
for name in QualityReportAnalyzer.get_report_names()):
vnodes_expected_number = int(row["vnodes_number"]) \
if "vnodes_number" in row else None
report = measure_quality(
row["url"], to_commit=row["to"], from_commit=row["from"],
context=context, config=config, bblfsh=bblfsh,
vnodes_expected_number=vnodes_expected_number)
if report.train_report is not None:
with open(train_rep_loc, "w", encoding="utf-8") as f:
f.write(report.train_report)
if report.model_report is not None:
with open(model_rep_loc, "w", encoding="utf-8") as f:
f.write(report.model_report)
if report.test_report is not None:
with open(test_rep_loc, "w", encoding="utf-8") as f:
f.write(report.test_report)
for report_name in report:
with open(path_tmpl % report_name, "w", encoding="utf-8") as f:
f.write(report[report_name])
else:
report = {}
log.info("Found existing reports for %s in %s", row["url"], output)
report = QualityReport()
with open(train_rep_loc, encoding="utf-8") as f:
report.train_report = f.read()
with open(model_rep_loc, encoding="utf-8") as f:
report.model_report = f.read()
with open(test_rep_loc, encoding="utf-8") as f:
report.test_report = f.read()
if (report.train_report is not None and
report.model_report is not None and
report.test_report is not None):
reports.append((row["url"], report))
else:
log.warning("skipped %s: train_report %s, model_report %s, test_report %s",
row["url"], report.train_report is not None,
report.model_report is not None,
report.test_report is not None)
for report_name in QualityReportAnalyzer.get_report_names():
report_path = path_tmpl % report_name
if not os.path.exists(report_path):
log.warning(
"skipped %s. %s report is missing", row["url"], report_name)
break
with open(path_tmpl % report_name, encoding="utf-8") as f:
report[report_name] = f.read()
else:
reports.append((row["url"], report))
except Exception:
log.exception("-" * 20 + "\nFailed to process %s repo", row["url"])
continue

for report_name in ("train_report", "test_report"):
for report_name in ("train", "test"):
summary = _generate_report_summary(reports, report_name)
log.info("\n%s\n%s", report_name, summary)
summary_loc = os.path.join(output, "summary-%s.md" % report_name)
Expand Down
18 changes: 9 additions & 9 deletions lookout/style/format/tests/test_quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def __exit__(self, *args):
class QualityReportTests(PretrainedModelTests):
def test_eval_empty_input(self):
"""Test on empty folder - expect only model and test report."""
config = {"analyze": {"language_defaults": {"uast_break_check": False}}}
config = {"analyze": {"language_defaults": {"uast_break_check": False}},
"aggregate": True}
with tempfile.TemporaryDirectory() as folder:
input_pattern = os.path.join(folder, "**", "*")
with Capturing() as output:
Expand Down Expand Up @@ -113,23 +114,22 @@ def test_eval_empty_input(self):

def test_eval(self):
"""Test on normal input."""
q_report_header = "# Train report for javascript"
q_report_header_train = "# Train report for javascript"
q_report_header_test = "# Test report for javascript"
input_pattern = os.path.join(self.jquery_dir, "**", "*")
with Capturing() as output:
print_reports(input_pattern=input_pattern, bblfsh=self.bblfsh,
language=self.language, model_path=self.model_path,
config={"analyze": {"language_defaults": {"uast_break_check": False}}})
self.assertIn(q_report_header, output[0])
self.assertIn("### Summary", output)
self.assertIn(q_report_header_train, output[0])
self.assertIn("### Classification report", output)
self.assertGreater(len(output), 100)
output = "\n".join(output)
test_report_start = output.find("Test report")
self.assertNotEqual(test_report_start, -1)
output = output[:test_report_start]
self.assertIn("javascript", _get_json_data(output))
self.assertIn("# Model report", output)
qcount = output.count(q_report_header)
qcount = output.count(q_report_header_train)
self.assertEqual(qcount, 14)
qcount = output.count(q_report_header_test)
self.assertEqual(qcount, 14)

def test_eval_aggregate(self):
Expand All @@ -145,7 +145,7 @@ def test_eval_aggregate(self):
output = "\n".join(output)
qcount = output.count(q_report_header)
self.assertEqual(qcount, 1)
output = output[:output.find("# Model report for")]
output = output[output.find("# Train report"):output.find("# Test report")]
metrics = _get_metrics(output)
expected_metrics = (0.9292385057471264, 0.9292385057471264,
0.8507070042749095, 0.9292385057471263,
Expand Down

0 comments on commit 07b2cc4

Please sign in to comment.