Skip to content

Commit

Permalink
Fix choice of automatic segmentation mode in CLI (#788)
Browse files Browse the repository at this point in the history
* Fix automatic segmentation cli
  • Loading branch information
anwai98 authored Nov 20, 2024
1 parent bb548f9 commit 41de117
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 26 deletions.
10 changes: 7 additions & 3 deletions doc/cli_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ The supported CLIs can be used by
- Running `$ micro_sam.image_series_annotator` for starting the image series annotator.
- Running `$ micro_sam.automatic_segmentation` for automatic instance segmentation.
- We support all post-processing parameters for automatic instance segmentation (for both AMG and AIS).
- The automatic segmentation mode can be controlled by: `--mode <MODE_NAME>`, where the available choice for `MODE_NAME` is `amg` / `ais`.
- AMG is supported by both default Segment Anything models and `micro-sam` models / finetuned models.
- AIS is supported by `micro-sam` models (or finetuned models; subjected to they are trained with the additional instance segmentation decoder)
- If these parameters are not provided by the user, `micro-sam` makes use of the best post-processing parameters (depending on the choice of model).
- The post-processing parameters can be changed by parsing the parameters via the CLI using `--<PARAMETER_NAME> <VALUE>.` For example, one can update the parameter values (eg. `pred_iou_thresh`, `stability_iou_thresh`, etc. - supported by AMG) using
```bash
$ micro_sam.automatic_segmentation ... --pred_iou_thresh 0.6 --stability_iou_thresh 0.6 ...
```
```bash
$ micro_sam.automatic_segmentation ... --pred_iou_thresh 0.6 --stability_iou_thresh 0.6 ...
```
- Remember to specify the automatic segmentation mode using `--mode <MODE_NAME>` when using additional post-processing parameters.
- You can check details for supported parameters and their respective default values at `micro_sam/instance_segmentation.py` under the `generate` method for `AutomaticMaskGenerator` and `InstanceSegmentationWithDecoder` class.

NOTE: For all CLIs above, you can find more details by adding the argument `-h` to the CLI script (eg. `$ micro_sam.annotator_2d -h`).
41 changes: 31 additions & 10 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import numpy as np
import imageio.v3 as imageio

from torch_em.data.datasets.util import split_kwargs

from . import util
from .instance_segmentation import (
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
)
from .multi_dimensional_segmentation import automatic_3d_segmentation

Expand All @@ -30,7 +33,7 @@ def get_predictor_and_segmenter(
Otherwise AIS will be used, which requires a special segmentation decoder.
If not specified AIS will be used if it is available and otherwise AMG will be used.
is_tiled: Whether to return segmenter for performing segmentation in tiling window style.
kwargs: Keyword arguments for the automatic instance segmentation class.
kwargs: Keyword arguments for the automatic mask generation class.
Returns:
The Segment Anything model.
Expand All @@ -46,17 +49,16 @@ def get_predictor_and_segmenter(

if amg is None:
amg = "decoder_state" not in state

if amg:
decoder = None
else:
if "decoder_state" not in state:
raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.")
raise RuntimeError("You have passed 'amg=False', but your model does not contain a segmentation decoder.")
decoder_state = state["decoder_state"]
decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device)

segmenter = get_amg(
predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs
)
segmenter = get_amg(predictor=predictor, is_tiled=is_tiled, decoder=decoder, **kwargs)

return predictor, segmenter

Expand Down Expand Up @@ -132,6 +134,7 @@ def automatic_instance_segmentation(
instances = np.zeros(this_shape, dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)

else:
if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
Expand Down Expand Up @@ -189,7 +192,7 @@ def main():
)
parser.add_argument(
"-c", "--checkpoint", default=None,
help="Checkpoint from which the SAM model will be loaded loaded."
help="Checkpoint from which the SAM model will be loaded."
)
parser.add_argument(
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
Expand All @@ -202,7 +205,8 @@ def main():
help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
)
parser.add_argument(
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
"--mode", type=str, default=None,
help="The choice of automatic segmentation with the Segment Anything models. Either 'amg' or 'ais'."
)
parser.add_argument(
"-d", "--device", default=None,
Expand All @@ -222,16 +226,33 @@ def _convert_argval(value):

# NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
# the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
generate_kwargs = {
extra_kwargs = {
parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

# Separate extra arguments as per where they should be passed in the automatic segmentation class.
# This is done to ensure the extra arguments are allocated to the desired location.
# eg. for AMG, 'points_per_side' is expected by '__init__',
# and 'stability_score_thresh' is expected in 'generate' method.
amg_class = AutomaticMaskGenerator if args.tile_shape is None else TiledAutomaticMaskGenerator
amg_kwargs, generate_kwargs = split_kwargs(amg_class, **extra_kwargs)

# Validate for the expected automatic segmentation mode.
# By default, it is set to 'None', i.e. searches for the decoder state to prioritize AIS for finetuned models.
# Otherwise, runs AMG for all models in any case.
amg = None
if args.mode is not None:
assert args.mode in ["ais", "amg"], \
f"'{args.mode}' is not a valid automatic segmentation mode. Please choose either 'amg' or 'ais'."
amg = (args.mode == "amg")

predictor, segmenter = get_predictor_and_segmenter(
model_type=args.model_type,
checkpoint=args.checkpoint,
device=args.device,
amg=args.amg,
amg=amg,
is_tiled=args.tile_shape is not None,
**amg_kwargs,
)

automatic_instance_segmentation(
Expand Down
14 changes: 6 additions & 8 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,10 +1122,7 @@ def initialize(


def get_amg(
predictor: SamPredictor,
is_tiled: bool,
decoder: Optional[torch.nn.Module] = None,
**kwargs,
predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, **kwargs,
) -> Union[AMGBase, InstanceSegmentationWithDecoder]:
"""Get the automatic mask generator class.
Expand All @@ -1139,9 +1136,10 @@ def get_amg(
The automatic mask generator.
"""
if decoder is None:
segmenter = TiledAutomaticMaskGenerator(predictor, **kwargs) if is_tiled else\
AutomaticMaskGenerator(predictor, **kwargs)
segmenter_class = TiledAutomaticMaskGenerator if is_tiled else AutomaticMaskGenerator
segmenter = segmenter_class(predictor, **kwargs)
else:
segmenter = TiledInstanceSegmentationWithDecoder(predictor, decoder, **kwargs) if is_tiled else\
InstanceSegmentationWithDecoder(predictor, decoder, **kwargs)
segmenter_class = TiledInstanceSegmentationWithDecoder if is_tiled else InstanceSegmentationWithDecoder
segmenter = segmenter_class(predictor, decoder, **kwargs)

return segmenter
49 changes: 44 additions & 5 deletions test/test_sam_annotator/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
import platform
import unittest
from shutil import which, rmtree
from subprocess import run
from shutil import which, rmtree

import imageio.v3 as imageio
import micro_sam.util as util
import pytest
import zarr
import pytest
import imageio.v3 as imageio
from skimage.data import binary_blobs

import micro_sam.util as util


class TestCLI(unittest.TestCase):
model_type = "vit_t_lm" if util.VIT_T_SUPPORT else "vit_b_lm"
default_model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
tmp_folder = "tmp-files"

def setUp(self):
Expand All @@ -36,7 +38,7 @@ def test_annotator_tracking(self):
def test_image_series_annotator(self):
self._test_command("micro_sam.image_series_annotator")

@pytest.mark.skipif(platform.system() == "Windows", reason="Gui test is not working on windows.")
@pytest.mark.skipif(platform.system() == "Windows", reason="CLI test is not working on windows.")
def test_precompute_embeddings(self):
self._test_command("micro_sam.precompute_embeddings")

Expand Down Expand Up @@ -83,9 +85,46 @@ def test_precompute_embeddings(self):
ais_path = os.path.join(emb_path3, f"image-{i}.zarr", "is_state.h5")
self.assertTrue(os.path.exists(ais_path))

@pytest.mark.skipif(platform.system() == "Windows", reason="CLI test is not working on windows.")
def test_automatic_segmentation(self):
self._test_command("micro_sam.automatic_segmentation")

# Create 1 image as testdata.
im_path = os.path.join(self.tmp_folder, "image.tif")
image_data = binary_blobs(512).astype("uint8") * 255
imageio.imwrite(im_path, image_data)

# Path to save automatic segmentation outputs.
out_path = "output.tif"

# Test AMG with default model in default mode.
run(["micro_sam.automatic_segmentation", "-i", im_path, "-o", out_path,
"-m", self.default_model_type, "--points_per_side", "4"])
self.assertTrue(os.path.exists(out_path))
os.remove(out_path)

# Test AMG with default model exclusively in AMG mode.
run(["micro_sam.automatic_segmentation", "-i", im_path, "-o", out_path,
"-m", self.default_model_type, "--mode", "amg", "--points_per_side", "4"])
self.assertTrue(os.path.exists(out_path))
os.remove(out_path)

# Test AIS with 'micro-sam' model in default mode.
run(["micro_sam.automatic_segmentation", "-i", im_path, "-o", out_path, "-m", self.model_type])
self.assertTrue(os.path.exists(out_path))
os.remove(out_path)

# Test AIS with 'micro-sam' model exclusively in AMG mode.
run(["micro_sam.automatic_segmentation", "-i", im_path, "-o", out_path,
"-m", self.model_type, "--mode", "amg", "--points_per_side", "4"])
self.assertTrue(os.path.exists(out_path))
os.remove(out_path)

# Test AIS with 'micro-sam' model exclusively in AIS mode.
run(["micro_sam.automatic_segmentation", "-i", im_path, "-o", out_path, "-m", self.model_type, "--mode", "ais"])
self.assertTrue(os.path.exists(out_path))
os.remove(out_path)


if __name__ == "__main__":
unittest.main()

0 comments on commit 41de117

Please sign in to comment.