Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute initializations on CPU much faster #1056

Merged
merged 17 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions desc/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Backend functions for DESC, with options for JAX or regular numpy."""

import functools
import os
import warnings

Expand Down Expand Up @@ -112,6 +113,28 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def set_default_cpu(func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could come up with a better/more descriptive name, default_device_cpu? or something? but not a dealbreaker

"""Decorator to set default device to CPU for a function.

Parameters
----------
func : callable
Function to decorate

Returns
-------
wrapper : callable
Decorated function that will run always on CPU even if
there are available GPUs.
"""

@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)

return wrapper

def sign(x):
"""Sign function, but returns 1 for x==0.

Expand Down Expand Up @@ -374,6 +397,7 @@ def tangent_solve(g, y):
# for coverage purposes
else: # pragma: no cover
jit = lambda func, *args, **kwargs: func
set_default_cpu = lambda func: func
import scipy.optimize
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
Expand Down
6 changes: 5 additions & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from termcolor import colored

from desc.backend import cond, fori_loop, jnp, put
from desc.backend import cond, fori_loop, jnp, put, set_default_cpu
from desc.grid import ConcentricGrid, Grid, LinearGrid
from desc.utils import errorif

Expand Down Expand Up @@ -177,6 +177,7 @@ def _compute(
return data


@set_default_cpu
def get_data_deps(keys, obj, has_axis=False, basis="rpz"):
"""Get list of data keys needed to compute a given quantity.

Expand Down Expand Up @@ -227,6 +228,7 @@ def _get_deps_1_key(key):
return sorted(list(set(out)))


@set_default_cpu
def get_derivs(keys, obj, has_axis=False, basis="rpz"):
"""Get dict of derivative orders needed to compute a given quantity.

Expand Down Expand Up @@ -316,6 +318,7 @@ def get_profiles(keys, obj, grid=None, has_axis=False, basis="rpz"):
return profiles


@set_default_cpu
def get_params(keys, obj, has_axis=False, basis="rpz"):
"""Get parameters needed to compute a given quantity.

Expand Down Expand Up @@ -356,6 +359,7 @@ def get_params(keys, obj, has_axis=False, basis="rpz"):
return temp_params


@set_default_cpu
def get_transforms(
keys, obj, grid, jitable=False, has_axis=False, basis="rpz", **kwargs
):
Expand Down
6 changes: 5 additions & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.constants import mu_0
from termcolor import colored

from desc.backend import jnp
from desc.backend import jnp, set_default_cpu
from desc.basis import FourierZernikeBasis, fourier, zernike_radial
from desc.compat import ensure_positive_jacobian
from desc.compute import compute as compute_fun
Expand Down Expand Up @@ -156,6 +156,7 @@ class Equilibrium(IOAble, Optimizable):
"_N_grid",
]

@set_default_cpu
def __init__(
self,
Psi=1.0,
Expand Down Expand Up @@ -525,6 +526,7 @@ def copy(self, deepcopy=True):
new = copy.copy(self)
return new

@set_default_cpu
def change_resolution(
self,
L=None,
Expand Down Expand Up @@ -610,6 +612,7 @@ def change_resolution(
self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes)
self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes)

@set_default_cpu
def get_surface_at(self, rho=None, theta=None, zeta=None):
"""Return a representation for a given coordinate surface.

Expand Down Expand Up @@ -1150,6 +1153,7 @@ def compute_theta_coords(
**kwargs,
)

@set_default_cpu
def is_nested(self, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None):
"""Check that an equilibrium has properly nested flux surfaces in a plane.

Expand Down
15 changes: 14 additions & 1 deletion desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@

import numpy as np

from desc.backend import block_diag, jit, jnp, put, root_scalar, sign, vmap
from desc.backend import (
block_diag,
jit,
jnp,
put,
root_scalar,
set_default_cpu,
sign,
vmap,
)
from desc.basis import DoubleFourierSeries, ZernikePolynomial
from desc.compute import rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from desc.grid import Grid, LinearGrid
Expand Down Expand Up @@ -57,6 +66,7 @@ class FourierRZToroidalSurface(Surface):
"_rho",
]

@set_default_cpu
def __init__(
self,
R_lmn=None,
Expand Down Expand Up @@ -165,6 +175,7 @@ def rho(self):
def rho(self, rho):
self._rho = rho

@set_default_cpu
def change_resolution(self, *args, **kwargs):
"""Change the maximum poloidal and toroidal resolution."""
assert (
Expand Down Expand Up @@ -801,6 +812,7 @@ class ZernikeRZToroidalSection(Surface):
"_zeta",
]

@set_default_cpu
def __init__(
self,
R_lmn=None,
Expand Down Expand Up @@ -912,6 +924,7 @@ def zeta(self):
def zeta(self, zeta):
self._zeta = zeta

@set_default_cpu
def change_resolution(self, *args, **kwargs):
"""Change the maximum radial and poloidal resolution."""
assert (
Expand Down
9 changes: 8 additions & 1 deletion desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from termcolor import colored

from desc.backend import jnp, tree_leaves, tree_map, tree_structure
from desc.backend import jnp, set_default_cpu, tree_leaves, tree_map, tree_structure
from desc.basis import zernike_radial, zernike_radial_coeffs
from desc.utils import broadcast_tree, errorif, setdefault

Expand Down Expand Up @@ -275,6 +275,7 @@ def __init__(
name=name,
)

@set_default_cpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to apply this to all the build methods? It seems kind of inconsistent with where it is used now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it to the parent objective class before and for my test case it didn't give better results, so I picked suitable build methods intuitively.

Also, maybe one important thing to keep in mind is when we build smt on cpu, probably the arrays are on cpu memory, we may need to copy them to gpu memory as we did during hackathon.

def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -373,6 +374,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -467,6 +469,7 @@ def __init__(
normalize_target=False,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -553,6 +556,7 @@ def __init__(
normalize_target=False,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -1548,6 +1552,7 @@ def __init__(
normalize_target=normalize_target,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -1714,6 +1719,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -3029,6 +3035,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Expand Down
10 changes: 9 additions & 1 deletion desc/objectives/objective_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@

import numpy as np

from desc.backend import jit, jnp, tree_flatten, tree_unflatten, use_jax
from desc.backend import (
jit,
jnp,
set_default_cpu,
tree_flatten,
tree_unflatten,
use_jax,
)
from desc.derivatives import Derivative
from desc.io import IOAble
from desc.optimizable import Optimizable
Expand Down Expand Up @@ -148,6 +155,7 @@ def jit(self): # noqa: C901
if obj._use_jit:
obj.jit()

@set_default_cpu
def build(self, use_jit=None, verbose=1):
"""Build the objective.

Expand Down
Loading