diff --git a/data_io/dataset_reading.py b/data_io/dataset_reading.py index f41d9e2..49b7212 100644 --- a/data_io/dataset_reading.py +++ b/data_io/dataset_reading.py @@ -165,10 +165,11 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform): if output_slice is not None: component_erosion_steps = original_dataset.get('component_erosion_steps', 0) dilation_amount = 1 + component_erosion_steps - dilated_output_slices = tuple([slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice]) + dilated_output_slices = tuple(slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice) + de_dilation_slices = (Ellipsis,) + tuple(slice(dilation_amount, -dilation_amount) for _ in output_slice) components, affinities, mask = get_outputs(original_dataset, dilated_output_slices) mask_threshold = float(original_dataset.get('mask_threshold', 0)) - mask_fraction_of_this_batch = np.mean(mask) + mask_fraction_of_this_batch = np.mean(mask[de_dilation_slices]) good_enough = mask_fraction_of_this_batch > mask_threshold if not good_enough: return None @@ -186,7 +187,6 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform): affinities = augmented_dilated_dataset["label"] mask = augmented_dilated_dataset["mask"] image = augmented_dilated_dataset["data"] - de_dilation_slices = (Ellipsis,) + tuple([slice(dilation_amount, -dilation_amount) for _ in output_slice]) dataset_numpy['components'] = components[de_dilation_slices] dataset_numpy['label'] = affinities[de_dilation_slices] dataset_numpy['mask'] = mask[de_dilation_slices]