-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Execute initializations on CPU much faster #1056
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1056 +/- ##
=======================================
Coverage 95.24% 95.24%
=======================================
Files 87 87
Lines 21920 21947 +27
=======================================
+ Hits 20877 20903 +26
- Misses 1043 1044 +1
|
desc/equilibrium/equilibrium.py
Outdated
self._spectral_indexing = setdefault( | ||
spectral_indexing, getattr(surface, "spectral_indexing", "ansi") | ||
) | ||
with jax.default_device(jax.devices("cpu")[0]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the idea of having something in backend that we can just decorate functions with, without having to mess with jax everywhere.
like
@oncpu
def __init__(self, ...):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried that with
def func(fun, *args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
fun
set_default_device = func
But it complaint about input arguments like R_lmn=...
, and the functions we need to use this operator doesn't have same input structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe I had a bug somewhere else. I would also like to have an operator instead of this version.
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +0.62 +/- 6.97 | +3.17e-03 +/- 3.57e-02 | 5.16e-01 +/- 3.4e-02 | 5.13e-01 +/- 9.5e-03 |
test_build_transform_fft_midres | +0.10 +/- 1.73 | +5.87e-04 +/- 1.03e-02 | 5.94e-01 +/- 8.2e-03 | 5.94e-01 +/- 6.2e-03 |
test_build_transform_fft_highres | +0.41 +/- 7.84 | +4.01e-03 +/- 7.76e-02 | 9.93e-01 +/- 7.5e-02 | 9.89e-01 +/- 1.8e-02 |
test_equilibrium_init_lowres | +4.17 +/- 5.09 | +1.53e-01 +/- 1.87e-01 | 3.82e+00 +/- 1.9e-01 | 3.67e+00 +/- 1.9e-02 |
test_equilibrium_init_medres | +0.79 +/- 3.65 | +3.27e-02 +/- 1.51e-01 | 4.17e+00 +/- 1.5e-01 | 4.14e+00 +/- 2.1e-02 |
test_equilibrium_init_highres | +0.18 +/- 1.56 | +1.01e-02 +/- 8.67e-02 | 5.55e+00 +/- 8.2e-02 | 5.54e+00 +/- 2.7e-02 |
test_objective_compile_dshape_current | +3.03 +/- 1.11 | +1.15e-01 +/- 4.22e-02 | 3.90e+00 +/- 3.5e-02 | 3.79e+00 +/- 2.4e-02 |
test_objective_compile_atf | +2.79 +/- 2.96 | +2.27e-01 +/- 2.42e-01 | 8.39e+00 +/- 1.0e-01 | 8.17e+00 +/- 2.2e-01 |
test_objective_compute_dshape_current | +0.93 +/- 4.11 | +1.17e-05 +/- 5.13e-05 | 1.26e-03 +/- 2.8e-05 | 1.25e-03 +/- 4.3e-05 |
test_objective_compute_atf | +2.09 +/- 5.85 | +8.83e-05 +/- 2.47e-04 | 4.32e-03 +/- 2.1e-04 | 4.23e-03 +/- 1.3e-04 |
test_objective_jac_dshape_current | +2.28 +/- 9.38 | +8.20e-04 +/- 3.37e-03 | 3.68e-02 +/- 2.7e-03 | 3.59e-02 +/- 2.0e-03 |
test_objective_jac_atf | +1.56 +/- 3.82 | +2.89e-02 +/- 7.10e-02 | 1.89e+00 +/- 6.1e-02 | 1.86e+00 +/- 3.6e-02 |
-test_perturb_1 | +6.29 +/- 1.60 | +8.24e-01 +/- 2.09e-01 | 1.39e+01 +/- 8.6e-02 | 1.31e+01 +/- 1.9e-01 |
-test_perturb_2 | +3.89 +/- 0.98 | +7.02e-01 +/- 1.76e-01 | 1.88e+01 +/- 1.6e-01 | 1.80e+01 +/- 7.1e-02 |
test_proximal_jac_atf | -0.01 +/- 0.92 | -6.09e-04 +/- 6.72e-02 | 7.31e+00 +/- 4.7e-02 | 7.31e+00 +/- 4.8e-02 |
test_proximal_freeb_compute | -0.04 +/- 0.72 | -6.47e-05 +/- 1.27e-03 | 1.77e-01 +/- 9.3e-04 | 1.77e-01 +/- 8.7e-04 |
test_proximal_freeb_jac | +0.30 +/- 0.75 | +2.18e-02 +/- 5.53e-02 | 7.36e+00 +/- 4.6e-02 | 7.34e+00 +/- 3.0e-02 |
test_solve_fixed_iter | -0.24 +/- 4.94 | -4.40e-02 +/- 8.94e-01 | 1.80e+01 +/- 6.8e-01 | 1.81e+01 +/- 5.8e-01 | |
desc/backend.py
Outdated
@@ -73,6 +73,7 @@ | |||
vmap = jax.vmap | |||
scan = jax.lax.scan | |||
bincount = jnp.bincount | |||
set_default_cpu = jax.default_device(jax.devices("cpu")[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we want something like this:
def set_default_cpu(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with jax.default_device(jax.devices("cpu")[0]):
return func(*args, **kwargs)
return wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we want all the build methods on cpu? I would think that at high resolution building the transforms would be faster on gpu.
Also see my note about using the context manager, I think what you have now might be setting cpu as the default globally.
I will make some final checks again tomorrow but from previous tests, I saw that building transform LMN=15 is faster on CPU, both pre jit after jit versions. The only build that took longer on CPU before the start of optimization loop was the LinearConstraintProjection. It was 3 times slower on CPU. I don't think it default everythinh to CPU, because as I add this decorators to new function step by step, I noticed some time changes. So, I believe it only affects the decorated function. To be sure, I can switch to your method, but current version is also working as intended. |
…elease JAX changes how it works
I changed the function to what @f0uriest suggested. It works fine. Here are some profiling results, With this PR - MasterINIT surface: 0.3768799304962158 -- 1.4135913848876953
INIT equilibrium: 2.777209520339966 -- 9.259985446929932
INIT change resolution: 0.36481261253356934 -- 1.5080368518829346
INIT objective: 0.0001895427703857422 -- 0.0001552104949951172
INIT constraints: 0.0001468658447265625 -- 0.00015878677368164062
INIT optimizer: 2.3365020751953125e-05 -- 2.4318695068359375e-05
Building objective: force
Precomputing transforms
Timer: Precomputing transforms = 1.30 sec -- 2.82 sec
Timer: Objective build = 4.28 sec -- 11.5 sec
Timer: Objective build = 2.63 sec -- 3.70 sec
Timer: Linear constraint projection build = 14.0 sec -- 12.2 sec
Timer: Solution time = 53.0 sec -- 52.3 sec
Timer: Avg time per step = 8.83 sec -- 8.73 sec
INIT solve: 80.54720854759216 -- 86.62764954566956 Results are from following script (I edited the times to be more understandable): from desc import set_device
set_device("gpu")
import numpy as np
from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.objectives import (
ObjectiveFunction,
ForceBalance,
get_fixed_boundary_constraints,
)
from desc.optimize import Optimizer
from desc.plotting import plot_1d, plot_section, plot_surfaces
from desc.profiles import PowerSeriesProfile
import sys
from desc.backend import jax
import time
res = 15
t0 = time.time()
surface_2D = FourierRZToroidalSurface(
R_lmn=np.array([10, -1]), # boundary coefficients
Z_lmn=np.array([1]),
modes_R=np.array([[0, 0], [1, 0]]), # [M, N] boundary Fourier modes
modes_Z=np.array([[-1, 0]]),
NFP=5, # number of (toroidal) field periods (does not matter for 2D, but will for 3D solution)
)
print(f"INIT surface: {time.time()-t0}")
t0 = time.time()
eq = Equilibrium(surface=surface_2D, sym=True)
print(f"INIT equilibrium: {time.time()-t0}")
t0 = time.time()
eq.change_resolution(
L=res,
M=res,
N=res,
L_grid=2 * res,
M_grid=2 * res,
N_grid=2 * res,
)
eq.resolution_summary()
print(f"INIT change resolution: {time.time()-t0}")
t0 = time.time()
objective = ObjectiveFunction(ForceBalance(eq=eq))
print(f"INIT objective: {time.time()-t0}")
t0 = time.time()
constraints = get_fixed_boundary_constraints(eq=eq)
print(f"INIT constraints: {time.time()-t0}")
t0 = time.time()
optimizer = Optimizer("lsq-exact")
print(f"INIT optimizer: {time.time()-t0}")
t0 = time.time()
eq, solver_outputs = eq.solve(
objective=objective,
constraints=constraints,
optimizer=optimizer,
maxiter=5,
verbose=3,
#options={"tr_method": "qr"},
)
print(f"INIT solve: {time.time()-t0}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use this decorator on an objective function's compute
method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?
No, this will only make a difference outside of JIT. Inside JIT everything has to be on a single device (can't mix cpu and gpu, this was what #763 was aiming to fix but haven't been able to get it working). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably also add this in a number of other places:
- to the curve/coil init and change resolution methods
- magnetic field classes init?
- profiles?
- all build methods of all objectives?
- init/build for
LinearConstraintProjection
andProximalProjection
factorize_linear_constraints
desc/backend.py
Outdated
@@ -112,6 +113,28 @@ def put(arr, inds, vals): | |||
return arr | |||
return jnp.asarray(arr).at[inds].set(vals) | |||
|
|||
def set_default_cpu(func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we could come up with a better/more descriptive name, default_device_cpu
? or something? but not a dealbreaker
desc/objectives/linear_objectives.py
Outdated
@@ -275,6 +275,7 @@ def __init__( | |||
name=name, | |||
) | |||
|
|||
@set_default_cpu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to apply this to all the build methods? It seems kind of inconsistent with where it is used now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added it to the parent objective class before and for my test case it didn't give better results, so I picked suitable build methods intuitively.
Also, maybe one important thing to keep in mind is when we build smt on cpu, probably the arrays are on cpu memory, we may need to copy them to gpu memory as we did during hackathon.
I have tried this for LinearConstraintProjection, it was very slow. I don't know others. I have tried to put it for build methods that have for loops on modes etc. I can try digging more into profiling when I return back to Princeton, I cannot use Princeton VPN properly on Linux. It causes problems to connect visualization node etc |
I don't know tbh. It wouldn't hurt to try. The compute functions that I know didn't have not vectorized for loops. |
Hmm OK. So if an external objective has to run on a CPU (like some STELLOPT code), then the whole optimization has to also be on a CPU? That's a bummer but not the end of the world. |
No in that case the jaxify stuff will allow you to run just the stellopt part on CPU. But it does that by basically exiting Jax. Within Jax you can't mix CPU/GPU |
OK maybe we can move this conversation to PR #1028 where I will test this out. I was having trouble running on GPU before, but I think that was related to multiprocessing with JAX. I will test if this decorator works for the external objective. |
I have changed the decorator name from |
Some of the initialization function use a lot of for loop that are extremely slow on GPUs. If we move the execution to CPU, we get at least x5 performance for equilibrium and geometry classes. Since the Github actions are run on CPU, we won't see any speed up, but on GPU interactive sessions, this should speed up plotting etc.
Add the
for the parts you believe don't need to be run on GPUs @f0uriest @dpanici @rahulgaur104 @ddudt @kianorr @unalmis
This was actually noticed on previous PR #1053 when the benchmark didn't result in speed up, but the GPU eq init got faster by 5 seconds.