Skip to content

Commit

Permalink
Enhance the docs on parallel=True with JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
adtzlr committed Nov 10, 2024
1 parent 0d5bdfb commit 9e5fdc9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 14 deletions.
6 changes: 0 additions & 6 deletions docs/felupe/constitution/autodiff.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ automatic differentiation. The default backend is based on :mod:`tensortrax` whi
with FElupe. For more computationally expensive material formulations, :mod:`jax` may
be the preferred option.

.. note::
JAX uses single-precision (32bit) data types by default. This requires to relax the
tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If required, JAX may be
enforced to use double-precision at startup with
``jax.config.update("jax_enable_x64", True)``.

It is straightforward to switch between these backends.

.. tab:: tensortrax (default)
Expand Down
25 changes: 24 additions & 1 deletion docs/felupe/constitution/autodiff/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,30 @@
Materials with Automatic Differentiation (JAX)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This page contains material model formulations with automatic differentiation using :mod:`jax`.
This page contains material model formulations with automatic differentiation using
:mod:`jax`.

.. note::

JAX uses single-precision (32bit) data types by default. This requires to relax the
tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If required, JAX may be
enforced to use double-precision at startup with
``jax.config.update("jax_enable_x64", True)``.

.. note::

The number of local XLA devices available must be greater or equal the number of the
parallel-mapped axis, i.e. the number of quadrature points per cell when used in
:class:`~felupe.constitution.jax.Material` and
:class:`~felupe.constitution.jax.Hyperelastic` along with ``parallel=True``. To use
the multiple cores of a CPU device as multiple local XLA devices, the XLA device
count must be defined at startup.

.. code-block:: python
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
**Frameworks**

Expand Down
8 changes: 5 additions & 3 deletions src/felupe/constitution/jax/_hyperelastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class Hyperelastic(Material):
jit : bool, optional
A flag to invoke just-in-time compilation (default is True).
parallel : bool, optional
A flag to invoke threaded strain energy density function evaluations (default
is False). Not implemented.
A flag to invoke parallel strain energy density function evaluations (default
is False). If True, the quadrature points are executed in parallel. The number
of devices must be greater or equal the number of quadrature points per cell.
**kwargs : dict, optional
Optional keyword-arguments for the strain energy density function.
Expand Down Expand Up @@ -170,7 +171,8 @@ def __init__(self, fun, nstatevars=0, jit=True, parallel=False, **kwargs):

methods = [jax.vmap, jax.vmap]
if parallel:
methods[0] = jax.pmap
methods[0] = jax.pmap # apply on quadrature-points
jit = False # pmap uses jit

self._grad = vmap2(
jax.grad(self.fun, has_aux=has_aux),
Expand Down
10 changes: 6 additions & 4 deletions src/felupe/constitution/jax/_material.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class Material(MaterialDefault):
jit : bool, optional
A flag to invoke just-in-time compilation (default is True).
parallel : bool, optional
A flag to invoke threaded function evaluations (defaultnis False). Not
implemented.
A flag to invoke parallel function evaluations (default is False). If True, the
quadrature points are executed in parallel. The number of devices must be
greater or equal the number of quadrature points per cell.
jacobian : callable or None, optional
A callable for the Jacobian. Default is None, where :func:`jax.jacobian` is
used. This may be used to switch to forward-mode differentian
Expand Down Expand Up @@ -107,7 +108,7 @@ def viscoelastic(F, Cin, mu, eta, dtime):
S = mu * dev(Cu @ jnp.linalg.inv(Ci)) @ jnp.linalg.inv(C)
# first Piola-Kirchhoff stress tensor and state variable
i, j = triu_indices(3)
i, j = jnp.triu_indices(3)
to_triu = lambda C: C[i, j]
return F @ S, to_triu(Ci)
Expand Down Expand Up @@ -187,7 +188,8 @@ def __init__(

methods = [jax.vmap, jax.vmap]
if parallel:
methods[0] = jax.pmap
methods[0] = jax.pmap # apply on quadrature-points
jit = False # pmap uses jit

self._grad = vmap2(
self.fun, in_axes=in_axes, out_axes=out_axes_grad, methods=methods
Expand Down

0 comments on commit 9e5fdc9

Please sign in to comment.