Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute initializations on CPU much faster #1056

Merged
merged 17 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
vmap = jax.vmap
scan = jax.lax.scan
bincount = jnp.bincount
set_default_cpu = jax.default_device(jax.devices("cpu")[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we want something like this:

def set_default_cpu(func):
    
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with jax.default_device(jax.devices("cpu")[0]):
            return func(*args, **kwargs)
        
    return wrapper

from jax import custom_jvp
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
Expand Down Expand Up @@ -374,6 +375,7 @@ def tangent_solve(g, y):
# for coverage purposes
else: # pragma: no cover
jit = lambda func, *args, **kwargs: func
set_default_cpu = lambda func: func
import scipy.optimize
from scipy.integrate import odeint # noqa: F401
from scipy.linalg import ( # noqa: F401
Expand Down
6 changes: 5 additions & 1 deletion desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from termcolor import colored

from desc.backend import cond, fori_loop, jnp, put
from desc.backend import cond, fori_loop, jnp, put, set_default_cpu
from desc.grid import ConcentricGrid, Grid, LinearGrid

from .data_index import allowed_kwargs, data_index
Expand Down Expand Up @@ -144,6 +144,7 @@ def _compute(
return data


@set_default_cpu
def get_data_deps(keys, obj, has_axis=False):
"""Get list of data keys needed to compute a given quantity.

Expand Down Expand Up @@ -189,6 +190,7 @@ def _get_deps_1_key(key):
return sorted(list(set(out)))


@set_default_cpu
def get_derivs(keys, obj, has_axis=False):
"""Get dict of derivative orders needed to compute a given quantity.

Expand Down Expand Up @@ -274,6 +276,7 @@ def get_profiles(keys, obj, grid=None, has_axis=False, jitable=False, **kwargs):
return profiles


@set_default_cpu
def get_params(keys, obj, has_axis=False, **kwargs):
"""Get parameters needed to compute a given quantity.

Expand Down Expand Up @@ -311,6 +314,7 @@ def get_params(keys, obj, has_axis=False, **kwargs):
return temp_params


@set_default_cpu
def get_transforms(keys, obj, grid, jitable=False, **kwargs):
"""Get transforms needed to compute a given quantity on a given grid.

Expand Down
6 changes: 5 additions & 1 deletion desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy.constants import mu_0
from termcolor import colored

from desc.backend import jnp
from desc.backend import jnp, set_default_cpu
from desc.basis import FourierZernikeBasis, fourier, zernike_radial
from desc.compat import ensure_positive_jacobian
from desc.compute import compute as compute_fun
Expand Down Expand Up @@ -156,6 +156,7 @@ class Equilibrium(IOAble, Optimizable):
"_N_grid",
]

@set_default_cpu
def __init__(
self,
Psi=1.0,
Expand Down Expand Up @@ -525,6 +526,7 @@ def copy(self, deepcopy=True):
new = copy.copy(self)
return new

@set_default_cpu
def change_resolution(
self,
L=None,
Expand Down Expand Up @@ -604,6 +606,7 @@ def change_resolution(
self._Z_lmn = copy_coeffs(self.Z_lmn, old_modes_Z, self.Z_basis.modes)
self._L_lmn = copy_coeffs(self.L_lmn, old_modes_L, self.L_basis.modes)

@set_default_cpu
def get_surface_at(self, rho=None, theta=None, zeta=None):
"""Return a representation for a given coordinate surface.

Expand Down Expand Up @@ -1103,6 +1106,7 @@ def compute_theta_coords(
**kwargs,
)

@set_default_cpu
def is_nested(self, grid=None, R_lmn=None, Z_lmn=None, L_lmn=None, msg=None):
"""Check that an equilibrium has properly nested flux surfaces in a plane.

Expand Down
15 changes: 14 additions & 1 deletion desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@

import numpy as np

from desc.backend import block_diag, jit, jnp, put, root_scalar, sign, vmap
from desc.backend import (
block_diag,
jit,
jnp,
put,
root_scalar,
set_default_cpu,
sign,
vmap,
)
from desc.basis import DoubleFourierSeries, ZernikePolynomial
from desc.compute import rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from desc.grid import Grid, LinearGrid
Expand Down Expand Up @@ -57,6 +66,7 @@ class FourierRZToroidalSurface(Surface):
"_rho",
]

@set_default_cpu
def __init__(
self,
R_lmn=None,
Expand Down Expand Up @@ -165,6 +175,7 @@ def rho(self):
def rho(self, rho):
self._rho = rho

@set_default_cpu
def change_resolution(self, *args, **kwargs):
"""Change the maximum poloidal and toroidal resolution."""
assert (
Expand Down Expand Up @@ -799,6 +810,7 @@ class ZernikeRZToroidalSection(Surface):
"_zeta",
]

@set_default_cpu
def __init__(
self,
R_lmn=None,
Expand Down Expand Up @@ -910,6 +922,7 @@ def zeta(self):
def zeta(self, zeta):
self._zeta = zeta

@set_default_cpu
def change_resolution(self, *args, **kwargs):
"""Change the maximum radial and poloidal resolution."""
assert (
Expand Down
9 changes: 8 additions & 1 deletion desc/objectives/linear_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from termcolor import colored

from desc.backend import jnp, tree_leaves, tree_map, tree_structure
from desc.backend import jnp, set_default_cpu, tree_leaves, tree_map, tree_structure
from desc.basis import zernike_radial, zernike_radial_coeffs
from desc.utils import broadcast_tree, errorif, setdefault

Expand Down Expand Up @@ -275,6 +275,7 @@ def __init__(
name=name,
)

@set_default_cpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to apply this to all the build methods? It seems kind of inconsistent with where it is used now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it to the parent objective class before and for my test case it didn't give better results, so I picked suitable build methods intuitively.

Also, maybe one important thing to keep in mind is when we build smt on cpu, probably the arrays are on cpu memory, we may need to copy them to gpu memory as we did during hackathon.

def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -373,6 +374,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -467,6 +469,7 @@ def __init__(
normalize_target=False,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -553,6 +556,7 @@ def __init__(
normalize_target=False,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -1548,6 +1552,7 @@ def __init__(
normalize_target=normalize_target,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -1714,6 +1719,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=False, verbose=1):
"""Build constant arrays.

Expand Down Expand Up @@ -2880,6 +2886,7 @@ def __init__(
name=name,
)

@set_default_cpu
def build(self, use_jit=True, verbose=1):
"""Build constant arrays.

Expand Down
14 changes: 8 additions & 6 deletions desc/optimize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from termcolor import colored

from desc.backend import jax
from desc.io import IOAble
from desc.objectives import (
FixCurrent,
Expand Down Expand Up @@ -213,12 +214,13 @@ def optimize( # noqa: C901 - FIXME: simplify this
nonlinear_constraint = _combine_constraints(nonlinear_constraints)

# make sure everything is built
if objective is not None and not objective.built:
objective.build(verbose=verbose)
if linear_constraint is not None and not linear_constraint.built:
linear_constraint.build(verbose=verbose)
if nonlinear_constraint is not None and not nonlinear_constraint.built:
nonlinear_constraint.build(verbose=verbose)
with jax.default_device(jax.devices("cpu")[0]):
if objective is not None and not objective.built:
objective.build(verbose=verbose)
if linear_constraint is not None and not linear_constraint.built:
linear_constraint.build(verbose=verbose)
if nonlinear_constraint is not None and not nonlinear_constraint.built:
nonlinear_constraint.build(verbose=verbose)

# combine arguments from all three objective functions
if linear_constraint is not None and nonlinear_constraint is not None:
Expand Down
Loading