diff --git a/examples/ns_unsteady_cylinder/README.md b/examples/ns_unsteady_cylinder/README.md new file mode 100644 index 00000000..4311fdbf --- /dev/null +++ b/examples/ns_unsteady_cylinder/README.md @@ -0,0 +1,19 @@ +[# Navier–Stokes flow around a cylinder + +## Problem Set-up + +We consider a fluid with a density of $\rho=1.0$ and describe its behavior using the time-dependent incompressible Navier-Stokes equations +$$\begin{aligned} + \mathbf{u}_t + \mathbf{u} \nabla \mathbf{u} + \nabla p - \nu \mathbf{u} = 0, \\ + \nabla \cdot \mathbf{u} = 0, +\end{aligned}$$ + +with $\mathbf{u}=(u, v)$ defining the velocity field and $p$ the pressure. The kinematic viscosity is taken as $\nu =0.001$. + +## Results + +![ns_cylinder](/figures/ns_cylinder_u.gif) + +![ns_cylinder](/figures/ns_cylinder_v.gif) + +![ns_cylinder](/figures/ns_cylinder_w.gif)] \ No newline at end of file diff --git a/examples/ns_unsteady_cylinder/configs/default.py b/examples/ns_unsteady_cylinder/configs/default.py new file mode 100644 index 00000000..2aa85f72 --- /dev/null +++ b/examples/ns_unsteady_cylinder/configs/default.py @@ -0,0 +1,102 @@ +import ml_collections + +import jax.numpy as jnp + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.mode = "train" + + # Weights & Biases + config.wandb = wandb = ml_collections.ConfigDict() + wandb.project = "PINN-NS_unsteady_cylinder" + wandb.name = "default" + wandb.tag = None + + # Nondimensionalization + config.nondim = True + + # Arch + config.arch = arch = ml_collections.ConfigDict() + arch.arch_name = "ModifiedMlp" + arch.num_layers = 4 + arch.layer_size = 256 + arch.out_dim = 3 + arch.activation = "gelu" # gelu works better than tanh for this problem + arch.periodicity = False + arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1.0, "embed_dim": 256}) + arch.reparam = ml_collections.ConfigDict( + {"type": "weight_fact", "mean": 1.0, "stddev": 0.1} + ) + + # Optim + config.optim = optim = ml_collections.ConfigDict() + optim.optimizer = "Adam" + optim.beta1 = 0.9 + optim.beta2 = 0.999 + optim.eps = 1e-8 + optim.learning_rate = 1e-3 + optim.decay_rate = 0.9 + optim.decay_steps = 2000 + optim.grad_accum_steps = 0 + + # Training + config.training = training = ml_collections.ConfigDict() + training.max_steps = 200000 + training.num_time_windows = 10 + + training.inflow_batch_size = 2048 + training.outflow_batch_size = 2048 + training.noslip_batch_size = 2048 + training.ic_batch_size = 2048 + training.res_batch_size = 4096 + + # Weighting + config.weighting = weighting = ml_collections.ConfigDict() + weighting.scheme = "grad_norm" + weighting.init_weights = { + "u_ic": 1.0, + "v_ic": 1.0, + "p_ic": 1.0, + "u_in": 1.0, + "v_in": 1.0, + "u_out": 1.0, + "v_out": 1.0, + "u_noslip": 1.0, + "v_noslip": 1.0, + "ru": 1.0, + "rv": 1.0, + "rc": 1.0, + } + + weighting.momentum = 0.9 + weighting.update_every_steps = 1000 # 100 for grad norm and 1000 for ntk + + weighting.use_causal = True + weighting.causal_tol = 1.0 + weighting.num_chunks = 16 + + # Logging + config.logging = logging = ml_collections.ConfigDict() + logging.log_every_steps = 100 + logging.log_errors = True + logging.log_losses = True + logging.log_weights = True + logging.log_grads = False + logging.log_ntk = False + logging.log_preds = False + + # Saving + config.saving = saving = ml_collections.ConfigDict() + saving.save_every_steps = 10000 + saving.num_keep_ckpts = 10 + + # Input shape for initializing Flax models + config.input_dim = 3 + + # Integer for PRNG random seed. + config.seed = 42 + + return config diff --git a/examples/ns_unsteady_cylinder/configs/sota.py b/examples/ns_unsteady_cylinder/configs/sota.py new file mode 100644 index 00000000..618a89f8 --- /dev/null +++ b/examples/ns_unsteady_cylinder/configs/sota.py @@ -0,0 +1,102 @@ +import ml_collections + +import jax.numpy as jnp + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.mode = "train" + + # Weights & Biases + config.wandb = wandb = ml_collections.ConfigDict() + wandb.project = "PINN-NS_unsteady_cylinder" + wandb.name = "sota" + wandb.tag = None + + # Nondimensionalization + config.nondim = True + + # Arch + config.arch = arch = ml_collections.ConfigDict() + arch.arch_name = "ModifiedMlp" + arch.num_layers = 5 + arch.layer_size = 256 + arch.out_dim = 3 + arch.activation = "gelu" # gelu works better than tanh for this problem + arch.periodicity = False + arch.fourier_emb = ml_collections.ConfigDict({"embed_scale": 1.0, "embed_dim": 256}) + arch.reparam = ml_collections.ConfigDict( + {"type": "weight_fact", "mean": 1.0, "stddev": 0.1} + ) + + # Optim + config.optim = optim = ml_collections.ConfigDict() + optim.optimizer = "Adam" + optim.beta1 = 0.9 + optim.beta2 = 0.999 + optim.eps = 1e-8 + optim.learning_rate = 1e-3 + optim.decay_rate = 0.9 + optim.decay_steps = 2000 + optim.grad_accum_steps = 0 + + # Training + config.training = training = ml_collections.ConfigDict() + training.max_steps = 200000 + training.num_time_windows = 10 + + training.inflow_batch_size = 2048 + training.outflow_batch_size = 2048 + training.noslip_batch_size = 2048 + training.ic_batch_size = 2048 + training.res_batch_size = 4096 + + # Weighting + config.weighting = weighting = ml_collections.ConfigDict() + weighting.scheme = "grad_norm" + weighting.init_weights = { + "u_ic": 100.0, + "v_ic": 100.0, + "p_ic": 100.0, + "u_in": 100.0, + "v_in": 100.0, + "u_out": 1.0, + "v_out": 1.0, + "u_noslip": 10.0, + "v_noslip": 10.0, + "ru": 1.0, + "rv": 1.0, + "rc": 1.0, + } + + weighting.momentum = 0.9 + weighting.update_every_steps = 1000 # 100 for grad norm and 1000 for ntk + + weighting.use_causal = True + weighting.causal_tol = 1.0 + weighting.num_chunks = 16 + + # Logging + config.logging = logging = ml_collections.ConfigDict() + logging.log_every_steps = 100 + logging.log_errors = True + logging.log_losses = True + logging.log_weights = True + logging.log_grads = False + logging.log_ntk = False + logging.log_preds = False + + # Saving + config.saving = saving = ml_collections.ConfigDict() + saving.save_every_steps = 10000 + saving.num_keep_ckpts = 10 + + # Input shape for initializing Flax models + config.input_dim = 3 + + # Integer for PRNG random seed. + config.seed = 42 + + return config diff --git a/examples/ns_unsteady_cylinder/data/fine_mesh.npy b/examples/ns_unsteady_cylinder/data/fine_mesh.npy new file mode 100644 index 00000000..819842e3 Binary files /dev/null and b/examples/ns_unsteady_cylinder/data/fine_mesh.npy differ diff --git a/examples/ns_unsteady_cylinder/data/fine_mesh_near_cylinder.npy b/examples/ns_unsteady_cylinder/data/fine_mesh_near_cylinder.npy new file mode 100644 index 00000000..1da3fd9d Binary files /dev/null and b/examples/ns_unsteady_cylinder/data/fine_mesh_near_cylinder.npy differ diff --git a/examples/ns_unsteady_cylinder/data/ns_unsteady.npy b/examples/ns_unsteady_cylinder/data/ns_unsteady.npy new file mode 100644 index 00000000..903c52dc Binary files /dev/null and b/examples/ns_unsteady_cylinder/data/ns_unsteady.npy differ diff --git a/examples/ns_unsteady_cylinder/eval.py b/examples/ns_unsteady_cylinder/eval.py new file mode 100644 index 00000000..db79c67f --- /dev/null +++ b/examples/ns_unsteady_cylinder/eval.py @@ -0,0 +1,182 @@ +from functools import partial +import time +import os + +from absl import logging + +from flax.training import checkpoints + +import jax +import jax.numpy as jnp +from jax import random, jit, vmap, pmap +from jax.tree_util import tree_map + +import scipy.io +import ml_collections + +import wandb + +import models + +from jaxpi.utils import restore_checkpoint + +import matplotlib.pyplot as plt +import matplotlib.tri as tri + + +def parabolic_inflow(y, U_max): + u = 4 * U_max * y * (0.41 - y) / (0.41**2) + v = jnp.zeros_like(y) + return u, v + + +def evaluate(config: ml_collections.ConfigDict, workdir: str): + # Load dataset + ( + u_ref, + v_ref, + p_ref, + coords, + inflow_coords, + outflow_coords, + wall_coords, + cylinder_coords, + nu, + ) = get_dataset() + + U_max = 0.3 # maximum velocity + u_inflow, _ = parabolic_inflow(inflow_coords[:, 1], U_max) + + # Nondimensionalization + if config.nondim == True: + # Nondimensionalization parameters + U_star = 0.2 # characteristic velocity + L_star = 0.1 # characteristic length + Re = U_star * L_star / nu + + # Nondimensionalize coordinates and inflow velocity + inflow_coords = inflow_coords / L_star + outflow_coords = outflow_coords / L_star + wall_coords = wall_coords / L_star + cylinder_coords = cylinder_coords / L_star + coords = coords / L_star + + # Nondimensionalize flow field + u_inflow = u_inflow / U_star + u_ref = u_ref / U_star + v_ref = v_ref / U_star + + else: + Re = nu + + # Initialize model + model = models.NavierStokes2D( + config, + u_inflow, + inflow_coords, + outflow_coords, + wall_coords, + cylinder_coords, + Re, + ) + + # Restore checkpoint + ckpt_path = os.path.join(".", "ckpt", config.wandb.name) + model.state = restore_checkpoint(model.state, ckpt_path) + params = model.state.params + + # Predict + u_pred = model.u_pred_fn(params, coords[:, 0], coords[:, 1]) + v_pred = model.v_pred_fn(params, coords[:, 0], coords[:, 1]) + + u_error = jnp.sqrt(jnp.mean((u_ref - u_pred) ** 2)) / jnp.sqrt(jnp.mean(u_ref**2)) + v_error = jnp.sqrt(jnp.mean((v_ref - v_pred) ** 2)) / jnp.sqrt(jnp.mean(v_ref**2)) + + print("l2_error of u: {:.4e}".format(u_error)) + print("l2_error of v: {:.4e}".format(v_error)) + + # Plot + # Save dir + save_dir = os.path.join(workdir, "figures", config.wandb.name) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + + if config.nondim == True: + # Dimensionalize coordinates and flow field + coords = coords * L_star + + u_ref = u_ref * U_star + v_ref = v_ref * U_star + + u_pred = u_pred * U_star + v_pred = v_pred * U_star + + # Triangulation + x = coords[:, 0] + y = coords[:, 1] + triang = tri.Triangulation(x, y) + + # Mask the triangles inside the cylinder + center = (0.2, 0.2) + radius = 0.05 + + x_tri = x[triang.triangles].mean(axis=1) + y_tri = y[triang.triangles].mean(axis=1) + dist_from_center = jnp.sqrt((x_tri - center[0]) ** 2 + (y_tri - center[1]) ** 2) + triang.set_mask(dist_from_center < radius) + + fig1 = plt.figure(figsize=(18, 12)) + plt.subplot(3, 1, 1) + plt.tricontourf(triang, u_ref, cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(3, 1, 2) + plt.tricontourf(triang, u_pred, cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Predicted u(x, y)") + plt.tight_layout() + + plt.subplot(3, 1, 3) + plt.tricontourf(triang, jnp.abs(u_ref - u_pred), cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Absolute error") + plt.tight_layout() + + save_path = os.path.join(save_dir, "ns_steady_u.pdf") + fig1.savefig(save_path, bbox_inches="tight", dpi=300) + + fig2 = plt.figure(figsize=(18, 12)) + plt.subplot(3, 1, 1) + plt.tricontourf(triang, v_ref, cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Exact") + plt.tight_layout() + + plt.subplot(3, 1, 2) + plt.tricontourf(triang, v_pred, cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Predicted u(x, y)") + plt.tight_layout() + + plt.subplot(3, 1, 3) + plt.tricontourf(triang, jnp.abs(v_ref - v_pred), cmap="jet", levels=100) + plt.colorbar() + plt.xlabel("x") + plt.ylabel("y") + plt.title("Absolute error") + plt.tight_layout() + + save_path = os.path.join(save_dir, "ns_steady_v.pdf") + fig2.savefig(save_path, bbox_inches="tight", dpi=300) diff --git a/examples/ns_unsteady_cylinder/figures/ns_cylinder_pred.png b/examples/ns_unsteady_cylinder/figures/ns_cylinder_pred.png new file mode 100644 index 00000000..77c5296b Binary files /dev/null and b/examples/ns_unsteady_cylinder/figures/ns_cylinder_pred.png differ diff --git a/examples/ns_unsteady_cylinder/figures/ns_cylinder_u.gif b/examples/ns_unsteady_cylinder/figures/ns_cylinder_u.gif new file mode 100644 index 00000000..7b9a2d9d Binary files /dev/null and b/examples/ns_unsteady_cylinder/figures/ns_cylinder_u.gif differ diff --git a/examples/ns_unsteady_cylinder/figures/ns_cylinder_v.gif b/examples/ns_unsteady_cylinder/figures/ns_cylinder_v.gif new file mode 100644 index 00000000..4e19e3de Binary files /dev/null and b/examples/ns_unsteady_cylinder/figures/ns_cylinder_v.gif differ diff --git a/examples/ns_unsteady_cylinder/figures/ns_cylinder_w.gif b/examples/ns_unsteady_cylinder/figures/ns_cylinder_w.gif new file mode 100644 index 00000000..8f882e23 Binary files /dev/null and b/examples/ns_unsteady_cylinder/figures/ns_cylinder_w.gif differ diff --git a/examples/ns_unsteady_cylinder/main.py b/examples/ns_unsteady_cylinder/main.py new file mode 100644 index 00000000..f81c5a73 --- /dev/null +++ b/examples/ns_unsteady_cylinder/main.py @@ -0,0 +1,39 @@ +# DETERMINISTIC +import os + +# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_reductions --xla_gpu_autotune_level=0" +# os.environ['TF_CUDNN_DETERMINISTIC'] = '1' + +from absl import app +from absl import flags +from absl import logging + +import jax +from ml_collections import config_flags + +import train +import eval + +FLAGS = flags.FLAGS + +flags.DEFINE_string("workdir", ".", "Directory to store model data.") + +config_flags.DEFINE_config_file( + "config", + "./configs/default.py", + "File path to the training hyperparameter configuration.", + lock_config=True, +) + + +def main(argv): + if FLAGS.config.mode == "train": + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + + elif FLAGS.config.mode == "eval": + eval.evaluate(FLAGS.config, FLAGS.workdir) + + +if __name__ == "__main__": + flags.mark_flags_as_required(["config", "workdir"]) + app.run(main) diff --git a/examples/ns_unsteady_cylinder/models.py b/examples/ns_unsteady_cylinder/models.py new file mode 100644 index 00000000..425d6f59 --- /dev/null +++ b/examples/ns_unsteady_cylinder/models.py @@ -0,0 +1,479 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from jax import lax, jit, grad, vmap +from jax.tree_util import tree_map + +import optax + +from jaxpi import archs +from jaxpi.models import ForwardBVP, ForwardIVP +from jaxpi.utils import ntk_fn +from jaxpi.evaluator import BaseEvaluator + + +class NavierStokes2D(ForwardIVP): + def __init__(self, config, inflow_fn, temporal_dom, coords, Re): + super().__init__(config) + + self.inflow_fn = inflow_fn + self.temporal_dom = temporal_dom + self.coords = coords + self.Re = Re # Reynolds number + + # Non-dimensionalized domain length and width + self.L, self.W = self.coords.max(axis=0) - self.coords.min(axis=0) + + if config.nondim == True: + self.U_star = 1.0 + self.L_star = 0.1 + else: + self.U_star = 1.0 + self.L_star = 1.0 + + # Predict functions over batch + self.u0_pred_fn = vmap(self.u_net, (None, None, 0, 0)) + self.v0_pred_fn = vmap(self.v_net, (None, None, 0, 0)) + self.p0_pred_fn = vmap(self.p_net, (None, None, 0, 0)) + + self.u_pred_fn = vmap(self.u_net, (None, 0, 0, 0)) + self.v_pred_fn = vmap(self.v_net, (None, 0, 0, 0)) + self.p_pred_fn = vmap(self.p_net, (None, 0, 0, 0)) + self.w_pred_fn = vmap(self.w_net, (None, 0, 0, 0)) + self.r_pred_fn = vmap(self.r_net, (None, 0, 0, 0)) + + def neural_net(self, params, t, x, y): + t = t / self.temporal_dom[1] # rescale t into [0, 1] + x = x / self.L # rescale x into [0, 1] + y = y / self.W # rescale y into [0, 1] + inputs = jnp.stack([t, x, y]) + outputs = self.state.apply_fn(params, inputs) + + # Start with an initial state of the channel flow + y_hat = y * self.L_star * self.W + u = outputs[0] + 4 * 1.5 * y_hat * (0.41 - y_hat) / (0.41**2) + v = outputs[1] + p = outputs[2] + return u, v, p + + def u_net(self, params, t, x, y): + u, _, _ = self.neural_net(params, t, x, y) + return u + + def v_net(self, params, t, x, y): + _, v, _ = self.neural_net(params, t, x, y) + return v + + def p_net(self, params, t, x, y): + _, _, p = self.neural_net(params, t, x, y) + return p + + def w_net(self, params, t, x, y): + u, v, _ = self.neural_net(params, t, x, y) + u_y = grad(self.u_net, argnums=3)(params, t, x, y) + v_x = grad(self.v_net, argnums=2)(params, t, x, y) + w = v_x - u_y + return w + + def r_net(self, params, t, x, y): + u, v, p = self.neural_net(params, t, x, y) + + u_t = grad(self.u_net, argnums=1)(params, t, x, y) + v_t = grad(self.v_net, argnums=1)(params, t, x, y) + + u_x = grad(self.u_net, argnums=2)(params, t, x, y) + v_x = grad(self.v_net, argnums=2)(params, t, x, y) + p_x = grad(self.p_net, argnums=2)(params, t, x, y) + + u_y = grad(self.u_net, argnums=3)(params, t, x, y) + v_y = grad(self.v_net, argnums=3)(params, t, x, y) + p_y = grad(self.p_net, argnums=3)(params, t, x, y) + + u_xx = grad(grad(self.u_net, argnums=2), argnums=2)(params, t, x, y) + v_xx = grad(grad(self.v_net, argnums=2), argnums=2)(params, t, x, y) + + u_yy = grad(grad(self.u_net, argnums=3), argnums=3)(params, t, x, y) + v_yy = grad(grad(self.v_net, argnums=3), argnums=3)(params, t, x, y) + + # PDE residual + ru = u_t + u * u_x + v * u_y + p_x - (u_xx + u_yy) / self.Re + rv = v_t + u * v_x + v * v_y + p_y - (v_xx + v_yy) / self.Re + rc = u_x + v_y + + # outflow boundary residual + u_out = u_x / self.Re - p + v_out = v_x + + return ru, rv, rc, u_out, v_out + + def ru_net(self, params, t, x, y): + ru, _, _, _, _ = self.r_net(params, t, x, y) + return ru + + def rv_net(self, params, t, x, y): + _, rv, _, _, _ = self.r_net(params, t, x, y) + return rv + + def rc_net(self, params, t, x, y): + _, _, rc, _, _ = self.r_net(params, t, x, y) + return rc + + def u_out_net(self, params, t, x, y): + _, _, _, u_out, _ = self.r_net(params, t, x, y) + return u_out + + def v_out_net(self, params, t, x, y): + _, _, _, _, v_out = self.r_net(params, t, x, y) + return v_out + + @partial(jit, static_argnums=(0,)) + def res_and_w(self, params, batch): + # Sort temporal coordinates + t_sorted = batch[:, 0].sort() + ru_pred, rv_pred, rc_pred, _, _ = self.r_pred_fn( + params, t_sorted, batch[:, 1], batch[:, 2] + ) + + ru_pred = ru_pred.reshape(self.num_chunks, -1) + rv_pred = rv_pred.reshape(self.num_chunks, -1) + rc_pred = rc_pred.reshape(self.num_chunks, -1) + + ru_l = jnp.mean(ru_pred**2, axis=1) + rv_l = jnp.mean(rv_pred**2, axis=1) + rc_l = jnp.mean(rc_pred**2, axis=1) + + ru_gamma = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ ru_l))) + rv_gamma = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ rv_l))) + rc_gamma = lax.stop_gradient(jnp.exp(-self.tol * (self.M @ rc_l))) + + # Take minimum of the causal weights + gamma = jnp.vstack([ru_gamma, rv_gamma, rc_gamma]) + gamma = gamma.min(0) + + return ru_l, rv_l, rc_l, gamma + + @partial(jit, static_argnums=(0,)) + def compute_diag_ntk(self, params, batch): + # Unpack batch + ic_batch = batch["ic"] + inflow_batch = batch["inflow"] + outflow_batch = batch["outflow"] + noslip_batch = batch["noslip"] + res_batch = batch["res"] + + coords_batch, u_batch, v_batch, p_batch = ic_batch + + u_ic_ntk = vmap(ntk_fn, (None, None, None, 0, 0))( + self.u_net, params, 0.0, coords_batch[:, 0], coords_batch[:, 1] + ) + v_ic_ntk = vmap(ntk_fn, (None, None, None, 0, 0))( + self.v_net, params, 0.0, coords_batch[:, 0], coords_batch[:, 1] + ) + p_ic_ntk = vmap(ntk_fn, (None, None, None, 0, 0))( + self.p_net, params, 0.0, coords_batch[:, 0], coords_batch[:, 1] + ) + + u_in_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.u_net, + params, + inflow_batch[:, 0], + inflow_batch[:, 1], + inflow_batch[:, 2], + ) + v_in_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.v_net, + params, + inflow_batch[:, 0], + inflow_batch[:, 1], + inflow_batch[:, 2], + ) + + u_out_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.u_out_net, + params, + outflow_batch[:, 0], + outflow_batch[:, 1], + outflow_batch[:, 2], + ) + v_out_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.v_out_net, + params, + outflow_batch[:, 0], + outflow_batch[:, 1], + outflow_batch[:, 2], + ) + + u_noslip_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.u_net, + params, + noslip_batch[:, 0], + noslip_batch[:, 1], + noslip_batch[:, 2], + ) + v_noslip_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.v_net, + params, + noslip_batch[:, 0], + noslip_batch[:, 1], + noslip_batch[:, 2], + ) + + # Consider the effect of causal weights + if self.config.weighting.use_causal: + res_batch = jnp.array( + [res_batch[:, 0].sort(), res_batch[:, 1], res_batch[:, 2]] + ).T + ru_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.ru_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + rv_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.rv_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + rc_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.rc_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + + ru_ntk = ru_ntk.reshape(self.num_chunks, -1) # shape: (num_chunks, -1) + rv_ntk = rv_ntk.reshape(self.num_chunks, -1) + rc_ntk = rc_ntk.reshape(self.num_chunks, -1) + + ru_ntk = jnp.mean( + ru_ntk, axis=1 + ) # average convergence rate over each chunk + rv_ntk = jnp.mean(rv_ntk, axis=1) + rc_ntk = jnp.mean(rc_ntk, axis=1) + + _, _, _, causal_weights = self.res_and_w(params, res_batch) + ru_ntk = ru_ntk * causal_weights # multiply by causal weights + rv_ntk = rv_ntk * causal_weights + rc_ntk = rc_ntk * causal_weights + else: + ru_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.ru_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + rv_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.rv_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + rc_ntk = vmap(ntk_fn, (None, None, 0, 0, 0))( + self.rc_net, params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + + ntk_dict = { + "u_ic": u_ic_ntk, + "v_ic": v_ic_ntk, + "p_ic": p_ic_ntk, + "u_in": u_in_ntk, + "v_in": v_in_ntk, + "u_out": u_out_ntk, + "v_out": v_out_ntk, + "u_noslip": u_noslip_ntk, + "v_noslip": v_noslip_ntk, + "ru": ru_ntk, + "rv": rv_ntk, + "rc": rc_ntk, + } + + return ntk_dict + + @partial(jit, static_argnums=(0,)) + def losses(self, params, batch): + # Unpack batch + ic_batch = batch["ic"] + inflow_batch = batch["inflow"] + outflow_batch = batch["outflow"] + noslip_batch = batch["noslip"] + res_batch = batch["res"] + + # Initial condition loss + coords_batch, u_batch, v_batch, p_batch = ic_batch + + u_ic_pred = self.u0_pred_fn(params, 0.0, coords_batch[:, 0], coords_batch[:, 1]) + v_ic_pred = self.v0_pred_fn(params, 0.0, coords_batch[:, 0], coords_batch[:, 1]) + p_ic_pred = self.p0_pred_fn(params, 0.0, coords_batch[:, 0], coords_batch[:, 1]) + + u_ic_loss = jnp.mean((u_ic_pred - u_batch) ** 2) + v_ic_loss = jnp.mean((v_ic_pred - v_batch) ** 2) + p_ic_loss = jnp.mean((p_ic_pred - p_batch) ** 2) + + # inflow loss + u_in, _ = self.inflow_fn(inflow_batch[:, 2]) + + u_in_pred = self.u_pred_fn( + params, inflow_batch[:, 0], inflow_batch[:, 1], inflow_batch[:, 2] + ) + v_in_pred = self.v_pred_fn( + params, inflow_batch[:, 0], inflow_batch[:, 1], inflow_batch[:, 2] + ) + + u_in_loss = jnp.mean((u_in_pred - u_in) ** 2) + v_in_loss = jnp.mean(v_in_pred**2) + + # outflow loss + _, _, _, u_out_pred, v_out_pred = self.r_pred_fn( + params, outflow_batch[:, 0], outflow_batch[:, 1], outflow_batch[:, 2] + ) + + u_out_loss = jnp.mean(u_out_pred**2) + v_out_loss = jnp.mean(v_out_pred**2) + + # noslip loss + u_noslip_pred = self.u_pred_fn( + params, noslip_batch[:, 0], noslip_batch[:, 1], noslip_batch[:, 2] + ) + v_noslip_pred = self.v_pred_fn( + params, noslip_batch[:, 0], noslip_batch[:, 1], noslip_batch[:, 2] + ) + + u_noslip_loss = jnp.mean(u_noslip_pred**2) + v_noslip_loss = jnp.mean(v_noslip_pred**2) + + # residual loss + if self.config.weighting.use_causal == True: + ru_l, rv_l, rc_l, gamma = self.res_and_w(params, res_batch) + ru_loss = jnp.mean(gamma * ru_l) + rv_loss = jnp.mean(gamma * rv_l) + rc_loss = jnp.mean(gamma * rc_l) + + else: + ru_pred, rv_pred, rc_pred, _, _ = self.r_pred_fn( + params, res_batch[:, 0], res_batch[:, 1], res_batch[:, 2] + ) + ru_loss = jnp.mean(ru_pred**2) + rv_loss = jnp.mean(rv_pred**2) + rc_loss = jnp.mean(rc_pred**2) + + loss_dict = { + "u_ic": u_ic_loss, + "v_ic": v_ic_loss, + "p_ic": p_ic_loss, + "u_in": u_in_loss, + "v_in": v_in_loss, + "u_out": u_out_loss, + "v_out": v_out_loss, + "u_noslip": u_noslip_loss, + "v_noslip": v_noslip_loss, + "ru": ru_loss, + "rv": rv_loss, + "rc": rc_loss, + } + + return loss_dict + + def u_v_grads(self, params, t, x, y): + u_x = grad(self.u_net, argnums=2)(params, t, x, y) + v_x = grad(self.v_net, argnums=2)(params, t, x, y) + + u_y = grad(self.u_net, argnums=3)(params, t, x, y) + v_y = grad(self.v_net, argnums=3)(params, t, x, y) + + return u_x, v_x, u_y, v_y + + @partial(jit, static_argnums=(0,)) + def compute_drag_lift(self, params, t, U_star, L_star): + nu = 0.001 # Dimensional viscosity + radius = 0.05 # radius of cylinder + center = (0.2, 0.2) # center of cylinder + num_theta = 256 # number of points on cylinder for evaluation + + # Discretize cylinder into points + theta = jnp.linspace(0.0, 2 * jnp.pi, num_theta) + d_theta = theta[1] - theta[0] + ds = radius * d_theta + + # Cylinder coordinates + x_cyl = radius * jnp.cos(theta) + center[0] + y_cyl = radius * jnp.sin(theta) + center[1] + + # Out normals of cylinder + n_x = jnp.cos(theta) + n_y = jnp.sin(theta) + + # Nondimensionalize input cylinder coordinates + x_cyl = x_cyl / L_star + y_cyl = y_cyl / L_star + + # Nondimensionalize fonrt and back points + front = jnp.array([center[0] - radius, center[1]]) / L_star + back = jnp.array([center[0] + radius, center[1]]) / L_star + + # Predictions + u_x_pred, v_x_pred, u_y_pred, v_y_pred = vmap( + vmap(self.u_v_grads, (None, None, 0, 0)), (None, 0, None, None) + )(params, t, x_cyl, y_cyl) + + p_pred = vmap(vmap(self.p_net, (None, None, 0, 0)), (None, 0, None, None))( + params, t, x_cyl, y_cyl + ) + + p_pred = p_pred - jnp.mean(p_pred, axis=1, keepdims=True) + + p_front_pred = vmap(self.p_net, (None, 0, None, None))( + params, t, front[0], front[1] + ) + p_back_pred = vmap(self.p_net, (None, 0, None, None))( + params, t, back[0], back[1] + ) + p_diff = p_front_pred - p_back_pred + + # Dimensionalize velocity gradients and pressure + u_x_pred = u_x_pred * U_star / L_star + v_x_pred = v_x_pred * U_star / L_star + u_y_pred = u_y_pred * U_star / L_star + v_y_pred = v_y_pred * U_star / L_star + p_pred = p_pred * U_star**2 + p_diff = p_diff * U_star**2 + + I0 = (-p_pred[:, :-1] + 2 * nu * u_x_pred[:, :-1]) * n_x[:-1] + nu * ( + u_y_pred[:, :-1] + v_x_pred[:, :-1] + ) * n_y[:-1] + I1 = (-p_pred[:, 1:] + 2 * nu * u_x_pred[:, 1:]) * n_x[1:] + nu * ( + u_y_pred[:, 1:] + v_x_pred[:, 1:] + ) * n_y[1:] + + F_D = 0.5 * jnp.sum(I0 + I1, axis=1) * ds + + I0 = (-p_pred[:, :-1] + 2 * nu * v_y_pred[:, :-1]) * n_y[:-1] + nu * ( + u_y_pred[:, :-1] + v_x_pred[:, :-1] + ) * n_x[:-1] + I1 = (-p_pred[:, 1:] + 2 * nu * v_y_pred[:, 1:]) * n_y[1:] + nu * ( + u_y_pred[:, 1:] + v_x_pred[:, 1:] + ) * n_x[1:] + + F_L = 0.5 * jnp.sum(I0 + I1, axis=1) * ds + + # Nondimensionalized drag and lift and pressure difference + C_D = 2 / (U_star**2 * L_star) * F_D + C_L = 2 / (U_star**2 * L_star) * F_L + + return C_D, C_L, p_diff + + +class NavierStokesEvaluator(BaseEvaluator): + def __init__(self, config, model): + super().__init__(config, model) + + # def log_preds(self, params, x_star, y_star): + # u_pred = vmap(vmap(model.u_net, (None, None, 0)), (None, 0, None))(params, x_star, y_star) + # v_pred = vmap(vmap(model.v_net, (None, None, 0)), (None, 0, None))(params, x_star, y_star) + # U_pred = jnp.sqrt(u_pred ** 2 + v_pred ** 2) + # + # fig = plt.figure() + # plt.pcolor(U_pred.T, cmap='jet') + # log_dict['U_pred'] = fig + # fig.close() + + def __call__(self, state, batch): + self.log_dict = super().__call__(state, batch) + + if self.config.weighting.use_causal: + _, _, _, causal_weight = self.model.res_and_w(state.params, batch["res"]) + self.log_dict["cas_weight"] = causal_weight.min() + + # if self.config.logging.log_errors: + # self.log_errors(state.params, coords, u_ref, v_ref) + # + # if self.config.logging.log_preds: + # self.log_preds(state.params, coords) + + return self.log_dict diff --git a/examples/ns_unsteady_cylinder/train.py b/examples/ns_unsteady_cylinder/train.py new file mode 100644 index 00000000..bd75857a --- /dev/null +++ b/examples/ns_unsteady_cylinder/train.py @@ -0,0 +1,305 @@ +import functools +from functools import partial +import time +import os + +from absl import logging + +import jax + +import jax.numpy as jnp +from jax import random, vmap, pmap, local_device_count +from jax.tree_util import tree_map + +import matplotlib.pyplot as plt + +import numpy as np +import scipy.io +import ml_collections + +import wandb + +import models + +from jaxpi.samplers import BaseSampler, SpaceSampler, TimeSpaceSampler +from jaxpi.logging import Logger +from jaxpi.utils import save_checkpoint + +from utils import get_dataset, get_fine_mesh, parabolic_inflow + + +class ICSampler(SpaceSampler): + def __init__(self, u, v, p, coords, batch_size, rng_key=random.PRNGKey(1234)): + super().__init__(coords, batch_size, rng_key) + + self.u = u + self.v = v + self.p = p + + @partial(pmap, static_broadcasted_argnums=(0,)) + def data_generation(self, key): + "Generates data containing batch_size samples" + idx = random.choice(key, self.coords.shape[0], shape=(self.batch_size,)) + + coords_batch = self.coords[idx, :] + + u_batch = self.u[idx] + v_batch = self.v[idx] + p_batch = self.p[idx] + + batch = (coords_batch, u_batch, v_batch, p_batch) + + return batch + + +class ResSampler(BaseSampler): + def __init__( + self, + temporal_dom, + coarse_coords, + fine_coords, + batch_size, + rng_key=random.PRNGKey(1234), + ): + super().__init__(batch_size, rng_key) + + self.temporal_dom = temporal_dom + + self.coarse_coords = coarse_coords + self.fine_coords = fine_coords + + @partial(pmap, static_broadcasted_argnums=(0,)) + def data_generation(self, key): + "Generates data containing batch_size samples" + subkeys = random.split(key, 4) + + temporal_batch = random.uniform( + subkeys[0], + shape=(2 * self.batch_size, 1), + minval=self.temporal_dom[0], + maxval=self.temporal_dom[1], + ) + + coarse_idx = random.choice( + subkeys[1], + self.coarse_coords.shape[0], + shape=(self.batch_size,), + replace=True, + ) + fine_idx = random.choice( + subkeys[2], + self.fine_coords.shape[0], + shape=(self.batch_size,), + replace=True, + ) + + coarse_spatial_batch = self.coarse_coords[coarse_idx, :] + fine_spatial_batch = self.fine_coords[fine_idx, :] + spatial_batch = jnp.vstack([coarse_spatial_batch, fine_spatial_batch]) + spatial_batch = random.permutation( + subkeys[3], spatial_batch + ) # mix the coarse and fine coordinates + + batch = jnp.concatenate([temporal_batch, spatial_batch], axis=1) + + return batch + + +def train_one_window(config, workdir, model, samplers, idx): + # Initialize evaluator + evaluator = models.NavierStokesEvaluator(config, model) + + # Initialize logger + logger = Logger() + + step_offset = idx * config.training.max_steps + + # jit warm up + print("Waiting for JIT...") + for step in range(config.training.max_steps): + start_time = time.time() + + # Sample mini-batch + batch = {} + for key, sampler in samplers.items(): + batch[key] = next(sampler) + + model.state = model.step(model.state, batch) + + # Update weights if necessary + if config.weighting.scheme in ["grad_norm", "ntk"]: + if step % config.weighting.update_every_steps == 0: + model.state = model.update_weights(model.state, batch) + + # Log training metrics, only use host 0 to record results + if jax.process_index() == 0: + if step % config.logging.log_every_steps == 0: + # Get the first replica of the state and batch + state = jax.device_get(tree_map(lambda x: x[0], model.state)) + batch = jax.device_get(tree_map(lambda x: x[0], batch)) + log_dict = evaluator(state, batch) + wandb.log(log_dict, step + step_offset) + + end_time = time.time() + # Report training metrics + logger.log_iter(step, start_time, end_time, log_dict) + + # Save checkpoint + if config.saving.save_every_steps is not None: + if (step + 1) % config.saving.save_every_steps == 0 or ( + step + 1 + ) == config.training.max_steps: + path = os.path.join( + workdir, "ckpt", config.wandb.name, "time_window_{}".format(idx + 1) + ) + save_checkpoint(model.state, path, keep=config.saving.num_keep_ckpts) + + return model + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): + # Initialize W&B + wandb_config = config.wandb + wandb.init(project=wandb_config.project, name=wandb_config.name) + + # Get dataset + ( + u_ref, + v_ref, + p_ref, + coords, + inflow_coords, + outflow_coords, + wall_coords, + cyl_coords, + nu, + ) = get_dataset() + ( + fine_coords, + fine_coords_near_cyl, + ) = get_fine_mesh() # finer mesh for evaluating PDE residuals + + noslip_coords = jnp.vstack((wall_coords, cyl_coords)) + + # T = 1.0 # final time of simulation + + # Nondimensionalization + if config.nondim == True: + # Nondimensionalization parameters + U_star = 1.0 # characteristic velocity + L_star = 0.1 # characteristic length + T_star = L_star / U_star # characteristic time + Re = U_star * L_star / nu + + # Nondimensionalize coordinates and inflow velocity + # T = T / T_star + inflow_coords = inflow_coords / L_star + outflow_coords = outflow_coords / L_star + noslip_coords = noslip_coords / L_star + + coords = coords / L_star + fine_coords = fine_coords / L_star + fine_coords_near_cyl = fine_coords_near_cyl / L_star + + # Nondimensionalize flow field + # u_inflow = u_inflow / U_star + u_ref = u_ref / U_star + v_ref = v_ref / U_star + p_ref = p_ref / U_star**2 + + else: + U_star = 1.0 + L_star = 1.0 + T_star = 1.0 + Re = 1 / nu + + # Temporal domain of each time window + t0 = 0.0 + t1 = 1.0 + + temporal_dom = jnp.array([t0, t1 * (1 + 0.05)]) + + # Inflow boundary conditions + U_max = 1.5 # maximum velocity + inflow_fn = lambda y: parabolic_inflow(y * L_star, U_max) + + # Set initial condition + # Use the last time step of a coarse numerical solution as the initial condition + u0 = u_ref[-1, :] + v0 = v_ref[-1, :] + p0 = p_ref[-1, :] + + for idx in range(config.training.num_time_windows): + logging.info("Training time window {}".format(idx + 1)) + + # Initialize Sampler + keys = random.split(random.PRNGKey(0), 5) + ic_sampler = iter( + ICSampler( + u0, v0, p0, coords, config.training.ic_batch_size, rng_key=keys[0] + ) + ) + inflow_sampler = iter( + TimeSpaceSampler( + temporal_dom, + inflow_coords, + config.training.inflow_batch_size, + rng_key=keys[1], + ) + ) + outflow_sampler = iter( + TimeSpaceSampler( + temporal_dom, + outflow_coords, + config.training.outflow_batch_size, + rng_key=keys[2], + ) + ) + noslip_sampler = iter( + TimeSpaceSampler( + temporal_dom, + noslip_coords, + config.training.noslip_batch_size, + rng_key=keys[3], + ) + ) + + res_sampler = iter( + ResSampler( + temporal_dom, + fine_coords, + fine_coords, + config.training.res_batch_size, + rng_key=keys[4], + ) + ) + + samplers = { + "ic": ic_sampler, + "inflow": inflow_sampler, + "outflow": outflow_sampler, + "noslip": noslip_sampler, + "res": res_sampler, + } + + # Initialize model + model = models.NavierStokes2D(config, inflow_fn, temporal_dom, coords, Re) + + # Train model for the current time window + model = train_one_window(config, workdir, model, samplers, idx) + + # Update the initial condition for the next time window + if config.training.num_time_windows > 1: + state = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], model.state)) + params = state.params + u0 = vmap(model.u_net, (None, None, 0, 0))( + params, t1, coords[:, 0], coords[:, 1] + ) + v0 = vmap(model.v_net, (None, None, 0, 0))( + params, t1, coords[:, 0], coords[:, 1] + ) + p0 = vmap(model.p_net, (None, None, 0, 0))( + params, t1, coords[:, 0], coords[:, 1] + ) + + del model, state, params diff --git a/examples/ns_unsteady_cylinder/utils.py b/examples/ns_unsteady_cylinder/utils.py new file mode 100644 index 00000000..0fe8b6c2 --- /dev/null +++ b/examples/ns_unsteady_cylinder/utils.py @@ -0,0 +1,43 @@ +import jax.numpy as jnp + + +def parabolic_inflow(y, U_max): + u = 4 * U_max * y * (0.41 - y) / (0.41**2) + v = jnp.zeros_like(y) + return u, v + + +def get_dataset(): + data = jnp.load("data/ns_unsteady.npy", allow_pickle=True).item() + u_ref = jnp.array(data["u"]) + v_ref = jnp.array(data["v"]) + p_ref = jnp.array(data["p"]) + t = jnp.array(data["t"]) + coords = jnp.array(data["coords"]) + inflow_coords = jnp.array(data["inflow_coords"]) + outflow_coords = jnp.array(data["outflow_coords"]) + wall_coords = jnp.array(data["wall_coords"]) + cylinder_coords = jnp.array(data["cylinder_coords"]) + nu = jnp.array(data["nu"]) + + return ( + u_ref, + v_ref, + p_ref, + coords, + inflow_coords, + outflow_coords, + wall_coords, + cylinder_coords, + nu, + ) + + +def get_fine_mesh(): + data = jnp.load("data/fine_mesh.npy", allow_pickle=True).item() + fine_coords = jnp.array(data["coords"]) + + data = jnp.load("data/fine_mesh_near_cylinder.npy", allow_pickle=True).item() + fine_coords_near_cyl = jnp.array(data["coords"]) + + return fine_coords, fine_coords_near_cyl