diff --git a/demo/image_demo.py b/demo/image_demo.py index 231aacb9dd..4da6b02396 100644 --- a/demo/image_demo.py +++ b/demo/image_demo.py @@ -21,6 +21,10 @@ def main(): help='Opacity of painted segmentation map. In (0, 1] range.') parser.add_argument( '--title', default='result', help='The image identifier.') + parser.add_argument( + '--segementation-only', + action='store_true', + help='Only show the segmentation map.') args = parser.parse_args() # build the model from a config file and a checkpoint file @@ -38,7 +42,8 @@ def main(): opacity=args.opacity, draw_gt=False, show=False if args.out_file is not None else True, - out_file=args.out_file) + out_file=args.out_file, + segmentation_only=args.segementation_only) if __name__ == '__main__': diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 57fc5d23dc..7f940f4354 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -128,6 +128,7 @@ def show_result_pyplot(model: BaseSegmentor, wait_time: float = 0, show: bool = True, withLabels: Optional[bool] = True, + segmentation_only: bool = False, save_dir=None, out_file=None): """Visualize the segmentation results on the image. @@ -170,7 +171,8 @@ def show_result_pyplot(model: BaseSegmentor, visualizer = SegLocalVisualizer( vis_backends=[dict(type='LocalVisBackend')], save_dir=save_dir, - alpha=opacity) + alpha=opacity, + segement_only=segmentation_only) visualizer.dataset_meta = dict( classes=model.dataset_meta['classes'], palette=model.dataset_meta['palette']) diff --git a/mmseg/visualization/local_visualizer.py b/mmseg/visualization/local_visualizer.py index 3096e3183b..3790549350 100644 --- a/mmseg/visualization/local_visualizer.py +++ b/mmseg/visualization/local_visualizer.py @@ -73,9 +73,12 @@ def __init__(self, palette: Optional[List] = None, dataset_name: Optional[str] = None, alpha: float = 0.8, + segement_only: bool = False, **kwargs): super().__init__(name, image, vis_backends, save_dir, **kwargs) self.alpha: float = alpha + if segement_only: + self.alpha = 1 self.set_dataset_meta(palette, classes, dataset_name) def _get_center_loc(self, mask: np.ndarray) -> np.ndarray: @@ -139,7 +142,7 @@ def _draw_sem_seg(self, for label, color in zip(labels, colors): mask[sem_seg[0] == label, :] = color - if withLabels: + if withLabels and not self.segement_only: font = cv2.FONT_HERSHEY_SIMPLEX # (0,1] to change the size of the text relative to the image scale = 0.05