Skip to content

Commit

Permalink
enable restricted function space on extruded meshes (#3905)
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam authored Dec 11, 2024
1 parent 7b2b81a commit be411c5
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 54 deletions.
8 changes: 3 additions & 5 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BCBase(object):
def __init__(self, V, sub_domain):

self._function_space = V
self.sub_domain = sub_domain
self.sub_domain = (sub_domain, ) if isinstance(sub_domain, str) else as_tuple(sub_domain)
# If this BC is defined on a subspace (IndexedFunctionSpace or
# ComponentFunctionSpace, possibly recursively), pull out the appropriate
# indices.
Expand Down Expand Up @@ -289,11 +289,9 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set):
subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain
if any(sub not in V.boundary_set for sub in subs):
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V.boundary_set) and not set(self.sub_domain).issubset(V.boundary_set):
raise ValueError(f"Sub-domain {self.sub_domain} not in the boundary set of the restricted space {V.boundary_set}.")
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
"Apply to the components by indexing the space with .sub(...)")
Expand Down
112 changes: 83 additions & 29 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1205,14 +1205,18 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
PETSc.DM dm
PETSc.Section section
PETSc.IS renumbering
PetscInt i, p, layers, pStart, pEnd, dof, j
PetscInt i, p, layers, offset_top, pStart, pEnd, dof, j, k
PetscInt dimension, ndof
PetscInt *dof_array = NULL
const PetscInt *entity_point_map
np.ndarray nodes
np.ndarray layer_extents
np.ndarray points
bint variable, extruded, on_base_
PETSc.SF point_sf
PetscInt nleaves
const PetscInt *ilocal = NULL
PetscInt factor

dm = mesh.topology_dm
if isinstance(dm, PETSc.DMSwarm) and on_base:
Expand All @@ -1221,32 +1225,31 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
extruded = mesh.cell_set._extruded
extruded_periodic = mesh.cell_set._extruded_periodic
on_base_ = on_base
dimension = get_topological_dimension(dm)
nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType)
if variable:
layer_extents = mesh.layer_extents
nodes = nodes_per_entity.reshape(dimension + 1, -1)
elif extruded:
if on_base:
nodes_per_entity = sum(nodes_per_entity[:, i] for i in range(2))
nodes = sum(nodes_per_entity[:, i] for i in range(2)).reshape(dimension + 1, -1)
else:
if extruded_periodic:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2))
nodes = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2)).reshape(dimension + 1, -1)
else:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2))
nodes = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2)).reshape(dimension + 1, -1)
else:
nodes = nodes_per_entity.reshape(dimension + 1, -1)
section = PETSc.Section().create(comm=mesh._comm)
get_chart(dm.dm, &pStart, &pEnd)
section.setChart(pStart, pEnd)

if boundary_set:
renumbering, (constrainedStart, constrainedEnd) = plex_renumbering(dm,
mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set)
if boundary_set and not extruded:
renumbering = plex_renumbering(dm, mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set)
else:
renumbering = mesh._dm_renumbering
constrainedStart = -1
constrainedEnd = -1

CHKERR(PetscSectionSetPermutation(section.sec, renumbering.iset))
dimension = get_topological_dimension(dm)
nodes = nodes_per_entity.reshape(dimension + 1, -1)
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd) # gets all points at dim i
if not variable:
Expand All @@ -1260,9 +1263,27 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
ndof = layers*nodes[i, 0] + (layers - 1)*nodes[i, 1]
CHKERR(PetscSectionSetDof(section.sec, p, block_size * ndof))

if boundary_set and extruded and variable:
raise NotImplementedError("Not implemented for variable layer extrusion")
if boundary_set:
# Handle "bottom" and "top" first.
if "bottom" in boundary_set and "top" in boundary_set:
factor = 2
elif "bottom" in boundary_set or "top" in boundary_set:
factor = 1
else:
factor = 0
if factor > 0:
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd)
dof = nodes_per_entity[i, 0]
for p in range(pStart, pEnd):
CHKERR(PetscSectionSetConstraintDof(section.sec, p, factor * dof))
# Potentially overwrite ds_t and dS_t constrained DoFs set in the {"bottom", "top"} cases.
for marker in boundary_set:
if marker == "on_boundary":
if marker in ["bottom", "top"]:
continue
elif marker == "on_boundary":
label = "exterior_facets"
marker = 1
else:
Expand All @@ -1276,11 +1297,36 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
CHKERR(PetscSectionGetDof(section.sec, p, &dof))
CHKERR(PetscSectionSetConstraintDof(section.sec, p, dof))
section.setUp()

if boundary_set:
# have to loop again as we need to call section.setUp() first
CHKERR(PetscSectionGetMaxDof(section.sec, &dof))
CHKERR(PetscMalloc1(dof, &dof_array))
for i in range(dof):
dof_array[i] = -1
if "bottom" in boundary_set or "top" in boundary_set:
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd)
if pEnd == pStart:
continue
dof = nodes_per_entity[i, 0]
j = 0
if "bottom" in boundary_set:
for k in range(dof):
dof_array[j] = k
j += 1
if "top" in boundary_set:
offset_top = (nodes_per_entity[i, 0] + nodes_per_entity[i, 1]) * (mesh.layers - 1)
for k in range(dof):
dof_array[j] = offset_top + k
j += 1
for p in range(pStart, pEnd):
# Potentially set wrong values for ds_t and dS_t constrained DoFs here,
# but we will overwrite them in the below.
CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array))
for marker in boundary_set:
if marker == "on_boundary":
if marker in ["bottom", "top"]:
continue
elif marker == "on_boundary":
label = "exterior_facets"
marker = 1
else:
Expand All @@ -1289,24 +1335,24 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
if n == 0:
continue
points = dm.getStratumIS(label, marker).indices
CHKERR(PetscSectionGetMaxDof(section.sec, &dof))
CHKERR(PetscMalloc1(dof, &dof_array))
for i in range(n):
p = points[i]
CHKERR(PetscSectionGetDof(section.sec, p, &dof))
for j in range(dof):
dof_array[j] = j
CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array))
CHKERR(PetscFree(dof_array))

CHKERR(PetscFree(dof_array))
constrained_nodes = 0

CHKERR(ISGetIndices(renumbering.iset, &entity_point_map))
for entity in range(constrainedStart, constrainedEnd):
CHKERR(PetscSectionGetDof(section.sec, entity_point_map[entity], &dof))
get_chart(dm.dm, &pStart, &pEnd)
point_sf = dm.getPointSF()
CHKERR(PetscSFGetGraph(point_sf.sf, NULL, &nleaves, &ilocal, NULL))
for p in range(pStart, pEnd):
CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof))
constrained_nodes += dof
CHKERR(ISRestoreIndices(renumbering.iset, &entity_point_map))

for i in range(nleaves):
p = ilocal[i] if ilocal else i
CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof))
constrained_nodes -= dof
return section, constrained_nodes


Expand Down Expand Up @@ -2460,7 +2506,7 @@ def plex_renumbering(PETSc.DM plex,
perm_is.setType("general")
CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart,
perm, PETSC_OWN_POINTER))
return perm_is, (lidx[1], lidx[3])
return perm_is

@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down Expand Up @@ -3310,23 +3356,31 @@ def make_global_numbering(PETSc.Section lsec, PETSc.Section gsec):
:arg lsec: Section describing local dof layout and numbers.
:arg gsec: Section describing global dof layout and numbers."""
cdef:
PetscInt c, p, pStart, pEnd, dof, cdof, loff, goff
PetscInt c, cc, p, pStart, pEnd, dof, cdof, loff, goff
np.ndarray val
PetscInt *dof_array = NULL

val = np.empty(lsec.getStorageSize(), dtype=IntType)
pStart, pEnd = lsec.getChart()

for p in range(pStart, pEnd):
CHKERR(PetscSectionGetDof(lsec.sec, p, &dof))
CHKERR(PetscSectionGetConstraintDof(lsec.sec, p, &cdof))
if dof > 0:
CHKERR(PetscSectionGetOffset(lsec.sec, p, &loff))
CHKERR(PetscSectionGetOffset(gsec.sec, p, &goff))
goff = cabs(goff)
if cdof > 0:
CHKERR(PetscSectionGetConstraintIndices(lsec.sec, p, &dof_array))
for c in range(dof):
val[loff + c] = -2
for c in range(cdof):
val[loff + dof_array[c]] = -1
cc = 0
for c in range(dof):
val[loff + c] = -1
if val[loff + c] < -1:
val[loff + c] = goff + cc
cc += 1
else:
goff = cabs(goff)
for c in range(dof):
val[loff + c] = goff + c
return val
Expand Down
1 change: 1 addition & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ cdef extern from "petscis.h" nogil:
int PetscSectionGetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt*)
int PetscSectionSetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt)
int PetscSectionSetConstraintIndices(PETSc.PetscSection,PetscInt, PetscInt[])
int PetscSectionGetConstraintIndices(PETSc.PetscSection,PetscInt, const PetscInt**)
int PetscSectionGetMaxDof(PETSc.PetscSection,PetscInt*)
int PetscSectionSetPermutation(PETSc.PetscSection,PETSc.PetscIS)
int ISGetIndices(PETSc.PetscIS,PetscInt*[])
Expand Down
14 changes: 13 additions & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import finat.ufl

from pyop2 import op2, mpi
from pyop2.utils import as_tuple

from firedrake import dmhooks, utils
from firedrake.functionspacedata import get_shared_data, create_element
Expand Down Expand Up @@ -876,6 +877,16 @@ class RestrictedFunctionSpace(FunctionSpace):
"""
def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
boundary_set_ = []
for boundary_domain in boundary_set:
if isinstance(boundary_domain, str):
boundary_set_.append(boundary_domain)
else:
# Currently, can not handle intersection of boundaries;
# e.g., boundary_set = [(1, 2)], which is different from [1, 2].
bd, = as_tuple(boundary_domain)
boundary_set_.append(bd)
boundary_set = boundary_set_
for boundary_domain in boundary_set:
label += str(boundary_domain)
label += "_"
Expand All @@ -896,7 +907,8 @@ def set_shared_data(self):
self.node_set = sdata.node_set
r"""A :class:`pyop2.types.set.Set` representing the function space nodes."""
self.dof_dset = op2.DataSet(self.node_set, self.shape or 1,
name="%s_nodes_dset" % self.name)
name="%s_nodes_dset" % self.name,
apply_local_global_filter=sdata.extruded)
r"""A :class:`pyop2.types.dataset.DataSet` representing the function space
degrees of freedom."""

Expand Down
4 changes: 2 additions & 2 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def _renumber_entities(self, reorder):
else:
# No reordering
reordering = None
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)[0]
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)

@utils.cached_property
def cell_closure(self):
Expand Down Expand Up @@ -1979,7 +1979,7 @@ def _renumber_entities(self, reorder):
perm_is.setIndices(perm)
return perm_is
else:
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)[0]
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)

@utils.cached_property # TODO: Recalculate if mesh moves
def cell_closure(self):
Expand Down
8 changes: 1 addition & 7 deletions firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import numbers

import numpy as np
import ufl
Expand All @@ -11,7 +10,6 @@
from firedrake.petsc import PETSc
from firedrake.parloops import par_loop, READ, INC
from firedrake.slate.slate import Tensor, AssembledVector
from pyop2.utils import as_tuple
from firedrake.slate.static_condensation.la_utils import SchurComplementBuilder
from firedrake.ufl_expr import adjoint

Expand Down Expand Up @@ -153,11 +151,7 @@ def initialize(self, pc):
if bc.function_space().index != self.vidx:
raise NotImplementedError("Dirichlet bc set on unsupported space.")
# append the set of sub domains
subdom = bc.sub_domain
if isinstance(subdom, str):
neumann_subdomains |= set([subdom])
else:
neumann_subdomains |= set(as_tuple(subdom, numbers.Integral))
neumann_subdomains |= set(bc.sub_domain)

# separate out the top and bottom bcs
extruded_neumann_subdomains = neumann_subdomains & {"top", "bottom"}
Expand Down
24 changes: 23 additions & 1 deletion pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,37 @@ def _vec(self):
# But use getSizes to save an Allreduce in computing the
# global size.
size = self.dataset.layout_vec.getSizes()
data = self._data[:size[0]]
if self.dataset._apply_local_global_filter:
data = self._data_filtered
else:
data = self._data[:size[0]]
return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm)

@utils.cached_property
def _data_filtered(self):
size, _ = self.dataset.layout_vec.getSizes()
size //= self.dataset.layout_vec.block_size
data = self._data[:size]
return np.empty_like(data)

@utils.cached_property
def _data_filter(self):
lgmap = self.dataset.lgmap
n = self.dataset.size
lgmap_owned = lgmap.block_indices[:n]
return lgmap_owned >= 0

@contextlib.contextmanager
def vec_context(self, access):
r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`.
:param access: Access descriptor: READ, WRITE, or RW."""
size = self.dataset.size
if self.dataset._apply_local_global_filter and access is not Access.WRITE:
self._data_filtered[:] = self._data[:size][self._data_filter]
yield self._vec
if self.dataset._apply_local_global_filter and access is not Access.READ:
self._data[:size][self._data_filter] = self._data_filtered[:]
if access is not Access.READ:
self.halo_valid = False

Expand Down
Loading

0 comments on commit be411c5

Please sign in to comment.