diff --git a/OmeSliCC/ImageGenerator.py b/OmeSliCC/ImageGenerator.py index 78cc7c8..9e25532 100644 --- a/OmeSliCC/ImageGenerator.py +++ b/OmeSliCC/ImageGenerator.py @@ -108,6 +108,9 @@ def __init__(self, size, tile_size, dtype=np.uint8, seed=None): self.tile_size = tile_size self.dtype = dtype + ranges = np.flip(np.ceil(np.divide(self.size, self.tile_size)).astype(int)) + self.tile_indices = list(np.ndindex(tuple(ranges))) + if seed is not None: np.random.seed(seed) @@ -133,9 +136,8 @@ def get_tiles(self): max_val = 2 ** (8 * np.dtype(dtype).itemsize) - 1 else: max_val = 1 - ranges = np.flip(np.ceil(np.divide(self.size, self.tile_size)).astype(int)) # flip: cycle over indices in x, y, z order using range = [z, y, x] - for indices in tqdm(list(np.ndindex(tuple(ranges)))): + for indices in tqdm(self.tile_indices): self.range0 = np.flip(indices) * tile_size self.range1 = np.min([self.range0 + self.tile_size, self.size], 0) shape = list(reversed(self.range1 - self.range0)) @@ -165,23 +167,35 @@ def save_tiff(filename, data, shape=None, dtype=None, tile_size=None, bigtiff=No compression=compression) -def render_image(data, shape, dtype): - #for slice1, tile in tqdm(data): - # image[slice1] = tile - - image = np.zeros(shape, dtype=dtype) - position = np.array([0] * len(shape)) - for tile in data: - image[position: position + tile.shape] = tile - position += tile.shape - # TODO: increase/carry-over (x,) y, z - image = image.reshape(shape) +def render_image(image_generator): + data = image_generator.get_tiles() + dtype = image_generator.dtype + tile_indices = image_generator.tile_indices + shape = np.flip(image_generator.size) + multi_dimensional = (len(shape) > 2) - if len(shape) <= 3 and shape[-1] <= 4: - show_image(image) + if multi_dimensional: + shape1 = list(shape)[1:] + [3] else: - i = shape[3] // 2 + 2 - show_image(image[:, :, i, :]) + shape1 = list(shape) + [3] + image = np.zeros(shape1, dtype=dtype) + first = True + for indices, tile in zip(tile_indices, data): + tile_shape = tile.shape[:-1] + range0 = np.multiply(indices, tile_shape) + range1 = np.min([range0 + tile_shape, shape], 0) + + # break after first z + if not first and np.all(range0[-2:] == [0, 0]): + break + + if multi_dimensional: + image[range0[1]: range1[1], range0[2]: range1[2], :] = tile[0, ...] + else: + image[range0[0]: range1[0], range0[1]: range1[1], :] = tile + first = False + + show_image(image) def show_image(image): @@ -210,5 +224,4 @@ def show_image(image): save_tiff(path, data, shape, dtype, tile_size=tile_shape, ome=True) print('save done') - data = image_generator.get_tiles() - render_image(data, shape, dtype) + render_image(image_generator)