-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the Storåkers hyperelastic foam model (#900)
* Add the hyperelastic foam model (tensortrax) * Update _foam.py * Update _foam.py * Check success of transverse-constraint in `ViewMaterial.plot()` * rename the Foam material to Storakers * Update _storakers.py * Update __init__.py * Update __init__.py * Update _storakers.py * Add the hyperelastic foam model also for JAX * add tests for hyperelastic foam model * Update _storakers.py * Update jax.rst * Update tensortrax.rst * Test the foam model with other parameters which will fail with incompressibility as initial guess in plotting * Update ex08_shear.py * Test value error in `umat.plot()` * Update test_constitution.py * Update _storakers.py
- Loading branch information
Showing
12 changed files
with
245 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,20 @@ | ||
from ._miehe_goektepe_lulei import miehe_goektepe_lulei | ||
from ._mooney_rivlin import mooney_rivlin | ||
from ._storakers import storakers | ||
from ._third_order_deformation import third_order_deformation | ||
from ._yeoh import yeoh | ||
|
||
__all__ = [ | ||
"miehe_goektepe_lulei", | ||
"mooney_rivlin", | ||
"storakers", | ||
"third_order_deformation", | ||
"yeoh", | ||
] | ||
|
||
# default (stable) material parameters | ||
miehe_goektepe_lulei.kwargs = dict(mu=0, N=100, U=0, p=2, q=2) | ||
mooney_rivlin.kwargs = dict(C10=0, C01=0) | ||
storakers.kwargs = dict(mu=[0], alpha=[2], beta=[1]) | ||
third_order_deformation.kwargs = dict(C10=0, C01=0, C11=0, C20=0, C30=0) | ||
yeoh.kwargs = dict(C10=0, C20=0, C30=0) |
36 changes: 36 additions & 0 deletions
36
src/felupe/constitution/jax/models/hyperelastic/_storakers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
This file is part of FElupe. | ||
FElupe is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
FElupe is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with FElupe. If not, see <http://www.gnu.org/licenses/>. | ||
""" | ||
from functools import wraps | ||
|
||
from jax.numpy import array, sqrt | ||
from jax.numpy import sum as asum | ||
from jax.numpy.linalg import eigvalsh | ||
|
||
from ....tensortrax.models.hyperelastic import storakers as storakers_docstring | ||
|
||
|
||
@wraps(storakers_docstring) | ||
def storakers(C, mu, alpha, beta): | ||
λ1, λ2, λ3 = sqrt(eigvalsh(C)) | ||
J = λ1 * λ2 * λ3 | ||
|
||
μ = array(mu) | ||
α = array(alpha) | ||
β = array(beta) | ||
|
||
return asum(2 * μ / α**2 * (λ1**α + λ2**α + λ3**α - 3 + (J ** (-α * β) - 1) / β)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
122 changes: 122 additions & 0 deletions
122
src/felupe/constitution/tensortrax/models/hyperelastic/_storakers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
This file is part of FElupe. | ||
FElupe is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
FElupe is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with FElupe. If not, see <http://www.gnu.org/licenses/>. | ||
""" | ||
|
||
from tensortrax.math import sum as tsum | ||
from tensortrax.math.linalg import det, eigvalsh | ||
|
||
|
||
def storakers(C, mu, alpha, beta): | ||
r"""Strain energy function of the Storåkers isotropic hyperelastic | ||
`foam <https://doi.org/10.1016/0022-5096(86)90033-5>`_ material formulation [1]_. | ||
Parameters | ||
---------- | ||
C : tensortrax.Tensor or jax.Array | ||
Right Cauchy-Green deformation tensor. | ||
mu : list of float | ||
List of moduli. | ||
alpha : list of float | ||
List of stretch exponents. | ||
beta : list of float | ||
List of coefficients for the degree of compressibility. | ||
Notes | ||
----- | ||
The strain energy function is given in Eq. :eq:`psi-foam` | ||
.. math:: | ||
:label: psi-ogden | ||
\psi = \sum_i \frac{2 \mu_i}{\alpha^2_i} \left[ | ||
\hat{\lambda}_1^{\alpha_i} + | ||
\hat{\lambda}_2^{\alpha_i} + | ||
\hat{\lambda}_3^{\alpha_i} - 3 | ||
+ \frac{1}{\beta_i} \left( J^{-\alpha \beta} - 1 \right) | ||
\right] | ||
The sum of the moduli :math:`\mu_i` is equal to the initial shear modulus | ||
:math:`\mu`, see Eq. :eq:`shear-modulus-foam`, | ||
.. math:: | ||
:label: shear-modulus-ogden | ||
\mu = \sum_i \mu_i | ||
and the initial bulk modulus is given in Eq. :eq:`bulk-modulus-foam`. | ||
.. math:: | ||
:label: bulk-modulus-ogden | ||
K = \sum_i 2 \mu_i \left( \frac{1}{3} + \beta_i \right) | ||
Examples | ||
-------- | ||
First, choose the desired automatic differentiation backend | ||
.. pyvista-plot:: | ||
:context: | ||
>>> # import felupe.constitution.jax as mat | ||
>>> import felupe.constitution.tensortrax as mat | ||
and create the hyperelastic material. | ||
.. pyvista-plot:: | ||
:context: | ||
>>> import felupe as fem | ||
>>> | ||
>>> umat = mat.Hyperelastic( | ||
... mat.models.hyperelastic.storakers, | ||
... mu=[4.5 * (1.85 / 2), -4.5 * (-9.2 / 2)], | ||
... alpha=[1.85, -9.2], | ||
... beta=[0.92, 0.92], | ||
... ) | ||
>>> ax = umat.plot( | ||
... ux=fem.math.linsteps([1, 2], 15), | ||
... ps=fem.math.linsteps([1, 1], 15), | ||
... bx=fem.math.linsteps([1, 1], 9), | ||
... ) | ||
.. pyvista-plot:: | ||
:include-source: False | ||
:context: | ||
:force_static: | ||
>>> import pyvista as pv | ||
>>> | ||
>>> fig = ax.get_figure() | ||
>>> chart = pv.ChartMPL(fig) | ||
>>> chart.show() | ||
References | ||
---------- | ||
.. [1] B. Storåkers, "On material representation and constitutive branching in | ||
finite compressible elasticity", Journal of the Mechanics and Physics of Solids, | ||
vol. 34, no. 2. Elsevier BV, pp. 125–145, Jan. 1986. doi: | ||
10.1016/0022-5096(86)90033-5. | ||
""" | ||
|
||
λ2 = eigvalsh(C) | ||
|
||
return tsum( | ||
[ | ||
2 * μ / α**2 * (tsum(λ2 ** (α / 2)) - 3 + (det(C) ** (-α * β / 2) - 1) / β) | ||
for μ, α, β in zip(mu, alpha, beta) | ||
] | ||
) |
Oops, something went wrong.