Skip to content

Commit

Permalink
fix_rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Shotgunosine committed Jun 22, 2024
1 parent 79b7b50 commit 57222fd
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
43 changes: 42 additions & 1 deletion nitransforms/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from nitransforms.base import (
SurfaceMesh
)
import nibabel as nb
from scipy.spatial import KDTree
from scipy.spatial.distance import cdist


class SurfaceTransformBase():
Expand Down Expand Up @@ -113,6 +116,7 @@ def __add__(self, other):
return self.__class__(self.reference, other.moving)
raise NotImplementedError


def _to_hdf5(self, x5_root):
"""Write transform to HDF5 file."""
triangles = x5_root.create_group("Triangles")
Expand Down Expand Up @@ -211,6 +215,7 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
interpolation_method : str
Only barycentric is currently implemented
"""

super().__init__(SurfaceMesh(reference), SurfaceMesh(moving), spherical=True)

self.reference.set_radius()
Expand All @@ -226,6 +231,39 @@ def __init__(self, reference, moving, interpolation_method='barycentric', mat=No
# transform
if mat is None:
self.__calculate_mat()
r_tree = KDTree(self.reference._coords)
m_tree = KDTree(self.moving._coords)
kmr_dists, kmr_closest = m_tree.query(self.reference._coords, k=10)

# invert the triangles to generate a lookup table from vertices to triangle index
tri_lut = dict()
for i, idxs in enumerate(self.moving._triangles):
for x in idxs:
if not x in tri_lut:
tri_lut[x] = [i]
else:
tri_lut[x].append(i)

# calculate the barycentric interpolation weights
bc_weights = []
enclosing = []
for sidx, (point, kmrv) in enumerate(zip(self.reference._coords, kmr_closest)):
close_tris = _find_close_tris(kmrv, tri_lut, self.moving)
ww, ee = _find_weights(point, close_tris, m_tree)
bc_weights.append(ww)
enclosing.append(ee)

# build sparse matrix
# commenting out code for barycentric nearest neighbor
#bary_nearest = []
mat = sparse.lil_array((self.reference._npoints, self.moving._npoints))
for s_ix, dd in enumerate(bc_weights):
for k, v in dd.items():
mat[s_ix, k] = v
# bary_nearest.append(np.array(list(dd.keys()))[np.array(list(dd.values())).argmax()])
# bary_nearest = np.array(bary_nearest)
# transpose so that number of out vertices is columns
self.mat = sparse.csr_array(mat.T)
else:
if isinstance(mat, sparse.csr_array):
self.mat = mat
Expand Down Expand Up @@ -283,7 +321,6 @@ def map(self, x):
return x

def __add__(self, other):

if (isinstance(other, SurfaceResampler)
and (other.interpolation_method == self.interpolation_method)):
return self.__class__(
Expand Down Expand Up @@ -455,6 +492,7 @@ def from_filename(cls, filename=None, reference_path=None, moving_path=None,


def _points_to_triangles(points, triangles):

"""Implementation that vectorizes project of a point to a set of triangles.
from: https://stackoverflow.com/a/32529589
"""
Expand Down Expand Up @@ -495,6 +533,7 @@ def _points_to_triangles(points, triangles):
m2 = v < 0
m3 = d < 0
m4 = a + d > b + e

m5 = ce > bd

t0 = m0 & m1 & m2 & m3
Expand Down Expand Up @@ -588,6 +627,7 @@ def _find_close_tris(kdsv, tri_lut, surface):
def _find_weights(point, close_tris, d_tree):
point = point[np.newaxis, :]
tri_dists = cdist(point, _points_to_triangles(point, close_tris).squeeze())

closest_tri = close_tris[(tri_dists == tri_dists.min()).squeeze()]
# make sure a single closest triangle was found
if closest_tri.shape[0] != 1:
Expand All @@ -599,6 +639,7 @@ def _find_weights(point, close_tris, d_tree):
# Make sure point is actually inside triangle
enclosing = True
if np.all((point > closest_tri).sum(0) != 3):

enclosing = False
_, ct_idxs = d_tree.query(closest_tri)
a = closest_tri[0]
Expand Down
1 change: 1 addition & 0 deletions nitransforms/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import h5py


from ..base import (
SpatialReference,
SampledSpatialData,
Expand Down
6 changes: 6 additions & 0 deletions nitransforms/tests/test_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
SurfaceResampler
)

from nitransforms.base import SurfaceMesh
from nitransforms.surface import SurfaceCoordinateTransform, SurfaceResampler


# def test_surface_transform_npz():
# mat = sparse.random(10, 10, density=0.5)
# xfm = SurfaceCoordinateTransform(mat)
Expand Down Expand Up @@ -42,6 +46,7 @@
# y_none = xfm.apply(x, normalize="none")
# assert y_none.sum() != y_element.sum()
# assert y_none.sum() != y_sum.sum()

def test_SurfaceTransformBase(testdata_path):
# note these transformations are a bit of a weird use of surface transformation, but I'm
# just testing the base class and the io
Expand Down Expand Up @@ -205,3 +210,4 @@ def test_SurfaceResampler(testdata_path, tmpdir):
assert resampling3 == resampling
resampled_thickness3 = resampling3.apply(subj_thickness.agg_data(), normalize='element')
assert np.all(resampled_thickness3 == resampled_thickness)

0 comments on commit 57222fd

Please sign in to comment.