Skip to content

Commit

Permalink
MixedFunctionSpace: interpolate and restrict (#3868)
Browse files Browse the repository at this point in the history
* Fix restrict=True for MixedFunctionSpace

* Interpolation of mixed UFL expressions
  • Loading branch information
pbrubeck authored Dec 6, 2024
1 parent fdfc908 commit 3285cfc
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 51 deletions.
69 changes: 66 additions & 3 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set) and sub_domain not in V.boundary_set:
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
if len(V.boundary_set):
subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain
if any(sub not in V.boundary_set for sub in subs):
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
Expand All @@ -311,10 +313,12 @@ def function_arg(self):
return self._function_arg

@PETSc.Log.EventDecorator()
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False):
def reconstruct(self, field=None, V=None, g=None, sub_domain=None, use_split=False, indices=()):
fs = self.function_space()
if V is None:
V = fs
for index in indices:
V = V.sub(index)
if g is None:
g = self._original_arg
if sub_domain is None:
Expand Down Expand Up @@ -686,3 +690,62 @@ def homogenize(bc):
return DirichletBC(bc.function_space(), 0, bc.sub_domain)
else:
raise TypeError("homogenize only takes a DirichletBC or a list/tuple of DirichletBCs")


def extract_subdomain_ids(bcs):
"""Return a tuple of subdomain ids for each component of a MixedFunctionSpace.
Parameters
----------
bcs :
A list of boundary conditions.
Returns
-------
A tuple of subdomain ids for each component of a MixedFunctionSpace.
"""
if isinstance(bcs, DirichletBC):
bcs = (bcs,)
if len(bcs) == 0:
return None

V = bcs[0].function_space()
while V.parent:
V = V.parent

_chain = itertools.chain.from_iterable
_to_tuple = lambda s: (s,) if isinstance(s, (int, str)) else s
subdomain_ids = tuple(tuple(_chain(_to_tuple(bc.sub_domain)
for bc in bcs if bc.function_space() == Vsub))
for Vsub in V)
return subdomain_ids


def restricted_function_space(V, ids):
"""Create a :class:`.RestrictedFunctionSpace` from a tuple of subdomain ids.
Parameters
----------
V :
FunctionSpace object to restrict
ids :
A tuple of subdomain ids.
Returns
-------
The RestrictedFunctionSpace.
"""
if not ids:
return V

assert len(ids) == len(V)
spaces = [Vsub if len(boundary_set) == 0 else
firedrake.RestrictedFunctionSpace(Vsub, boundary_set=boundary_set)
for Vsub, boundary_set in zip(V, ids)]

if len(spaces) == 1:
return spaces[0]
else:
return firedrake.MixedFunctionSpace(spaces)
7 changes: 3 additions & 4 deletions firedrake/eigensolver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""Specify and solve finite element eigenproblems."""
from firedrake.assemble import assemble
from firedrake.bcs import DirichletBC
from firedrake.bcs import extract_subdomain_ids, restricted_function_space
from firedrake.function import Function
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake import utils
from firedrake.petsc import OptionsManager, flatten_parameters
Expand Down Expand Up @@ -71,12 +70,12 @@ def __init__(self, A, M=None, bcs=None, bc_shift=0.0, restrict=True):
M = inner(u, v) * dx

if restrict and bcs: # assumed u and v are in the same space here
V_res = RestrictedFunctionSpace(self.output_space, boundary_set=set([bc.sub_domain for bc in bcs]))
V_res = restricted_function_space(self.output_space, extract_subdomain_ids(bcs))
u_res = TrialFunction(V_res)
v_res = TestFunction(V_res)
self.M = replace(M, {u: u_res, v: v_res})
self.A = replace(A, {u: u_res, v: v_res})
self.bcs = [DirichletBC(V_res, bc.function_arg, bc.sub_domain) for bc in bcs]
self.bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.restricted_space = V_res
else:
self.A = A # LHS
Expand Down
4 changes: 2 additions & 2 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def split(self):
@utils.cached_property
def _components(self):
if self.function_space().rank == 0:
return tuple((self, ))
return (self, )
else:
if self.dof_dset.cdim == 1:
return (CoordinatelessFunction(self.function_space().sub(0), val=self.dat,
Expand Down Expand Up @@ -335,7 +335,7 @@ def split(self):
@utils.cached_property
def _components(self):
if self.function_space().rank == 0:
return tuple((self, ))
return (self, )
else:
return tuple(type(self)(self.function_space().sub(i), self.topological.sub(i))
for i in range(self.function_space().block_size))
Expand Down
11 changes: 6 additions & 5 deletions firedrake/functionspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,20 +308,21 @@ def rec(eles):


@PETSc.Log.EventDecorator("CreateFunctionSpace")
def RestrictedFunctionSpace(function_space, name=None, boundary_set=[]):
def RestrictedFunctionSpace(function_space, boundary_set=[], name=None):
"""Create a :class:`.RestrictedFunctionSpace`.
Parameters
----------
function_space :
FunctionSpace object to restrict
name :
An optional name for the function space.
boundary_set :
A set of subdomains of the mesh in which Dirichlet boundary conditions
will be applied.
name :
An optional name for the function space.
"""
return impl.WithGeometry.create(impl.RestrictedFunctionSpace(function_space, name=name,
boundary_set=boundary_set),
return impl.WithGeometry.create(impl.RestrictedFunctionSpace(function_space,
boundary_set=boundary_set,
name=name),
function_space.mesh())
8 changes: 3 additions & 5 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,17 +865,16 @@ class RestrictedFunctionSpace(FunctionSpace):
output of the solver.
:arg function_space: The :class:`FunctionSpace` to restrict.
:kwarg boundary_set: A set of subdomains on which a DirichletBC will be applied.
:kwarg name: An optional name for this :class:`RestrictedFunctionSpace`,
useful for later identification.
:kwarg boundary_set: A set of subdomains on which a DirichletBC will be
applied.
Notes
-----
If using this class to solve or similar, a list of DirichletBCs will still
need to be specified on this space and passed into the function.
"""
def __init__(self, function_space, name=None, boundary_set=frozenset()):
def __init__(self, function_space, boundary_set=frozenset(), name=None):
label = ""
for boundary_domain in boundary_set:
label += str(boundary_domain)
Expand All @@ -889,8 +888,7 @@ def __init__(self, function_space, name=None, boundary_set=frozenset()):
label=self._label)
self.function_space = function_space
self.name = name or (function_space.name or "Restricted" + "_"
+ "_".join(sorted(
[str(i) for i in self.boundary_set])))
+ "_".join(sorted(map(str, self.boundary_set))))

def set_shared_data(self):
sdata = get_shared_data(self._mesh, self.ufl_element(), self.boundary_set)
Expand Down
29 changes: 25 additions & 4 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,8 @@ def make_interpolator(expr, V, subset, access, bcs=None):
elif len(arguments) == 1:
if isinstance(V, firedrake.Function):
raise ValueError("Cannot interpolate an expression with an argument into a Function")
if len(V) > 1:
raise NotImplementedError("Interpolation of mixed expressions with arguments is not supported")
argfs = arguments[0].function_space()
source_mesh = argfs.mesh()
argfs_map = argfs.cell_node_map()
Expand Down Expand Up @@ -992,10 +994,29 @@ def callable():
if numpy.prod(expr.ufl_shape, dtype=int) != V.value_size:
raise RuntimeError('Expression of length %d required, got length %d'
% (V.value_size, numpy.prod(expr.ufl_shape, dtype=int)))
if len(V) > 1:
raise NotImplementedError(
"UFL expressions for mixed functions are not yet supported.")
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))

if len(V) == 1:
loops.extend(_interpolator(V, tensor, expr, subset, arguments, access, bcs=bcs))
else:
if (hasattr(expr, "subfunctions") and len(expr.subfunctions) == len(V)
and all(sub_expr.ufl_shape == Vsub.value_shape for Vsub, sub_expr in zip(V, expr.subfunctions))):
# Use subfunctions if they match the target shapes
expressions = expr.subfunctions
else:
# Unflatten the expression into the shapes of the mixed components
offset = 0
expressions = []
for Vsub in V:
if len(Vsub.value_shape) == 0:
expressions.append(expr[offset])
else:
components = [expr[offset + j] for j in range(Vsub.value_size)]
expressions.append(ufl.as_tensor(numpy.reshape(components, Vsub.value_shape)))
offset += Vsub.value_size
# Interpolate each sub expression into each function space
for Vsub, sub_tensor, sub_expr in zip(V, tensor, expressions):
loops.extend(_interpolator(Vsub, sub_tensor, sub_expr, subset, arguments, access, bcs=bcs))

if bcs and len(arguments) == 0:
loops.extend(partial(bc.apply, f) for bc in bcs)

Expand Down
15 changes: 6 additions & 9 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
DEFAULT_SNES_PARAMETERS
)
from firedrake.function import Function
from firedrake.functionspace import RestrictedFunctionSpace
from firedrake.ufl_expr import TrialFunction, TestFunction
from firedrake.bcs import DirichletBC, EquationBC
from firedrake.bcs import DirichletBC, EquationBC, extract_subdomain_ids, restricted_function_space
from firedrake.adjoint_utils import NonlinearVariationalProblemMixin, NonlinearVariationalSolverMixin
from ufl import replace

Expand Down Expand Up @@ -88,19 +87,17 @@ def __init__(self, F, u, bcs=None, J=None,
self.restrict = restrict

if restrict and bcs:
V_res = RestrictedFunctionSpace(V, boundary_set=set([bc.sub_domain for bc in bcs]))
bcs = [DirichletBC(V_res, bc.function_arg, bc.sub_domain) for bc in bcs]
V_res = restricted_function_space(V, extract_subdomain_ids(bcs))
bcs = [bc.reconstruct(V=V_res, indices=bc._indices) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
replace_dict = {F_arg: v_res}
replace_dict[self.u] = self.u_restrict
self.F = replace(F, replace_dict)
self.F = replace(F, {F_arg: v_res, self.u: self.u_restrict})
v_arg, u_arg = self.J.arguments()
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res})
self.J = replace(self.J, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
if self.Jp:
v_arg, u_arg = self.Jp.arguments()
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res})
self.Jp = replace(self.Jp, {v_arg: v_res, u_arg: u_res, self.u: self.u_restrict})
self.restricted_space = V_res
else:
self.u_restrict = u
Expand Down
41 changes: 41 additions & 0 deletions tests/firedrake/regression/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,47 @@ def test_function():
assert np.allclose(g.dat.data, h.dat.data)


def test_mixed_expression():
m = UnitTriangleMesh()
x = SpatialCoordinate(m)
V1 = FunctionSpace(m, 'P', 1)
V2 = FunctionSpace(m, 'P', 2)

V = V1 * V2
expressions = [x[0], x[0]*x[1]]
expr = as_vector(expressions)
fg = assemble(interpolate(expr, V))
f, g = fg.subfunctions

f1 = Function(V1).interpolate(expressions[0])
g1 = Function(V2).interpolate(expressions[1])
assert np.allclose(f.dat.data, f1.dat.data)
assert np.allclose(g.dat.data, g1.dat.data)


def test_mixed_function():
m = UnitTriangleMesh()
x = SpatialCoordinate(m)
V1 = FunctionSpace(m, 'RT', 1)
V2 = FunctionSpace(m, 'DG', 0)
V = V1 * V2

expressions = [x[0], x[1], Constant(0.444)]
expr = as_vector(expressions)
v = assemble(interpolate(expr, V))

W1 = FunctionSpace(m, 'RT', 2)
W2 = FunctionSpace(m, 'DG', 1)
W = W1 * W2
w = assemble(interpolate(v, W))

f, g = w.subfunctions
f1 = Function(W1).interpolate(x)
g1 = Function(W2).interpolate(expressions[-1])
assert np.allclose(f.dat.data, f1.dat.data)
assert np.allclose(g.dat.data, g1.dat.data)


def test_inner():
m = UnitTriangleMesh()
V1 = FunctionSpace(m, 'P', 1)
Expand Down
29 changes: 26 additions & 3 deletions tests/firedrake/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,40 @@ def test_restricted_function_space_coord_change(j):
new_mesh = Mesh(Function(V).interpolate(as_vector([x, y])))
new_V = FunctionSpace(new_mesh, "CG", j)
bc = DirichletBC(new_V, 0, 1)
new_V_restricted = RestrictedFunctionSpace(new_V, name="Restricted", boundary_set=[1])
new_V_restricted = RestrictedFunctionSpace(new_V, boundary_set=[1], name="Restricted")

compare_function_space_assembly(new_V, new_V_restricted, [bc])


def test_poisson_restricted_mixed_space():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V * Q

u, p = TrialFunctions(Z)
v, q = TestFunctions(Z)
a = inner(u, v)*dx + inner(p, div(v))*dx + inner(div(u), q)*dx
L = inner(1, q)*dx

bcs = [DirichletBC(Z.sub(0), 0, [1])]

w = Function(Z)
solve(a == L, w, bcs=bcs, restrict=False)

w2 = Function(Z)
solve(a == L, w2, bcs=bcs, restrict=True)

assert errornorm(w.subfunctions[0], w2.subfunctions[0]) < 1.e-12
assert errornorm(w.subfunctions[1], w2.subfunctions[1]) < 1.e-12


@pytest.mark.parametrize(["i", "j"], [(1, 0), (2, 0), (2, 1)])
def test_restricted_mixed_spaces(i, j):
def test_poisson_mixed_restricted_spaces(i, j):
mesh = UnitSquareMesh(1, 1)
DG = FunctionSpace(mesh, "DG", j)
CG = VectorFunctionSpace(mesh, "CG", i)
CG_res = RestrictedFunctionSpace(CG, "Restricted", boundary_set=[4])
CG_res = RestrictedFunctionSpace(CG, boundary_set=[4], name="Restricted")
W = CG * DG
W_res = CG_res * DG
bc = DirichletBC(W.sub(0), 0, 4)
Expand Down
Loading

0 comments on commit 3285cfc

Please sign in to comment.