Skip to content

Commit

Permalink
Merge pull request #251 from ngodber/curvilinear_mesh_improvements
Browse files Browse the repository at this point in the history
refactor closest_points_index to use cKDTree. This yields significant…
  • Loading branch information
jcapriot authored Oct 14, 2021
2 parents 1fb1ee7 + c799f67 commit 5ebeb22
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 17 deletions.
71 changes: 69 additions & 2 deletions discretize/base/base_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import numpy as np
import os
import json

from scipy.spatial import KDTree
from discretize.utils import mkvc, Identity
from discretize.utils.code_utils import deprecate_property, deprecate_method
from discretize.utils.code_utils import deprecate_property, deprecate_method, as_array_n_by_dim
import warnings


Expand Down Expand Up @@ -793,6 +793,73 @@ def copy(self):
items.pop("__class__", None)
return cls(**items)

def closest_points_index(self, locations, grid_loc='CC', discard=False):
"""Find the indicies for the nearest grid location for a set of points.
Parameters
----------
locations : (n, dim) numpy.ndarray
Points to query.
grid_loc : {'CC', 'N', 'Fx', 'Fy', 'Fz', 'Ex', 'Ex', 'Ey', 'Ez'}
Specifies the grid on which points are being moved to.
discard : bool, optional
Whether to discard the intenally created `scipy.spatial.KDTree`.
Returns
-------
(n ) numpy.ndarray of int
Vector of length *n* containing the indicies for the closest
respective cell center, node, face or edge.
Examples
--------
Here we define a set of random (x, y) locations and find the closest
cell centers and nodes on a mesh.
>>> from discretize import TensorMesh
>>> from discretize.utils import closest_points_index
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> h = 2*np.ones(5)
>>> mesh = TensorMesh([h, h], x0='00')
Define some random locations, grid cell centers and grid nodes,
>>> xy_random = np.random.uniform(0, 10, size=(4,2))
>>> xy_centers = mesh.cell_centers
>>> xy_nodes = mesh.nodes
Find indicies of closest cell centers and nodes,
>>> ind_centers = mesh.closest_points_index(xy_random, 'cell_centers')
>>> ind_nodes = mesh.closest_points_index(xy_random, 'nodes')
Plot closest cell centers and nodes
>>> fig = plt.figure(figsize=(5, 5))
>>> ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
>>> mesh.plot_grid(ax=ax)
>>> ax.scatter(xy_random[:, 0], xy_random[:, 1], 50, 'k')
>>> ax.scatter(xy_centers[ind_centers, 0], xy_centers[ind_centers, 1], 50, 'r')
>>> ax.scatter(xy_nodes[ind_nodes, 0], xy_nodes[ind_nodes, 1], 50, 'b')
>>> plt.show()
"""
locations = as_array_n_by_dim(locations, self.dim)

grid_loc = self._parse_location_type(grid_loc)
tree_name = f'_{grid_loc}_tree'

tree = getattr(self, tree_name, None)
if tree is None:
grid = getattr(self, grid_loc)
tree = KDTree(grid)
_, ind = tree.query(locations)

if not discard:
setattr(self, tree_name, tree)

return ind

@property
def reference_is_rotated(self):
"""Indicates whether mesh uses standard coordinate axes
Expand Down
23 changes: 8 additions & 15 deletions discretize/utils/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.ndimage as ndi
import scipy.sparse as sp

from discretize.utils.code_utils import as_array_n_by_dim, is_scalar
from discretize.utils.code_utils import is_scalar
from scipy.spatial import cKDTree, Delaunay
from scipy import interpolate
import discretize
Expand Down Expand Up @@ -232,20 +232,13 @@ def closest_points_index(mesh, pts, grid_loc="CC", **kwargs):
DeprecationWarning,
)
grid_loc = kwargs["gridLoc"]

pts = as_array_n_by_dim(pts, mesh.dim)
grid = getattr(mesh, "grid" + grid_loc)
nodeInds = np.empty(pts.shape[0], dtype=int)

for i, pt in enumerate(pts):
if mesh.dim == 1:
nodeInds[i] = ((pt - grid) ** 2).argmin()
else:
nodeInds[i] = (
((np.tile(pt, (grid.shape[0], 1)) - grid) ** 2).sum(axis=1).argmin()
)

return nodeInds
warnings.warn(
"The closest_points_index utilty function has been moved to be a method of "
"a class object. Please access it as mesh.closest_points_index(). This will "
"be removed in a future version of discretize",
DeprecationWarning,
)
return mesh.closest_points_index(pts, grid_loc=grid_loc, discard=True)


def extract_core_mesh(xyzlim, mesh, mesh_type="tensor"):
Expand Down

0 comments on commit 5ebeb22

Please sign in to comment.