Skip to content

Commit

Permalink
Added support for mrtrix and dipy reconstructed data
Browse files Browse the repository at this point in the history
  • Loading branch information
clintg6 committed Apr 7, 2018
1 parent 946b64e commit 9c31b3f
Showing 1 changed file with 100 additions and 67 deletions.
167 changes: 100 additions & 67 deletions mittens/_mittens.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
neighborsY = [o[1] for o in opposites] + [o[0] for o in opposites]

class MITTENS(Spatial):
def __init__(self, fibgz_file="", nifti_prefix="",
def __init__(self, odf_file="", odf_array = "", nifti_prefix="",
real_affine_image="", mask_image="",
step_size=np.sqrt(3)/2. , angle_max=35, odf_resolution="odf8",
angle_weights="flat", angle_weighting_power=1.,normalize_doubleODF=True):
Expand All @@ -50,8 +50,11 @@ def __init__(self, fibgz_file="", nifti_prefix="",
Parameters:
===========
fibgz_file:str
Path to a dsi studio fib.gz file
odf_file:str
Path to a mrtrix amplitudes nii.gz file or dsi studio fib.gz file
odf_array:numpy array
Array of odf data sampled from a symmetric sphere, default is 321 directions (odf8).
Requires real_affine_image path. Assumes RAS orientation.
nifti_prefix:str
Prefix used when calculating singleODF and/or doubleODF transition
probabilities.
Expand All @@ -62,8 +65,7 @@ def __init__(self, fibgz_file="", nifti_prefix="",
default affine from DSI Studio will be used.
mask_image:str
Path to a NIfTI file that has nonzero values in voxels that will be used
as nodes in the graph. If none is provided, the default mask estimated by
DSI Studio is used.
as nodes in the graph. If none is provided, the default mask estimated from the ODFs is used.
step_size:float
Step size in voxel units. Used for calculating transition probabilities
angle_max:float
Expand All @@ -86,8 +88,8 @@ def __init__(self, fibgz_file="", nifti_prefix="",
If you're unable to initialize a MITTENS object with your desired combination,
try downloading or generating/compiling the necessary Fortran modules.
"""
if fibgz_file == nifti_prefix == "":
raise ValueError("Must provide either a DSI Studio fib file or prefix to "
if odf_file == "" and nifti_prefix == "" and type(odf_array) == str:
raise ValueError("Must provide either a DSI Studio fib file, nifti ODF file, ODF array, or prefix to "
"NIfTI1 images written out by a previous run")
# These will get filled out from loading a fibgz or niftis
self.flat_mask = None
Expand Down Expand Up @@ -116,16 +118,20 @@ def __init__(self, fibgz_file="", nifti_prefix="",
get_transition_analysis_matrices(self.odf_resolution, self.angle_max,
self.angle_weights, self.angle_weighting_power)
self.n_unique_vertices = self.odf_vertices.shape[0]//2
if fibgz_file:
logger.info("Loading DSI Studio fib file")
self._load_fibgz(fibgz_file)
self._initialize_nulls()
self._set_real_affine(real_affine_image)

if odf_file:
logger.info("Loading ODF file")
self._load_odf(odf_file)
if np.ndarray == type(odf_array):
if not real_affine_image: raise ValueError("Must specify path to the NIFTI volume containing the affine")
logger.info("Loading ODF array")
self._load_odf_array(odf_array)
if nifti_prefix:
logger.info("Loading output from pre-existing NIfTIs")
self._load_niftis(nifti_prefix)

self._initialize_nulls()
self._set_real_affine(real_affine_image)

def _initialize_nulls(self):

# Note, only entries for unique vertices are created, but they
Expand Down Expand Up @@ -156,75 +162,102 @@ def _initialize_nulls(self):
if self.normalize_doubleODF:
self.doubleODF_null_probs = self.doubleODF_null_probs / self.doubleODF_null_probs.sum()

def _load_fibgz(self, path):
logger.info("Loading %s", path)
f = load_fibgz(path)
logger.info("Loaded %s", path)
self.orientation = "lps"
# Check that this fib file matches what we expect
fib_odf_vertices = f['odf_vertices'].T
matches = np.allclose(self.odf_vertices, fib_odf_vertices)
if not matches:
logger.critical("ODF Angles in fib file do not match %s", self.odf_resolution)
return

# Extract the spacing info from the fib file
self.volume_grid = f['dimension'].squeeze()
def _load_odf_array(self, odf_array):
self.volume_grid = odf_array.shape[:3]
aff = np.ones(4,dtype=np.float)
aff[:3] = f['voxel_size'].squeeze()
# DSI Studio stores data in LPS+
#aff = aff * np.array([-1,-1,1,1])
aff[:3] = self.real_affine[0][0]
self.ras_affine = np.diag(aff)
self.voxel_size = aff[:3]
numSamples = odf_array.shape[-1]//2
odf_array = odf_array[::-1,::-1,:,:numSamples]
odf_array = odf_array.reshape(np.prod(odf_array.shape[:3]),odf_array.shape[-1], order="F")
odf_array[odf_array < 0] = 0
odf_sum = odf_array.sum(1)
odf_sum_mask = odf_sum > 0
if op.exists(self.mask_image):
mask = nib.load(self.mask_image)
self.flat_mask = mask.get_data()
self.flat_mask = self.flat_mask[::-1,::-1,:]
self.flat_mask = self.flat_mask.flatten(order="F") > 0
else:
self.flat_mask = np.ones(self.volume_grid, dtype = np.bool).flatten()
self.flat_mask = odf_sum_mask & self.flat_mask
self.odf_values = odf_array[self.flat_mask,:].astype(np.float64)

# Coordinate mapping information from fib file
self.flat_mask = f["fa0"].squeeze() > 0
self.nvoxels = self.flat_mask.sum()
self.voxel_coords = np.array(np.unravel_index(
np.flatnonzero(self.flat_mask), self.volume_grid, order="F")).T
self.coordinate_lut = dict(
[(tuple(coord), n) for n,coord in enumerate(self.voxel_coords)])

"""
This is not necessary. There will always be voxels without 26
outgoing edges.
norm_factor = self.odf_values.sum(1)
norm_factor[norm_factor == 0] = 1.
self.odf_values = self.odf_values / norm_factor[:,np.newaxis] * 0.5
logger.info("Loaded ODF data: %s",str(self.odf_values.shape))
self.orientation = "lps"

# Which probabilities point to a voxel not in the mask?
logger.info("Checking for voxels with 26 neighbors")
all_neighbors_in_graph = np.ones((self.nvoxels), dtype=np.bool)
for j, starting_voxel in enumerate(self.voxel_coords):
for i, name in enumerate(neighbor_names):
coord = tuple(starting_voxel + lps_neighbor_shifts[name])
if not coord in self.coordinate_lut:
all_neighbors_in_graph[j] = False
break
# Make new masks
flat_output = np.zeros(np.prod(self.volume_grid))
flat_output[self.flat_mask] = all_neighbors_in_graph
self.flat_mask = flat_output > 0
nvoxels = self.flat_mask.sum()
logger.info("Removed %d/%d voxels with incomplete neighbors",
self.nvoxels - nvoxels, self.nvoxels)
self.nvoxels = nvoxels
def _load_odf(self, path):
logger.info("Loading %s", path)
if path.find("nii") > 0:
mrodfs = nib.load(path)
self.volume_grid = mrodfs.shape[:3]
print("Loading mrTRIX FOD file")
if op.exists(self.mask_image):
mrmask = nib.load(self.mask_image)
self.flat_mask = mrmask.get_data()
self.flat_mask = self.flat_mask[::-1,::-1,:]
self.flat_mask = self.flat_mask.flatten(order="F") > 0
else:
self.flat_mask = np.ones(self.volume_grid, dtype = np.bool).flatten()
aff = np.ones(4,dtype=np.float)
aff[:3] = mrodfs.header.get_zooms()[0]
odfs = mrodfs.get_data()
odfs = odfs[::-1,::-1,:,:]
odfs = odfs.reshape(np.prod(odfs.shape[:3]),odfs.shape[-1], order="F")
odf_sum = odfs.sum(1)
odf_sum_mask = odf_sum > 0
self.flat_mask = odf_sum_mask & self.flat_mask
#self.flat_mask[vox_mask_inds[~odf_sum_mask]] = False
self.odf_values = odfs[self.flat_mask,:].astype(np.float64)

else:
f = load_fibgz(path)
self.volume_grid = f['dimension'].squeeze()
aff = np.ones(4,dtype=np.float)
aff[:3] = f['voxel_size'].squeeze()
# Create a contiguous ODF matrix, skipping all zero rows
logger.info("Loading DSI Studio ODF data")
odf_vars = [k for k in f.keys() if re.match("odf\\d+",k)]
valid_odfs = []
self.flat_mask = f["fa0"].squeeze() > 0
for n in range(len(odf_vars)):
varname = "odf%d" % n
odfs = f[varname]
odf_sum = odfs.sum(0)
odf_sum_mask = odf_sum > 0
valid_odfs.append(odfs[:,odf_sum_mask].T)

self.odf_values = np.row_stack(valid_odfs).astype(np.float64)

self.orientation = "lps"
logger.info("Loaded %s", path)
# Check that this fib file matches what we expect
#fib_odf_vertices = f['odf_vertices'].T
#matches = np.allclose(self.odf_vertices, fib_odf_vertices)
#if not matches:
#logger.critical("ODF Angles in fib file do not match %s", self.odf_resolution)
#return
# DSI Studio stores data in LPS+
#aff = aff * np.array([-1,-1,1,1])
self.ras_affine = np.diag(aff)
self.voxel_size = aff[:3]

# Coordinate mapping information from fib file
self.nvoxels = self.flat_mask.sum()
self.voxel_coords = np.array(np.unravel_index(
np.flatnonzero(self.flat_mask), self.volume_grid, order="F")).T
self.coordinate_lut = dict(
[(tuple(coord), n) for n,coord in enumerate(self.voxel_coords)])

"""
# Create a contiguous ODF matrix, skipping all zero rows
logger.info("Loading ODF data")
odf_vars = [k for k in f.keys() if re.match("odf\\d+",k)]
valid_odfs = []
for n in range(len(odf_vars)):
varname = "odf%d" % n
odfs = f[varname]
odf_sum = odfs.sum(0)
odf_sum_mask = odf_sum > 0
valid_odfs.append(odfs[:,odf_sum_mask].T)

self.odf_values = np.row_stack(valid_odfs).astype(np.float64)
norm_factor = self.odf_values.sum(1)
norm_factor[norm_factor == 0] = 1.
self.odf_values = self.odf_values / norm_factor[:,np.newaxis] * 0.5
Expand Down

0 comments on commit 9c31b3f

Please sign in to comment.