-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add jac_chunk_size
keyword argument to ObjectiveFunction
to reduce memory usage of forward mode Jacobian calculation
#1052
Changes from all commits
e1ffec8
5aef59c
67844a7
60e142c
c27bf9b
f3b6683
517a519
6b786d0
135adce
0366043
131ae78
8cfcf8d
8f0515c
c3ba7ac
ce5742f
1f32705
02ba37e
feb006e
eda29ee
95e25cc
7aca7ad
fcdf492
2239746
74e9d0c
d9b34d1
9ad2688
41171c8
4dc0c98
4b09c02
3e38d25
0d59bd9
5d15de3
9aca31f
6f2bbdb
f3b5f25
e4e27a2
51e3e40
da359d5
fbebfb8
fc8c98c
98c93d5
0787b84
bc1d498
1fe12df
5b245f4
f088a06
7c804a1
480a4a7
15a7f48
ea74b24
83c39b4
83e96c0
2df20b5
da92365
e433704
5ce88c6
e9efd27
60e1c74
4137d36
bd34534
8b2645d
1f98694
1c1fab4
3e8e364
b0b78fe
be57bbc
cd865d0
e0e513f
a6e29d6
6988660
f4fbbcf
9100e7c
15d95f1
3e99510
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,322 @@ | ||
"""Utility functions for the ``batched_vectorize`` function.""" | ||
|
||
import functools | ||
from typing import Callable, Optional | ||
|
||
from desc.backend import jax, jnp | ||
|
||
if jax.__version_info__ >= (0, 4, 16): | ||
from jax.extend import linear_util as lu | ||
else: | ||
from jax import linear_util as lu | ||
|
||
from jax._src.numpy.vectorize import ( | ||
_apply_excluded, | ||
_check_output_dims, | ||
_parse_gufunc_signature, | ||
_parse_input_dimensions, | ||
) | ||
|
||
# The following section of this code is derived from the NetKet project | ||
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ | ||
# netket/jax/_chunk_utils.py | ||
# | ||
# The original copyright notice is as follows | ||
# Copyright 2021 The NetKet Authors - All rights reserved. | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
|
||
|
||
def _treeify(f): | ||
def _f(x, *args, **kwargs): | ||
return jax.tree_util.tree_map(lambda y: f(y, *args, **kwargs), x) | ||
|
||
return _f | ||
|
||
|
||
@_treeify | ||
def _unchunk(x): | ||
return x.reshape((-1,) + x.shape[2:]) | ||
|
||
|
||
@_treeify | ||
def _chunk(x, chunk_size=None): | ||
# chunk_size=None -> add just a dummy chunk dimension, | ||
# same as np.expand_dims(x, 0) | ||
if x.ndim == 0: | ||
raise ValueError("x cannot be chunked as it has 0 dimensions.") | ||
n = x.shape[0] | ||
if chunk_size is None: | ||
chunk_size = n | ||
|
||
n_chunks, residual = divmod(n, chunk_size) | ||
if residual != 0: | ||
raise ValueError( | ||
"The first dimension of x must be divisible by chunk_size." | ||
+ f"\n Got x.shape={x.shape} but chunk_size={chunk_size}." | ||
) | ||
return x.reshape((n_chunks, chunk_size) + x.shape[1:]) | ||
|
||
|
||
#### | ||
|
||
# The following section of this code is derived from the NetKet project | ||
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ | ||
# netket/jax/_scanmap.py | ||
|
||
|
||
def scan_append(f, x): | ||
"""Evaluate f element by element in x while appending the results. | ||
|
||
Parameters | ||
---------- | ||
f: a function that takes elements of the leading dimension of x | ||
x: a pytree where each leaf array has the same leading dimension | ||
|
||
Returns | ||
------- | ||
a (pytree of) array(s) with leading dimension same as x, | ||
containing the evaluation of f at each element in x | ||
""" | ||
carry_init = True | ||
|
||
def f_(carry, x): | ||
return False, f(x) | ||
|
||
_, res_append = jax.lax.scan(f_, carry_init, x, unroll=1) | ||
return res_append | ||
|
||
|
||
# TODO in_axes a la vmap? | ||
def _scanmap(fun, scan_fun, argnums=0): | ||
"""A helper function to wrap f with a scan_fun.""" | ||
|
||
def f_(*args, **kwargs): | ||
f = lu.wrap_init(fun, kwargs) | ||
f_partial, dyn_args = jax.api_util.argnums_partial( | ||
f, argnums, args, require_static_args_hashable=False | ||
) | ||
return scan_fun(lambda x: f_partial.call_wrapped(*x), dyn_args) | ||
|
||
return f_ | ||
|
||
|
||
# The following section of this code is derived from the NetKet project | ||
# https://github.com/netket/netket/blob/9881c9fb217a2ac4dc9274a054bf6e6a2993c519/ | ||
# netket/jax/_vmap_chunked.py | ||
|
||
|
||
def _eval_fun_in_chunks(vmapped_fun, chunk_size, argnums, *args, **kwargs): | ||
n_elements = jax.tree_util.tree_leaves(args[argnums[0]])[0].shape[0] | ||
n_chunks, n_rest = divmod(n_elements, chunk_size) | ||
|
||
if n_chunks == 0 or chunk_size >= n_elements: | ||
y = vmapped_fun(*args, **kwargs) | ||
else: | ||
# split inputs | ||
def _get_chunks(x): | ||
x_chunks = jax.tree_util.tree_map( | ||
lambda x_: x_[: n_elements - n_rest, ...], x | ||
) | ||
x_chunks = _chunk(x_chunks, chunk_size) | ||
return x_chunks | ||
|
||
def _get_rest(x): | ||
x_rest = jax.tree_util.tree_map( | ||
lambda x_: x_[n_elements - n_rest :, ...], x | ||
) | ||
return x_rest | ||
|
||
args_chunks = [ | ||
_get_chunks(a) if i in argnums else a for i, a in enumerate(args) | ||
] | ||
args_rest = [_get_rest(a) if i in argnums else a for i, a in enumerate(args)] | ||
|
||
y_chunks = _unchunk( | ||
_scanmap(vmapped_fun, scan_append, argnums)(*args_chunks, **kwargs) | ||
) | ||
|
||
if n_rest == 0: | ||
y = y_chunks | ||
else: | ||
y_rest = vmapped_fun(*args_rest, **kwargs) | ||
y = jax.tree_util.tree_map( | ||
lambda y1, y2: jnp.concatenate((y1, y2)), y_chunks, y_rest | ||
) | ||
return y | ||
|
||
|
||
def _chunk_vmapped_function( | ||
vmapped_fun: Callable, | ||
chunk_size: Optional[int], | ||
argnums=0, | ||
) -> Callable: | ||
"""Takes a vmapped function and computes it in chunks.""" | ||
if chunk_size is None: | ||
return vmapped_fun | ||
|
||
if isinstance(argnums, int): | ||
argnums = (argnums,) | ||
return functools.partial(_eval_fun_in_chunks, vmapped_fun, chunk_size, argnums) | ||
|
||
|
||
def _parse_in_axes(in_axes): | ||
if isinstance(in_axes, int): | ||
in_axes = (in_axes,) | ||
|
||
if not set(in_axes).issubset((0, None)): | ||
raise NotImplementedError("Only in_axes 0/None are currently supported") | ||
|
||
argnums = tuple( | ||
map(lambda ix: ix[0], filter(lambda ix: ix[1] is not None, enumerate(in_axes))) | ||
) | ||
return in_axes, argnums | ||
|
||
|
||
def vmap_chunked( | ||
f: Callable, | ||
in_axes=0, | ||
*, | ||
chunk_size: Optional[int], | ||
) -> Callable: | ||
"""Behaves like jax.vmap but uses scan to chunk the computations in smaller chunks. | ||
|
||
Parameters | ||
---------- | ||
f: The function to be vectorised. | ||
in_axes: The axes that should be scanned along. Only supports `0` or `None` | ||
chunk_size: The maximum size of the chunks to be used. If it is `None`, | ||
chunking is disabled | ||
|
||
|
||
Returns | ||
------- | ||
f: A vectorised and chunked function | ||
""" | ||
in_axes, argnums = _parse_in_axes(in_axes) | ||
vmapped_fun = jax.vmap(f, in_axes=in_axes) | ||
return _chunk_vmapped_function(vmapped_fun, chunk_size, argnums) | ||
|
||
|
||
def batched_vectorize(pyfunc, *, excluded=frozenset(), signature=None, chunk_size=None): | ||
"""Define a vectorized function with broadcasting and batching. | ||
|
||
:func:`vectorize` is a convenience wrapper for defining vectorized | ||
functions with broadcasting, in the style of NumPy's | ||
`generalized universal functions | ||
<https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html>`_. | ||
It allows for defining functions that are automatically repeated across | ||
any leading dimensions, without the implementation of the function needing to | ||
be concerned about how to handle higher dimensional inputs. | ||
|
||
:func:`jax.numpy.vectorize` has the same interface as | ||
:class:`numpy.vectorize`, but it is syntactic sugar for an auto-batching | ||
transformation (:func:`vmap`) rather than a Python loop. This should be | ||
considerably more efficient, but the implementation must be written in terms | ||
of functions that act on JAX arrays. | ||
|
||
Parameters | ||
---------- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above docstrings are not too important but this one might be checked more, so maybe make it consistent with out doc format? Again, not too important, we can change it in a later PR. |
||
pyfunc: callable,function to vectorize. | ||
excluded: optional set of integers representing positional arguments for | ||
which the function will not be vectorized. These will be passed directly | ||
to ``pyfunc`` unmodified. | ||
signature: optional generalized universal function signature, e.g., | ||
``(m,n),(n)->(m)`` for vectorized matrix-vector multiplication. If | ||
provided, ``pyfunc`` will be called with (and expected to return) arrays | ||
with shapes given by the size of corresponding core dimensions. By | ||
default, pyfunc is assumed to take scalars arrays as input and output. | ||
chunk_size: the size of the batches to pass to vmap. If None, defaults to | ||
the largest possible chunk_size (like the default behavior of ``vectorize11) | ||
|
||
Returns | ||
------- | ||
Batch-vectorized version of the given function. | ||
|
||
""" | ||
if any(not isinstance(exclude, (str, int)) for exclude in excluded): | ||
raise TypeError( | ||
"jax.numpy.vectorize can only exclude integer or string arguments, " | ||
"but excluded={!r}".format(excluded) | ||
) | ||
if any(isinstance(e, int) and e < 0 for e in excluded): | ||
raise ValueError(f"excluded={excluded!r} contains negative numbers") | ||
|
||
@functools.wraps(pyfunc) | ||
def wrapped(*args, **kwargs): | ||
error_context = ( | ||
"on vectorized function with excluded={!r} and " | ||
"signature={!r}".format(excluded, signature) | ||
) | ||
excluded_func, args, kwargs = _apply_excluded(pyfunc, excluded, args, kwargs) | ||
|
||
if signature is not None: | ||
input_core_dims, output_core_dims = _parse_gufunc_signature(signature) | ||
else: | ||
input_core_dims = [()] * len(args) | ||
output_core_dims = None | ||
|
||
none_args = {i for i, arg in enumerate(args) if arg is None} | ||
if any(none_args): | ||
if any(input_core_dims[i] != () for i in none_args): | ||
raise ValueError( | ||
f"Cannot pass None at locations {none_args} with {signature=}" | ||
) | ||
excluded_func, args, _ = _apply_excluded(excluded_func, none_args, args, {}) | ||
input_core_dims = [ | ||
dim for i, dim in enumerate(input_core_dims) if i not in none_args | ||
] | ||
|
||
args = tuple(map(jnp.asarray, args)) | ||
|
||
broadcast_shape, dim_sizes = _parse_input_dimensions( | ||
args, input_core_dims, error_context | ||
) | ||
|
||
checked_func = _check_output_dims( | ||
excluded_func, dim_sizes, output_core_dims, error_context | ||
) | ||
|
||
# Rather than broadcasting all arguments to full broadcast shapes, prefer | ||
# expanding dimensions using vmap. By pushing broadcasting | ||
# into vmap, we can make use of more efficient batching rules for | ||
# primitives where only some arguments are batched (e.g., for | ||
# lax_linalg.triangular_solve), and avoid instantiating large broadcasted | ||
# arrays. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. commenting for reference |
||
|
||
squeezed_args = [] | ||
rev_filled_shapes = [] | ||
|
||
for arg, core_dims in zip(args, input_core_dims): | ||
noncore_shape = arg.shape[: arg.ndim - len(core_dims)] | ||
|
||
pad_ndim = len(broadcast_shape) - len(noncore_shape) | ||
filled_shape = pad_ndim * (1,) + noncore_shape | ||
rev_filled_shapes.append(filled_shape[::-1]) | ||
|
||
squeeze_indices = tuple( | ||
i for i, size in enumerate(noncore_shape) if size == 1 | ||
) | ||
squeezed_arg = jnp.squeeze(arg, axis=squeeze_indices) | ||
squeezed_args.append(squeezed_arg) | ||
|
||
vectorized_func = checked_func | ||
dims_to_expand = [] | ||
for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)): | ||
in_axes = tuple(None if size == 1 else 0 for size in axis_sizes) | ||
if all(axis is None for axis in in_axes): | ||
dims_to_expand.append(len(broadcast_shape) - 1 - negdim) | ||
else: | ||
# change the vmap here to chunked_vmap | ||
vectorized_func = vmap_chunked( | ||
vectorized_func, in_axes, chunk_size=chunk_size | ||
) | ||
result = vectorized_func(*squeezed_args) | ||
|
||
if not dims_to_expand: | ||
return result | ||
elif isinstance(result, tuple): | ||
return tuple(jnp.expand_dims(r, axis=dims_to_expand) for r in result) | ||
else: | ||
return jnp.expand_dims(result, axis=dims_to_expand) | ||
|
||
return wrapped |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not consistent with our docstring format but no problem. Just pointing out.