Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fix to casting masks in AMG post-processing #780

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions micro_sam/_vendored.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def njit(func):


def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""Calculates boxes in XYXY format around masks. Return [0,0,0,0] for an empty mask.
"""Calculates boxes in XYXY format around masks. Return [0, 0, 0, 0] for an empty mask.

For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.

Expand All @@ -38,7 +38,7 @@ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
It further ensures that inputs are boolean tensors, otherwise the function yields wrong results.
See https://github.com/facebookresearch/segment-anything/issues/552 for details.
"""
assert masks.dtype == torch.bool
assert masks.dtype == torch.bool, masks.dtype

# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _postprocess_small_regions(self, mask_data, min_area, nms_thresh):

# recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
boxes = batched_mask_to_box(masks.to(torch.bool)) # Casting this to boolean as we work with one-hot labels.
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores, dtype=torch.float),
Expand Down
Loading