Skip to content

Commit

Permalink
utility_meshes: enable PeriodicBoxMesh(..., hexahedral=True)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Jul 5, 2024
1 parent 12e2082 commit baff755
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 138 deletions.
300 changes: 162 additions & 138 deletions firedrake/utility_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,38 @@ def TensorBoxMesh(
return m


def _firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz):
plex.removeLabel(dmcommon.FACE_SETS_LABEL)
nvert = 4 # num. vertices on faect
# Apply boundary IDs
plex.createLabel(dmcommon.FACE_SETS_LABEL)
plex.markBoundaryFaces("boundary_faces")
coords = plex.getCoordinates()
coord_sec = plex.getCoordinateSection()
cdim = plex.getCoordinateDim()
assert cdim == 3
if plex.getStratumSize("boundary_faces", 1) > 0:
boundary_faces = plex.getStratumIS("boundary_faces", 1).getIndices()
xtol = Lx / (2 * nx)
ytol = Ly / (2 * ny)
ztol = Lz / (2 * nz)
for face in boundary_faces:
face_coords = plex.vecGetClosure(coord_sec, coords, face)
if all([abs(face_coords[0 + cdim * i]) < xtol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1)
if all([abs(face_coords[0 + cdim * i] - Lx) < xtol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2)
if all([abs(face_coords[1 + cdim * i]) < ytol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3)
if all([abs(face_coords[1 + cdim * i] - Ly) < ytol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4)
if all([abs(face_coords[2 + cdim * i]) < ztol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 5)
if all([abs(face_coords[2 + cdim * i] - Lz) < ztol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 6)
plex.removeLabel("boundary_faces")


@PETSc.Log.EventDecorator()
def BoxMesh(
nx,
Expand Down Expand Up @@ -1657,46 +1689,15 @@ def BoxMesh(
raise ValueError("Number of cells must be a postive integer")
if hexahedral:
plex = PETSc.DMPlex().createBoxMesh((nx, ny, nz), lower=(0., 0., 0.), upper=(Lx, Ly, Lz), simplex=False, periodic=False, interpolate=True, comm=comm)
plex.removeLabel(dmcommon.FACE_SETS_LABEL)
nvert = 4 # num. vertices on faect

# Apply boundary IDs
plex.createLabel(dmcommon.FACE_SETS_LABEL)
plex.markBoundaryFaces("boundary_faces")
coords = plex.getCoordinates()
coord_sec = plex.getCoordinateSection()
cdim = plex.getCoordinateDim()
assert cdim == 3
if plex.getStratumSize("boundary_faces", 1) > 0:
boundary_faces = plex.getStratumIS("boundary_faces", 1).getIndices()
xtol = Lx / (2 * nx)
ytol = Ly / (2 * ny)
ztol = Lz / (2 * nz)
for face in boundary_faces:
face_coords = plex.vecGetClosure(coord_sec, coords, face)
if all([abs(face_coords[0 + cdim * i]) < xtol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1)
if all([abs(face_coords[0 + cdim * i] - Lx) < xtol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2)
if all([abs(face_coords[1 + cdim * i]) < ytol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3)
if all([abs(face_coords[1 + cdim * i] - Ly) < ytol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4)
if all([abs(face_coords[2 + cdim * i]) < ztol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 5)
if all([abs(face_coords[2 + cdim * i] - Lz) < ztol for i in range(nvert)]):
plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 6)
plex.removeLabel("boundary_faces")
m = mesh.Mesh(
_firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz)
return mesh.Mesh(
plex,
reorder=reorder,
distribution_parameters=distribution_parameters,
name=name,
distribution_name=distribution_name,
permutation_name=permutation_name,
comm=comm,
)
return m
comm=comm)
else:
xcoords = np.linspace(0, Lx, nx + 1, dtype=np.double)
ycoords = np.linspace(0, Ly, ny + 1, dtype=np.double)
Expand Down Expand Up @@ -1837,6 +1838,8 @@ def PeriodicBoxMesh(
Lx,
Ly,
Lz,
directions=None,
hexahedral=False,
reorder=None,
distribution_parameters=None,
comm=COMM_WORLD,
Expand Down Expand Up @@ -1869,120 +1872,139 @@ def PeriodicBoxMesh(
raise ValueError(
"3D periodic meshes with fewer than 3 cells are not currently supported"
)
if hexahedral:
if directions is None:
directions = [True, True, True]
if len(directions) != 3:
raise ValueError(f"directions must have exactly dim ( = 3) elements : Got {directions}")
plex = PETSc.DMPlex().createBoxMesh((nx, ny, nz), lower=(0., 0., 0.), upper=(Lx, Ly, Lz), simplex=False, periodic=directions, interpolate=True, sparseLocalize=False, comm=comm)
_firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz)
return mesh.Mesh(
plex,
reorder=reorder,
distribution_parameters=distribution_parameters,
name=name,
distribution_name=distribution_name,
permutation_name=permutation_name,
comm=comm)
else:
if directions is not None:
raise NotImplementedError("Can only specify directions with hexahedral = True")
xcoords = np.arange(0.0, Lx, Lx / nx, dtype=np.double)
ycoords = np.arange(0.0, Ly, Ly / ny, dtype=np.double)
zcoords = np.arange(0.0, Lz, Lz / nz, dtype=np.double)
coords = (
np.asarray(np.meshgrid(xcoords, ycoords, zcoords)).swapaxes(0, 3).reshape(-1, 3)
)
i, j, k = np.meshgrid(
np.arange(nx, dtype=np.int32),
np.arange(ny, dtype=np.int32),
np.arange(nz, dtype=np.int32),
)
v0 = k * nx * ny + j * nx + i
v1 = k * nx * ny + j * nx + (i + 1) % nx
v2 = k * nx * ny + ((j + 1) % ny) * nx + i
v3 = k * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx
v4 = ((k + 1) % nz) * nx * ny + j * nx + i
v5 = ((k + 1) % nz) * nx * ny + j * nx + (i + 1) % nx
v6 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + i
v7 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx

xcoords = np.arange(0.0, Lx, Lx / nx, dtype=np.double)
ycoords = np.arange(0.0, Ly, Ly / ny, dtype=np.double)
zcoords = np.arange(0.0, Lz, Lz / nz, dtype=np.double)
coords = (
np.asarray(np.meshgrid(xcoords, ycoords, zcoords)).swapaxes(0, 3).reshape(-1, 3)
)
i, j, k = np.meshgrid(
np.arange(nx, dtype=np.int32),
np.arange(ny, dtype=np.int32),
np.arange(nz, dtype=np.int32),
)
v0 = k * nx * ny + j * nx + i
v1 = k * nx * ny + j * nx + (i + 1) % nx
v2 = k * nx * ny + ((j + 1) % ny) * nx + i
v3 = k * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx
v4 = ((k + 1) % nz) * nx * ny + j * nx + i
v5 = ((k + 1) % nz) * nx * ny + j * nx + (i + 1) % nx
v6 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + i
v7 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx

cells = [
[v0, v1, v3, v7],
[v0, v1, v7, v5],
[v0, v5, v7, v4],
[v0, v3, v2, v7],
[v0, v6, v4, v7],
[v0, v2, v6, v7],
]
cells = np.asarray(cells).reshape(-1, ny, nx, nz).swapaxes(0, 3).reshape(-1, 4)
plex = mesh.plex_from_cell_list(
3, cells, coords, comm, mesh._generate_default_mesh_topology_name(name)
)
m = mesh.Mesh(
plex,
reorder=reorder_noop,
distribution_parameters=distribution_parameters_no_overlap,
name=name,
distribution_name=distribution_name,
permutation_name=permutation_name,
comm=comm,
)

old_coordinates = m.coordinates
new_coordinates = Function(
VectorFunctionSpace(
m, FiniteElement("DG", tetrahedron, 1, variant="equispaced")
),
name=mesh._generate_default_mesh_coordinates_name(name),
)
cells = [
[v0, v1, v3, v7],
[v0, v1, v7, v5],
[v0, v5, v7, v4],
[v0, v3, v2, v7],
[v0, v6, v4, v7],
[v0, v2, v6, v7],
]
cells = np.asarray(cells).reshape(-1, ny, nx, nz).swapaxes(0, 3).reshape(-1, 4)
plex = mesh.plex_from_cell_list(
3, cells, coords, comm, mesh._generate_default_mesh_topology_name(name)
)
m = mesh.Mesh(
plex,
reorder=reorder_noop,
distribution_parameters=distribution_parameters_no_overlap,
name=name,
distribution_name=distribution_name,
permutation_name=permutation_name,
comm=comm,
)

domain = ""
instructions = f"""
<{RealType}> x0 = real(old_coords[0, 0])
<{RealType}> x1 = real(old_coords[1, 0])
<{RealType}> x2 = real(old_coords[2, 0])
<{RealType}> x3 = real(old_coords[3, 0])
<{RealType}> x_max = fmax(fmax(fmax(x0, x1), x2), x3)
<{RealType}> y0 = real(old_coords[0, 1])
<{RealType}> y1 = real(old_coords[1, 1])
<{RealType}> y2 = real(old_coords[2, 1])
<{RealType}> y3 = real(old_coords[3, 1])
<{RealType}> y_max = fmax(fmax(fmax(y0, y1), y2), y3)
<{RealType}> z0 = real(old_coords[0, 2])
<{RealType}> z1 = real(old_coords[1, 2])
<{RealType}> z2 = real(old_coords[2, 2])
<{RealType}> z3 = real(old_coords[3, 2])
<{RealType}> z_max = fmax(fmax(fmax(z0, z1), z2), z3)
new_coords[0, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[0, 0] == 0.) else old_coords[0, 0]
new_coords[0, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[0, 1] == 0.) else old_coords[0, 1]
new_coords[0, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[0, 2] == 0.) else old_coords[0, 2]
new_coords[1, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[1, 0] == 0.) else old_coords[1, 0]
new_coords[1, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[1, 1] == 0.) else old_coords[1, 1]
new_coords[1, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[1, 2] == 0.) else old_coords[1, 2]
new_coords[2, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[2, 0] == 0.) else old_coords[2, 0]
new_coords[2, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[2, 1] == 0.) else old_coords[2, 1]
new_coords[2, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[2, 2] == 0.) else old_coords[2, 2]
new_coords[3, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[3, 0] == 0.) else old_coords[3, 0]
new_coords[3, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[3, 1] == 0.) else old_coords[3, 1]
new_coords[3, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[3, 2] == 0.) else old_coords[3, 2]
"""
hx = Constant(Lx / nx)
hy = Constant(Ly / ny)
hz = Constant(Lz / nz)
old_coordinates = m.coordinates
new_coordinates = Function(
VectorFunctionSpace(
m, FiniteElement("DG", tetrahedron, 1, variant="equispaced")
),
name=mesh._generate_default_mesh_coordinates_name(name),
)

par_loop(
(domain, instructions),
dx,
{
"new_coords": (new_coordinates, WRITE),
"old_coords": (old_coordinates, READ),
"hx": (hx, READ),
"hy": (hy, READ),
"hz": (hz, READ),
},
)
return _postprocess_periodic_mesh(new_coordinates,
comm,
distribution_parameters,
reorder,
name,
distribution_name,
permutation_name)
domain = ""
instructions = f"""
<{RealType}> x0 = real(old_coords[0, 0])
<{RealType}> x1 = real(old_coords[1, 0])
<{RealType}> x2 = real(old_coords[2, 0])
<{RealType}> x3 = real(old_coords[3, 0])
<{RealType}> x_max = fmax(fmax(fmax(x0, x1), x2), x3)
<{RealType}> y0 = real(old_coords[0, 1])
<{RealType}> y1 = real(old_coords[1, 1])
<{RealType}> y2 = real(old_coords[2, 1])
<{RealType}> y3 = real(old_coords[3, 1])
<{RealType}> y_max = fmax(fmax(fmax(y0, y1), y2), y3)
<{RealType}> z0 = real(old_coords[0, 2])
<{RealType}> z1 = real(old_coords[1, 2])
<{RealType}> z2 = real(old_coords[2, 2])
<{RealType}> z3 = real(old_coords[3, 2])
<{RealType}> z_max = fmax(fmax(fmax(z0, z1), z2), z3)
new_coords[0, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[0, 0] == 0.) else old_coords[0, 0]
new_coords[0, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[0, 1] == 0.) else old_coords[0, 1]
new_coords[0, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[0, 2] == 0.) else old_coords[0, 2]
new_coords[1, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[1, 0] == 0.) else old_coords[1, 0]
new_coords[1, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[1, 1] == 0.) else old_coords[1, 1]
new_coords[1, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[1, 2] == 0.) else old_coords[1, 2]
new_coords[2, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[2, 0] == 0.) else old_coords[2, 0]
new_coords[2, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[2, 1] == 0.) else old_coords[2, 1]
new_coords[2, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[2, 2] == 0.) else old_coords[2, 2]
new_coords[3, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[3, 0] == 0.) else old_coords[3, 0]
new_coords[3, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[3, 1] == 0.) else old_coords[3, 1]
new_coords[3, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[3, 2] == 0.) else old_coords[3, 2]
"""
hx = Constant(Lx / nx)
hy = Constant(Ly / ny)
hz = Constant(Lz / nz)

par_loop(
(domain, instructions),
dx,
{
"new_coords": (new_coordinates, WRITE),
"old_coords": (old_coordinates, READ),
"hx": (hx, READ),
"hy": (hy, READ),
"hz": (hz, READ),
},
)
return _postprocess_periodic_mesh(new_coordinates,
comm,
distribution_parameters,
reorder,
name,
distribution_name,
permutation_name)


@PETSc.Log.EventDecorator()
def PeriodicUnitCubeMesh(
nx,
ny,
nz,
directions=None,
hexahedral=False,
reorder=None,
distribution_parameters=None,
comm=COMM_WORLD,
Expand Down Expand Up @@ -2014,6 +2036,8 @@ def PeriodicUnitCubeMesh(
1.0,
1.0,
1.0,
directions=directions,
hexahedral=hexahedral,
reorder=reorder,
distribution_parameters=distribution_parameters,
comm=comm,
Expand Down
10 changes: 10 additions & 0 deletions tests/regression/test_mesh_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def test_boxmesh_kind(kind, num_cells):
assert m.num_cells() == num_cells


def test_periodic_unit_cube_hex():
mesh = PeriodicBoxMesh(3, 3, 3, 1., 1., 1., directions=[True, True, False], hexahedral=True)
x, y, z = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "CG", 3)
expr = (1 - x) * x + (1 - y) * y + z
f = Function(V).interpolate(expr)
error = assemble((f - expr) ** 2 * dx)
assert error < 1.e-30


@pytest.mark.parallel(nprocs=4)
def test_split_comm_dm_mesh():
nspace = 2
Expand Down

0 comments on commit baff755

Please sign in to comment.