Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execute initializations on CPU much faster #1056

Merged
merged 17 commits into from
Jul 24, 2024
Merged

Execute initializations on CPU much faster #1056

merged 17 commits into from
Jul 24, 2024

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Jun 14, 2024

Some of the initialization function use a lot of for loop that are extremely slow on GPUs. If we move the execution to CPU, we get at least x5 performance for equilibrium and geometry classes. Since the Github actions are run on CPU, we won't see any speed up, but on GPU interactive sessions, this should speed up plotting etc.

Add the

with jax.default_device(jax.devices("cpu")[0]):

for the parts you believe don't need to be run on GPUs @f0uriest @dpanici @rahulgaur104 @ddudt @kianorr @unalmis

This was actually noticed on previous PR #1053 when the benchmark didn't result in speed up, but the GPU eq init got faster by 5 seconds.

image (1)
image

Copy link

codecov bot commented Jun 14, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 95.24%. Comparing base (e76f80e) to head (d5be1f9).
Report is 1805 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1056   +/-   ##
=======================================
  Coverage   95.24%   95.24%           
=======================================
  Files          87       87           
  Lines       21920    21947   +27     
=======================================
+ Hits        20877    20903   +26     
- Misses       1043     1044    +1     
Files with missing lines Coverage Δ
desc/backend.py 90.24% <100.00%> (+0.43%) ⬆️
desc/compute/utils.py 96.48% <100.00%> (+0.03%) ⬆️
desc/equilibrium/equilibrium.py 96.00% <100.00%> (+0.02%) ⬆️
desc/geometry/surface.py 96.85% <100.00%> (+0.03%) ⬆️
desc/objectives/linear_objectives.py 96.54% <100.00%> (+0.04%) ⬆️
desc/objectives/objective_funs.py 93.95% <100.00%> (+0.01%) ⬆️

... and 1 file with indirect coverage changes

self._spectral_indexing = setdefault(
spectral_indexing, getattr(surface, "spectral_indexing", "ansi")
)
with jax.default_device(jax.devices("cpu")[0]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having something in backend that we can just decorate functions with, without having to mess with jax everywhere.

like

@oncpu
def __init__(self, ...):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried that with

def func(fun, *args, **kwargs):
       with jax.default_device(jax.devices("cpu")[0]):
               fun                     
set_default_device = func

But it complaint about input arguments like R_lmn=..., and the functions we need to use this operator doesn't have same input structure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe I had a bug somewhere else. I would also like to have an operator instead of this version.

@YigitElma
Copy link
Collaborator Author

Building objectives and transforms might also benefit CPU
image
image
The first build is the same, since it is wrapped by eq.init but the others are the same function

Copy link
Contributor

github-actions bot commented Jun 14, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_lowres         |     +0.62 +/- 6.97     | +3.17e-03 +/- 3.57e-02 |  5.16e-01 +/- 3.4e-02  |  5.13e-01 +/- 9.5e-03  |
 test_build_transform_fft_midres         |     +0.10 +/- 1.73     | +5.87e-04 +/- 1.03e-02 |  5.94e-01 +/- 8.2e-03  |  5.94e-01 +/- 6.2e-03  |
 test_build_transform_fft_highres        |     +0.41 +/- 7.84     | +4.01e-03 +/- 7.76e-02 |  9.93e-01 +/- 7.5e-02  |  9.89e-01 +/- 1.8e-02  |
 test_equilibrium_init_lowres            |     +4.17 +/- 5.09     | +1.53e-01 +/- 1.87e-01 |  3.82e+00 +/- 1.9e-01  |  3.67e+00 +/- 1.9e-02  |
 test_equilibrium_init_medres            |     +0.79 +/- 3.65     | +3.27e-02 +/- 1.51e-01 |  4.17e+00 +/- 1.5e-01  |  4.14e+00 +/- 2.1e-02  |
 test_equilibrium_init_highres           |     +0.18 +/- 1.56     | +1.01e-02 +/- 8.67e-02 |  5.55e+00 +/- 8.2e-02  |  5.54e+00 +/- 2.7e-02  |
 test_objective_compile_dshape_current   |     +3.03 +/- 1.11     | +1.15e-01 +/- 4.22e-02 |  3.90e+00 +/- 3.5e-02  |  3.79e+00 +/- 2.4e-02  |
 test_objective_compile_atf              |     +2.79 +/- 2.96     | +2.27e-01 +/- 2.42e-01 |  8.39e+00 +/- 1.0e-01  |  8.17e+00 +/- 2.2e-01  |
 test_objective_compute_dshape_current   |     +0.93 +/- 4.11     | +1.17e-05 +/- 5.13e-05 |  1.26e-03 +/- 2.8e-05  |  1.25e-03 +/- 4.3e-05  |
 test_objective_compute_atf              |     +2.09 +/- 5.85     | +8.83e-05 +/- 2.47e-04 |  4.32e-03 +/- 2.1e-04  |  4.23e-03 +/- 1.3e-04  |
 test_objective_jac_dshape_current       |     +2.28 +/- 9.38     | +8.20e-04 +/- 3.37e-03 |  3.68e-02 +/- 2.7e-03  |  3.59e-02 +/- 2.0e-03  |
 test_objective_jac_atf                  |     +1.56 +/- 3.82     | +2.89e-02 +/- 7.10e-02 |  1.89e+00 +/- 6.1e-02  |  1.86e+00 +/- 3.6e-02  |
-test_perturb_1                          |     +6.29 +/- 1.60     | +8.24e-01 +/- 2.09e-01 |  1.39e+01 +/- 8.6e-02  |  1.31e+01 +/- 1.9e-01  |
-test_perturb_2                          |     +3.89 +/- 0.98     | +7.02e-01 +/- 1.76e-01 |  1.88e+01 +/- 1.6e-01  |  1.80e+01 +/- 7.1e-02  |
 test_proximal_jac_atf                   |     -0.01 +/- 0.92     | -6.09e-04 +/- 6.72e-02 |  7.31e+00 +/- 4.7e-02  |  7.31e+00 +/- 4.8e-02  |
 test_proximal_freeb_compute             |     -0.04 +/- 0.72     | -6.47e-05 +/- 1.27e-03 |  1.77e-01 +/- 9.3e-04  |  1.77e-01 +/- 8.7e-04  |
 test_proximal_freeb_jac                 |     +0.30 +/- 0.75     | +2.18e-02 +/- 5.53e-02 |  7.36e+00 +/- 4.6e-02  |  7.34e+00 +/- 3.0e-02  |
 test_solve_fixed_iter                   |     -0.24 +/- 4.94     | -4.40e-02 +/- 8.94e-01 |  1.80e+01 +/- 6.8e-01  |  1.81e+01 +/- 5.8e-01  |

@YigitElma YigitElma marked this pull request as draft June 14, 2024 19:46
desc/backend.py Outdated
@@ -73,6 +73,7 @@
vmap = jax.vmap
scan = jax.lax.scan
bincount = jnp.bincount
set_default_cpu = jax.default_device(jax.devices("cpu")[0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we want something like this:

def set_default_cpu(func):
    
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        with jax.default_device(jax.devices("cpu")[0]):
            return func(*args, **kwargs)
        
    return wrapper

@YigitElma YigitElma marked this pull request as ready for review June 19, 2024 23:52
@YigitElma YigitElma requested a review from f0uriest June 19, 2024 23:52
dpanici
dpanici previously approved these changes Jun 20, 2024
Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we want all the build methods on cpu? I would think that at high resolution building the transforms would be faster on gpu.

Also see my note about using the context manager, I think what you have now might be setting cpu as the default globally.

@YigitElma
Copy link
Collaborator Author

Are you sure we want all the build methods on cpu? I would think that at high resolution building the transforms would be faster on gpu.

Also see my note about using the context manager, I think what you have now might be setting cpu as the default globally.

I will make some final checks again tomorrow but from previous tests, I saw that building transform LMN=15 is faster on CPU, both pre jit after jit versions. The only build that took longer on CPU before the start of optimization loop was the LinearConstraintProjection. It was 3 times slower on CPU.

I don't think it default everythinh to CPU, because as I add this decorators to new function step by step, I noticed some time changes. So, I believe it only affects the decorated function. To be sure, I can switch to your method, but current version is also working as intended.

@YigitElma
Copy link
Collaborator Author

YigitElma commented Jul 18, 2024

I changed the function to what @f0uriest suggested. It works fine. Here are some profiling results,

With this PR - Master

INIT surface: 0.3768799304962158   --   1.4135913848876953
INIT equilibrium: 2.777209520339966   --  9.259985446929932
INIT change resolution: 0.36481261253356934  --  1.5080368518829346
INIT objective: 0.0001895427703857422  --  0.0001552104949951172
INIT constraints: 0.0001468658447265625  --  0.00015878677368164062
INIT optimizer: 2.3365020751953125e-05  --  2.4318695068359375e-05
Building objective: force
Precomputing transforms
Timer: Precomputing transforms = 1.30 sec  --  2.82 sec
Timer: Objective build = 4.28 sec  --  11.5 sec
Timer: Objective build = 2.63 sec  --  3.70 sec
Timer: Linear constraint projection build = 14.0 sec  --  12.2 sec
Timer: Solution time = 53.0 sec  --  52.3 sec
Timer: Avg time per step = 8.83 sec  --  8.73 sec
INIT solve: 80.54720854759216  --  86.62764954566956

Results are from following script (I edited the times to be more understandable):

from desc import set_device

set_device("gpu")

import numpy as np

from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.objectives import (
    ObjectiveFunction,
    ForceBalance,
    get_fixed_boundary_constraints,
)
from desc.optimize import Optimizer
from desc.plotting import plot_1d, plot_section, plot_surfaces
from desc.profiles import PowerSeriesProfile
import sys
from desc.backend import jax
import time

res = 15

t0 = time.time()
surface_2D = FourierRZToroidalSurface(
    R_lmn=np.array([10, -1]),  # boundary coefficients
    Z_lmn=np.array([1]),
    modes_R=np.array([[0, 0], [1, 0]]),  # [M, N] boundary Fourier modes
    modes_Z=np.array([[-1, 0]]),
    NFP=5,  # number of (toroidal) field periods (does not matter for 2D, but will for 3D solution)
)
print(f"INIT surface: {time.time()-t0}")

t0 = time.time()
eq = Equilibrium(surface=surface_2D, sym=True)
print(f"INIT equilibrium: {time.time()-t0}")

t0 = time.time()
eq.change_resolution(
    L=res,
    M=res,
    N=res,
    L_grid=2 * res,
    M_grid=2 * res,
    N_grid=2 * res,
)
eq.resolution_summary()
print(f"INIT change resolution: {time.time()-t0}")

t0 = time.time()
objective = ObjectiveFunction(ForceBalance(eq=eq))
print(f"INIT objective: {time.time()-t0}")

t0 = time.time()
constraints = get_fixed_boundary_constraints(eq=eq)
print(f"INIT constraints: {time.time()-t0}")

t0 = time.time()
optimizer = Optimizer("lsq-exact")
print(f"INIT optimizer: {time.time()-t0}")


t0 = time.time()
eq, solver_outputs = eq.solve(
        objective=objective,
        constraints=constraints,
        optimizer=optimizer,
        maxiter=5,
        verbose=3,
        #options={"tr_method": "qr"},
    )
print(f"INIT solve: {time.time()-t0}")

@YigitElma YigitElma requested review from f0uriest and dpanici July 18, 2024 13:22
dpanici
dpanici previously approved these changes Jul 18, 2024
ddudt
ddudt previously approved these changes Jul 18, 2024
Copy link
Collaborator

@ddudt ddudt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

@f0uriest
Copy link
Member

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

No, this will only make a difference outside of JIT. Inside JIT everything has to be on a single device (can't mix cpu and gpu, this was what #763 was aiming to fix but haven't been able to get it working).

Copy link
Member

@f0uriest f0uriest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably also add this in a number of other places:

  • to the curve/coil init and change resolution methods
  • magnetic field classes init?
  • profiles?
  • all build methods of all objectives?
  • init/build for LinearConstraintProjection and ProximalProjection
  • factorize_linear_constraints

desc/backend.py Outdated
@@ -112,6 +113,28 @@ def put(arr, inds, vals):
return arr
return jnp.asarray(arr).at[inds].set(vals)

def set_default_cpu(func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we could come up with a better/more descriptive name, default_device_cpu? or something? but not a dealbreaker

@@ -275,6 +275,7 @@ def __init__(
name=name,
)

@set_default_cpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to apply this to all the build methods? It seems kind of inconsistent with where it is used now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added it to the parent objective class before and for my test case it didn't give better results, so I picked suitable build methods intuitively.

Also, maybe one important thing to keep in mind is when we build smt on cpu, probably the arrays are on cpu memory, we may need to copy them to gpu memory as we did during hackathon.

@YigitElma
Copy link
Collaborator Author

We should probably also add this in a number of other places:

  • to the curve/coil init and change resolution methods
  • magnetic field classes init?
  • profiles?
  • all build methods of all objectives?
  • init/build for LinearConstraintProjection and ProximalProjection
  • factorize_linear_constraints

I have tried this for LinearConstraintProjection, it was very slow. I don't know others. I have tried to put it for build methods that have for loops on modes etc.

I can try digging more into profiling when I return back to Princeton, I cannot use Princeton VPN properly on Linux. It causes problems to connect visualization node etc

@YigitElma
Copy link
Collaborator Author

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

I don't know tbh. It wouldn't hurt to try. The compute functions that I know didn't have not vectorized for loops.

@ddudt
Copy link
Collaborator

ddudt commented Jul 18, 2024

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

No, this will only make a difference outside of JIT. Inside JIT everything has to be on a single device (can't mix cpu and gpu, this was what #763 was aiming to fix but haven't been able to get it working).

Hmm OK. So if an external objective has to run on a CPU (like some STELLOPT code), then the whole optimization has to also be on a CPU? That's a bummer but not the end of the world.

@f0uriest
Copy link
Member

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

No, this will only make a difference outside of JIT. Inside JIT everything has to be on a single device (can't mix cpu and gpu, this was what #763 was aiming to fix but haven't been able to get it working).

Hmm OK. So if an external objective has to run on a CPU (like some STELLOPT code), then the whole optimization has to also be on a CPU? That's a bummer but not the end of the world.

No in that case the jaxify stuff will allow you to run just the stellopt part on CPU. But it does that by basically exiting Jax. Within Jax you can't mix CPU/GPU

@ddudt
Copy link
Collaborator

ddudt commented Jul 19, 2024

Could we use this decorator on an objective function's compute method? It might be useful for the external objectives, so that everything else can run on GPU except the external codes that need to run on CPU. Or would that break JAX stuff?

No, this will only make a difference outside of JIT. Inside JIT everything has to be on a single device (can't mix cpu and gpu, this was what #763 was aiming to fix but haven't been able to get it working).

Hmm OK. So if an external objective has to run on a CPU (like some STELLOPT code), then the whole optimization has to also be on a CPU? That's a bummer but not the end of the world.

No in that case the jaxify stuff will allow you to run just the stellopt part on CPU. But it does that by basically exiting Jax. Within Jax you can't mix CPU/GPU

OK maybe we can move this conversation to PR #1028 where I will test this out. I was having trouble running on GPU before, but I think that was related to multiprocessing with JAX. I will test if this decorator works for the external objective.

@ddudt ddudt dismissed stale reviews from dpanici and themself via 5bb27d2 July 23, 2024 19:23
@YigitElma YigitElma requested review from f0uriest, ddudt and dpanici July 24, 2024 15:50
@YigitElma
Copy link
Collaborator Author

I have changed the decorator name from set_default_cpu to execute_on_cpu which I think is the most descriptive name for this. About the additional places that @f0uriest mentioned, can we add them with a separate PR? We can use faster initialization during Simon's and make profiling of other objectives etc later

@YigitElma YigitElma merged commit b15a24f into master Jul 24, 2024
18 checks passed
@YigitElma YigitElma deleted the yge/cpu branch July 24, 2024 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants