diff --git a/nitransforms/surface.py b/nitransforms/surface.py index ce8a8069..b9d97f7a 100644 --- a/nitransforms/surface.py +++ b/nitransforms/surface.py @@ -83,18 +83,18 @@ def __init__(self, reference, moving): Parameters ---------- reference: surface - Surface with the destination coordinates for each index. - moving: surface Surface with the starting coordinates for each index. + moving: surface + Surface with the destination coordinates for each index. """ - super().__init__(reference=reference, moving=moving) + super().__init__(reference=SurfaceMesh(reference), moving=SurfaceMesh(moving)) if np.all(self._reference._triangles != self._moving._triangles): raise ValueError("Both surfaces for an index transform must have corresponding" " vertices.") def map(self, x, inverse=False): - if inverse: + if not inverse: source = self.reference dest = self.moving else: @@ -113,6 +113,77 @@ def __add__(self, other): return self.__class__(self.reference, other.moving) raise NotImplementedError + def _to_hdf5(self, x5_root): + """Write transform to HDF5 file.""" + triangles = x5_root.create_group("Triangles") + coords = x5_root.create_group("Coordinates") + coords.create_dataset("0", data=self.reference._coords) + coords.create_dataset("1", data=self.moving._coords) + triangles.create_dataset("0", data=self.reference._triangles) + xform = x5_root.create_group("Transform") + xform.attrs["Type"] = "SurfaceCoordinateTransform" + reference = xform.create_group("Reference") + reference['Coordinates'] = h5py.SoftLink('/0/Coordinates/0') + reference['Triangles'] = h5py.SoftLink('/0/Triangles/0') + moving = xform.create_group("Moving") + moving['Coordinates'] = h5py.SoftLink('/0/Coordinates/1') + moving['Triangles'] = h5py.SoftLink('/0/Triangles/0') + + def to_filename(self, filename, fmt=None): + """Store the transform.""" + if fmt is None: + fmt = "npz" if filename.endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # sparse.save_npz(filename, self.mat) + # return filename + + with h5py.File(filename, "w") as out_file: + out_file.attrs["Format"] = "X5" + out_file.attrs["Version"] = np.uint16(1) + root = out_file.create_group("/0") + self._to_hdf5(root) + + return filename + + @classmethod + def from_filename(cls, filename=None, reference_path=None, moving_path=None, + fmt=None): + """Load transform from file.""" + if filename is None: + if reference_path is None or moving_path is None: + raise ValueError("You must pass either a X5 file or a pair of reference and moving" + " surfaces.") + return cls(SurfaceMesh(nb.load(reference_path)), + SurfaceMesh(nb.load(moving_path))) + + if fmt is None: + try: + fmt = "npz" if filename.endswith(".npz") else "X5" + except AttributeError: + fmt = "npz" if filename.as_posix().endswith(".npz") else "X5" + + if fmt == "npz": + raise NotImplementedError + # return cls(sparse.load_npz(filename)) + + if fmt != "X5": + raise ValueError("Only npz and X5 formats are supported.") + + with h5py.File(filename, "r") as f: + assert f.attrs["Format"] == "X5" + xform = f["/0/Transform"] + reference = SurfaceMesh.from_arrays( + xform['Reference']['Coordinates'], + xform['Reference']['Triangles'] + ) + + moving = SurfaceMesh.from_arrays( + xform['Moving']['Coordinates'], + xform['Moving']['Triangles'] + ) + return cls(reference, moving) class SurfaceResampler(SurfaceTransformBase): """Represents transformations in which the coordinate space remains the same and the indicies @@ -293,17 +364,26 @@ def apply(self, x, inverse=False, normalize="element"): def _to_hdf5(self, x5_root): """Write transform to HDF5 file.""" + triangles = x5_root.create_group("Triangles") + coords = x5_root.create_group("Coordinates") + coords.create_dataset("0", data=self.reference._coords) + coords.create_dataset("1", data=self.moving._coords) + triangles.create_dataset("0", data=self.reference._triangles) + triangles.create_dataset("1", data=self.moving._triangles) xform = x5_root.create_group("Transform") xform.attrs["Type"] = "SurfaceResampling" - xform.attrs['interpolation_method'] = self.interpolation_method - xform.create_dataset("mat_data", data=self.mat.data) - xform.create_dataset("mat_indices", data=self.mat.indices) - xform.create_dataset("mat_indptr", data=self.mat.indptr) - xform.create_dataset("mat_shape", data=self.mat.shape) - xform.create_dataset("reference_coordinates", data=self.reference._coords) - xform.create_dataset("reference_triangles", data=self.reference._triangles) - xform.create_dataset("moving_coordinates", data=self.moving._coords) - xform.create_dataset("moving_triangles", data=self.moving._triangles) + xform.attrs['InterpolationMethod'] = self.interpolation_method + mat = xform.create_group("IndexWeights") + mat.create_dataset("Data", data=self.mat.data) + mat.create_dataset("Indices", data=self.mat.indices) + mat.create_dataset("Indptr", data=self.mat.indptr) + mat.create_dataset("Shape", data=self.mat.shape) + reference = xform.create_group("Reference") + reference['Coordinates'] = h5py.SoftLink('/0/Coordinates/0') + reference['Triangles'] = h5py.SoftLink('/0/Triangles/0') + moving = xform.create_group("Moving") + moving['Coordinates'] = h5py.SoftLink('/0/Coordinates/1') + moving['Triangles'] = h5py.SoftLink('/0/Triangles/1') def to_filename(self, filename, fmt=None): """Store the transform.""" @@ -338,7 +418,10 @@ def from_filename(cls, filename=None, reference_path=None, moving_path=None, interpolation_method=interpolation_method) if fmt is None: - fmt = "npz" if filename.endswith(".npz") else "X5" + try: + fmt = "npz" if filename.endswith(".npz") else "X5" + except AttributeError: + fmt = "npz" if filename.as_posix().endswith(".npz") else "X5" if fmt == "npz": raise NotImplementedError @@ -350,20 +433,24 @@ def from_filename(cls, filename=None, reference_path=None, moving_path=None, with h5py.File(filename, "r") as f: assert f.attrs["Format"] == "X5" xform = f["/0/Transform"] - mat = sparse.csr_matrix( - (xform["mat_data"][()], xform["mat_indices"][()], xform["mat_indptr"][()]), - shape=xform["mat_shape"][()], - ) + try: + iws = xform['IndexWeights'] + mat = sparse.csr_matrix( + (iws["Data"][()], iws["Indices"][()], iws["Indptr"][()]), + shape=iws["Shape"][()], + ) + except KeyError: + mat=None reference = SurfaceMesh.from_arrays( - xform['reference_coordinates'], - xform['reference_triangles'] + xform['Reference']['Coordinates'], + xform['Reference']['Triangles'] ) moving = SurfaceMesh.from_arrays( - xform['moving_coordinates'], - xform['moving_triangles'] + xform['Moving']['Coordinates'], + xform['Moving']['Triangles'] ) - interpolation_method = xform.attrs['interpolation_method'] + interpolation_method = xform.attrs['InterpolationMethod'] return cls(reference, moving, interpolation_method=interpolation_method, mat=mat) diff --git a/nitransforms/tests/test_surface.py b/nitransforms/tests/test_surface.py index a6a17a62..de046edf 100644 --- a/nitransforms/tests/test_surface.py +++ b/nitransforms/tests/test_surface.py @@ -85,14 +85,15 @@ def test_SurfaceCoordinateTransform(testdata_path): # test loading from filenames sct = SurfaceCoordinateTransform(sphere_reg, pial) - sctf = SurfaceCoordinateTransform.from_filename(sphere_reg_path, pial_path) + sctf = SurfaceCoordinateTransform.from_filename(reference_path=sphere_reg_path, + moving_path=pial_path) assert sct == sctf # test mapping - assert np.all(sct.map(sct.moving._coords[:100]) == sct.reference._coords[:100]) - assert np.all(sct.map(sct.reference._coords[:100], inverse=True) == sct.moving._coords[:100]) + assert np.all(sct.map(sct.moving._coords[:100], inverse=True) == sct.reference._coords[:100]) + assert np.all(sct.map(sct.reference._coords[:100]) == sct.moving._coords[:100]) with pytest.raises(NotImplementedError): - sct.map(sct.reference._coords[0]) + sct.map(sct.moving._coords[0]) # test inversion and addition scti = ~sct @@ -106,6 +107,17 @@ def test_SurfaceCoordinateTransform(testdata_path): assert np.all(scti.reference._triangles == sct.reference._triangles) assert scti == sct +def test_SurfaceCoordinateTransformIO(testdata_path, tmpdir): + sphere_reg_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii" + pial_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_pial.surf.gii" + fslr_sphere_path = testdata_path / "tpl-fsLR_hemi-R_den-32k_sphere.surf.gii" + + sct = SurfaceCoordinateTransform(sphere_reg_path, pial_path) + fn = tempfile.mktemp(suffix=".h5") + sct.to_filename(fn) + sct2 = SurfaceCoordinateTransform.from_filename(fn) + assert sct == sct2 + def test_ProjectUnproject(testdata_path): sphere_reg_path = testdata_path / "sub-sid000005_ses-budapest_acq-MPRAGE_hemi-R_space-fsLR_desc-reg_sphere.surf.gii"