diff --git a/docs/felupe/constitution/autodiff.rst b/docs/felupe/constitution/autodiff.rst index 66b3d1a8..9e79c6b1 100644 --- a/docs/felupe/constitution/autodiff.rst +++ b/docs/felupe/constitution/autodiff.rst @@ -15,12 +15,6 @@ automatic differentiation. The default backend is based on :mod:`tensortrax` whi with FElupe. For more computationally expensive material formulations, :mod:`jax` may be the preferred option. -.. note:: - JAX uses single-precision (32bit) data types by default. This requires to relax the - tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If required, JAX may be - enforced to use double-precision at startup with - ``jax.config.update("jax_enable_x64", True)``. - It is straightforward to switch between these backends. .. tab:: tensortrax (default) diff --git a/docs/felupe/constitution/autodiff/jax.rst b/docs/felupe/constitution/autodiff/jax.rst index 3c14a8db..e57b0f48 100644 --- a/docs/felupe/constitution/autodiff/jax.rst +++ b/docs/felupe/constitution/autodiff/jax.rst @@ -3,7 +3,30 @@ Materials with Automatic Differentiation (JAX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -This page contains material model formulations with automatic differentiation using :mod:`jax`. +This page contains material model formulations with automatic differentiation using +:mod:`jax`. + +.. note:: + + JAX uses single-precision (32bit) data types by default. This requires to relax the + tolerance of :func:`~felupe.newtonrhapson` to ``tol=1e-4``. If required, JAX may be + enforced to use double-precision at startup with + ``jax.config.update("jax_enable_x64", True)``. + +.. note:: + + The number of local XLA devices available must be greater or equal the number of the + parallel-mapped axis, i.e. the number of quadrature points per cell when used in + :class:`~felupe.constitution.jax.Material` and + :class:`~felupe.constitution.jax.Hyperelastic` along with ``parallel=True``. To use + the multiple cores of a CPU device as multiple local XLA devices, the XLA device + count must be defined at startup. + + .. code-block:: python + + import os + + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4" **Frameworks** diff --git a/src/felupe/constitution/jax/_hyperelastic.py b/src/felupe/constitution/jax/_hyperelastic.py index 984f53f0..f18d83f9 100644 --- a/src/felupe/constitution/jax/_hyperelastic.py +++ b/src/felupe/constitution/jax/_hyperelastic.py @@ -44,8 +44,9 @@ class Hyperelastic(Material): jit : bool, optional A flag to invoke just-in-time compilation (default is True). parallel : bool, optional - A flag to invoke threaded strain energy density function evaluations (default - is False). Not implemented. + A flag to invoke parallel strain energy density function evaluations (default + is False). If True, the quadrature points are executed in parallel. The number + of devices must be greater or equal the number of quadrature points per cell. **kwargs : dict, optional Optional keyword-arguments for the strain energy density function. @@ -170,7 +171,8 @@ def __init__(self, fun, nstatevars=0, jit=True, parallel=False, **kwargs): methods = [jax.vmap, jax.vmap] if parallel: - methods[0] = jax.pmap + methods[0] = jax.pmap # apply on quadrature-points + jit = False # pmap uses jit self._grad = vmap2( jax.grad(self.fun, has_aux=has_aux), diff --git a/src/felupe/constitution/jax/_material.py b/src/felupe/constitution/jax/_material.py index 8b6ba755..44317b3e 100644 --- a/src/felupe/constitution/jax/_material.py +++ b/src/felupe/constitution/jax/_material.py @@ -44,8 +44,9 @@ class Material(MaterialDefault): jit : bool, optional A flag to invoke just-in-time compilation (default is True). parallel : bool, optional - A flag to invoke threaded function evaluations (defaultnis False). Not - implemented. + A flag to invoke parallel function evaluations (default is False). If True, the + quadrature points are executed in parallel. The number of devices must be + greater or equal the number of quadrature points per cell. jacobian : callable or None, optional A callable for the Jacobian. Default is None, where :func:`jax.jacobian` is used. This may be used to switch to forward-mode differentian @@ -107,7 +108,7 @@ def viscoelastic(F, Cin, mu, eta, dtime): S = mu * dev(Cu @ jnp.linalg.inv(Ci)) @ jnp.linalg.inv(C) # first Piola-Kirchhoff stress tensor and state variable - i, j = triu_indices(3) + i, j = jnp.triu_indices(3) to_triu = lambda C: C[i, j] return F @ S, to_triu(Ci) @@ -187,7 +188,8 @@ def __init__( methods = [jax.vmap, jax.vmap] if parallel: - methods[0] = jax.pmap + methods[0] = jax.pmap # apply on quadrature-points + jit = False # pmap uses jit self._grad = vmap2( self.fun, in_axes=in_axes, out_axes=out_axes_grad, methods=methods