Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ksagiyam/restricted extrusion #3905

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BCBase(object):
def __init__(self, V, sub_domain):

self._function_space = V
self.sub_domain = sub_domain
self.sub_domain = (sub_domain, ) if isinstance(sub_domain, str) else as_tuple(sub_domain)
# If this BC is defined on a subspace (IndexedFunctionSpace or
# ComponentFunctionSpace, possibly recursively), pull out the appropriate
# indices.
Expand Down Expand Up @@ -289,11 +289,9 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set):
subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain
if any(sub not in V.boundary_set for sub in subs):
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V.boundary_set) and not set(self.sub_domain).issubset(V.boundary_set):
raise ValueError(f"Sub-domain {self.sub_domain} not in the boundary set of the restricted space {V.boundary_set}.")
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
"Apply to the components by indexing the space with .sub(...)")
Expand Down
112 changes: 83 additions & 29 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1205,14 +1205,18 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
PETSc.DM dm
PETSc.Section section
PETSc.IS renumbering
PetscInt i, p, layers, pStart, pEnd, dof, j
PetscInt i, p, layers, offset_top, pStart, pEnd, dof, j, k
PetscInt dimension, ndof
PetscInt *dof_array = NULL
const PetscInt *entity_point_map
np.ndarray nodes
np.ndarray layer_extents
np.ndarray points
bint variable, extruded, on_base_
PETSc.SF point_sf
PetscInt nleaves
const PetscInt *ilocal = NULL
PetscInt factor

dm = mesh.topology_dm
if isinstance(dm, PETSc.DMSwarm) and on_base:
Expand All @@ -1221,32 +1225,31 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
extruded = mesh.cell_set._extruded
extruded_periodic = mesh.cell_set._extruded_periodic
on_base_ = on_base
dimension = get_topological_dimension(dm)
nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType)
if variable:
layer_extents = mesh.layer_extents
nodes = nodes_per_entity.reshape(dimension + 1, -1)
elif extruded:
if on_base:
nodes_per_entity = sum(nodes_per_entity[:, i] for i in range(2))
nodes = sum(nodes_per_entity[:, i] for i in range(2)).reshape(dimension + 1, -1)
else:
if extruded_periodic:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2))
nodes = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2)).reshape(dimension + 1, -1)
else:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2))
nodes = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2)).reshape(dimension + 1, -1)
else:
nodes = nodes_per_entity.reshape(dimension + 1, -1)
section = PETSc.Section().create(comm=mesh._comm)
get_chart(dm.dm, &pStart, &pEnd)
section.setChart(pStart, pEnd)

if boundary_set:
renumbering, (constrainedStart, constrainedEnd) = plex_renumbering(dm,
mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set)
if boundary_set and not extruded:
renumbering = plex_renumbering(dm, mesh._entity_classes, reordering=mesh._default_reordering, boundary_set=boundary_set)
else:
renumbering = mesh._dm_renumbering
constrainedStart = -1
constrainedEnd = -1

CHKERR(PetscSectionSetPermutation(section.sec, renumbering.iset))
dimension = get_topological_dimension(dm)
nodes = nodes_per_entity.reshape(dimension + 1, -1)
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd) # gets all points at dim i
if not variable:
Expand All @@ -1260,9 +1263,27 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
ndof = layers*nodes[i, 0] + (layers - 1)*nodes[i, 1]
CHKERR(PetscSectionSetDof(section.sec, p, block_size * ndof))

if boundary_set and extruded and variable:
raise NotImplementedError("Not implemented for variable layer extrusion")
if boundary_set:
# Handle "bottom" and "top" first.
if "bottom" in boundary_set and "top" in boundary_set:
factor = 2
elif "bottom" in boundary_set or "top" in boundary_set:
factor = 1
else:
factor = 0
if factor > 0:
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd)
dof = nodes_per_entity[i, 0]
for p in range(pStart, pEnd):
CHKERR(PetscSectionSetConstraintDof(section.sec, p, factor * dof))
# Potentially overwrite ds_t and dS_t constrained DoFs set in the {"bottom", "top"} cases.
for marker in boundary_set:
if marker == "on_boundary":
if marker in ["bottom", "top"]:
continue
elif marker == "on_boundary":
label = "exterior_facets"
marker = 1
else:
Expand All @@ -1276,11 +1297,36 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
CHKERR(PetscSectionGetDof(section.sec, p, &dof))
CHKERR(PetscSectionSetConstraintDof(section.sec, p, dof))
section.setUp()

if boundary_set:
# have to loop again as we need to call section.setUp() first
CHKERR(PetscSectionGetMaxDof(section.sec, &dof))
CHKERR(PetscMalloc1(dof, &dof_array))
for i in range(dof):
dof_array[i] = -1
if "bottom" in boundary_set or "top" in boundary_set:
for i in range(dimension + 1):
get_depth_stratum(dm.dm, i, &pStart, &pEnd)
if pEnd == pStart:
continue
dof = nodes_per_entity[i, 0]
j = 0
if "bottom" in boundary_set:
for k in range(dof):
dof_array[j] = k
j += 1
if "top" in boundary_set:
offset_top = (nodes_per_entity[i, 0] + nodes_per_entity[i, 1]) * (mesh.layers - 1)
for k in range(dof):
dof_array[j] = offset_top + k
j += 1
for p in range(pStart, pEnd):
# Potentially set wrong values for ds_t and dS_t constrained DoFs here,
# but we will overwrite them in the below.
CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array))
for marker in boundary_set:
if marker == "on_boundary":
if marker in ["bottom", "top"]:
continue
elif marker == "on_boundary":
label = "exterior_facets"
marker = 1
else:
Expand All @@ -1289,24 +1335,24 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
if n == 0:
continue
points = dm.getStratumIS(label, marker).indices
CHKERR(PetscSectionGetMaxDof(section.sec, &dof))
CHKERR(PetscMalloc1(dof, &dof_array))
for i in range(n):
p = points[i]
CHKERR(PetscSectionGetDof(section.sec, p, &dof))
for j in range(dof):
dof_array[j] = j
CHKERR(PetscSectionSetConstraintIndices(section.sec, p, dof_array))
CHKERR(PetscFree(dof_array))

CHKERR(PetscFree(dof_array))
constrained_nodes = 0

CHKERR(ISGetIndices(renumbering.iset, &entity_point_map))
for entity in range(constrainedStart, constrainedEnd):
CHKERR(PetscSectionGetDof(section.sec, entity_point_map[entity], &dof))
get_chart(dm.dm, &pStart, &pEnd)
point_sf = dm.getPointSF()
CHKERR(PetscSFGetGraph(point_sf.sf, NULL, &nleaves, &ilocal, NULL))
for p in range(pStart, pEnd):
CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof))
constrained_nodes += dof
CHKERR(ISRestoreIndices(renumbering.iset, &entity_point_map))

for i in range(nleaves):
p = ilocal[i] if ilocal else i
CHKERR(PetscSectionGetConstraintDof(section.sec, p, &dof))
constrained_nodes -= dof
return section, constrained_nodes


Expand Down Expand Up @@ -2460,7 +2506,7 @@ def plex_renumbering(PETSc.DM plex,
perm_is.setType("general")
CHKERR(ISGeneralSetIndices(perm_is.iset, pEnd - pStart,
perm, PETSC_OWN_POINTER))
return perm_is, (lidx[1], lidx[3])
return perm_is

@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down Expand Up @@ -3310,23 +3356,31 @@ def make_global_numbering(PETSc.Section lsec, PETSc.Section gsec):
:arg lsec: Section describing local dof layout and numbers.
:arg gsec: Section describing global dof layout and numbers."""
cdef:
PetscInt c, p, pStart, pEnd, dof, cdof, loff, goff
PetscInt c, cc, p, pStart, pEnd, dof, cdof, loff, goff
np.ndarray val
PetscInt *dof_array = NULL

val = np.empty(lsec.getStorageSize(), dtype=IntType)
pStart, pEnd = lsec.getChart()

for p in range(pStart, pEnd):
CHKERR(PetscSectionGetDof(lsec.sec, p, &dof))
CHKERR(PetscSectionGetConstraintDof(lsec.sec, p, &cdof))
if dof > 0:
CHKERR(PetscSectionGetOffset(lsec.sec, p, &loff))
CHKERR(PetscSectionGetOffset(gsec.sec, p, &goff))
goff = cabs(goff)
if cdof > 0:
CHKERR(PetscSectionGetConstraintIndices(lsec.sec, p, &dof_array))
for c in range(dof):
val[loff + c] = -2
for c in range(cdof):
val[loff + dof_array[c]] = -1
cc = 0
for c in range(dof):
val[loff + c] = -1
if val[loff + c] < -1:
val[loff + c] = goff + cc
cc += 1
else:
goff = cabs(goff)
for c in range(dof):
val[loff + c] = goff + c
return val
Expand Down
1 change: 1 addition & 0 deletions firedrake/cython/petschdr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ cdef extern from "petscis.h" nogil:
int PetscSectionGetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt*)
int PetscSectionSetConstraintDof(PETSc.PetscSection,PetscInt,PetscInt)
int PetscSectionSetConstraintIndices(PETSc.PetscSection,PetscInt, PetscInt[])
int PetscSectionGetConstraintIndices(PETSc.PetscSection,PetscInt, const PetscInt**)
int PetscSectionGetMaxDof(PETSc.PetscSection,PetscInt*)
int PetscSectionSetPermutation(PETSc.PetscSection,PETSc.PetscIS)
int ISGetIndices(PETSc.PetscIS,PetscInt*[])
Expand Down
14 changes: 13 additions & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import finat.ufl

from pyop2 import op2, mpi
from pyop2.utils import as_tuple

from firedrake import dmhooks, utils
from firedrake.functionspacedata import get_shared_data, create_element
Expand Down Expand Up @@ -876,6 +877,16 @@ class RestrictedFunctionSpace(FunctionSpace):
"""
def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
boundary_set_ = []
for boundary_domain in boundary_set:
if isinstance(boundary_domain, str):
boundary_set_.append(boundary_domain)
else:
# Currently, can not handle intersection of boundaries;
# e.g., boundary_set = [(1, 2)], which is different from [1, 2].
bd, = as_tuple(boundary_domain)
boundary_set_.append(bd)
boundary_set = boundary_set_
for boundary_domain in boundary_set:
label += str(boundary_domain)
label += "_"
Expand All @@ -896,7 +907,8 @@ def set_shared_data(self):
self.node_set = sdata.node_set
r"""A :class:`pyop2.types.set.Set` representing the function space nodes."""
self.dof_dset = op2.DataSet(self.node_set, self.shape or 1,
name="%s_nodes_dset" % self.name)
name="%s_nodes_dset" % self.name,
apply_local_global_filter=sdata.extruded)
r"""A :class:`pyop2.types.dataset.DataSet` representing the function space
degrees of freedom."""

Expand Down
4 changes: 2 additions & 2 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def _renumber_entities(self, reorder):
else:
# No reordering
reordering = None
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)[0]
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)

@utils.cached_property
def cell_closure(self):
Expand Down Expand Up @@ -1979,7 +1979,7 @@ def _renumber_entities(self, reorder):
perm_is.setIndices(perm)
return perm_is
else:
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)[0]
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)

@utils.cached_property # TODO: Recalculate if mesh moves
def cell_closure(self):
Expand Down
8 changes: 1 addition & 7 deletions firedrake/slate/static_condensation/hybridization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import numbers

import numpy as np
import ufl
Expand All @@ -11,7 +10,6 @@
from firedrake.petsc import PETSc
from firedrake.parloops import par_loop, READ, INC
from firedrake.slate.slate import Tensor, AssembledVector
from pyop2.utils import as_tuple
from firedrake.slate.static_condensation.la_utils import SchurComplementBuilder
from firedrake.ufl_expr import adjoint

Expand Down Expand Up @@ -153,11 +151,7 @@ def initialize(self, pc):
if bc.function_space().index != self.vidx:
raise NotImplementedError("Dirichlet bc set on unsupported space.")
# append the set of sub domains
subdom = bc.sub_domain
if isinstance(subdom, str):
neumann_subdomains |= set([subdom])
else:
neumann_subdomains |= set(as_tuple(subdom, numbers.Integral))
neumann_subdomains |= set(bc.sub_domain)

# separate out the top and bottom bcs
extruded_neumann_subdomains = neumann_subdomains & {"top", "bottom"}
Expand Down
24 changes: 23 additions & 1 deletion pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,15 +807,37 @@ def _vec(self):
# But use getSizes to save an Allreduce in computing the
# global size.
size = self.dataset.layout_vec.getSizes()
data = self._data[:size[0]]
if self.dataset._apply_local_global_filter:
data = self._data_filtered
else:
data = self._data[:size[0]]
return PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm)

@utils.cached_property
def _data_filtered(self):
size, _ = self.dataset.layout_vec.getSizes()
size //= self.dataset.layout_vec.block_size
data = self._data[:size]
return np.empty_like(data)

@utils.cached_property
def _data_filter(self):
lgmap = self.dataset.lgmap
n = self.dataset.size
lgmap_owned = lgmap.block_indices[:n]
return lgmap_owned >= 0

@contextlib.contextmanager
def vec_context(self, access):
r"""A context manager for a :class:`PETSc.Vec` from a :class:`Dat`.

:param access: Access descriptor: READ, WRITE, or RW."""
size = self.dataset.size
if self.dataset._apply_local_global_filter and access is not Access.WRITE:
self._data_filtered[:] = self._data[:size][self._data_filter]
yield self._vec
if self.dataset._apply_local_global_filter and access is not Access.READ:
self._data[:size][self._data_filter] = self._data_filtered[:]
if access is not Access.READ:
self.halo_valid = False

Expand Down
Loading
Loading