Skip to content

Commit

Permalink
Bug fix: Check mask threshold on de-dilated mask array
Browse files Browse the repository at this point in the history
  • Loading branch information
grisaitis committed Nov 29, 2016
1 parent f618a99 commit 457d7dc
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions data_io/dataset_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,11 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
if output_slice is not None:
component_erosion_steps = original_dataset.get('component_erosion_steps', 0)
dilation_amount = 1 + component_erosion_steps
dilated_output_slices = tuple([slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice])
dilated_output_slices = tuple(slice(s.start - dilation_amount, s.stop + dilation_amount, s.step) for s in output_slice)
de_dilation_slices = (Ellipsis,) + tuple(slice(dilation_amount, -dilation_amount) for _ in output_slice)
components, affinities, mask = get_outputs(original_dataset, dilated_output_slices)
mask_threshold = float(original_dataset.get('mask_threshold', 0))
mask_fraction_of_this_batch = np.mean(mask)
mask_fraction_of_this_batch = np.mean(mask[de_dilation_slices])
good_enough = mask_fraction_of_this_batch > mask_threshold
if not good_enough:
return None
Expand All @@ -186,7 +187,6 @@ def get_numpy_dataset(original_dataset, input_slice, output_slice, transform):
affinities = augmented_dilated_dataset["label"]
mask = augmented_dilated_dataset["mask"]
image = augmented_dilated_dataset["data"]
de_dilation_slices = (Ellipsis,) + tuple([slice(dilation_amount, -dilation_amount) for _ in output_slice])
dataset_numpy['components'] = components[de_dilation_slices]
dataset_numpy['label'] = affinities[de_dilation_slices]
dataset_numpy['mask'] = mask[de_dilation_slices]
Expand Down

0 comments on commit 457d7dc

Please sign in to comment.