Skip to content

Commit

Permalink
Add test for stitching labels
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Nov 20, 2024
1 parent 4d07cc1 commit 5dad99c
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion test/segmentation/test_stitching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
import unittest

from elf.evaluation import rand_index
import numpy as np
from skimage.data import binary_blobs
from skimage.measure import label

from elf.evaluation import rand_index


class TestStitching(unittest.TestCase):
def get_data(self, size=1024, ndim=2):
data = binary_blobs(size, blob_size_fraction=0.1, volume_fraction=0.2, n_dim=ndim)
return data

def get_tiled_data(self, size=1024, ndim=2, tile_shape=(512, 512)):
data = self.get_data(size=size, ndim=ndim)
data = label(data) # Ensure all inputs are instances (the blobs are semantic labels)

# Create tiles out of the data.
# Ensure offset for objects per tile to get individual ids per object per tile.
import nifty.tools as nt
blocking = nt.blocking([0] * ndim, data.shape, tile_shape)
n_blocks = blocking.numberOfBlocks

offset = 0
bb_tiles, tiles = [], []
for tile_id in range(n_blocks):
block = blocking.getBlock(tile_id)
bb = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))

tile = data[bb]
tile = label(tile)
tile[tile != 0] += offset
offset = tile.max()

tiles.append(tile)
bb_tiles.append(bb)

# Finally, let's stitch back the individual tiles.
labels = np.zeros(data.shape)
for tile, loc in zip(tiles, bb_tiles):
labels[loc] = tile

return labels, data # returns the stitched labels and original labels

def test_stitch_segmentation(self):
from elf.segmentation.stitching import stitch_segmentation

Expand Down Expand Up @@ -43,6 +76,18 @@ def _segment(input_, block_id=None):
are, _ = rand_index(segmentation, expected_segmentation)
self.assertTrue(are < 0.05)

def test_stitch_tiled_segmentation(self):
from elf.segmentation.stitching import stitch_tiled_segmentation

tile_shapes = [(224, 224), (256, 256), (512, 512)]
for tile_shape in tile_shapes:
# Get the tiled segmentation with unmerged instances at tile interfaces.
labels, original_labels = self.get_tiled_data()
stitched_labels = stitch_tiled_segmentation(segmentation=labels, tile_shape=tile_shape)
self.assertEqual(labels.shape, stitched_labels.shape)
# self.assertEqual(len(np.unique(original_labels)), len(np.unique(stitched_labels)))
print(len(np.unique(original_labels)), len(np.unique(stitched_labels)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 5dad99c

Please sign in to comment.