diff --git a/mittens/_mittens.py b/mittens/_mittens.py index 2461ef8..57a710f 100644 --- a/mittens/_mittens.py +++ b/mittens/_mittens.py @@ -39,7 +39,7 @@ neighborsY = [o[1] for o in opposites] + [o[0] for o in opposites] class MITTENS(Spatial): - def __init__(self, fibgz_file="", nifti_prefix="", + def __init__(self, odf_file="", odf_array = "", nifti_prefix="", real_affine_image="", mask_image="", step_size=np.sqrt(3)/2. , angle_max=35, odf_resolution="odf8", angle_weights="flat", angle_weighting_power=1.,normalize_doubleODF=True): @@ -50,8 +50,11 @@ def __init__(self, fibgz_file="", nifti_prefix="", Parameters: =========== - fibgz_file:str - Path to a dsi studio fib.gz file + odf_file:str + Path to a mrtrix amplitudes nii.gz file or dsi studio fib.gz file + odf_array:numpy array + Array of odf data sampled from a symmetric sphere, default is 321 directions (odf8). + Requires real_affine_image path. Assumes RAS orientation. nifti_prefix:str Prefix used when calculating singleODF and/or doubleODF transition probabilities. @@ -62,8 +65,7 @@ def __init__(self, fibgz_file="", nifti_prefix="", default affine from DSI Studio will be used. mask_image:str Path to a NIfTI file that has nonzero values in voxels that will be used - as nodes in the graph. If none is provided, the default mask estimated by - DSI Studio is used. + as nodes in the graph. If none is provided, the default mask estimated from the ODFs is used. step_size:float Step size in voxel units. Used for calculating transition probabilities angle_max:float @@ -86,8 +88,8 @@ def __init__(self, fibgz_file="", nifti_prefix="", If you're unable to initialize a MITTENS object with your desired combination, try downloading or generating/compiling the necessary Fortran modules. """ - if fibgz_file == nifti_prefix == "": - raise ValueError("Must provide either a DSI Studio fib file or prefix to " + if odf_file == "" and nifti_prefix == "" and type(odf_array) == str: + raise ValueError("Must provide either a DSI Studio fib file, nifti ODF file, ODF array, or prefix to " "NIfTI1 images written out by a previous run") # These will get filled out from loading a fibgz or niftis self.flat_mask = None @@ -116,16 +118,20 @@ def __init__(self, fibgz_file="", nifti_prefix="", get_transition_analysis_matrices(self.odf_resolution, self.angle_max, self.angle_weights, self.angle_weighting_power) self.n_unique_vertices = self.odf_vertices.shape[0]//2 - if fibgz_file: - logger.info("Loading DSI Studio fib file") - self._load_fibgz(fibgz_file) + self._initialize_nulls() + self._set_real_affine(real_affine_image) + + if odf_file: + logger.info("Loading ODF file") + self._load_odf(odf_file) + if np.ndarray == type(odf_array): + if not real_affine_image: raise ValueError("Must specify path to the NIFTI volume containing the affine") + logger.info("Loading ODF array") + self._load_odf_array(odf_array) if nifti_prefix: logger.info("Loading output from pre-existing NIfTIs") self._load_niftis(nifti_prefix) - self._initialize_nulls() - self._set_real_affine(real_affine_image) - def _initialize_nulls(self): # Note, only entries for unique vertices are created, but they @@ -156,75 +162,102 @@ def _initialize_nulls(self): if self.normalize_doubleODF: self.doubleODF_null_probs = self.doubleODF_null_probs / self.doubleODF_null_probs.sum() - def _load_fibgz(self, path): - logger.info("Loading %s", path) - f = load_fibgz(path) - logger.info("Loaded %s", path) - self.orientation = "lps" - # Check that this fib file matches what we expect - fib_odf_vertices = f['odf_vertices'].T - matches = np.allclose(self.odf_vertices, fib_odf_vertices) - if not matches: - logger.critical("ODF Angles in fib file do not match %s", self.odf_resolution) - return - - # Extract the spacing info from the fib file - self.volume_grid = f['dimension'].squeeze() + def _load_odf_array(self, odf_array): + self.volume_grid = odf_array.shape[:3] aff = np.ones(4,dtype=np.float) - aff[:3] = f['voxel_size'].squeeze() - # DSI Studio stores data in LPS+ - #aff = aff * np.array([-1,-1,1,1]) + aff[:3] = self.real_affine[0][0] self.ras_affine = np.diag(aff) - self.voxel_size = aff[:3] + numSamples = odf_array.shape[-1]//2 + odf_array = odf_array[::-1,::-1,:,:numSamples] + odf_array = odf_array.reshape(np.prod(odf_array.shape[:3]),odf_array.shape[-1], order="F") + odf_array[odf_array < 0] = 0 + odf_sum = odf_array.sum(1) + odf_sum_mask = odf_sum > 0 + if op.exists(self.mask_image): + mask = nib.load(self.mask_image) + self.flat_mask = mask.get_data() + self.flat_mask = self.flat_mask[::-1,::-1,:] + self.flat_mask = self.flat_mask.flatten(order="F") > 0 + else: + self.flat_mask = np.ones(self.volume_grid, dtype = np.bool).flatten() + self.flat_mask = odf_sum_mask & self.flat_mask + self.odf_values = odf_array[self.flat_mask,:].astype(np.float64) - # Coordinate mapping information from fib file - self.flat_mask = f["fa0"].squeeze() > 0 self.nvoxels = self.flat_mask.sum() self.voxel_coords = np.array(np.unravel_index( np.flatnonzero(self.flat_mask), self.volume_grid, order="F")).T self.coordinate_lut = dict( [(tuple(coord), n) for n,coord in enumerate(self.voxel_coords)]) - """ - This is not necessary. There will always be voxels without 26 - outgoing edges. + norm_factor = self.odf_values.sum(1) + norm_factor[norm_factor == 0] = 1. + self.odf_values = self.odf_values / norm_factor[:,np.newaxis] * 0.5 + logger.info("Loaded ODF data: %s",str(self.odf_values.shape)) + self.orientation = "lps" - # Which probabilities point to a voxel not in the mask? - logger.info("Checking for voxels with 26 neighbors") - all_neighbors_in_graph = np.ones((self.nvoxels), dtype=np.bool) - for j, starting_voxel in enumerate(self.voxel_coords): - for i, name in enumerate(neighbor_names): - coord = tuple(starting_voxel + lps_neighbor_shifts[name]) - if not coord in self.coordinate_lut: - all_neighbors_in_graph[j] = False - break - - # Make new masks - flat_output = np.zeros(np.prod(self.volume_grid)) - flat_output[self.flat_mask] = all_neighbors_in_graph - self.flat_mask = flat_output > 0 - nvoxels = self.flat_mask.sum() - logger.info("Removed %d/%d voxels with incomplete neighbors", - self.nvoxels - nvoxels, self.nvoxels) - self.nvoxels = nvoxels + def _load_odf(self, path): + logger.info("Loading %s", path) + if path.find("nii") > 0: + mrodfs = nib.load(path) + self.volume_grid = mrodfs.shape[:3] + print("Loading mrTRIX FOD file") + if op.exists(self.mask_image): + mrmask = nib.load(self.mask_image) + self.flat_mask = mrmask.get_data() + self.flat_mask = self.flat_mask[::-1,::-1,:] + self.flat_mask = self.flat_mask.flatten(order="F") > 0 + else: + self.flat_mask = np.ones(self.volume_grid, dtype = np.bool).flatten() + aff = np.ones(4,dtype=np.float) + aff[:3] = mrodfs.header.get_zooms()[0] + odfs = mrodfs.get_data() + odfs = odfs[::-1,::-1,:,:] + odfs = odfs.reshape(np.prod(odfs.shape[:3]),odfs.shape[-1], order="F") + odf_sum = odfs.sum(1) + odf_sum_mask = odf_sum > 0 + self.flat_mask = odf_sum_mask & self.flat_mask + #self.flat_mask[vox_mask_inds[~odf_sum_mask]] = False + self.odf_values = odfs[self.flat_mask,:].astype(np.float64) + + else: + f = load_fibgz(path) + self.volume_grid = f['dimension'].squeeze() + aff = np.ones(4,dtype=np.float) + aff[:3] = f['voxel_size'].squeeze() + # Create a contiguous ODF matrix, skipping all zero rows + logger.info("Loading DSI Studio ODF data") + odf_vars = [k for k in f.keys() if re.match("odf\\d+",k)] + valid_odfs = [] + self.flat_mask = f["fa0"].squeeze() > 0 + for n in range(len(odf_vars)): + varname = "odf%d" % n + odfs = f[varname] + odf_sum = odfs.sum(0) + odf_sum_mask = odf_sum > 0 + valid_odfs.append(odfs[:,odf_sum_mask].T) + + self.odf_values = np.row_stack(valid_odfs).astype(np.float64) + + self.orientation = "lps" + logger.info("Loaded %s", path) + # Check that this fib file matches what we expect + #fib_odf_vertices = f['odf_vertices'].T + #matches = np.allclose(self.odf_vertices, fib_odf_vertices) + #if not matches: + #logger.critical("ODF Angles in fib file do not match %s", self.odf_resolution) + #return + # DSI Studio stores data in LPS+ + #aff = aff * np.array([-1,-1,1,1]) + self.ras_affine = np.diag(aff) + self.voxel_size = aff[:3] + + # Coordinate mapping information from fib file + self.nvoxels = self.flat_mask.sum() self.voxel_coords = np.array(np.unravel_index( np.flatnonzero(self.flat_mask), self.volume_grid, order="F")).T self.coordinate_lut = dict( [(tuple(coord), n) for n,coord in enumerate(self.voxel_coords)]) - """ - # Create a contiguous ODF matrix, skipping all zero rows - logger.info("Loading ODF data") - odf_vars = [k for k in f.keys() if re.match("odf\\d+",k)] - valid_odfs = [] - for n in range(len(odf_vars)): - varname = "odf%d" % n - odfs = f[varname] - odf_sum = odfs.sum(0) - odf_sum_mask = odf_sum > 0 - valid_odfs.append(odfs[:,odf_sum_mask].T) - - self.odf_values = np.row_stack(valid_odfs).astype(np.float64) norm_factor = self.odf_values.sum(1) norm_factor[norm_factor == 0] = 1. self.odf_values = self.odf_values / norm_factor[:,np.newaxis] * 0.5