-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jac_chunk_size
keyword argument to ObjectiveFunction
to reduce memory usage of forward mode Jacobian calculation
#1052
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1052 +/- ##
==========================================
- Coverage 95.30% 92.19% -3.12%
==========================================
Files 95 96 +1
Lines 23944 23560 -384
==========================================
- Hits 22821 21721 -1100
- Misses 1123 1839 +716
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_lowres | +3.16 +/- 3.96 | +1.69e-02 +/- 2.12e-02 | 5.53e-01 +/- 2.0e-02 | 5.36e-01 +/- 6.8e-03 |
test_equilibrium_init_medres | -0.44 +/- 5.40 | -1.94e-02 +/- 2.36e-01 | 4.34e+00 +/- 1.0e-01 | 4.36e+00 +/- 2.1e-01 |
test_equilibrium_init_highres | -0.74 +/- 2.41 | -4.25e-02 +/- 1.39e-01 | 5.73e+00 +/- 1.2e-01 | 5.77e+00 +/- 6.4e-02 |
test_objective_compile_dshape_current | -1.42 +/- 1.53 | -5.72e-02 +/- 6.15e-02 | 3.97e+00 +/- 5.0e-02 | 4.03e+00 +/- 3.6e-02 |
test_objective_compute_dshape_current | -1.97 +/- 3.73 | -7.30e-05 +/- 1.39e-04 | 3.64e-03 +/- 4.4e-05 | 3.71e-03 +/- 1.3e-04 |
test_objective_jac_dshape_current | -0.67 +/- 4.78 | -2.76e-04 +/- 1.96e-03 | 4.08e-02 +/- 1.4e-03 | 4.11e-02 +/- 1.3e-03 |
test_perturb_2 | +0.42 +/- 3.47 | +7.51e-02 +/- 6.14e-01 | 1.78e+01 +/- 5.1e-01 | 1.77e+01 +/- 3.5e-01 |
test_proximal_freeb_jac | -0.28 +/- 1.56 | -2.12e-02 +/- 1.17e-01 | 7.51e+00 +/- 7.8e-02 | 7.53e+00 +/- 8.8e-02 |
test_solve_fixed_iter | +0.40 +/- 57.53 | +2.00e-02 +/- 2.88e+00 | 5.03e+00 +/- 2.0e+00 | 5.01e+00 +/- 2.1e+00 |
test_build_transform_fft_midres | -0.76 +/- 5.52 | -4.77e-03 +/- 3.46e-02 | 6.23e-01 +/- 1.1e-02 | 6.28e-01 +/- 3.3e-02 |
test_build_transform_fft_highres | -0.34 +/- 3.28 | -3.46e-03 +/- 3.36e-02 | 1.02e+00 +/- 9.3e-03 | 1.02e+00 +/- 3.2e-02 |
test_equilibrium_init_lowres | +1.50 +/- 3.83 | +5.83e-02 +/- 1.49e-01 | 3.95e+00 +/- 1.5e-01 | 3.89e+00 +/- 3.4e-02 |
test_objective_compile_atf | -0.08 +/- 4.11 | -6.25e-03 +/- 3.25e-01 | 7.90e+00 +/- 2.4e-01 | 7.91e+00 +/- 2.2e-01 |
test_objective_compute_atf | +2.00 +/- 2.81 | +2.10e-04 +/- 2.97e-04 | 1.07e-02 +/- 2.5e-04 | 1.05e-02 +/- 1.5e-04 |
test_objective_jac_atf | +1.18 +/- 2.10 | +2.33e-02 +/- 4.16e-02 | 2.00e+00 +/- 3.0e-02 | 1.98e+00 +/- 2.8e-02 |
test_perturb_1 | +7.72 +/- 3.83 | +9.70e-01 +/- 4.81e-01 | 1.35e+01 +/- 4.2e-01 | 1.26e+01 +/- 2.3e-01 |
test_proximal_jac_atf | +1.08 +/- 0.76 | +8.87e-02 +/- 6.27e-02 | 8.29e+00 +/- 4.7e-02 | 8.20e+00 +/- 4.1e-02 |
test_proximal_freeb_compute | +2.84 +/- 1.08 | +5.27e-03 +/- 2.00e-03 | 1.91e-01 +/- 1.8e-03 | 1.86e-01 +/- 9.6e-04 | |
You might already be aware but fyi: jax-ml/jax#19614 |
If you don't care about jax's native multi-GPU sharding support it should be easy to just vendor our implementation. The former 2 files are on purpose standalone. Only Remove all branches hitting of Also replace |
@dpanici jax batched vmap has been merged to master |
…ESC into dp/jacobian-batched-vmap
@dpanici make separate branch with the implementation using JAX's version, and in this PR implement the one based off of |
…ESC into dp/jacobian-batched-vmap
…ESC into dp/jacobian-batched-vmap
…unked implementation
…ESC into dp/jacobian-batched-vmap
|
||
Parameters | ||
---------- | ||
f: a function that takes elements of the leading dimension of x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not consistent with our docstring format but no problem. Just pointing out.
of functions that act on JAX arrays. | ||
|
||
Parameters | ||
---------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above docstrings are not too important but this one might be checked more, so maybe make it consistent with out doc format? Again, not too important, we can change it in a later PR.
@@ -474,8 +553,6 @@ def jac_scaled_error(self, x, constants=None): | |||
|
|||
if self._deriv_mode == "batched": | |||
J = Derivative(self.compute_scaled_error, mode="fwd")(x, constants) | |||
if self._deriv_mode == "looped": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a completely unrelated comment but I think we should explain this jacobian, jvp, etc better in some dev guide. For example, an individual objectives jac_scaled_error
doesn't use batched_vmap
it is usually jax.jacfwd
. On the other hand, when we wrap the objective and constraints, the jacobian is calculated by corresponding jvp_
method, not jac_
method. Even now it confuses me. What is the case we use jac_
method instead of jvp_
? Good thing to clarify in long-waiting dev-guide
# into vmap, we can make use of more efficient batching rules for | ||
# primitives where only some arguments are batched (e.g., for | ||
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted | ||
# arrays. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commenting for reference
Most of the documentation of our objectives has the same parameters that are inherited from the main `_Objective` class. This PR removes the repeated docstring from each objective and updates the docstring by inheriting which reduces the lines of code and also facilitates the maintenance. For example, when we add `jac_chunk_size` in #1052, we have to copy-paste the docs to every single objective which is tedious. Introduces `collect_docs` function that creates docstring for common parameters and with option to overwrite user can give a custom definition for a parameter without changing the order of the docs Resolves #879
jnp.vectorize
calls to instead usebatched_vectorize
which performs the function vectorization in smaller chunks, which reduces the memory cost of the calculation, at the expense of taking longer the smaller the chunk size is.jac_chunk_size
toObjectiveFunction
and_Objective
to control the above chunk size for thefwd
mode Jacobian calculationNone
, the chunk size is equal todim_x
, so no chunking is doneint
, this is the chunk size to be used."auto"
for theObjectiveFunction
, will use a heuristic for the maximumjac_chunk_size
needed to fit the jacobian calculation on the available device memory, according to the formula:max_jac_chunk_size = (desc_config.get("avail_mem") / estimated_memory_usage - 0.22) / 0.85 * self.dim_x
ObjectiveFunction
jac_chunk_size
is used ifderiv_mode="batched"
, and the_Objective
jac_chunk_size
will be used ifderiv_mode="blocked"
This works well, this is LMN18 equilibrium solve with 1.5 oversampled grid and
maxiter=10
memory trace vs time on GPU, where we get 4x memory decrease with negligible runtime increase:Also, I can do up to an
LMN=20
eqForceBalance
objective with the default double grid oversampling, and with the"auto"
chunk sizing, the jacobian compiles and computes without going OOM on an 80gb GPU (on master this would go OOM).TODO
netket
dim_x
)chunk_size
argument to every Objective classLinearObjective
classes, though technically you could"chunked"
as a deriv_mode toDerivative
(or, just as an argument toDerivative
to be used when"batched"
is used) - > I don't remember what this was exactly, I think we can keep just for Objectiveschunk_size
tojacobian_chunk_size
for Objective kwargTODO Later
Resolves #826