From baff75558d16d9e7f916c90797234d2f758fb6d0 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Fri, 5 Jul 2024 17:08:16 +0100 Subject: [PATCH] utility_meshes: enable PeriodicBoxMesh(..., hexahedral=True) --- firedrake/utility_meshes.py | 300 ++++++++++++----------- tests/regression/test_mesh_generation.py | 10 + 2 files changed, 172 insertions(+), 138 deletions(-) diff --git a/firedrake/utility_meshes.py b/firedrake/utility_meshes.py index b0b3b94490..4599c53d99 100644 --- a/firedrake/utility_meshes.py +++ b/firedrake/utility_meshes.py @@ -1609,6 +1609,38 @@ def TensorBoxMesh( return m +def _firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz): + plex.removeLabel(dmcommon.FACE_SETS_LABEL) + nvert = 4 # num. vertices on faect + # Apply boundary IDs + plex.createLabel(dmcommon.FACE_SETS_LABEL) + plex.markBoundaryFaces("boundary_faces") + coords = plex.getCoordinates() + coord_sec = plex.getCoordinateSection() + cdim = plex.getCoordinateDim() + assert cdim == 3 + if plex.getStratumSize("boundary_faces", 1) > 0: + boundary_faces = plex.getStratumIS("boundary_faces", 1).getIndices() + xtol = Lx / (2 * nx) + ytol = Ly / (2 * ny) + ztol = Lz / (2 * nz) + for face in boundary_faces: + face_coords = plex.vecGetClosure(coord_sec, coords, face) + if all([abs(face_coords[0 + cdim * i]) < xtol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1) + if all([abs(face_coords[0 + cdim * i] - Lx) < xtol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2) + if all([abs(face_coords[1 + cdim * i]) < ytol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3) + if all([abs(face_coords[1 + cdim * i] - Ly) < ytol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4) + if all([abs(face_coords[2 + cdim * i]) < ztol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 5) + if all([abs(face_coords[2 + cdim * i] - Lz) < ztol for i in range(nvert)]): + plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 6) + plex.removeLabel("boundary_faces") + + @PETSc.Log.EventDecorator() def BoxMesh( nx, @@ -1657,46 +1689,15 @@ def BoxMesh( raise ValueError("Number of cells must be a postive integer") if hexahedral: plex = PETSc.DMPlex().createBoxMesh((nx, ny, nz), lower=(0., 0., 0.), upper=(Lx, Ly, Lz), simplex=False, periodic=False, interpolate=True, comm=comm) - plex.removeLabel(dmcommon.FACE_SETS_LABEL) - nvert = 4 # num. vertices on faect - - # Apply boundary IDs - plex.createLabel(dmcommon.FACE_SETS_LABEL) - plex.markBoundaryFaces("boundary_faces") - coords = plex.getCoordinates() - coord_sec = plex.getCoordinateSection() - cdim = plex.getCoordinateDim() - assert cdim == 3 - if plex.getStratumSize("boundary_faces", 1) > 0: - boundary_faces = plex.getStratumIS("boundary_faces", 1).getIndices() - xtol = Lx / (2 * nx) - ytol = Ly / (2 * ny) - ztol = Lz / (2 * nz) - for face in boundary_faces: - face_coords = plex.vecGetClosure(coord_sec, coords, face) - if all([abs(face_coords[0 + cdim * i]) < xtol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 1) - if all([abs(face_coords[0 + cdim * i] - Lx) < xtol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 2) - if all([abs(face_coords[1 + cdim * i]) < ytol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 3) - if all([abs(face_coords[1 + cdim * i] - Ly) < ytol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 4) - if all([abs(face_coords[2 + cdim * i]) < ztol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 5) - if all([abs(face_coords[2 + cdim * i] - Lz) < ztol for i in range(nvert)]): - plex.setLabelValue(dmcommon.FACE_SETS_LABEL, face, 6) - plex.removeLabel("boundary_faces") - m = mesh.Mesh( + _firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz) + return mesh.Mesh( plex, reorder=reorder, distribution_parameters=distribution_parameters, name=name, distribution_name=distribution_name, permutation_name=permutation_name, - comm=comm, - ) - return m + comm=comm) else: xcoords = np.linspace(0, Lx, nx + 1, dtype=np.double) ycoords = np.linspace(0, Ly, ny + 1, dtype=np.double) @@ -1837,6 +1838,8 @@ def PeriodicBoxMesh( Lx, Ly, Lz, + directions=None, + hexahedral=False, reorder=None, distribution_parameters=None, comm=COMM_WORLD, @@ -1869,113 +1872,130 @@ def PeriodicBoxMesh( raise ValueError( "3D periodic meshes with fewer than 3 cells are not currently supported" ) + if hexahedral: + if directions is None: + directions = [True, True, True] + if len(directions) != 3: + raise ValueError(f"directions must have exactly dim ( = 3) elements : Got {directions}") + plex = PETSc.DMPlex().createBoxMesh((nx, ny, nz), lower=(0., 0., 0.), upper=(Lx, Ly, Lz), simplex=False, periodic=directions, interpolate=True, sparseLocalize=False, comm=comm) + _firedrake_box_mesh_hexahedral_mark_boundaries(plex, nx, ny, nz, Lx, Ly, Lz) + return mesh.Mesh( + plex, + reorder=reorder, + distribution_parameters=distribution_parameters, + name=name, + distribution_name=distribution_name, + permutation_name=permutation_name, + comm=comm) + else: + if directions is not None: + raise NotImplementedError("Can only specify directions with hexahedral = True") + xcoords = np.arange(0.0, Lx, Lx / nx, dtype=np.double) + ycoords = np.arange(0.0, Ly, Ly / ny, dtype=np.double) + zcoords = np.arange(0.0, Lz, Lz / nz, dtype=np.double) + coords = ( + np.asarray(np.meshgrid(xcoords, ycoords, zcoords)).swapaxes(0, 3).reshape(-1, 3) + ) + i, j, k = np.meshgrid( + np.arange(nx, dtype=np.int32), + np.arange(ny, dtype=np.int32), + np.arange(nz, dtype=np.int32), + ) + v0 = k * nx * ny + j * nx + i + v1 = k * nx * ny + j * nx + (i + 1) % nx + v2 = k * nx * ny + ((j + 1) % ny) * nx + i + v3 = k * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx + v4 = ((k + 1) % nz) * nx * ny + j * nx + i + v5 = ((k + 1) % nz) * nx * ny + j * nx + (i + 1) % nx + v6 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + i + v7 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx - xcoords = np.arange(0.0, Lx, Lx / nx, dtype=np.double) - ycoords = np.arange(0.0, Ly, Ly / ny, dtype=np.double) - zcoords = np.arange(0.0, Lz, Lz / nz, dtype=np.double) - coords = ( - np.asarray(np.meshgrid(xcoords, ycoords, zcoords)).swapaxes(0, 3).reshape(-1, 3) - ) - i, j, k = np.meshgrid( - np.arange(nx, dtype=np.int32), - np.arange(ny, dtype=np.int32), - np.arange(nz, dtype=np.int32), - ) - v0 = k * nx * ny + j * nx + i - v1 = k * nx * ny + j * nx + (i + 1) % nx - v2 = k * nx * ny + ((j + 1) % ny) * nx + i - v3 = k * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx - v4 = ((k + 1) % nz) * nx * ny + j * nx + i - v5 = ((k + 1) % nz) * nx * ny + j * nx + (i + 1) % nx - v6 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + i - v7 = ((k + 1) % nz) * nx * ny + ((j + 1) % ny) * nx + (i + 1) % nx - - cells = [ - [v0, v1, v3, v7], - [v0, v1, v7, v5], - [v0, v5, v7, v4], - [v0, v3, v2, v7], - [v0, v6, v4, v7], - [v0, v2, v6, v7], - ] - cells = np.asarray(cells).reshape(-1, ny, nx, nz).swapaxes(0, 3).reshape(-1, 4) - plex = mesh.plex_from_cell_list( - 3, cells, coords, comm, mesh._generate_default_mesh_topology_name(name) - ) - m = mesh.Mesh( - plex, - reorder=reorder_noop, - distribution_parameters=distribution_parameters_no_overlap, - name=name, - distribution_name=distribution_name, - permutation_name=permutation_name, - comm=comm, - ) - - old_coordinates = m.coordinates - new_coordinates = Function( - VectorFunctionSpace( - m, FiniteElement("DG", tetrahedron, 1, variant="equispaced") - ), - name=mesh._generate_default_mesh_coordinates_name(name), - ) + cells = [ + [v0, v1, v3, v7], + [v0, v1, v7, v5], + [v0, v5, v7, v4], + [v0, v3, v2, v7], + [v0, v6, v4, v7], + [v0, v2, v6, v7], + ] + cells = np.asarray(cells).reshape(-1, ny, nx, nz).swapaxes(0, 3).reshape(-1, 4) + plex = mesh.plex_from_cell_list( + 3, cells, coords, comm, mesh._generate_default_mesh_topology_name(name) + ) + m = mesh.Mesh( + plex, + reorder=reorder_noop, + distribution_parameters=distribution_parameters_no_overlap, + name=name, + distribution_name=distribution_name, + permutation_name=permutation_name, + comm=comm, + ) - domain = "" - instructions = f""" - <{RealType}> x0 = real(old_coords[0, 0]) - <{RealType}> x1 = real(old_coords[1, 0]) - <{RealType}> x2 = real(old_coords[2, 0]) - <{RealType}> x3 = real(old_coords[3, 0]) - <{RealType}> x_max = fmax(fmax(fmax(x0, x1), x2), x3) - <{RealType}> y0 = real(old_coords[0, 1]) - <{RealType}> y1 = real(old_coords[1, 1]) - <{RealType}> y2 = real(old_coords[2, 1]) - <{RealType}> y3 = real(old_coords[3, 1]) - <{RealType}> y_max = fmax(fmax(fmax(y0, y1), y2), y3) - <{RealType}> z0 = real(old_coords[0, 2]) - <{RealType}> z1 = real(old_coords[1, 2]) - <{RealType}> z2 = real(old_coords[2, 2]) - <{RealType}> z3 = real(old_coords[3, 2]) - <{RealType}> z_max = fmax(fmax(fmax(z0, z1), z2), z3) - - new_coords[0, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[0, 0] == 0.) else old_coords[0, 0] - new_coords[0, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[0, 1] == 0.) else old_coords[0, 1] - new_coords[0, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[0, 2] == 0.) else old_coords[0, 2] - - new_coords[1, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[1, 0] == 0.) else old_coords[1, 0] - new_coords[1, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[1, 1] == 0.) else old_coords[1, 1] - new_coords[1, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[1, 2] == 0.) else old_coords[1, 2] - - new_coords[2, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[2, 0] == 0.) else old_coords[2, 0] - new_coords[2, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[2, 1] == 0.) else old_coords[2, 1] - new_coords[2, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[2, 2] == 0.) else old_coords[2, 2] - - new_coords[3, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[3, 0] == 0.) else old_coords[3, 0] - new_coords[3, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[3, 1] == 0.) else old_coords[3, 1] - new_coords[3, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[3, 2] == 0.) else old_coords[3, 2] - """ - hx = Constant(Lx / nx) - hy = Constant(Ly / ny) - hz = Constant(Lz / nz) + old_coordinates = m.coordinates + new_coordinates = Function( + VectorFunctionSpace( + m, FiniteElement("DG", tetrahedron, 1, variant="equispaced") + ), + name=mesh._generate_default_mesh_coordinates_name(name), + ) - par_loop( - (domain, instructions), - dx, - { - "new_coords": (new_coordinates, WRITE), - "old_coords": (old_coordinates, READ), - "hx": (hx, READ), - "hy": (hy, READ), - "hz": (hz, READ), - }, - ) - return _postprocess_periodic_mesh(new_coordinates, - comm, - distribution_parameters, - reorder, - name, - distribution_name, - permutation_name) + domain = "" + instructions = f""" + <{RealType}> x0 = real(old_coords[0, 0]) + <{RealType}> x1 = real(old_coords[1, 0]) + <{RealType}> x2 = real(old_coords[2, 0]) + <{RealType}> x3 = real(old_coords[3, 0]) + <{RealType}> x_max = fmax(fmax(fmax(x0, x1), x2), x3) + <{RealType}> y0 = real(old_coords[0, 1]) + <{RealType}> y1 = real(old_coords[1, 1]) + <{RealType}> y2 = real(old_coords[2, 1]) + <{RealType}> y3 = real(old_coords[3, 1]) + <{RealType}> y_max = fmax(fmax(fmax(y0, y1), y2), y3) + <{RealType}> z0 = real(old_coords[0, 2]) + <{RealType}> z1 = real(old_coords[1, 2]) + <{RealType}> z2 = real(old_coords[2, 2]) + <{RealType}> z3 = real(old_coords[3, 2]) + <{RealType}> z_max = fmax(fmax(fmax(z0, z1), z2), z3) + + new_coords[0, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[0, 0] == 0.) else old_coords[0, 0] + new_coords[0, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[0, 1] == 0.) else old_coords[0, 1] + new_coords[0, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[0, 2] == 0.) else old_coords[0, 2] + + new_coords[1, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[1, 0] == 0.) else old_coords[1, 0] + new_coords[1, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[1, 1] == 0.) else old_coords[1, 1] + new_coords[1, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[1, 2] == 0.) else old_coords[1, 2] + + new_coords[2, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[2, 0] == 0.) else old_coords[2, 0] + new_coords[2, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[2, 1] == 0.) else old_coords[2, 1] + new_coords[2, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[2, 2] == 0.) else old_coords[2, 2] + + new_coords[3, 0] = x_max+hx[0] if (x_max > real(1.5*hx[0]) and old_coords[3, 0] == 0.) else old_coords[3, 0] + new_coords[3, 1] = y_max+hy[0] if (y_max > real(1.5*hy[0]) and old_coords[3, 1] == 0.) else old_coords[3, 1] + new_coords[3, 2] = z_max+hz[0] if (z_max > real(1.5*hz[0]) and old_coords[3, 2] == 0.) else old_coords[3, 2] + """ + hx = Constant(Lx / nx) + hy = Constant(Ly / ny) + hz = Constant(Lz / nz) + + par_loop( + (domain, instructions), + dx, + { + "new_coords": (new_coordinates, WRITE), + "old_coords": (old_coordinates, READ), + "hx": (hx, READ), + "hy": (hy, READ), + "hz": (hz, READ), + }, + ) + return _postprocess_periodic_mesh(new_coordinates, + comm, + distribution_parameters, + reorder, + name, + distribution_name, + permutation_name) @PETSc.Log.EventDecorator() @@ -1983,6 +2003,8 @@ def PeriodicUnitCubeMesh( nx, ny, nz, + directions=None, + hexahedral=False, reorder=None, distribution_parameters=None, comm=COMM_WORLD, @@ -2014,6 +2036,8 @@ def PeriodicUnitCubeMesh( 1.0, 1.0, 1.0, + directions=directions, + hexahedral=hexahedral, reorder=reorder, distribution_parameters=distribution_parameters, comm=comm, diff --git a/tests/regression/test_mesh_generation.py b/tests/regression/test_mesh_generation.py index 637b0a1502..ffe29dbd1d 100644 --- a/tests/regression/test_mesh_generation.py +++ b/tests/regression/test_mesh_generation.py @@ -476,6 +476,16 @@ def test_boxmesh_kind(kind, num_cells): assert m.num_cells() == num_cells +def test_periodic_unit_cube_hex(): + mesh = PeriodicBoxMesh(3, 3, 3, 1., 1., 1., directions=[True, True, False], hexahedral=True) + x, y, z = SpatialCoordinate(mesh) + V = FunctionSpace(mesh, "CG", 3) + expr = (1 - x) * x + (1 - y) * y + z + f = Function(V).interpolate(expr) + error = assemble((f - expr) ** 2 * dx) + assert error < 1.e-30 + + @pytest.mark.parallel(nprocs=4) def test_split_comm_dm_mesh(): nspace = 2