From 301c01eb30c960e0211d2aff2e033fd754464038 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=84=9C=EC=9E=A5=EC=9B=90=5FT4105?= Date: Wed, 4 Jan 2023 15:40:56 +0900 Subject: [PATCH] :seedling: Add MMSeg custom train.py, test.py --- mmsegmentation/tools/test_jwseo.py | 320 +++++++++++++++++++++++++ mmsegmentation/tools/train_jwseo.py | 354 ++++++++++++++++++++++++++++ 2 files changed, 674 insertions(+) create mode 100644 mmsegmentation/tools/test_jwseo.py create mode 100644 mmsegmentation/tools/train_jwseo.py diff --git a/mmsegmentation/tools/test_jwseo.py b/mmsegmentation/tools/test_jwseo.py new file mode 100644 index 0000000..a643b08 --- /dev/null +++ b/mmsegmentation/tools/test_jwseo.py @@ -0,0 +1,320 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import shutil +import time +import warnings + +import mmcv +import torch +from mmcv.cnn.utils import revert_sync_batchnorm +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) +from mmcv.utils import DictAction + +from mmseg import digit_version +from mmseg.apis import multi_gpu_test, single_gpu_test +from mmseg.datasets import build_dataloader, build_dataset +from mmseg.models import build_segmentor +from mmseg.utils import build_ddp, build_dp, get_device, setup_multi_processes + + +def parse_args(): + parser = argparse.ArgumentParser( + description='mmseg test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help=('if specified, the evaluation metric results will be dumped' + 'into the directory as json')) + parser.add_argument( + '--aug-test', action='store_true', help='Use Flip and Multi scale aug') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' + ' for generic datasets, and "cityscapes" for Cityscapes') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='whether to use gpu to collect results.') + parser.add_argument( + '--gpu-id', + type=int, + default=0, + help='id of gpu to use ' + '(only applicable to non-distributed testing)') + parser.add_argument( + '--tmpdir', + help='tmp directory used for collecting results from multiple ' + 'workers, available when gpu_collect is not specified') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + assert args.out or args.eval or args.format_only or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # set multi-process settings + setup_multi_processes(cfg) + + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + if args.aug_test: + # hard code index + cfg.data.test.pipeline[1].img_ratios = [ + 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 + ] + cfg.data.test.pipeline[1].flip = True + cfg.model.pretrained = None + cfg.data.test.test_mode = True + + if args.gpu_id is not None: + cfg.gpu_ids = [args.gpu_id] + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + cfg.gpu_ids = [args.gpu_id] + distributed = False + if len(cfg.gpu_ids) > 1: + warnings.warn(f'The gpu-ids is reset from {cfg.gpu_ids} to ' + f'{cfg.gpu_ids[0:1]} to avoid potential error in ' + 'non-distribute testing time.') + cfg.gpu_ids = cfg.gpu_ids[0:1] + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + rank, _ = get_dist_info() + # allows not to create + if args.work_dir is not None and rank == 0: + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.aug_test: + json_file = osp.join(args.work_dir, + f'eval_multi_scale_{timestamp}.json') + else: + json_file = osp.join(args.work_dir, + f'eval_single_scale_{timestamp}.json') + elif rank == 0: + work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + mmcv.mkdir_or_exist(osp.abspath(work_dir)) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.aug_test: + json_file = osp.join(work_dir, + f'eval_multi_scale_{timestamp}.json') + else: + json_file = osp.join(work_dir, + f'eval_single_scale_{timestamp}.json') + + # build the dataloader + # TODO: support multiple images per gpu (only minor changes are needed) + dataset = build_dataset(cfg.data.test) + # The default loader config + loader_cfg = dict( + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + shuffle=False) + # The overall dataloader settings + loader_cfg.update({ + k: v + for k, v in cfg.data.items() if k not in [ + 'train', 'val', 'test', 'train_dataloader', 'val_dataloader', + 'test_dataloader' + ] + }) + test_loader_cfg = { + **loader_cfg, + 'samples_per_gpu': 1, + 'shuffle': False, # Not shuffle by default + **cfg.data.get('test_dataloader', {}) + } + # build the dataloader + data_loader = build_dataloader(dataset, **test_loader_cfg) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + print('"CLASSES" not found in meta, use dataset.CLASSES instead') + model.CLASSES = dataset.CLASSES + if 'PALETTE' in checkpoint.get('meta', {}): + model.PALETTE = checkpoint['meta']['PALETTE'] + else: + print('"PALETTE" not found in meta, use dataset.PALETTE instead') + model.PALETTE = dataset.PALETTE + + # clean gpu memory when starting a new evaluation. + torch.cuda.empty_cache() + eval_kwargs = {} if args.eval_options is None else args.eval_options + + # Deprecated + efficient_test = eval_kwargs.get('efficient_test', False) + if efficient_test: + warnings.warn( + '``efficient_test=True`` does not have effect in tools/test.py, ' + 'the evaluation and format results are CPU memory efficient by ' + 'default') + + eval_on_format_results = ( + args.eval is not None and 'cityscapes' in args.eval) + if eval_on_format_results: + assert len(args.eval) == 1, 'eval on format results is not ' \ + 'applicable for metrics other than ' \ + 'cityscapes' + if args.format_only or eval_on_format_results: + if 'imgfile_prefix' in eval_kwargs: + tmpdir = eval_kwargs['imgfile_prefix'] + else: + tmpdir = '.format_cityscapes' + eval_kwargs.setdefault('imgfile_prefix', tmpdir) + mmcv.mkdir_or_exist(tmpdir) + else: + tmpdir = None + + cfg.device = get_device() + if not distributed: + warnings.warn( + 'SyncBN is only supported with DDP. To be compatible with DP, ' + 'we convert SyncBN to BN. Please use dist_train.sh which can ' + 'avoid this error.') + if not torch.cuda.is_available(): + assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \ + 'Please use MMCV >= 1.4.4 for CPU training!' + model = revert_sync_batchnorm(model) + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) + results = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + False, + args.opacity, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) + else: + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False) + results = multi_gpu_test( + model, + data_loader, + args.tmpdir, + args.gpu_collect, + False, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + warnings.warn( + 'The behavior of ``args.out`` has been changed since MMSeg ' + 'v0.16, the pickled outputs could be seg map as type of ' + 'np.array, pre-eval results or file paths for ' + '``dataset.format_results()``.') + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + if args.eval: + eval_kwargs.update(metric=args.eval) + metric = dataset.evaluate(results, **eval_kwargs) + metric_dict = dict(config=args.config, metric=metric) + mmcv.dump(metric_dict, json_file, indent=4) + if tmpdir is not None and eval_on_format_results: + # remove tmp dir when cityscapes evaluation + shutil.rmtree(tmpdir) + + +if __name__ == '__main__': + main() diff --git a/mmsegmentation/tools/train_jwseo.py b/mmsegmentation/tools/train_jwseo.py new file mode 100644 index 0000000..5264b3c --- /dev/null +++ b/mmsegmentation/tools/train_jwseo.py @@ -0,0 +1,354 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import time +import warnings +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Union + +import mmcv +import pytz +import torch +import torch.distributed as dist +from mmcv.cnn.utils import revert_sync_batchnorm +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import Config, DictAction, get_git_hash +from mmcv.utils.config import ConfigDict +from mmseg import __version__ +from mmseg.apis import init_random_seed, set_random_seed, train_segmentor +from mmseg.datasets import build_dataset +from mmseg.models import build_segmentor +from mmseg.utils import collect_env, get_device, get_root_logger, setup_multi_processes +from rich.console import Console + +KST_TZ = pytz.timezone("Asia/Seoul") +GPU_ID = 0 +SEED = 2022 +PROJ_DIR = Path(__file__).parent.parent.parent +WORK_DIRS = PROJ_DIR / "work_dirs" +PTH_PREFIX = "iter_" +TIME_FORMAT = "%m/%d %H:%M" + +console = Console(record=True) + + +class WandbResumeType(str, Enum): + allow = "allow" + must = "must" + never = "never" + auto = "auto" + + +def get_latest_checkpoint(work_dir: Path) -> Union[str, None]: + latest_checkpoint_file = work_dir / "latest_checkpoint" + if not latest_checkpoint_file.exists(): + return None + + with open(latest_checkpoint_file, "r", encoding="utf8") as f: + checkpoint_path = f.read() + + return checkpoint_path + + +def get_wandb_hook_index(cfg) -> Union[int, None]: + try: + wandb_hook_index = next( + i + for i, hook in enumerate(cfg.log_config.hooks) + if hook.type == "MMSegWandbHook" + ) + except Exception: + return None + + return wandb_hook_index + + +def set_wandb_name(wandb_hook: ConfigDict, config_path: Path) -> None: + name = config_path.stem.replace("_", " ") + name = datetime.now(KST_TZ).strftime(f"{TIME_FORMAT} ") + name + wandb_hook.init_kwargs.name = name + + +def set_wandb_resume_id(wandb_hook: ConfigDict, resume_wandb_id) -> None: + wandb_hook.init_kwargs.resume = "allow" + wandb_hook.init_kwargs.id = resume_wandb_id + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a segmentor") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument("--load-from", help="the checkpoint file to load weights from") + parser.add_argument("--resume-from", help="the checkpoint file to resume from") + parser.add_argument("--resume-wandb-id", help="the checkpoint file to resume from") + parser.add_argument( + "--resume-wandb-type", default="allow", help="wandb resuming behavior " + ) + parser.add_argument( + "--no-validate", + action="store_true", + help="whether not to evaluate the checkpoint during training", + ) + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + "--gpus", + type=int, + help="(Deprecated, please use --gpu-id) number of gpus to use " + "(only applicable to non-distributed training)", + ) + group_gpus.add_argument( + "--gpu-ids", + type=int, + nargs="+", + help="(Deprecated, please use --gpu-id) ids of gpus to use " + "(only applicable to non-distributed training)", + ) + group_gpus.add_argument( + "--gpu-id", + type=int, + default=0, + help="id of gpu to use " "(only applicable to non-distributed training)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed") + parser.add_argument( + "--diff_seed", + action="store_true", + help="Whether or not set different seeds for different ranks", + ) + parser.add_argument( + "--deterministic", + action="store_true", + help="whether to set deterministic options for CUDNN backend.", + ) + parser.add_argument( + "--options", + nargs="+", + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + "not be supported in version v0.22.0. Override some settings in the " + "used config, the key-value pair in xxx=yyy format will be merged " + "into config file. If the value to be overwritten is a list, it " + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + "marks are necessary and that no white space is allowed.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument( + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "--auto-resume", + action="store_true", + default=True, + help="resume from the latest checkpoint automatically.", + ) + args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) + if args.options: + warnings.warn( + "--options is deprecated in favor of --cfg-options. " + "--options will not be supported in version v0.22.0." + ) + args.cfg_options = args.options + + if args.resume_from is not None and args.resume_wandb_id is None: + raise ValueError("'--resume-from' should be with '--resume-wandb-id'") + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # set cudnn_benchmark + if cfg.get("cudnn_benchmark", False): + torch.backends.cudnn.benchmark = True + + # ------ + + config_filepath = Path(args.config) + + work_dir = WORK_DIRS / config_filepath.stem + work_dir.mkdir(parents=True, exist_ok=True) + + if args.resume_from is not None: + cfg.resume_from = args.resume_from + + elif latest_checkpoint := get_latest_checkpoint(work_dir): + cfg.resume_from = latest_checkpoint + + if latest_checkpoint and args.resume_wandb_id is None: + raise ValueError( + "'latest_checkpoint' is found. Please set '--resume-wandb-id'" + ) + + if wandb_hook_index := get_wandb_hook_index(cfg): + wandb_hook = cfg.log_config.hooks[wandb_hook_index] + set_wandb_name(wandb_hook, config_filepath) + + if args.resume_wandb_id: + set_wandb_resume_id(wandb_hook, args.resume_wandb_id) + + # ---------- + + if args.gpus is not None: + cfg.gpu_ids = range(1) + warnings.warn( + "`--gpus` is deprecated because we only support " + "single GPU mode in non-distributed training. " + "Use `gpus=1` now." + ) + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids[0:1] + warnings.warn( + "`--gpu-ids` is deprecated, please use `--gpu-id`. " + "Because we only support single GPU mode in " + "non-distributed training. Use the first GPU " + "in `gpu_ids` now." + ) + if args.gpus is None and args.gpu_ids is None: + cfg.gpu_ids = [args.gpu_id] + + # # work_dir is determined in this priority: CLI > segment in file > filename + # if args.work_dir is not None: + # # update configs according to CLI args if args.work_dir is not None + # cfg.work_dir = args.work_dir + # elif cfg.get("work_dir", None) is None: + # # use config filename as default work_dir if cfg.work_dir is None + # cfg.work_dir = osp.join( + # "./work_dirs", osp.splitext(osp.basename(args.config))[0] + # ) + if args.load_from is not None: + cfg.load_from = args.load_from + # if args.resume_from is not None: + # cfg.resume_from = args.resume_from + cfg.auto_resume = args.auto_resume + + # init distributed env first, since logger depends on the dist info. + if args.launcher == "none": + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # gpu_ids is used to calculate iter when resuming checkpoint + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + cfg.work_dir = str(work_dir) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime("%Y-%m-%d %H;%M", time.localtime()) + log_file = work_dir / f"{timestamp}.log" + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # set multi-process settings + setup_multi_processes(cfg) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + # log env info + env_info_dict = collect_env() + env_info = "\n".join([f"{k}: {v}" for k, v in env_info_dict.items()]) + dash_line = "-" * 60 + "\n" + logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) + + # log some basic info + logger.info(f"Distributed training: {distributed}") + logger.info(f"Config:\n{cfg.pretty_text}") + + # set random seeds + cfg.device = get_device() + seed = init_random_seed(args.seed, device=cfg.device) + seed = seed + dist.get_rank() if args.diff_seed else seed + logger.info(f"Set random seed to {seed}, " f"deterministic: {args.deterministic}") + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta = { + "env_info": env_info, + "seed": seed, + "exp_name": config_filepath.stem, + } + + model = build_segmentor( + cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") + ) + model.init_weights() + + # SyncBN is not support for DP + if not distributed: + warnings.warn( + "SyncBN is only supported with DDP. To be compatible with DP, " + "we convert SyncBN to BN. Please use dist_train.sh which can " + "avoid this error." + ) + model = revert_sync_batchnorm(model) + + logger.info(model) + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + # val_dataset.pipeline = cfg.data.train.pipeline + # AttributeError: 'ConfigDict' object has no attribute 'pipeline' + # -> dataset/train.py에서 대신 설정하기 + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmseg version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmseg_version=f"{__version__}+{get_git_hash()[:7]}", + config=cfg.pretty_text, + CLASSES=datasets[0].CLASSES, + PALETTE=datasets[0].PALETTE, + ) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + # passing checkpoint meta for saving best checkpoint + meta.update(cfg.checkpoint_config.meta) + train_segmentor( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta, + ) + + +if __name__ == "__main__": + main()