diff --git a/discretize/base/base_mesh.py b/discretize/base/base_mesh.py index ce4a2223a..2e8bbd524 100644 --- a/discretize/base/base_mesh.py +++ b/discretize/base/base_mesh.py @@ -5,9 +5,9 @@ import numpy as np import os import json - +from scipy.spatial import KDTree from discretize.utils import mkvc, Identity -from discretize.utils.code_utils import deprecate_property, deprecate_method +from discretize.utils.code_utils import deprecate_property, deprecate_method, as_array_n_by_dim import warnings @@ -793,6 +793,73 @@ def copy(self): items.pop("__class__", None) return cls(**items) + def closest_points_index(self, locations, grid_loc='CC', discard=False): + """Find the indicies for the nearest grid location for a set of points. + + Parameters + ---------- + locations : (n, dim) numpy.ndarray + Points to query. + grid_loc : {'CC', 'N', 'Fx', 'Fy', 'Fz', 'Ex', 'Ex', 'Ey', 'Ez'} + Specifies the grid on which points are being moved to. + discard : bool, optional + Whether to discard the intenally created `scipy.spatial.KDTree`. + + Returns + ------- + (n ) numpy.ndarray of int + Vector of length *n* containing the indicies for the closest + respective cell center, node, face or edge. + + Examples + -------- + Here we define a set of random (x, y) locations and find the closest + cell centers and nodes on a mesh. + + >>> from discretize import TensorMesh + >>> from discretize.utils import closest_points_index + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> h = 2*np.ones(5) + >>> mesh = TensorMesh([h, h], x0='00') + + Define some random locations, grid cell centers and grid nodes, + + >>> xy_random = np.random.uniform(0, 10, size=(4,2)) + >>> xy_centers = mesh.cell_centers + >>> xy_nodes = mesh.nodes + + Find indicies of closest cell centers and nodes, + + >>> ind_centers = mesh.closest_points_index(xy_random, 'cell_centers') + >>> ind_nodes = mesh.closest_points_index(xy_random, 'nodes') + + Plot closest cell centers and nodes + + >>> fig = plt.figure(figsize=(5, 5)) + >>> ax = fig.add_axes([0.1, 0.1, 0.8, 0.8]) + >>> mesh.plot_grid(ax=ax) + >>> ax.scatter(xy_random[:, 0], xy_random[:, 1], 50, 'k') + >>> ax.scatter(xy_centers[ind_centers, 0], xy_centers[ind_centers, 1], 50, 'r') + >>> ax.scatter(xy_nodes[ind_nodes, 0], xy_nodes[ind_nodes, 1], 50, 'b') + >>> plt.show() + """ + locations = as_array_n_by_dim(locations, self.dim) + + grid_loc = self._parse_location_type(grid_loc) + tree_name = f'_{grid_loc}_tree' + + tree = getattr(self, tree_name, None) + if tree is None: + grid = getattr(self, grid_loc) + tree = KDTree(grid) + _, ind = tree.query(locations) + + if not discard: + setattr(self, tree_name, tree) + + return ind + @property def reference_is_rotated(self): """Indicates whether mesh uses standard coordinate axes diff --git a/discretize/utils/mesh_utils.py b/discretize/utils/mesh_utils.py index 18a19ce32..dcc367888 100644 --- a/discretize/utils/mesh_utils.py +++ b/discretize/utils/mesh_utils.py @@ -2,7 +2,7 @@ import scipy.ndimage as ndi import scipy.sparse as sp -from discretize.utils.code_utils import as_array_n_by_dim, is_scalar +from discretize.utils.code_utils import is_scalar from scipy.spatial import cKDTree, Delaunay from scipy import interpolate import discretize @@ -232,20 +232,13 @@ def closest_points_index(mesh, pts, grid_loc="CC", **kwargs): DeprecationWarning, ) grid_loc = kwargs["gridLoc"] - - pts = as_array_n_by_dim(pts, mesh.dim) - grid = getattr(mesh, "grid" + grid_loc) - nodeInds = np.empty(pts.shape[0], dtype=int) - - for i, pt in enumerate(pts): - if mesh.dim == 1: - nodeInds[i] = ((pt - grid) ** 2).argmin() - else: - nodeInds[i] = ( - ((np.tile(pt, (grid.shape[0], 1)) - grid) ** 2).sum(axis=1).argmin() - ) - - return nodeInds + warnings.warn( + "The closest_points_index utilty function has been moved to be a method of " + "a class object. Please access it as mesh.closest_points_index(). This will " + "be removed in a future version of discretize", + DeprecationWarning, + ) + return mesh.closest_points_index(pts, grid_loc=grid_loc, discard=True) def extract_core_mesh(xyzlim, mesh, mesh_type="tensor"):