diff --git a/nitransforms/tests/test_vis.py b/nitransforms/tests/test_vis.py index 7240455f..227378d5 100644 --- a/nitransforms/tests/test_vis.py +++ b/nitransforms/tests/test_vis.py @@ -8,8 +8,9 @@ def test_read_path(data_path): - """Check that filepaths are a supported method for loading and reading transforms with PlotDenseField""" - PlotDenseField(transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz") + "Check that filepaths are a supported method for loading " + "and reading transforms with PlotDenseField" + PlotDenseField(transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz") def test_slice_values(): @@ -24,7 +25,7 @@ def test_slice_values(): zslice=-1, ) - """Check that IndexError is issued if provided slices are beyond range of transform dimensions""" + "Check that IndexError is issued if provided slices are beyond range of transform dimensions" with pytest.raises(IndexError): xfm = DenseFieldTransform( field=np.zeros((10, 10, 10, 3)), @@ -34,22 +35,22 @@ def test_slice_values(): transform=xfm._field, reference=xfm._reference, ).test_slices( - xslice=xfm._field.shape[0]+1, - yslice=xfm._field.shape[1]+1, - zslice=xfm._field.shape[2]+1, + xslice=xfm._field.shape[0] + 1, + yslice=xfm._field.shape[1] + 1, + zslice=xfm._field.shape[2] + 1, ) def test_show_transform(data_path, output_path): PlotDenseField( - transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" ).show_transform( xslice=50, yslice=50, zslice=50, ) if output_path is not None: - plt.savefig(output_path / "show_transform.svg", bbox_inches="tight") + plt.savefig(output_path/"show_transform.svg", bbox_inches="tight") else: plt.show() @@ -57,7 +58,7 @@ def test_show_transform(data_path, output_path): def test_plot_distortion(data_path, output_path): fig, axes = plt.subplots(1, 3, figsize=(12, 4)) PlotDenseField( - transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" ).plot_distortion( axes=axes, xslice=50, @@ -65,7 +66,7 @@ def test_plot_distortion(data_path, output_path): zslice=50, ) if output_path is not None: - plt.savefig(output_path / "plot_distortion.svg", bbox_inches="tight") + plt.savefig(output_path/"plot_distortion.svg", bbox_inches="tight") else: plt.show() @@ -73,7 +74,7 @@ def test_plot_distortion(data_path, output_path): def test_plot_quiverdsm(data_path, output_path): fig, axes = plt.subplots(1, 3, figsize=(12, 4)) PlotDenseField( - transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" ).plot_quiverdsm( axes=axes, xslice=50, @@ -82,7 +83,7 @@ def test_plot_quiverdsm(data_path, output_path): ) if output_path is not None: - plt.savefig(output_path / "plot_quiverdsm.svg", bbox_inches="tight") + plt.savefig(output_path/"plot_quiverdsm.svg", bbox_inches="tight") else: plt.show() @@ -92,7 +93,7 @@ def test_3dquiver(data_path, output_path): fig = plt.figure() axes = fig.add_subplot(projection='3d') PlotDenseField( - transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz", + transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz", ).plot_quiverdsm( axes=axes, xslice=None, @@ -102,15 +103,15 @@ def test_3dquiver(data_path, output_path): ) if output_path is not None: - plt.savefig(output_path / "plot_3dquiver.svg", bbox_inches="tight") + plt.savefig(output_path/"plot_3dquiver.svg", bbox_inches="tight") else: plt.show() - + def test_plot_jacobian(data_path, output_path): fig, axes = plt.subplots(1, 3, figsize=(12, 5)) PlotDenseField( - transform = data_path / "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" + transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz" ).plot_jacobian( axes=axes, xslice=50, @@ -119,9 +120,10 @@ def test_plot_jacobian(data_path, output_path): ) if output_path is not None: - plt.savefig(output_path / "plot_jacobian.svg", bbox_inches="tight") + plt.savefig(output_path/"plot_jacobian.svg", bbox_inches="tight") else: plt.show() + if __name__ == "__main__": pytest.main([__file__]) diff --git a/nitransforms/vis.py b/nitransforms/vis.py index 0712ee01..2935959b 100644 --- a/nitransforms/vis.py +++ b/nitransforms/vis.py @@ -11,12 +11,14 @@ from nitransforms.nonlinear import DenseFieldTransform + class PlotDenseField(): """ - Vizualisation of a transformation file using nitransform's DenseFielTransform module. Generates four sorts of plots: - i) the deformed grid superimposed on the normalised deformation field density map\n - iii) the quiver map of the field, coloured according to its diffusion scalar map\n - iv) the quiver map of the field, coloured according to the jacobian of the coordinate matrices\n + Vizualisation of a transformation file using nitransform's DenseFielTransform module. + Generates four sorts of plots: + i) deformed grid superimposed on the normalised deformation field density map\n + iii) quiver map of the field coloured by its diffusion scalar map\n + iv) quiver map of the field coloured by the jacobian of the coordinate matrices\n for 3 image projections: i) axial (fixed z slice)\n ii) saggital (fixed y slice)\n @@ -43,20 +45,31 @@ def __init__(self, transform, is_deltas=True, reference=None): try: """if field provided by path""" self._voxel_size = nb.load(transform).header.get_zooms() - except: + except TypeError: """if field provided by numpy array (eg tests)""" deltas = [] for i in range(self._xfm.ndim): - deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i])) / len(self._xfm._field[i])) + deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i])) + / len(self._xfm._field[i])) assert np.all(deltas == deltas[0]) assert np.mean(deltas) == deltas[0] - + deltas.append(0) self._voxel_size = deltas - def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, show_grid=True, lw=0.1, save_to_path=None): + def show_transform( + self, + xslice, + yslice, + zslice, + scaling=1, + show_brain=True, + show_grid=True, + lw=0.1, + save_to_path=None, + ): """ Plot output field from DenseFieldTransform class. @@ -104,7 +117,7 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho ) fig.subplots_adjust(bottom=0.15) - projections=["Axial", "Coronal", "Sagittal"] + projections = ["Axial", "Coronal", "Sagittal"] for i, ax in enumerate(axes): if i < 3: xlabel = None @@ -113,28 +126,60 @@ def show_transform(self, xslice, yslice, zslice, scaling=1, show_brain=True, sho xlabel = ylabel = None format_axes(ax, xlabel=xlabel, ylabel=ylabel, labelsize=16) - self.plot_distortion((axes[2], axes[1], axes[0]), xslice, yslice, zslice, show_grid=show_grid, show_brain=show_brain, lw=lw, show_titles=False) - self.plot_quiverdsm((axes[5], axes[4], axes[3]), xslice, yslice, zslice, scaling=scaling, show_titles=False) - self.plot_jacobian((axes[8], axes[7], axes[6]), xslice, yslice, zslice, show_titles=False) + self.plot_distortion( + (axes[2], axes[1], axes[0]), + xslice, + yslice, + zslice, + show_grid=show_grid, + show_brain=show_brain, + lw=lw, + show_titles=False, + ) + self.plot_quiverdsm( + (axes[5], axes[4], axes[3]), + xslice, + yslice, + zslice, + scaling=scaling, + show_titles=False, + ) + self.plot_jacobian( + (axes[8],axes[7], axes[6]), + xslice, + yslice, + zslice, + show_titles=False, + ) - sliders = self.sliders(fig, xslice, yslice, zslice) - #NotImplemented: Interactive slider update here: + # sliders = self.sliders(fig, xslice, yslice, zslice) + # NotImplemented: Interactive slider update here: if save_to_path is not None: assert os.path.isdir(os.path.dirname(save_to_path)) plt.savefig(str(save_to_path), dpi=300) else: pass - - - def plot_distortion(self, axes, xslice, yslice, zslice, show_brain=True, show_grid=True, lw=0.1, show_titles=True): + + def plot_distortion( + self, + axes, + xslice, + yslice, + zslice, + show_brain=True, + show_grid=True, + lw=0.1, + show_titles=True, + ): """ - Plot the distortion grid. + Plot the distortion grid. Parameters ---------- axis :obj:`tuple` - Axes on which the grid should be plotted. Requires 3 axes to illustrate all projections (eg ax1: Axial, ax2: coronal, ax3: Sagittal) + Axes on which the grid should be plotted. Requires 3 axes to illustrate + all projections (eg ax1: Axial, ax2: coronal, ax3: Sagittal) xslice: :obj:`int` x plane to select for axial projection of the transform. yslice: :obj:`int` @@ -142,7 +187,7 @@ def plot_distortion(self, axes, xslice, yslice, zslice, show_brain=True, show_gr zslice: :obj:`int` z plane to select for sagittal prjection of the transform. show_brain: :obj:`bool` - Whether the normalised density map of the distortions should be plotted (Default: True). + Whether the normalised density map of the distortion should be plotted (Default: True). show_grid: :obj:`bool` Whether the distorted grid lines should be plotted (Default: True). lw: :obj:`float` @@ -183,32 +228,51 @@ def plot_distortion(self, axes, xslice, yslice, zslice, show_brain=True, show_gr len1, len2 = shape[0], shape[1] c = np.sqrt(vec1**2 + vec2**2) - c = c/c.max() + c = c / c.max() - if show_grid==True: - x_moved = dim1+vec1 - y_moved = dim2+vec2 + if show_grid: + x_moved = dim1 + vec1 + y_moved = dim2 + vec2 for idx in range(0, len1, 1): - axes[index].plot(x_moved[idx*len2:(idx+1)*len2], y_moved[idx*len2:(idx+1)*len2], c='k', lw=lw) + axes[index].plot( + x_moved[idx * len2:(idx + 1) * len2], + y_moved[idx * len2:(idx + 1) * len2], + c='k', + lw=lw, + ) for idx in range(0, len2, 1): - axes[index].plot(x_moved[idx::len2], y_moved[idx::len2], c='k', lw=lw) - - if show_brain==True: + axes[index].plot( + x_moved[idx::len2], + y_moved[idx::len2], + c='k', + lw=lw, + ) + + if show_brain: axes[index].scatter(dim1, dim2, c=c, cmap='RdPu') - if show_titles==True: + if show_titles: axes[index].set_title(titles[index], fontsize=14, weight='bold') - - def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, show_titles=True): + def plot_quiverdsm( + self, + axes, + xslice, + yslice, + zslice, + scaling=1, + three_D=False, + show_titles=True, + ): """ Plot the Diffusion Scalar Map (dsm) as a quiver plot. Parameters ---------- axis :obj:`tuple` - Axes on which the quiver should be plotted. Requires 3 axes to illustrate the dsm mapped as a quiver plot for each projection. + Axes on which the quiver should be plotted. Requires 3 axes to illustrate + the dsm mapped as a quiver plot for each projection. xslice: :obj:`int` x plane to select for axial projection of the transform. yslice: :obj:`int` @@ -219,7 +283,7 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, Fraction by which the quiver plot arrows are to be scaled (default: 1). three_D: :obj:`bool` Whether the quiver plot is to be projected onto a 3D axis (default: False) - + Example: fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True) PlotDenseField( @@ -246,7 +310,6 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, scaling=10, three_D=True, ) - format_axes(ax) #, xticks=[-250, -200, -150, -100, -50, 0], yticks=[-200, -150, -100, -50, 0], zticks=[-200, -150, -100, -50, 0] plt.show() """ xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) @@ -254,21 +317,21 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, if three_D is not False: raise NotImplementedError("3d Quiver plot not finalised.") - - #finalise 3d quiver below: + + # finalise 3d quiver below: for i, j in enumerate(planes): x, y, z, u, v, w = j magnitude = np.sqrt(u**2 + v**2 + w**2) - clr3d = plt.cm.viridis(magnitude/magnitude.max()) - xyz = axes.quiver(x, y, z, u, v, w, colors=clr3d, length=1/scaling) + clr3d = plt.cm.viridis(magnitude / magnitude.max()) + xyz = axes.quiver(x, y, z, u, v, w, colors=clr3d, length=1 / scaling) plt.colorbar(xyz) else: for index, plane in enumerate(planes): x, y, z, u, v, w = plane c_reds, c_greens, c_blues, zeros = [], [], [], [] - ##Optimise here, matrix operations + # Optimise here, matrix operations for ind, (i, j, k, l, m, n) in enumerate(zip(x, y, z, u, v, w)): if np.abs(u[ind]) > [np.abs(v[ind]) and np.abs(w[ind])]: c_reds.append((i, j, k, l, m, n, u[ind])) @@ -292,22 +355,43 @@ def plot_quiverdsm(self, axes, xslice, yslice, zslice, scaling=1, three_D=False, elif index == 2: dim1, dim2, vec1, vec2 = 0, 1, 3, 4 - axes[index].quiver(c_reds[:, dim1], c_reds[:, dim2], c_reds[:, vec1], c_reds[:, vec2], c_reds[:, -1], cmap='Reds') - axes[index].quiver(c_greens[:, dim1], c_greens[:, dim2], c_greens[:, vec1], c_greens[:, vec2], c_greens[:, -1], cmap='Greens') - axes[index].quiver(c_blues[:, dim1], c_blues[:, dim2], c_blues[:, vec1], c_blues[:, vec2], c_blues[:, -1], cmap='Blues') + axes[index].quiver( + c_reds[:, dim1], + c_reds[:, dim2], + c_reds[:, vec1], + c_reds[:, vec2], + c_reds[:, -1], + cmap='Reds', + ) + axes[index].quiver( + c_greens[:, dim1], + c_greens[:, dim2], + c_greens[:, vec1], + c_greens[:, vec2], + c_greens[:, -1], + cmap='Greens', + ) + axes[index].quiver( + c_blues[:, dim1], + c_blues[:, dim2], + c_blues[:, vec1], + c_blues[:, vec2], + c_blues[:, -1], + cmap='Blues', + ) - if show_titles==True: + if show_titles: axes[index].set_title(titles[index], fontsize=14, weight='bold') - def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True): """ Map the divergence of the transformation field using a quiver plot. - + Parameters ---------- axis :obj:`tuple` - Axes on which the quiver should be plotted. Requires 3 axes to illustrate each projection (eg ax1: Axial, ax2: coronal, ax3: Sagittal) + Axes on which the quiver should be plotted. Requires 3 axes to illustrate + each projection (eg ax1: Axial, ax2: coronal, ax3: Sagittal) xslice: :obj:`int` x plane to select for axial projection of the transform. yslice: :obj:`int` @@ -338,7 +422,7 @@ def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True): ] jacobians = np.zeros((3), dtype=np.ndarray) - #iterating through the three chosen planes to calculate corresponding coordinates + # iterating through the three chosen planes to calculate corresponding coordinates for ind, s in enumerate(slices): s = [xslice, slice(None), slice(None), None] if ind == 0 else s s = [slice(None), yslice, slice(None), None] if ind == 1 else s @@ -359,7 +443,7 @@ def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True): c = jacobians[index] plot = ax.scatter(dim1, dim2, c=c, norm=mpl.colors.CenteredNorm(), cmap='seismic') - if show_titles==True: + if show_titles: ax.set_title(titles[index], fontsize=14, weight='bold') plt.colorbar(plot, location='bottom', orientation='horizontal', label=str(r'$J$')) @@ -373,31 +457,32 @@ def test_slices(self, xslice, yslice, zslice): raise ValueError("Slice values must be positive integers") if int(xslice) > xfm.shape[0]: - raise IndexError(f"x-slice {xslice} out of range for transform object with x-dimension of length {xfm.shape[0]}") + raise IndexError(f"x-slice {xslice} out of range for transform object " + f"with x-dimension of length {xfm.shape[0]}") if int(yslice) > xfm.shape[1]: - raise IndexError(f"y-slice {yslice} out of range for transform object with y-dimension of length {xfm.shape[1]}") + raise IndexError(f"y-slice {yslice} out of range for transform object " + f"with y-dimension of length {xfm.shape[1]}") if int(zslice) > xfm.shape[2]: - raise IndexError(f"z-slice {zslice} out of range for transform object with z-dimension of length {xfm.shape[2]}") + raise IndexError(f"z-slice {zslice} out of range for transform object " + f"with z-dimension of length {xfm.shape[2]}") - return(int(xslice), int(yslice), int(zslice)) + return (int(xslice), int(yslice), int(zslice)) except TypeError as e: """exception for case of 3d quiver plot""" assert str(e) == "'<' not supported between instances of 'NoneType' and 'int'" - return(xslice, yslice, zslice) - + return (xslice, yslice, zslice) def get_coords(self): - """Calculate vector components of the field using the reference coordinates""" - x = self._xfm.reference.ndcoords[0].reshape(np.shape(self._xfm._field[...,-1])) - y = self._xfm.reference.ndcoords[1].reshape(np.shape(self._xfm._field[...,-1])) - z = self._xfm.reference.ndcoords[2].reshape(np.shape(self._xfm._field[...,-1])) - u = self._xfm._field[..., 0] - x - v = self._xfm._field[..., 1] - y - w = self._xfm._field[..., 2] - z - return x, y, z, u, v, w - + """Calculate vector components of the field using the reference coordinates""" + x = self._xfm.reference.ndcoords[0].reshape(np.shape(self._xfm._field[...,-1])) + y = self._xfm.reference.ndcoords[1].reshape(np.shape(self._xfm._field[...,-1])) + z = self._xfm.reference.ndcoords[2].reshape(np.shape(self._xfm._field[...,-1])) + u = self._xfm._field[..., 0] - x + v = self._xfm._field[..., 1] - y + w = self._xfm._field[..., 2] - z + return x, y, z, u, v, w def get_jacobian(self): """Calculate the Jacobian matrix of the field""" @@ -425,13 +510,13 @@ def get_jacobian(self): for idx, j in enumerate(partials): if idx < 3: dim = zeros[-1,:,:][None,:,:] - ax=0 + ax = 0 elif idx >=3 and idx < 6: dim = zeros[:,-1,:][:,None,:] - ax=1 + ax = 1 elif idx >=6: dim = zeros[:,:,-1][:,:,None] - ax=2 + ax = 2 partials[idx] = np.append(j, dim, axis=ax).flatten() @@ -449,16 +534,15 @@ def get_jacobian(self): ) return jacobians - def get_planes(self, xslice, yslice, zslice): """Define slice selection for visualisation""" xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) titles = ["Sagittal", "Coronal", "Axial"] - planes = [0]*3 + planes = [0] * 3 slices = [ [False, False, False, False], # [:,:,index] [False, False, False, False], # [:,index,:] - [False, False, False, False], # [index,:,:] + [False, False, False, False], # [index,:,:] ] for ind, s in enumerate(slices): @@ -468,8 +552,9 @@ def get_planes(self, xslice, yslice, zslice): s = [xslice, slice(None), slice(None), None] if ind == 0 else s s = [slice(None), yslice, slice(None), None] if ind == 1 else s s = [slice(None), slice(None), zslice, None] if ind == 2 else s - #For 3d quiver: - s = [slice(None), slice(None), slice(None), None] if xslice == yslice == zslice == None else s + # For 3d quiver: + if xslice == yslice == zslice is None: + s = [slice(None), slice(None), slice(None), None] """computing coordinates within each plane""" x = x[s[0], s[1], s[2]] @@ -487,34 +572,34 @@ def get_planes(self, xslice, yslice, zslice): w = w.flatten() """check indexing has retrieved correct dimensions""" - if ind==0 and xslice!=None: + if ind == 0 and xslice is not None: assert x.shape == u.shape == np.shape(self._xfm._field[-1,...,-1].flatten()) - elif ind==1 and yslice!=None: + elif ind == 1 and yslice is not None: assert y.shape == v.shape == np.shape(self._xfm._field[:,-1,:,-1].flatten()) - elif ind==2 and zslice!=None: + elif ind == 2 and zslice is not None: assert z.shape == w.shape == np.shape(self._xfm._field[...,-1,-1].flatten()) """store 3 slices of datapoints, with overall shape [3 x [6 x [data]]]""" planes[ind] = [x, y, z, u, v, w] return planes, titles - def sliders(self, fig, xslice, yslice, zslice): - #This successfully generates a slider, but it cannot be used. - #Currently, slider only acts as a label to show slice values. - #raise NotImplementedError("Slider implementation not finalised. Static slider can be generated but is not interactive") + # This successfully generates a slider, but it cannot be used. + # Currently, slider only acts as a label to show slice values. + # raise NotImplementedError("Slider implementation not finalised. + # Static slider can be generated but is not interactive") xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice) slices = [ [zslice, len(self._xfm._field[0][0]), "zslice"], [yslice, len(self._xfm._field[0]), "yslice"], [xslice, len(self._xfm._field), "xslice"], - ] + ] axes = [ - [1/7, 0.1, 1/7, 0.025], - [3/7, 0.1, 1/7, 0.025], - [5/7, 0.1, 1/7, 0.025], - ] + [1 / 7, 0.1, 1 / 7, 0.025], + [3 / 7, 0.1, 1 / 7, 0.025], + [5 / 7, 0.1, 1 / 7, 0.025], + ] sliders = [] for index, slider_axis in enumerate(axes): @@ -534,28 +619,28 @@ def sliders(self, fig, xslice, yslice, zslice): return sliders - def update_sliders(self, slider): raise NotImplementedError("Interactive sliders not implemented.") - + new_slider = slider.val return new_slider def get_2dcenters(x, y, step=2): - samples_x = np.arange(x.min(), x.max(), step=step).astype(int) - samples_y = np.arange(y.min(), y.max(), step=step).astype(int) + samples_x = np.arange(x.min(), x.max(), step=step).astype(int) + samples_y = np.arange(y.min(), y.max(), step=step).astype(int) - lenx = len(samples_x) - leny = len(samples_y) - return zip(*product(samples_x, samples_y)), lenx, leny + lenx = len(samples_x) + leny = len(samples_y) + return zip(*product(samples_x, samples_y)), lenx, leny def format_fig(figsize, gs_rows, gs_cols, **kwargs): - params={'gs_wspace':0, - 'gs_hspace':1/8, - 'suptitle':None, - } + params={ + 'gs_wspace' : 0, + 'gs_hspace': 1 / 8, + 'suptitle': None, + } params.update(kwargs) fig = plt.figure(figsize=figsize) @@ -608,6 +693,6 @@ def format_axes(axis, **kwargs): axis.xaxis.set_rotate_label(params['rotate_3dlabel']) axis.yaxis.set_rotate_label(params['rotate_3dlabel']) axis.zaxis.set_rotate_label(params['rotate_3dlabel']) - except: + except AttributeError: pass return