Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed May 8, 2023
1 parent 8f2a3d3 commit 0776b5c
Show file tree
Hide file tree
Showing 1,012 changed files with 19,244 additions and 14,743 deletions.
82 changes: 43 additions & 39 deletions mmsegmentation/.dev/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,61 @@
from mmseg.utils import get_root_logger

# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")


def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
"""Download checkpoint and check if hash code is true."""
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
url = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}" # noqa

r = requests.get(url)
assert r.status_code != 403, f'{url} Access denied.'
assert r.status_code != 403, f"{url} Access denied."

with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
with open(osp.join(collect_dir, checkpoint_name), "wb") as code:
code.write(r.content)

true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
true_hash_code = osp.splitext(checkpoint_name)[0].split("-")[1]

# check hash code
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
with open(osp.join(collect_dir, checkpoint_name), "rb") as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
cur_hash_code = sha256_cal.hexdigest()[:8]

assert true_hash_code == cur_hash_code, f'{url} download failed, '
'incomplete downloaded file or url invalid.'
assert true_hash_code == cur_hash_code, f"{url} download failed, "
"incomplete downloaded file or url invalid."

if cur_hash_code != true_hash_code:
os.remove(osp.join(collect_dir, checkpoint_name))


def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument("config", help="test config file path")
parser.add_argument("checkpoint_root", help="Checkpoint file root path")
parser.add_argument("-i", "--img", default="demo/demo.png", help="Image file")
parser.add_argument("-a", "--aug", action="store_true", help="aug test")
parser.add_argument("-m", "--model-name", help="model name to inference")
parser.add_argument("-s", "--show", action="store_true", help="show results")
parser.add_argument(
'-i', '--img', default='demo/demo.png', help='Image file')
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
parser.add_argument('-m', '--model-name', help='model name to inference')
parser.add_argument(
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
"-d", "--device", default="cuda:0", help="Device used for inference"
)
return parser.parse_args()


def inference_model(config_name, checkpoint, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
if 'flip' in cfg.data.test.pipeline[
1] and 'img_scale' in cfg.data.test.pipeline[1]:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
if (
"flip" in cfg.data.test.pipeline[1]
and "img_scale" in cfg.data.test.pipeline[1]
):
cfg.data.test.pipeline[1].img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
cfg.data.test.pipeline[1].flip = True
elif logger is None:
print(f'{config_name}: unable to start aug test', flush=True)
print(f"{config_name}: unable to start aug test", flush=True)
else:
logger.error(f'{config_name}: unable to start aug test')
logger.error(f"{config_name}: unable to start aug test")

model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
Expand All @@ -94,44 +93,49 @@ def main(args):
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
config_name = model_info["config"].strip()
print(f"processing: {config_name}", flush=True)
checkpoint = osp.join(
args.checkpoint_root, model_info["checkpoint"].strip()
)
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, args)
except Exception:
print(f'{config_name} test failed!')
print(f"{config_name} test failed!")
continue
return
else:
raise RuntimeError('model name input error.')
raise RuntimeError("model name input error.")

# test all model
logger = get_root_logger(
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
log_file="benchmark_inference_image.log", log_level=logging.ERROR
)

for model_name in config:
model_infos = config[model_name]

if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_path = model_info['config'].strip()
print("processing: ", model_info["config"], flush=True)
config_path = model_info["config"].strip()
config_name = osp.splitext(osp.basename(config_path))[0]
checkpoint_name = model_info['checkpoint'].strip()
checkpoint_name = model_info["checkpoint"].strip()
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)

# ensure checkpoint exists
try:
if not osp.exists(checkpoint):
download_checkpoint(checkpoint_name, model_name,
config_name.rstrip('.py'),
args.checkpoint_root)
download_checkpoint(
checkpoint_name,
model_name,
config_name.rstrip(".py"),
args.checkpoint_root,
)
except Exception:
logger.error(f'{checkpoint_name} download error')
logger.error(f"{checkpoint_name} download error")
continue

# test model inference with checkpoint
Expand All @@ -142,6 +146,6 @@ def main(args):
logger.error(f'{config_path} " : {repr(e)}')


if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
main(args)
65 changes: 35 additions & 30 deletions mmsegmentation/.dev/check_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ def check_url(url):


def parse_args():
parser = ArgumentParser('url valid check.')
parser = ArgumentParser("url valid check.")
parser.add_argument(
'-m',
'--model-name',
type=str,
help='Select the model needed to check')
"-m", "--model-name", type=str, help="Select the model needed to check"
)

return parser.parse_args()

Expand All @@ -42,56 +40,63 @@ def main():
# yml path generate.
# If model_name is not set, script will check all of the models.
if model_name is not None:
yml_list = [(model_name, f'configs/{model_name}/{model_name}.yml')]
yml_list = [(model_name, f"configs/{model_name}/{model_name}.yml")]
else:
# check all
yml_list = [(x, f'configs/{x}/{x}.yml') for x in os.listdir('configs/')
if x != '_base_']
yml_list = [
(x, f"configs/{x}/{x}.yml") for x in os.listdir("configs/") if x != "_base_"
]

logger = get_root_logger(log_file='url_check.log', log_level=logging.ERROR)
logger = get_root_logger(log_file="url_check.log", log_level=logging.ERROR)

for model_name, yml_path in yml_list:
# Default yaml loader unsafe.
model_infos = yml.load(
open(yml_path, 'r'), Loader=yml.CLoader)['Models']
model_infos = yml.load(open(yml_path), Loader=yml.CLoader)["Models"]
for model_info in model_infos:
config_name = model_info['Name']
checkpoint_url = model_info['Weights']
config_name = model_info["Name"]
checkpoint_url = model_info["Weights"]
# checkpoint url check
status_code, flag = check_url(checkpoint_url)
if flag:
logger.info(f'checkpoint | {config_name} | {checkpoint_url} | '
f'{status_code} valid')
logger.info(
f"checkpoint | {config_name} | {checkpoint_url} | "
f"{status_code} valid"
)
else:
logger.error(
f'checkpoint | {config_name} | {checkpoint_url} | '
f'{status_code} | error')
f"checkpoint | {config_name} | {checkpoint_url} | "
f"{status_code} | error"
)
# log_json check
checkpoint_name = checkpoint_url.split('/')[-1]
model_time = '-'.join(checkpoint_name.split('-')[:-1]).replace(
f'{config_name}_', '')
checkpoint_name = checkpoint_url.split("/")[-1]
model_time = "-".join(checkpoint_name.split("-")[:-1]).replace(
f"{config_name}_", ""
)
# two style of log_json name
# use '_' to link model_time (will be deprecated)
log_json_url_1 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json' # noqa
log_json_url_1 = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}_{model_time}.log.json" # noqa
status_code_1, flag_1 = check_url(log_json_url_1)
# use '-' to link model_time
log_json_url_2 = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json' # noqa
log_json_url_2 = f"https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{config_name}-{model_time}.log.json" # noqa
status_code_2, flag_2 = check_url(log_json_url_2)
if flag_1 or flag_2:
if flag_1:
logger.info(
f'log.json | {config_name} | {log_json_url_1} | '
f'{status_code_1} | valid')
f"log.json | {config_name} | {log_json_url_1} | "
f"{status_code_1} | valid"
)
else:
logger.info(
f'log.json | {config_name} | {log_json_url_2} | '
f'{status_code_2} | valid')
f"log.json | {config_name} | {log_json_url_2} | "
f"{status_code_2} | valid"
)
else:
logger.error(
f'log.json | {config_name} | {log_json_url_1} & '
f'{log_json_url_2} | {status_code_1} & {status_code_2} | '
'error')
f"log.json | {config_name} | {log_json_url_1} & "
f"{log_json_url_2} | {status_code_1} & {status_code_2} | "
"error"
)


if __name__ == '__main__':
if __name__ == "__main__":
main()
56 changes: 29 additions & 27 deletions mmsegmentation/.dev/gather_benchmark_evaluation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@

def parse_args():
parser = argparse.ArgumentParser(
description='Gather benchmarked model evaluation results')
parser.add_argument('config', help='test config file path')
description="Gather benchmarked model evaluation results"
)
parser.add_argument("config", help="test config file path")
parser.add_argument(
'root',
type=str,
help='root path of benchmarked models to be gathered')
"root", type=str, help="root path of benchmarked models to be gathered"
)
parser.add_argument(
'--out',
"--out",
type=str,
default='benchmark_evaluation_info.json',
help='output path of gathered metrics and compared '
'results to be stored')
default="benchmark_evaluation_info.json",
help="output path of gathered metrics and compared " "results to be stored",
)

args = parser.parse_args()
return args


if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()

root_path = args.root
Expand All @@ -40,52 +40,54 @@ def parse_args():
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
previous_metrics = model_info['metric']
config = model_info['config'].strip()
previous_metrics = model_info["metric"]
config = model_info["config"].strip()
fname, _ = osp.splitext(osp.basename(config))

# Load benchmark evaluation json
metric_json_dir = osp.join(root_path, fname)
if not osp.exists(metric_json_dir):
print(f'{metric_json_dir} not existed.')
print(f"{metric_json_dir} not existed.")
continue

json_list = glob.glob(osp.join(metric_json_dir, '*.json'))
json_list = glob.glob(osp.join(metric_json_dir, "*.json"))
if len(json_list) == 0:
print(f'There is no eval json in {metric_json_dir}.')
print(f"There is no eval json in {metric_json_dir}.")
continue

log_json_path = list(sorted(json_list))[-1]
metric = mmcv.load(log_json_path)
if config not in metric.get('config', {}):
print(f'{config} not included in {log_json_path}')
if config not in metric.get("config", {}):
print(f"{config} not included in {log_json_path}")
continue

# Compare between new benchmark results and previous metrics
differential_results = {}
new_metrics = {}
for record_metric_key in previous_metrics:
if record_metric_key not in metric['metric']:
raise KeyError('record_metric_key not exist, please '
'check your config')
if record_metric_key not in metric["metric"]:
raise KeyError(
"record_metric_key not exist, please " "check your config"
)
old_metric = previous_metrics[record_metric_key]
new_metric = round(metric['metric'][record_metric_key] * 100,
2)
new_metric = round(metric["metric"][record_metric_key] * 100, 2)

differential = new_metric - old_metric
flag = '+' if differential > 0 else '-'
flag = "+" if differential > 0 else "-"
differential_results[
record_metric_key] = f'{flag}{abs(differential):.2f}'
record_metric_key
] = f"{flag}{abs(differential):.2f}"
new_metrics[record_metric_key] = new_metric

result_dict[config] = dict(
differential=differential_results,
previous=previous_metrics,
new=new_metrics)
new=new_metrics,
)

if metrics_out:
mmcv.dump(result_dict, metrics_out, indent=4)
print('===================================')
print("===================================")
for config_name, metrics in result_dict.items():
print(config_name, metrics)
print('===================================')
print("===================================")
Loading

0 comments on commit 0776b5c

Please sign in to comment.