diff --git a/desc/backend.py b/desc/backend.py index 6e123f49e9..c26213b045 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -1,5 +1,6 @@ """Backend functions for DESC, with options for JAX or regular numpy.""" +import functools import os import warnings @@ -114,6 +115,28 @@ def put(arr, inds, vals): return arr return jnp.asarray(arr).at[inds].set(vals) + def execute_on_cpu(func): + """Decorator to set default device to CPU for a function. + + Parameters + ---------- + func : callable + Function to decorate + + Returns + ------- + wrapper : callable + Decorated function that will run always on CPU even if + there are available GPUs. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with jax.default_device(jax.devices("cpu")[0]): + return func(*args, **kwargs) + + return wrapper + def sign(x): """Sign function, but returns 1 for x==0. @@ -376,6 +399,7 @@ def tangent_solve(g, y): # for coverage purposes else: # pragma: no cover jit = lambda func, *args, **kwargs: func + execute_on_cpu = lambda func: func import scipy.optimize from scipy.integrate import odeint # noqa: F401 from scipy.linalg import ( # noqa: F401 diff --git a/desc/compute/utils.py b/desc/compute/utils.py index 6ff9b6eee7..4567637be3 100644 --- a/desc/compute/utils.py +++ b/desc/compute/utils.py @@ -6,7 +6,7 @@ import numpy as np from termcolor import colored -from desc.backend import cond, fori_loop, jnp, put +from desc.backend import cond, execute_on_cpu, fori_loop, jnp, put from desc.grid import ConcentricGrid, Grid, LinearGrid from ..utils import errorif, warnif @@ -198,6 +198,7 @@ def _compute( return data +@execute_on_cpu def get_data_deps(keys, obj, has_axis=False, basis="rpz", data=None): """Get list of keys needed to compute ``keys`` given already computed data. @@ -357,6 +358,7 @@ def _grow_seeds(parameterization, seeds, search_space, has_axis=False): return out +@execute_on_cpu def get_derivs(keys, obj, has_axis=False, basis="rpz"): """Get dict of derivative orders needed to compute a given quantity. @@ -446,6 +448,7 @@ def get_profiles(keys, obj, grid=None, has_axis=False, basis="rpz"): return profiles +@execute_on_cpu def get_params(keys, obj, has_axis=False, basis="rpz"): """Get parameters needed to compute a given quantity. @@ -486,6 +489,7 @@ def get_params(keys, obj, has_axis=False, basis="rpz"): return temp_params +@execute_on_cpu def get_transforms( keys, obj, grid, jitable=False, has_axis=False, basis="rpz", **kwargs ): diff --git a/desc/equilibrium/equilibrium.py b/desc/equilibrium/equilibrium.py index e2dbc227fb..d321699331 100644 --- a/desc/equilibrium/equilibrium.py +++ b/desc/equilibrium/equilibrium.py @@ -10,7 +10,7 @@ from scipy.constants import mu_0 from termcolor import colored -from desc.backend import jnp +from desc.backend import execute_on_cpu, jnp from desc.basis import FourierZernikeBasis, fourier, zernike_radial from desc.compat import ensure_positive_jacobian from desc.compute import compute as compute_fun @@ -164,6 +164,7 @@ class Equilibrium(IOAble, Optimizable): "_N_grid", ] + @execute_on_cpu def __init__( self, Psi=1.0, @@ -533,6 +534,7 @@ def copy(self, deepcopy=True): new = copy.copy(self) return new + @execute_on_cpu def change_resolution( self, L=None, @@ -618,6 +620,7 @@ def change_resolution( self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes) self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes) + @execute_on_cpu def get_surface_at(self, rho=None, theta=None, zeta=None): """Return a representation for a given coordinate surface. @@ -1256,6 +1259,7 @@ def compute_theta_coords( **kwargs, ) + @execute_on_cpu def is_nested(self, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None): """Check that an equilibrium has properly nested flux surfaces in a plane. diff --git a/desc/geometry/surface.py b/desc/geometry/surface.py index 48747efd9c..79f1b871a9 100644 --- a/desc/geometry/surface.py +++ b/desc/geometry/surface.py @@ -4,7 +4,16 @@ import numpy as np -from desc.backend import block_diag, jit, jnp, put, root_scalar, sign, vmap +from desc.backend import ( + block_diag, + execute_on_cpu, + jit, + jnp, + put, + root_scalar, + sign, + vmap, +) from desc.basis import DoubleFourierSeries, ZernikePolynomial from desc.compute import rpz2xyz_vec, xyz2rpz, xyz2rpz_vec from desc.grid import Grid, LinearGrid @@ -57,6 +66,7 @@ class FourierRZToroidalSurface(Surface): "_rho", ] + @execute_on_cpu def __init__( self, R_lmn=None, @@ -165,6 +175,7 @@ def rho(self): def rho(self, rho): self._rho = rho + @execute_on_cpu def change_resolution(self, *args, **kwargs): """Change the maximum poloidal and toroidal resolution.""" assert ( @@ -801,6 +812,7 @@ class ZernikeRZToroidalSection(Surface): "_zeta", ] + @execute_on_cpu def __init__( self, R_lmn=None, @@ -912,6 +924,7 @@ def zeta(self): def zeta(self, zeta): self._zeta = zeta + @execute_on_cpu def change_resolution(self, *args, **kwargs): """Change the maximum radial and poloidal resolution.""" assert ( diff --git a/desc/objectives/linear_objectives.py b/desc/objectives/linear_objectives.py index 43b8a20feb..f2092d0752 100644 --- a/desc/objectives/linear_objectives.py +++ b/desc/objectives/linear_objectives.py @@ -11,7 +11,7 @@ import numpy as np from termcolor import colored -from desc.backend import jnp, tree_leaves, tree_map, tree_structure +from desc.backend import execute_on_cpu, jnp, tree_leaves, tree_map, tree_structure from desc.basis import zernike_radial, zernike_radial_coeffs from desc.utils import broadcast_tree, errorif, setdefault @@ -275,6 +275,7 @@ def __init__( name=name, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -373,6 +374,7 @@ def __init__( name=name, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -467,6 +469,7 @@ def __init__( normalize_target=False, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -553,6 +556,7 @@ def __init__( normalize_target=False, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -1548,6 +1552,7 @@ def __init__( normalize_target=normalize_target, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -1714,6 +1719,7 @@ def __init__( name=name, ) + @execute_on_cpu def build(self, use_jit=False, verbose=1): """Build constant arrays. @@ -3029,6 +3035,7 @@ def __init__( name=name, ) + @execute_on_cpu def build(self, use_jit=True, verbose=1): """Build constant arrays. diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index c2ba6387f0..8f7843a2cf 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -5,7 +5,7 @@ import numpy as np -from desc.backend import jit, jnp, tree_flatten, tree_unflatten, use_jax +from desc.backend import execute_on_cpu, jit, jnp, tree_flatten, tree_unflatten, use_jax from desc.derivatives import Derivative from desc.io import IOAble from desc.optimizable import Optimizable @@ -148,6 +148,7 @@ def jit(self): # noqa: C901 if obj._use_jit: obj.jit() + @execute_on_cpu def build(self, use_jit=None, verbose=1): """Build the objective.