Skip to content

Commit

Permalink
Fix: pacify flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien Marabotto authored and Julien Marabotto committed Jul 15, 2024
1 parent 6a1eb59 commit 2d54dc1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 31 deletions.
24 changes: 12 additions & 12 deletions nitransforms/tests/test_vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
def test_read_path(data_path):
"Check that filepaths are a supported method for loading "
"and reading transforms with PlotDenseField"
PlotDenseField(transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz")
PlotDenseField(transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz")


def test_slice_values():
"""Check that ValueError is issued if negative slices are provided"""
with pytest.raises(ValueError):
PlotDenseField(
transform = np.zeros((10, 10, 10, 3)),
transform=np.zeros((10, 10, 10, 3)),
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
).test_slices(
xslice=-1,
Expand All @@ -43,38 +43,38 @@ def test_slice_values():

def test_show_transform(data_path, output_path):
PlotDenseField(
transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).show_transform(
xslice=50,
yslice=50,
zslice=50,
)
if output_path is not None:
plt.savefig(output_path/"show_transform.svg", bbox_inches="tight")
plt.savefig(output_path/ "show_transform.svg", bbox_inches="tight")
else:
plt.show()


def test_plot_distortion(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
PlotDenseField(
transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_distortion(
axes=axes,
xslice=50,
yslice=50,
zslice=50,
)
if output_path is not None:
plt.savefig(output_path/"plot_distortion.svg", bbox_inches="tight")
plt.savefig(output_path/ "plot_distortion.svg", bbox_inches="tight")
else:
plt.show()


def test_plot_quiverdsm(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
PlotDenseField(
transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_quiverdsm(
axes=axes,
xslice=50,
Expand All @@ -83,7 +83,7 @@ def test_plot_quiverdsm(data_path, output_path):
)

if output_path is not None:
plt.savefig(output_path/"plot_quiverdsm.svg", bbox_inches="tight")
plt.savefig(output_path/ "plot_quiverdsm.svg", bbox_inches="tight")
else:
plt.show()

Expand All @@ -93,7 +93,7 @@ def test_3dquiver(data_path, output_path):
fig = plt.figure()
axes = fig.add_subplot(projection='3d')
PlotDenseField(
transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz",
transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz",
).plot_quiverdsm(
axes=axes,
xslice=None,
Expand All @@ -103,15 +103,15 @@ def test_3dquiver(data_path, output_path):
)

if output_path is not None:
plt.savefig(output_path/"plot_3dquiver.svg", bbox_inches="tight")
plt.savefig(output_path/ "plot_3dquiver.svg", bbox_inches="tight")
else:
plt.show()


def test_plot_jacobian(data_path, output_path):
fig, axes = plt.subplots(1, 3, figsize=(12, 5))
PlotDenseField(
transform = data_path/"ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
transform=data_path/ "ds-005_sub-01_from-OASIS_to-T1_warp_fsl.nii.gz"
).plot_jacobian(
axes=axes,
xslice=50,
Expand All @@ -120,7 +120,7 @@ def test_plot_jacobian(data_path, output_path):
)

if output_path is not None:
plt.savefig(output_path/"plot_jacobian.svg", bbox_inches="tight")
plt.savefig(output_path/ "plot_jacobian.svg", bbox_inches="tight")
else:
plt.show()

Expand Down
36 changes: 17 additions & 19 deletions nitransforms/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class PlotDenseField():
"""
Vizualisation of a transformation file using nitransform's DenseFielTransform module.
Vizualisation of a transformation file using nitransform's DenseFielTransform module.
Generates four sorts of plots:
i) deformed grid superimposed on the normalised deformation field density map\n
iii) quiver map of the field coloured by its diffusion scalar map\n
Expand Down Expand Up @@ -49,15 +49,14 @@ def __init__(self, transform, is_deltas=True, reference=None):
"""if field provided by numpy array (eg tests)"""
deltas = []
for i in range(self._xfm.ndim):
deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i]))
deltas.append((np.max(self._xfm._field[i]) - np.min(self._xfm._field[i]))
/ len(self._xfm._field[i]))

assert np.all(deltas == deltas[0])
assert np.mean(deltas) == deltas[0]

deltas.append(0)
self._voxel_size = deltas

self._voxel_size = deltas

def show_transform(
self,
Expand All @@ -69,7 +68,7 @@ def show_transform(
show_grid=True,
lw=0.1,
save_to_path=None,
):
):
"""
Plot output field from DenseFieldTransform class.
Expand Down Expand Up @@ -171,7 +170,7 @@ def plot_distortion(
show_grid=True,
lw=0.1,
show_titles=True,
):
):
"""
Plot the distortion grid.
Expand Down Expand Up @@ -251,7 +250,7 @@ def plot_distortion(

if show_brain:
axes[index].scatter(dim1, dim2, c=c, cmap='RdPu')

if show_titles:
axes[index].set_title(titles[index], fontsize=14, weight='bold')

Expand All @@ -264,7 +263,7 @@ def plot_quiverdsm(
scaling=1,
three_D=False,
show_titles=True,
):
):
"""
Plot the Diffusion Scalar Map (dsm) as a quiver plot.
Expand Down Expand Up @@ -447,7 +446,6 @@ def plot_jacobian(self, axes, xslice, yslice, zslice, show_titles=True):
ax.set_title(titles[index], fontsize=14, weight='bold')
plt.colorbar(plot, location='bottom', orientation='horizontal', label=str(r'$J$'))


def test_slices(self, xslice, yslice, zslice):
"""Ensure slices are positive and within range of image dimensions"""
xfm = self._xfm._field
Expand Down Expand Up @@ -511,10 +509,10 @@ def get_jacobian(self):
if idx < 3:
dim = zeros[-1,:,:][None,:,:]
ax = 0
elif idx >=3 and idx < 6:
elif idx >= 3 and idx < 6:
dim = zeros[:,-1,:][:,None,:]
ax = 1
elif idx >=6:
elif idx >= 6:
dim = zeros[:,:,-1][:,:,None]
ax = 2

Expand Down Expand Up @@ -554,7 +552,7 @@ def get_planes(self, xslice, yslice, zslice):
s = [slice(None), slice(None), zslice, None] if ind == 2 else s
# For 3d quiver:
if xslice == yslice == zslice is None:
s = [slice(None), slice(None), slice(None), None]
s = [slice(None), slice(None), slice(None), None]

"""computing coordinates within each plane"""
x = x[s[0], s[1], s[2]]
Expand Down Expand Up @@ -586,7 +584,7 @@ def get_planes(self, xslice, yslice, zslice):
def sliders(self, fig, xslice, yslice, zslice):
# This successfully generates a slider, but it cannot be used.
# Currently, slider only acts as a label to show slice values.
# raise NotImplementedError("Slider implementation not finalised.
# raise NotImplementedError("Slider implementation not finalised.
# Static slider can be generated but is not interactive")

xslice, yslice, zslice = self.test_slices(xslice, yslice, zslice)
Expand Down Expand Up @@ -636,9 +634,9 @@ def get_2dcenters(x, y, step=2):


def format_fig(figsize, gs_rows, gs_cols, **kwargs):
params={
params = {
'gs_wspace' : 0,
'gs_hspace': 1 / 8,
'gs_hspace' : 1 / 8,
'suptitle': None,
}
params.update(kwargs)
Expand All @@ -657,15 +655,15 @@ def format_fig(figsize, gs_rows, gs_cols, **kwargs):
hspace=params['gs_hspace']
)

axes=[]
axes = []
for j in range(0, gs_cols):
for i in range(0, gs_rows):
axes.append(fig.add_subplot(gs[i,j]))
return fig, axes


def format_axes(axis, **kwargs):
params={
params = {
'title':None,
'xlabel':"x",
'ylabel':"y",
Expand All @@ -676,7 +674,7 @@ def format_axes(axis, **kwargs):
'rotate_3dlabel':False,
'labelsize':16,
'ticksize':14,
}
}
params.update(kwargs)

'''Format the figure axes. For 2D plots, zlabel and zticks parameters are None.'''
Expand Down

0 comments on commit 2d54dc1

Please sign in to comment.