forked from PredictiveIntelligenceLab/jaxpi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
jaxpi example unsteady ns over cylinder
- Loading branch information
sifanw
committed
Aug 6, 2023
1 parent
bb42483
commit 2eac62d
Showing
15 changed files
with
1,271 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
[# Navier–Stokes flow around a cylinder | ||
|
||
## Problem Set-up | ||
|
||
We consider a fluid with a density of $\rho=1.0$ and describe its behavior using the time-dependent incompressible Navier-Stokes equations | ||
$$\begin{aligned} | ||
\mathbf{u}_t + \mathbf{u} \nabla \mathbf{u} + \nabla p - \nu \mathbf{u} = 0, \\ | ||
\nabla \cdot \mathbf{u} = 0, | ||
\end{aligned}$$ | ||
|
||
with $\mathbf{u}=(u, v)$ defining the velocity field and $p$ the pressure. The kinematic viscosity is taken as $\nu =0.001$. | ||
|
||
## Results | ||
|
||
![ns_cylinder](/figures/ns_cylinder_u.gif) | ||
|
||
![ns_cylinder](/figures/ns_cylinder_v.gif) | ||
|
||
![ns_cylinder](/figures/ns_cylinder_w.gif)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import ml_collections | ||
|
||
import jax.numpy as jnp | ||
|
||
|
||
def get_config(): | ||
"""Get the default hyperparameter configuration.""" | ||
config = ml_collections.ConfigDict() | ||
|
||
config.mode = "train" | ||
|
||
# Weights & Biases | ||
config.wandb = wandb = ml_collections.ConfigDict() | ||
wandb.project = "PINN-NS_unsteady_cylinder" | ||
wandb.name = "default" | ||
wandb.tag = None | ||
|
||
# Nondimensionalization | ||
config.nondim = True | ||
|
||
# Arch | ||
config.arch = arch = ml_collections.ConfigDict() | ||
arch.arch_name = "ModifiedMlp" | ||
arch.num_layers = 4 | ||
arch.layer_size = 256 | ||
arch.out_dim = 3 | ||
arch.activation = "gelu" # gelu works better than tanh for this problem | ||
arch.periodicity = False | ||
arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1.0, "embed_dim": 256}) | ||
arch.reparam = ml_collections.ConfigDict( | ||
{"type": "weight_fact", "mean": 1.0, "stddev": 0.1} | ||
) | ||
|
||
# Optim | ||
config.optim = optim = ml_collections.ConfigDict() | ||
optim.optimizer = "Adam" | ||
optim.beta1 = 0.9 | ||
optim.beta2 = 0.999 | ||
optim.eps = 1e-8 | ||
optim.learning_rate = 1e-3 | ||
optim.decay_rate = 0.9 | ||
optim.decay_steps = 2000 | ||
optim.grad_accum_steps = 0 | ||
|
||
# Training | ||
config.training = training = ml_collections.ConfigDict() | ||
training.max_steps = 200000 | ||
training.num_time_windows = 10 | ||
|
||
training.inflow_batch_size = 2048 | ||
training.outflow_batch_size = 2048 | ||
training.noslip_batch_size = 2048 | ||
training.ic_batch_size = 2048 | ||
training.res_batch_size = 4096 | ||
|
||
# Weighting | ||
config.weighting = weighting = ml_collections.ConfigDict() | ||
weighting.scheme = "grad_norm" | ||
weighting.init_weights = { | ||
"u_ic": 1.0, | ||
"v_ic": 1.0, | ||
"p_ic": 1.0, | ||
"u_in": 1.0, | ||
"v_in": 1.0, | ||
"u_out": 1.0, | ||
"v_out": 1.0, | ||
"u_noslip": 1.0, | ||
"v_noslip": 1.0, | ||
"ru": 1.0, | ||
"rv": 1.0, | ||
"rc": 1.0, | ||
} | ||
|
||
weighting.momentum = 0.9 | ||
weighting.update_every_steps = 1000 # 100 for grad norm and 1000 for ntk | ||
|
||
weighting.use_causal = True | ||
weighting.causal_tol = 1.0 | ||
weighting.num_chunks = 16 | ||
|
||
# Logging | ||
config.logging = logging = ml_collections.ConfigDict() | ||
logging.log_every_steps = 100 | ||
logging.log_errors = True | ||
logging.log_losses = True | ||
logging.log_weights = True | ||
logging.log_grads = False | ||
logging.log_ntk = False | ||
logging.log_preds = False | ||
|
||
# Saving | ||
config.saving = saving = ml_collections.ConfigDict() | ||
saving.save_every_steps = 10000 | ||
saving.num_keep_ckpts = 10 | ||
|
||
# Input shape for initializing Flax models | ||
config.input_dim = 3 | ||
|
||
# Integer for PRNG random seed. | ||
config.seed = 42 | ||
|
||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import ml_collections | ||
|
||
import jax.numpy as jnp | ||
|
||
|
||
def get_config(): | ||
"""Get the default hyperparameter configuration.""" | ||
config = ml_collections.ConfigDict() | ||
|
||
config.mode = "train" | ||
|
||
# Weights & Biases | ||
config.wandb = wandb = ml_collections.ConfigDict() | ||
wandb.project = "PINN-NS_unsteady_cylinder" | ||
wandb.name = "sota" | ||
wandb.tag = None | ||
|
||
# Nondimensionalization | ||
config.nondim = True | ||
|
||
# Arch | ||
config.arch = arch = ml_collections.ConfigDict() | ||
arch.arch_name = "ModifiedMlp" | ||
arch.num_layers = 5 | ||
arch.layer_size = 256 | ||
arch.out_dim = 3 | ||
arch.activation = "gelu" # gelu works better than tanh for this problem | ||
arch.periodicity = False | ||
arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1.0, "embed_dim": 256}) | ||
arch.reparam = ml_collections.ConfigDict( | ||
{"type": "weight_fact", "mean": 1.0, "stddev": 0.1} | ||
) | ||
|
||
# Optim | ||
config.optim = optim = ml_collections.ConfigDict() | ||
optim.optimizer = "Adam" | ||
optim.beta1 = 0.9 | ||
optim.beta2 = 0.999 | ||
optim.eps = 1e-8 | ||
optim.learning_rate = 1e-3 | ||
optim.decay_rate = 0.9 | ||
optim.decay_steps = 2000 | ||
optim.grad_accum_steps = 0 | ||
|
||
# Training | ||
config.training = training = ml_collections.ConfigDict() | ||
training.max_steps = 200000 | ||
training.num_time_windows = 10 | ||
|
||
training.inflow_batch_size = 2048 | ||
training.outflow_batch_size = 2048 | ||
training.noslip_batch_size = 2048 | ||
training.ic_batch_size = 2048 | ||
training.res_batch_size = 4096 | ||
|
||
# Weighting | ||
config.weighting = weighting = ml_collections.ConfigDict() | ||
weighting.scheme = "grad_norm" | ||
weighting.init_weights = { | ||
"u_ic": 100.0, | ||
"v_ic": 100.0, | ||
"p_ic": 100.0, | ||
"u_in": 100.0, | ||
"v_in": 100.0, | ||
"u_out": 1.0, | ||
"v_out": 1.0, | ||
"u_noslip": 10.0, | ||
"v_noslip": 10.0, | ||
"ru": 1.0, | ||
"rv": 1.0, | ||
"rc": 1.0, | ||
} | ||
|
||
weighting.momentum = 0.9 | ||
weighting.update_every_steps = 1000 # 100 for grad norm and 1000 for ntk | ||
|
||
weighting.use_causal = True | ||
weighting.causal_tol = 1.0 | ||
weighting.num_chunks = 16 | ||
|
||
# Logging | ||
config.logging = logging = ml_collections.ConfigDict() | ||
logging.log_every_steps = 100 | ||
logging.log_errors = True | ||
logging.log_losses = True | ||
logging.log_weights = True | ||
logging.log_grads = False | ||
logging.log_ntk = False | ||
logging.log_preds = False | ||
|
||
# Saving | ||
config.saving = saving = ml_collections.ConfigDict() | ||
saving.save_every_steps = 10000 | ||
saving.num_keep_ckpts = 10 | ||
|
||
# Input shape for initializing Flax models | ||
config.input_dim = 3 | ||
|
||
# Integer for PRNG random seed. | ||
config.seed = 42 | ||
|
||
return config |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Oops, something went wrong.