Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jac_chunk_size keyword argument to ObjectiveFunction to reduce memory usage of forward mode Jacobian calculation #1052

Merged
merged 74 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 72 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
e1ffec8
add netket to use vmap_chunked, change the vectorize calls in objecti…
dpanici Jun 12, 2024
5aef59c
add chunk_size arg
dpanici Jun 12, 2024
67844a7
add chunk_size to objectivefunction as well
dpanici Jun 12, 2024
60e142c
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Jun 13, 2024
c27bf9b
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Jun 14, 2024
f3b6683
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Jun 14, 2024
517a519
set default chunk size to None to disable chunking by default:
dpanici Jun 14, 2024
6b786d0
fix default chunk size in backend
dpanici Jun 14, 2024
135adce
add a comment
dpanici Jun 24, 2024
0366043
add chunk_size arg to solve continuation
dpanici Jul 2, 2024
131ae78
fix error
dpanici Jul 2, 2024
8cfcf8d
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Jul 3, 2024
8f0515c
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Jul 15, 2024
c3ba7ac
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Jul 23, 2024
ce5742f
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Aug 16, 2024
1f32705
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Aug 22, 2024
02ba37e
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Aug 22, 2024
feb006e
remove netket dependence, vendor only needed parts from their vmap_ch…
dpanici Aug 22, 2024
eda29ee
add chunk_size arg to every Obective
dpanici Aug 22, 2024
95e25cc
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Aug 22, 2024
7aca7ad
put chunk size in a few tests
dpanici Aug 22, 2024
fcdf492
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Aug 22, 2024
2239746
add chunk size arg to blocked test
dpanici Aug 22, 2024
74e9d0c
add info about chunk_size to docs
dpanici Aug 26, 2024
d9b34d1
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Aug 26, 2024
9ad2688
change attribution statements
dpanici Aug 26, 2024
41171c8
fix typo in attribution statements
dpanici Aug 26, 2024
4dc0c98
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Aug 27, 2024
4b09c02
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Aug 28, 2024
3e38d25
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 3, 2024
0d59bd9
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 6, 2024
5d15de3
add warnings
dpanici Sep 8, 2024
9aca31f
set default to be dim_x/4 for chunk_size
dpanici Sep 8, 2024
6f2bbdb
fix warning causing test fail
dpanici Sep 9, 2024
f3b5f25
change chunk_size to jac_chunk_size, change docstring
dpanici Sep 9, 2024
e4e27a2
remove unused functions
dpanici Sep 9, 2024
51e3e40
simplify chunking fxn
dpanici Sep 9, 2024
da359d5
place batcehd_vectorize utils in own file
dpanici Sep 9, 2024
fbebfb8
add file
dpanici Sep 9, 2024
fc8c98c
remove jac_looped, deprecate 'looped' deriv_mode
dpanici Sep 9, 2024
98c93d5
re-implement scan_append, need to be more careful in replacing it
dpanici Sep 9, 2024
0787b84
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 10, 2024
bc1d498
simplify function slightly
dpanici Sep 10, 2024
1fe12df
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Sep 10, 2024
5b245f4
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 10, 2024
f088a06
remove print statements
dpanici Sep 11, 2024
7c804a1
simplify scan further
dpanici Sep 11, 2024
480a4a7
add doc for reducing memory usage
dpanici Sep 11, 2024
15a7f48
change constraint wrapper to use batched
dpanici Sep 11, 2024
ea74b24
address some comments
dpanici Sep 12, 2024
83c39b4
update chunk size to auto as default
dpanici Sep 12, 2024
83e96c0
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 12, 2024
2df20b5
fix test
dpanici Sep 14, 2024
da92365
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Sep 14, 2024
e433704
remove errant prints
dpanici Sep 14, 2024
5ce88c6
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 18, 2024
e9efd27
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 19, 2024
60e1c74
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 20, 2024
4137d36
change auto to base chunk size off of device memiry and jacobian size
dpanici Sep 21, 2024
bd34534
update test
dpanici Sep 21, 2024
8b2645d
update changelog
dpanici Sep 22, 2024
1f98694
Merge branch 'dp/jacobian-batched-vmap' of github.com:PlasmaControl/D…
dpanici Sep 22, 2024
1c1fab4
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 24, 2024
3e8e364
remove auto option for _Objective chunk size
dpanici Sep 24, 2024
b0b78fe
rename utils file
dpanici Sep 24, 2024
be57bbc
resolve comments
dpanici Sep 24, 2024
cd865d0
resolve further comment
dpanici Sep 24, 2024
e0e513f
change warnings to errors
dpanici Sep 24, 2024
a6e29d6
add disclaimer to ObjectiveFunction for auto chunk size on HPC CPU
dpanici Sep 24, 2024
6988660
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 24, 2024
f4fbbcf
Merge branch 'master' into dp/jacobian-batched-vmap
YigitElma Sep 25, 2024
9100e7c
update changelog
dpanici Sep 25, 2024
15d95f1
correct changelog
dpanici Sep 26, 2024
3e99510
Merge branch 'master' into dp/jacobian-batched-vmap
dpanici Sep 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ New Features
- Changes ``ToroidalFlux`` objective to default using a 1D loop integral of the vector potential
to compute the toroidal flux when possible, as opposed to a 2D surface integral of the magnetic field dotted with ``n_zeta``.
- Allow specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file
- Add ``jac_chunk_size`` to ``ObjectiveFunction`` and ``_Objective`` to control the above chunk size for the ``fwd`` mode Jacobian calculation
- if ``None``, the chunk size is equal to ``dim_x``, so no chunking is done
- if an ``int``, this is the chunk size to be used.
- if ``"auto"`` for the ``ObjectiveFunction``, will use a heuristic for the minimum ``jac_chunk_size`` needed to fit the jacobian calculation on the available device memory, according to the formula: ``min_jac_chunk_size = (desc_config.get("avail_mem") / estimated_memory_usage - 0.22) / 0.85 * self.dim_x`` with ``estimated_memory_usage = 2.4e-7 * self.dim_f * self.dim_x + 1``
dpanici marked this conversation as resolved.
Show resolved Hide resolved
- the ``ObjectiveFunction`` ``jac_chunk_size`` is used if ``deriv_mode="batched"``, and the ``_Objective`` ``jac_chunk_size`` will be used if ``deriv_mode="blocked"``

Bug Fixes

- Fixes bugs that occur when saving asymmetric equilibria as wout files
- Fixes bug that occurs when using ``VMECIO.plot_vmec_comparison`` to compare to an asymmetric wout file

Deprecations

- ``deriv_mode="looped"`` in ``ObjectiveFunction`` is deprecated and will be removed in a future version in favored of ``deriv_mode="batched"`` with ``jac_chunk_size=1``,


v0.12.1
Expand Down
322 changes: 322 additions & 0 deletions desc/batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
"""Utility functions for the ``batched_vectorize`` function."""

import functools
from typing import Callable, Optional

from desc.backend import jax, jnp

if jax.__version_info__ >= (0, 4, 16):
from jax.extend import linear_util as lu
else:
from jax import linear_util as lu

Check warning on line 11 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L11

Added line #L11 was not covered by tests

from jax._src.numpy.vectorize import (
_apply_excluded,
_check_output_dims,
_parse_gufunc_signature,
_parse_input_dimensions,
)

# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_chunk_utils.py
#
# The original copyright notice is as follows
# Copyright 2021 The NetKet Authors - All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");


def _treeify(f):
def _f(x, *args, **kwargs):
return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x)

return _f


@_treeify
def _unchunk(x):
return x.reshape((-1,) + x.shape[2:])


@_treeify
def _chunk(x, chunk_size=None):
# chunk_size=None -> add just a dummy chunk dimension,
# same as np.expand_dims(x, 0)
if x.ndim == 0:
raise ValueError("x cannot be chunked as it has 0 dimensions.")

Check warning on line 46 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L46

Added line #L46 was not covered by tests
n = x.shape[0]
if chunk_size is None:
chunk_size = n

Check warning on line 49 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L49

Added line #L49 was not covered by tests

n_chunks, residual = divmod(n, chunk_size)
if residual != 0:
raise ValueError(

Check warning on line 53 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L53

Added line #L53 was not covered by tests
"The first dimension of x must be divisible by chunk_size."
+ f"\n Got x.shape={x.shape} but chunk_size={chunk_size}."
)
return x.reshape((n_chunks, chunk_size) + x.shape[1:])


####

# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_scanmap.py


def scan_append(f, x):
"""Evaluate f element by element in x while appending the results.

Parameters
----------
f: a function that takes elements of the leading dimension of x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with our docstring format but no problem. Just pointing out.

x: a pytree where each leaf array has the same leading dimension

Returns
-------
a (pytree of) array(s) with leading dimension same as x,
containing the evaluation of f at each element in x
"""
carry_init = True

def f_(carry, x):
return False, f(x)

_, res_append = jax.lax.scan(f_, carry_init, x, unroll=1)
return res_append


# TODO in_axes a la vmap?
def _scanmap(fun, scan_fun, argnums=0):
"""A helper function to wrap f with a scan_fun."""

def f_(*args, **kwargs):
f = lu.wrap_init(fun, kwargs)
f_partial, dyn_args = jax.api_util.argnums_partial(
f, argnums, args, require_static_args_hashable=False
)
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args)

return f_


# The following section of this code is derived from the NetKet project
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/
# netket/jax/_vmap_chunked.py


def _eval_fun_in_chunks(vmapped_fun, chunk_size, argnums, *args, **kwargs):
n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0]
n_chunks, n_rest = divmod(n_elements, chunk_size)

if n_chunks == 0 or chunk_size >= n_elements:
y = vmapped_fun(*args, **kwargs)
else:
# split inputs
def _get_chunks(x):
x_chunks = jax.tree_util.tree_map(
lambda x_: x_[: n_elements - n_rest, ...], x
)
x_chunks = _chunk(x_chunks, chunk_size)
return x_chunks

def _get_rest(x):
x_rest = jax.tree_util.tree_map(
lambda x_: x_[n_elements - n_rest :, ...], x
)
return x_rest

args_chunks = [
_get_chunks(a) if i in argnums else a for i, a in enumerate(args)
]
args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)]

y_chunks = _unchunk(
_scanmap(vmapped_fun, scan_append, argnums)(*args_chunks, **kwargs)
)

if n_rest == 0:
y = y_chunks
else:
y_rest = vmapped_fun(*args_rest, **kwargs)
y = jax.tree_util.tree_map(
lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest
)
return y


def _chunk_vmapped_function(
vmapped_fun: Callable,
chunk_size: Optional[int],
argnums=0,
) -> Callable:
"""Takes a vmapped function and computes it in chunks."""
if chunk_size is None:
return vmapped_fun

Check warning on line 155 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L155

Added line #L155 was not covered by tests

if isinstance(argnums, int):
argnums = (argnums,)

Check warning on line 158 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L158

Added line #L158 was not covered by tests
return functools.partial(_eval_fun_in_chunks, vmapped_fun, chunk_size, argnums)


def _parse_in_axes(in_axes):
if isinstance(in_axes, int):
in_axes = (in_axes,)

Check warning on line 164 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L164

Added line #L164 was not covered by tests

if not set(in_axes).issubset((0, None)):
raise NotImplementedError("Only in_axes 0/None are currently supported")

Check warning on line 167 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L167

Added line #L167 was not covered by tests

argnums = tuple(
map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes)))
)
return in_axes, argnums


def vmap_chunked(
f: Callable,
in_axes=0,
*,
chunk_size: Optional[int],
) -> Callable:
"""Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks.

Parameters
----------
f: The function to be vectorised.
in_axes: The axes that should be scanned along. Only supports `0` or `None`
chunk_size: The maximum size of the chunks to be used. If it is `None`,
chunking is disabled


Returns
-------
f: A vectorised and chunked function
"""
in_axes, argnums = _parse_in_axes(in_axes)
vmapped_fun = jax.vmap(f, in_axes=in_axes)
return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums)


def batched_vectorize(pyfunc, *, excluded=frozenset(), signature=None, chunk_size=None):
"""Define a vectorized function with broadcasting and batching.

:func:`vectorize` is a convenience wrapper for defining vectorized
functions with broadcasting, in the style of NumPy's
`generalized universal functions
<https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_.
It allows for defining functions that are automatically repeated across
any leading dimensions, without the implementation of the function needing to
be concerned about how to handle higher dimensional inputs.

:func:`jax.numpy.vectorize` has the same interface as
:class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching
transformation (:func:`vmap`) rather than a Python loop. This should be
considerably more efficient, but the implementation must be written in terms
of functions that act on JAX arrays.

Parameters
----------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above docstrings are not too important but this one might be checked more, so maybe make it consistent with out doc format? Again, not too important, we can change it in a later PR.

pyfunc: function to vectorize.
excluded: optional set of integers representing positional arguments for
which the function will not be vectorized. These will be passed directly
to ``pyfunc`` unmodified.
signature: optional generalized universal function signature, e.g.,
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If
provided, ``pyfunc`` will be called with (and expected to return) arrays
with shapes given by the size of corresponding core dimensions. By
default, pyfunc is assumed to take scalars arrays as input and output.
chunk_size: the size of the batches to pass to vmap. If None, defaults to
the largest possible chunk_size (like the default behavior of ``vectorize11)

Returns
-------
Batch-vectorized version of the given function.

"""
if any(not isinstance(exclude, (str, int)) for exclude in excluded):
raise TypeError(

Check warning on line 237 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L237

Added line #L237 was not covered by tests
"jax.numpy.vectorize can only exclude integer or string arguments, "
"but excluded={!r}".format(excluded)
)
if any(isinstance(e, int) and e < 0 for e in excluded):
raise ValueError(f"excluded={excluded!r} contains negative numbers")

Check warning on line 242 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L242

Added line #L242 was not covered by tests

@functools.wraps(pyfunc)
def wrapped(*args, **kwargs):
error_context = (
"on vectorized function with excluded={!r} and "
"signature={!r}".format(excluded, signature)
)
excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs)

if signature is not None:
input_core_dims, output_core_dims = _parse_gufunc_signature(signature)
else:
input_core_dims = [()] * len(args)
output_core_dims = None

Check warning on line 256 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L255-L256

Added lines #L255 - L256 were not covered by tests

none_args = {i for i, arg in enumerate(args) if arg is None}
if any(none_args):
if any(input_core_dims[i] != () for i in none_args):
raise ValueError(

Check warning on line 261 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L260-L261

Added lines #L260 - L261 were not covered by tests
f"Cannot pass None at locations {none_args} with {signature=}"
)
excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {})
input_core_dims = [

Check warning on line 265 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L264-L265

Added lines #L264 - L265 were not covered by tests
dim for i, dim in enumerate(input_core_dims) if i not in none_args
]

args = tuple(map(jnp.asarray, args))

broadcast_shape, dim_sizes = _parse_input_dimensions(
args, input_core_dims, error_context
)

checked_func = _check_output_dims(
excluded_func, dim_sizes, output_core_dims, error_context
)

# Rather than broadcasting all arguments to full broadcast shapes, prefer
# expanding dimensions using vmap. By pushing broadcasting
# into vmap, we can make use of more efficient batching rules for
# primitives where only some arguments are batched (e.g., for
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted
# arrays.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commenting for reference


squeezed_args = []
rev_filled_shapes = []

for arg, core_dims in zip(args, input_core_dims):
noncore_shape = arg.shape[: arg.ndim - len(core_dims)]

pad_ndim = len(broadcast_shape) - len(noncore_shape)
filled_shape = pad_ndim * (1,) + noncore_shape
rev_filled_shapes.append(filled_shape[::-1])

squeeze_indices = tuple(
i for i, size in enumerate(noncore_shape) if size == 1
)
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices)
squeezed_args.append(squeezed_arg)

vectorized_func = checked_func
dims_to_expand = []
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)
if all(axis is None for axis in in_axes):
dims_to_expand.append(len(broadcast_shape) - 1 - negdim)
else:
# change the vmap here to chunked_vmap
vectorized_func = vmap_chunked(
vectorized_func, in_axes, chunk_size=chunk_size
)
result = vectorized_func(*squeezed_args)

if not dims_to_expand:
return result
elif isinstance(result, tuple):
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result)

Check warning on line 318 in desc/batching.py

View check run for this annotation

Codecov / codecov/patch

desc/batching.py#L318

Added line #L318 was not covered by tests
else:
return jnp.expand_dims(result, axis=dims_to_expand)

return wrapped
Loading