diff --git a/desc/grid.py b/desc/grid.py index b3757ea37..e6ebbb683 100644 --- a/desc/grid.py +++ b/desc/grid.py @@ -706,8 +706,8 @@ class Grid(_Grid): nodes.reshape((num_poloidal, num_radial, num_toroidal, 3), order="F"). jitable : bool Whether to skip certain checks and conditionals that don't work under jit. - Allows grid to be created on the fly with custom nodes, but weights, symmetry - etc. may be wrong if grid contains duplicate nodes. + Allows grid to be created on the fly with custom nodes, but weights, + symmetry etc. may be wrong if grid contains duplicate nodes. """ def __init__( @@ -793,6 +793,7 @@ def create_meshgrid( coordinates="rtz", period=(np.inf, 2 * np.pi, 2 * np.pi), NFP=1, + jitable=True, **kwargs, ): """Create a tensor-product grid from the given coordinates in a jitable manner. @@ -819,6 +820,10 @@ def create_meshgrid( Only makes sense to change from 1 if last coordinate is periodic with some constant divided by ``NFP`` and the nodes are placed within one field period. + jitable : bool + Whether to skip certain checks and conditionals that don't work under jit. + Allows grid to be created on the fly with custom nodes, but weights, + symmetry etc. may be wrong if grid contains duplicate nodes. Returns ------- @@ -861,10 +866,7 @@ def create_meshgrid( repeat(unique_a_idx // b.size, b.size, total_repeat_length=a.size * b.size), c.size, ) - inverse_b_idx = jnp.tile( - unique_b_idx, - a.size * c.size, - ) + inverse_b_idx = jnp.tile(unique_b_idx, a.size * c.size) inverse_c_idx = repeat(unique_c_idx // (a.size * b.size), (a.size * b.size)) return Grid( nodes=nodes, @@ -875,7 +877,7 @@ def create_meshgrid( NFP=NFP, sort=False, is_meshgrid=True, - jitable=True, + jitable=jitable, _unique_rho_idx=unique_a_idx, _unique_poloidal_idx=unique_b_idx, _unique_zeta_idx=unique_c_idx,