-import logging
-import math
-import numpy as np
-import random
-from scipy import ndimage
-
-from .batch_filter import BatchFilter
-from gunpowder.batch_request import BatchRequest
-from gunpowder.coordinate import Coordinate
-from gunpowder.ext import augment
-from gunpowder.roi import Roi
-from gunpowder.array import ArrayKey
-
-logger = logging.getLogger(__name__)
-
-
-[docs]class ElasticAugment(BatchFilter):
-
"""Elasticly deform a batch. Requests larger batches upstream to avoid data
-
loss due to rotation and jitter.
-
-
Args:
-
-
control_point_spacing (``tuple`` of ``int``):
-
-
Distance between control points for the elastic deformation, in
-
voxels per dimension.
-
-
jitter_sigma (``tuple`` of ``float``):
-
-
Standard deviation of control point jitter distribution, in voxels
-
per dimension.
-
-
rotation_interval (``tuple`` of two ``floats``):
-
-
Interval to randomly sample rotation angles from (0, 2PI).
-
-
prob_slip (``float``):
-
-
Probability of a section to "slip", i.e., be independently moved in
-
x-y.
-
-
prob_shift (``float``):
-
-
Probability of a section and all following sections to move in x-y.
-
-
max_misalign (``int``):
-
-
Maximal voxels to shift in x and y. Samples will be drawn
-
uniformly. Used if ``prob_slip + prob_shift`` > 0.
-
-
subsample (``int``):
-
-
Instead of creating an elastic transformation on the full
-
resolution, create one subsampled by the given factor, and linearly
-
interpolate to obtain the full resolution transformation. This can
-
significantly speed up this node, at the expense of having visible
-
piecewise linear deformations for large factors. Usually, a factor
-
of 4 can savely by used without noticable changes. However, the
-
default is 1 (i.e., no subsampling).
-
-
spatial_dims (``int``):
-
-
The number of spatial dimensions in arrays. Spatial dimensions are
-
assumed to be the last ones and cannot be more than 3 (default).
-
Set this value here to avoid treating channels as spacial
-
dimension. If, for example, your array is indexed as ``(c,y,x)``
-
(2D plus channels), you would want to set ``spatial_dims=2`` to
-
perform the elastic deformation only on x and y.
-
-
use_fast_points_transform (``bool``):
-
-
By solving for all of your points simultaneously with the following
-
3 step proceedure:
-
1) Rasterize nodes into numpy array
-
2) Apply elastic transform to array
-
3) Read out nodes via center of mass of transformed points
-
You can gain substantial speed up as opposed to calculating the
-
elastic transform for each point individually. However this may
-
lead to nodes being lost during the transform.
-
-
recompute_missing_points (``bool``):
-
-
Whether or not to compute the elastic transform node wise for nodes
-
that were lossed during the fast elastic transform process.
-
"""
-
-
def __init__(
-
self,
-
control_point_spacing,
-
jitter_sigma,
-
rotation_interval,
-
prob_slip=0,
-
prob_shift=0,
-
max_misalign=0,
-
subsample=1,
-
spatial_dims=3,
-
use_fast_points_transform=False,
-
recompute_missing_points=True,
-
):
-
-
self.control_point_spacing = control_point_spacing
-
self.jitter_sigma = jitter_sigma
-
self.rotation_start = rotation_interval[0]
-
self.rotation_max_amount = rotation_interval[1] - rotation_interval[0]
-
self.prob_slip = prob_slip
-
self.prob_shift = prob_shift
-
self.max_misalign = max_misalign
-
self.subsample = subsample
-
self.spatial_dims = spatial_dims
-
self.use_fast_points_transform = use_fast_points_transform
-
self.recompute_missing_points = recompute_missing_points
-
-
def prepare(self, request):
-
seed = request.random_seed
-
random.seed(seed)
-
# augment uses numpy for its randomness
-
np.random.seed(seed)
-
-
# get the voxel size
-
self.voxel_size = self.__get_common_voxel_size(request)
-
-
# get the total ROI of all requests
-
total_roi = request.get_total_roi()
-
logger.debug("total ROI is %s" % total_roi)
-
-
# First, get the total ROI of the request in spatial dimensions only.
-
# Channels and time don't matter. This is our master ROI.
-
-
# get master ROI
-
master_roi = Roi(
-
total_roi.get_begin()[-self.spatial_dims :],
-
total_roi.get_shape()[-self.spatial_dims :],
-
)
-
self.spatial_dims = master_roi.dims()
-
logger.debug("master ROI is %s" % master_roi)
-
-
# make sure the master ROI aligns with the voxel size
-
master_roi = master_roi.snap_to_grid(self.voxel_size, mode="grow")
-
logger.debug("master ROI aligned with voxel size is %s" % master_roi)
-
-
# get master roi in voxels
-
master_roi_voxels = master_roi / self.voxel_size
-
logger.debug("master ROI in voxels is %s" % master_roi_voxels)
-
-
# Second, create a master transformation. This is a transformation that
-
# covers all voxels of the all requested ROIs. The master transformation
-
# is zero-based.
-
-
# create a transformation with the size of the master ROI in voxels
-
self.master_transformation = self.__create_transformation(
-
master_roi_voxels.get_shape()
-
)
-
-
# Third, crop out parts of the master transformation for each of the
-
# smaller requested ROIs. Since these ROIs now have to align with the
-
# voxel size (which for points does not have to be the case), we also
-
# remember these smaller ROIs as target_rois in global world units.
-
-
# crop the parts corresponding to the requested ROIs
-
self.transformations = {}
-
self.target_rois = {}
-
deps = BatchRequest()
-
for key, spec in request.items():
-
-
spec = spec.copy()
-
-
if spec.roi is None:
-
continue
-
-
target_roi = Roi(
-
spec.roi.get_begin()[-self.spatial_dims :],
-
spec.roi.get_shape()[-self.spatial_dims :],
-
)
-
logger.debug("downstream request spatial ROI for %s is %s", key, target_roi)
-
-
# make sure the target ROI aligns with the voxel grid (which might
-
# not be the case for points)
-
target_roi = target_roi.snap_to_grid(self.voxel_size, mode="grow")
-
logger.debug(
-
"downstream request spatial ROI aligned with voxel grid for %s "
-
"is %s",
-
key,
-
target_roi,
-
)
-
-
# remember target ROI (this is where the transformation will project
-
# to)
-
self.target_rois[key] = target_roi
-
-
# get ROI in voxels
-
target_roi_voxels = target_roi / self.voxel_size
-
-
# get ROI relative to master ROI
-
target_roi_in_master_roi_voxels = (
-
target_roi_voxels - master_roi_voxels.get_begin()
-
)
-
-
# crop out relevant part of transformation for this request
-
transformation = np.copy(
-
self.master_transformation[
-
(slice(None),) + target_roi_in_master_roi_voxels.get_bounding_box()
-
]
-
)
-
self.transformations[key] = transformation
-
-
# get ROI of all voxels necessary to perfrom transformation
-
#
-
# for that we follow the same transformations to get from the
-
# request ROI to the target ROI in master ROI in voxels, just in
-
# reverse
-
source_roi_in_master_roi_voxels = self.__get_source_roi(transformation)
-
source_roi_voxels = (
-
source_roi_in_master_roi_voxels + master_roi_voxels.get_begin()
-
)
-
source_roi = source_roi_voxels * self.voxel_size
-
-
# transformation is still defined on voxels relative to master ROI
-
# in voxels (i.e., lowest source coordinate could be 5, but data
-
# array we get later starts at 0).
-
#
-
# shift transformation to be indexed relative to beginning of
-
# source_roi_voxels
-
self.__shift_transformation(
-
-source_roi_in_master_roi_voxels.get_begin(), transformation
-
)
-
-
# update upstream request
-
spec.roi = Roi(
-
spec.roi.get_begin()[: -self.spatial_dims]
-
+ source_roi.get_begin()[-self.spatial_dims :],
-
spec.roi.get_shape()[: -self.spatial_dims]
-
+ source_roi.get_shape()[-self.spatial_dims :],
-
)
-
-
deps[key] = spec
-
-
logger.debug("upstream request roi for %s = %s" % (key, spec.roi))
-
-
return deps
-
-
def process(self, batch, request):
-
-
for (array_key, array) in batch.arrays.items():
-
-
if array_key not in self.target_rois:
-
continue
-
-
# for arrays, the target ROI and the requested ROI should be the
-
# same in spatial coordinates
-
assert (
-
self.target_rois[array_key].get_begin()
-
== request[array_key].roi.get_begin()[-self.spatial_dims :]
-
), "Target roi offset {} does not match request roi offset {}".format(
-
self.target_rois[array_key].get_begin(),
-
request[array_key].roi.get_begin()[-self.spatial_dims :],
-
)
-
-
assert (
-
self.target_rois[array_key].get_shape()
-
== request[array_key].roi.get_shape()[-self.spatial_dims :]
-
), "Target roi offset {} does not match request roi offset {}".format(
-
self.target_rois[array_key].get_shape(),
-
request[array_key].roi.get_shape()[-self.spatial_dims :],
-
)
-
-
# reshape array data into (channels,) + spatial dims
-
shape = array.data.shape
-
channel_shape = shape[: -self.spatial_dims]
-
data = array.data.reshape((-1,) + shape[-self.spatial_dims :])
-
-
# apply transformation on each channel
-
data = np.array(
-
[
-
augment.apply_transformation(
-
data[c],
-
self.transformations[array_key],
-
interpolate=self.spec[array_key].interpolatable,
-
)
-
for c in range(data.shape[0])
-
]
-
)
-
-
data_roi = request[array_key].roi/self.spec[array_key].voxel_size
-
array.data = data.reshape(channel_shape + data_roi.get_shape()[-self.spatial_dims:])
-
-
# restore original ROIs
-
array.spec.roi = request[array_key].roi
-
-
for (graph_key, graph) in batch.graphs.items():
-
-
nodes = list(graph.nodes)
-
-
if self.use_fast_points_transform:
-
missed_nodes = self.__fast_point_projection(
-
self.transformations[graph_key],
-
nodes,
-
graph.spec.roi,
-
target_roi=self.target_rois[graph_key],
-
)
-
if not self.recompute_missing_points:
-
for node in set(missed_nodes):
-
graph.remove_node(node, retain_connectivity=True)
-
missed_nodes = []
-
else:
-
missed_nodes = nodes
-
-
for node in missed_nodes:
-
# logger.debug("projecting %s", node.location)
-
-
# get location relative to beginning of upstream ROI
-
location = node.location - graph.spec.roi.get_begin()
-
logger.debug("relative to upstream ROI: %s", location)
-
-
# get spatial coordinates of node in voxels
-
location_voxels = location[-self.spatial_dims :] / self.voxel_size
-
-
# get projected location in transformation data space, this
-
# yields voxel coordinates relative to target ROI
-
projected_voxels = self.__project(
-
self.transformations[graph_key], location_voxels
-
)
-
-
logger.debug(
-
"projected in voxels, relative to target ROI: %s", projected_voxels
-
)
-
-
if projected_voxels is None:
-
logger.debug("node outside of target, skipping")
-
graph.remove_node(node, retain_connectivity=True)
-
continue
-
-
# convert to world units (now in float again)
-
projected = projected_voxels * np.array(self.voxel_size)
-
-
logger.debug(
-
"projected in world units, relative to target ROI: %s",
-
projected)
-
-
# get global coordinates
-
projected += np.array(self.target_rois[graph_key].get_begin())
-
-
# update spatial coordinates of node location
-
node.location[-self.spatial_dims:] = projected
-
-
logger.debug("final location: %s", node.location)
-
-
# finally, it can happen that a node no longer is contained in
-
# the requested ROI (because larger ROIs than necessary have
-
# been requested upstream)
-
if not request[graph_key].roi.contains(node.location):
-
logger.debug("node outside of target, skipping")
-
graph.remove_node(node, retain_connectivity=True)
-
continue
-
-
# restore original ROIs
-
graph.spec.roi = request[graph_key].roi
-
-
def __get_common_voxel_size(self, request):
-
-
voxel_size = None
-
prev = None
-
for array_key in request.array_specs.keys():
-
if voxel_size is None:
-
voxel_size = self.spec[array_key].voxel_size[-self.spatial_dims:]
-
elif self.spec[array_key].voxel_size is not None:
-
assert voxel_size == self.spec[array_key].voxel_size[-self.spatial_dims:], \
-
"ElasticAugment can only be used with arrays of same voxel sizes, " \
-
"but %s has %s, and %s has %s."%(
-
array_key, self.spec[array_key].voxel_size,
-
prev, self.spec[prev].voxel_size)
-
prev = array_key
-
-
if voxel_size is None:
-
raise RuntimeError("voxel size must not be None")
-
-
return voxel_size
-
-
def __create_transformation(self, target_shape):
-
-
transformation = augment.create_identity_transformation(
-
target_shape, subsample=self.subsample
-
)
-
if sum(self.jitter_sigma) > 0:
-
transformation += augment.create_elastic_transformation(
-
target_shape,
-
self.control_point_spacing,
-
self.jitter_sigma,
-
subsample=self.subsample,
-
)
-
rotation = random.random() * self.rotation_max_amount + self.rotation_start
-
if rotation != 0:
-
transformation += augment.create_rotation_transformation(
-
target_shape, rotation, subsample=self.subsample
-
)
-
-
if self.subsample > 1:
-
transformation = augment.upscale_transformation(
-
transformation, target_shape
-
)
-
-
if self.prob_slip + self.prob_shift > 0:
-
self.__misalign(transformation)
-
-
return transformation
-
-
def __fast_point_projection(
-
self, transformation, nodes, source_roi, target_roi
-
):
-
if len(nodes) < 1:
-
return []
-
# rasterize the points into an array
-
ids, locs = zip(
-
*[
-
(
-
node.id,
-
(np.floor(node.location).astype(int) - source_roi.get_begin())
-
// self.voxel_size,
-
)
-
for node in nodes
-
if source_roi.contains(node.location)
-
]
-
)
-
ids, locs = np.array(ids), tuple(zip(*locs))
-
points_array = np.zeros(
-
source_roi.get_shape() / self.voxel_size, dtype=np.int64
-
)
-
points_array[locs] = ids
-
-
# reshape array data into (channels,) + spatial dims
-
shape = points_array.shape
-
data = points_array.reshape((-1,) + shape[-self.spatial_dims :])
-
-
# apply transformation on each channel
-
data = np.array(
-
[
-
augment.apply_transformation(
-
data[c], transformation, interpolate="nearest"
-
)
-
for c in range(data.shape[0])
-
]
-
)
-
-
missing_points = []
-
projected_locs = ndimage.measurements.center_of_mass(data > 0, data, ids)
-
projected_locs = [
-
np.array(loc[-self.spatial_dims :]) * self.voxel_size
-
+ target_roi.get_begin()
-
for loc in projected_locs
-
]
-
node_dict = {node.id:node for node in nodes}
-
for point_id, proj_loc in zip(ids, projected_locs):
-
point = node_dict.pop(point_id)
-
if not any([np.isnan(x) for x in proj_loc]):
-
assert (
-
len(proj_loc) == self.spatial_dims
-
), "projected location has wrong number of dimensions: {}, expected: {}".format(
-
len(proj_loc), self.spatial_dims
-
)
-
point.location[-self.spatial_dims :] = proj_loc
-
else:
-
missing_points.append(point)
-
for node in node_dict.values():
-
missing_points.append(point)
-
logger.debug(
-
"{} of {} points lost in fast points projection".format(
-
len(missing_points), len(ids)
-
)
-
)
-
-
return missing_points
-
-
def __project(self, transformation, location):
-
"""Find the projection of location given by transformation. Returns None
-
if projection lies outside of transformation."""
-
-
dims = len(location)
-
-
# subtract location from transformation
-
diff = transformation.copy()
-
for d in range(dims):
-
diff[d] -= location[d]
-
-
# square
-
diff2 = diff * diff
-
-
# sum
-
dist = diff2.sum(axis=0)
-
-
# find grid point closes to location
-
center_grid = Coordinate(np.unravel_index(dist.argmin(), dist.shape))
-
center_source = self.__source_at(transformation, center_grid)
-
-
logger.debug("projecting %s onto grid", location)
-
logger.debug("grid shape: %s", transformation.shape[1:])
-
logger.debug("grid projection: %s", center_grid)
-
logger.debug("dist shape: %s", dist.shape)
-
logger.debug("dist.argmin(): %s", dist.argmin())
-
logger.debug("dist[argmin]: %s", dist[center_grid])
-
logger.debug("transform[argmin]: %s",
-
transformation[(slice(None),) + center_grid])
-
logger.debug("min dist: %s", dist.min())
-
logger.debug("center source: %s", center_source)
-
-
# inspect grid edges incident to center_grid
-
for d in range(dims):
-
-
# nothing to do for dimensions without spatial extent
-
if transformation.shape[1 + d] == 1:
-
continue
-
-
dim_vector = tuple(1 if dd == d else 0 for dd in range(dims))
-
pos_grid = center_grid + dim_vector
-
neg_grid = center_grid - dim_vector
-
logger.debug("interpolating along %s", dim_vector)
-
-
pos_u = -1
-
neg_u = -1
-
-
if pos_grid[d] < transformation.shape[1 + d]:
-
pos_source = self.__source_at(transformation, pos_grid)
-
logger.debug("pos source: %s", pos_source)
-
pos_dist = pos_source[d] - center_source[d]
-
loc_dist = location[d] - center_source[d]
-
if pos_dist != 0:
-
pos_u = loc_dist / pos_dist
-
else:
-
pos_u = 0
-
-
if neg_grid[d] >= 0:
-
neg_source = self.__source_at(transformation, neg_grid)
-
logger.debug("neg source: %s", neg_source)
-
neg_dist = neg_source[d] - center_source[d]
-
loc_dist = location[d] - center_source[d]
-
if neg_dist != 0:
-
neg_u = loc_dist / neg_dist
-
else:
-
neg_u = 0
-
-
logger.debug("pos u/neg u: %s/%s", pos_u, neg_u)
-
-
# if a point only falls behind edges, it lies outside of the grid
-
if pos_u < 0 and neg_u < 0:
-
return None
-
-
return np.array(center_grid, dtype=np.float32)
-
-
def __source_at(self, transformation, index):
-
"""Read the source point of a transformation at index."""
-
-
slices = (slice(None),) + tuple(slice(i, i + 1) for i in index)
-
return transformation[slices].flatten()
-
-
def __get_source_roi(self, transformation):
-
-
dims = transformation.shape[0]
-
-
# get bounding box of needed data for transformation
-
bb_min = Coordinate(
-
int(math.floor(transformation[d].min())) for d in range(dims)
-
)
-
bb_max = Coordinate(
-
int(math.ceil(transformation[d].max())) + 1 for d in range(dims)
-
)
-
-
# create roi sufficiently large to feed transformation
-
source_roi = Roi(bb_min, bb_max - bb_min)
-
-
return source_roi
-
-
def __shift_transformation(self, shift, transformation):
-
-
for d in range(transformation.shape[0]):
-
transformation[d] += shift[d]
-
-
def __misalign(self, transformation):
-
-
assert (
-
transformation.shape[0] == 3
-
), "misalign can only be applied to 3D volumes"
-
-
num_sections = transformation[0].shape[0]
-
-
shifts = [Coordinate((0, 0, 0))] * num_sections
-
for z in range(num_sections):
-
-
r = random.random()
-
-
if r <= self.prob_slip:
-
-
shifts[z] = self.__random_offset()
-
-
elif r <= self.prob_slip + self.prob_shift:
-
-
offset = self.__random_offset()
-
for zp in range(z, num_sections):
-
shifts[zp] += offset
-
-
logger.debug("misaligning sections with " + str(shifts))
-
-
dims = 3
-
bb_min = tuple(int(math.floor(transformation[d].min())) for d in range(dims))
-
bb_max = tuple(int(math.ceil(transformation[d].max())) + 1 for d in range(dims))
-
logger.debug("min/max of transformation: " + str(bb_min) + "/" + str(bb_max))
-
-
for z in range(num_sections):
-
transformation[1][z, :, :] += shifts[z][1]
-
transformation[2][z, :, :] += shifts[z][2]
-
-
bb_min = tuple(int(math.floor(transformation[d].min())) for d in range(dims))
-
bb_max = tuple(int(math.ceil(transformation[d].max())) + 1 for d in range(dims))
-
logger.debug(
-
"min/max of transformation after misalignment: "
-
+ str(bb_min)
-
+ "/"
-
+ str(bb_max)
-
)
-
-
def __random_offset(self):
-
-
return Coordinate(
-
(0,)
-
+ tuple(
-
self.max_misalign - random.randint(0, 2 * int(self.max_misalign))
-
for d in range(2)
-
)
-
)
-