-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trefftz support for Firedrake #3775
Draft
UZerbinati
wants to merge
13
commits into
master
Choose a base branch
from
uz/trefftz
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
1d05e50
Trefftz support for Firedrake
48109d3
Clean up
4e19060
Making the linter happy
b84833c
Skip complex
fe9643f
Numpy doc style
4cb3d06
Merge remote-tracking branch 'origin' into uz/trefftz
cf7fbfa
WIP
08cd4a4
WIP
00dad68
Fix doc string
f763ba0
Merge branch 'master' of github.com:firedrakeproject/firedrake into u…
4c62f42
Added missing pytest import
9292c1d
Merge branch 'uz/trefftz' of github.com:firedrakeproject/firedrake in…
99df6dc
Fix ksp_python_type name
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
""" | ||
Provides a class to compute the Trefftz embedding of a given function space. | ||
It is also used to compute aggregation embedding of a given function space. | ||
""" | ||
from typing import Optional | ||
from firedrake.petsc import PETSc | ||
from firedrake.cython.dmcommon import FACE_SETS_LABEL, CELL_SETS_LABEL | ||
from firedrake.assemble import assemble | ||
from firedrake.mesh import Mesh | ||
from firedrake.functionspace import FunctionSpace | ||
from firedrake.function import Function | ||
from firedrake.ufl_expr import TestFunction, TrialFunction | ||
from firedrake.constant import Constant | ||
from ufl import dx, dS, inner, jump, grad, dot, CellDiameter, FacetNormal, Form | ||
import scipy.sparse as sp | ||
import numpy as np | ||
|
||
__all__ = ["TrefftzEmbedding", "TrefftzKSP", "AggregationEmbedding", "dumb_aggregation"] | ||
|
||
class TrefftzEmbedding: | ||
""" | ||
Computes the Trefftz embedding of a given function space | ||
Parameters | ||
---------- | ||
V : Ambient function space. | ||
b : Bilinear form defining the Trefftz operator. | ||
dim : Dimension of the embedding. | ||
Default is the dimension of the function space. | ||
tol : Tolerance for the singular values cutoff. | ||
Default is 1e-12. | ||
backend : Backend to use for the computation of the SVD. | ||
Default is "scipy". | ||
""" | ||
def __init__(self, V: FunctionSpace, b: Form, dim: Optional[int] = None, | ||
tol: Optional[float]=1e-12): | ||
Check failure on line 35 in firedrake/trefftz.py GitHub Actions / Run linterE252
|
||
self.V = V | ||
self.b = b | ||
self.dim = V.dim() if not dim else dim + 1 | ||
self.tol = tol | ||
self.svdsolver = "scipy" | ||
|
||
def assemble(self) -> tuple[PETSc.Mat, np.array]: | ||
""" | ||
Assemble the embedding, compute the SVD and return the embedding matrix | ||
""" | ||
self.B = assemble(self.b).M.handle | ||
if self.svdsolver == "scipy": | ||
indptr, indices, data = self.B.getValuesCSR() | ||
Bsp = sp.csr_matrix((data, indices, indptr), shape=self.B.getSize()) | ||
_, sig, VT = sp.linalg.svds(Bsp, k=self.dim-1, which="SM") | ||
QT = sp.csr_matrix(VT[0:sum(sig < self.tol), :]) | ||
QTpsc = PETSc.Mat().createAIJ(size=QT.shape, csr=(QT.indptr, QT.indices, QT.data)) | ||
self.dimT = QT.shape[0] | ||
self.sig = sig | ||
else: | ||
raise NotImplementedError("Only scipy backend is supported") | ||
return QTpsc, sig | ||
|
||
|
||
class TrefftzKSP: | ||
""" | ||
Wraps a PETSc KSP object to solve the reduced | ||
system obtained by the Trefftz embedding. | ||
|
||
There will bne no type hinting following petsc4py's style. | ||
""" | ||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
def get_appctx(ksp: PETSc.KSP): | ||
""" | ||
Get the application context from the KSP | ||
Parameters | ||
---------- | ||
ksp : The KSP object | ||
""" | ||
from firedrake.dmhooks import get_appctx | ||
return get_appctx(ksp.getDM()).appctx | ||
|
||
def setUp(self, ksp: PETSc.KSP): | ||
""" | ||
Set up the Trefftz KSP | ||
Parameters | ||
---------- | ||
ksp : The KSP object | ||
""" | ||
appctx = self.get_appctx(ksp) | ||
self.QT, _ = appctx["trefftz_embedding"].assemble() | ||
|
||
def solve(self, ksp: PETSc.KSP, b: PETSc.Vec, x:PETSc.Vec): | ||
""" | ||
Solve the Trefftz KSP | ||
Parameters | ||
---------- | ||
ksp : The KSP object | ||
b : The right-hand side | ||
x : The solution | ||
""" | ||
A, P = ksp.getOperators() | ||
self.Q = PETSc.Mat().createTranspose(self.QT) | ||
ATF = self.QT @ A @ self.Q | ||
PTF = self.QT @ P @ self.Q | ||
bTF = self.QT.createVecLeft() | ||
self.QT.mult(b, bTF) | ||
|
||
tiny_ksp = PETSc.KSP().create() | ||
tiny_ksp.setOperators(ATF, PTF) | ||
tiny_ksp.setOptionsPrefix("trefftz_") | ||
tiny_ksp.setFromOptions() | ||
xTF = ATF.createVecRight() | ||
tiny_ksp.solve(bTF, xTF) | ||
self.QT.multTranspose(xTF, x) | ||
ksp.setConvergedReason(tiny_ksp.getConvergedReason()) | ||
|
||
|
||
class AggregationEmbedding(TrefftzEmbedding): | ||
""" | ||
This class computes the aggregation embedding of a given function space. | ||
Parameters | ||
---------- | ||
V : Ambient function space. | ||
mesh : The mesh on which the aggregation is defined. | ||
polyMesh : The function defining the aggregation. | ||
dim : Dimension of the embedding. | ||
Default is the dimension of the function space. | ||
tol : Tolerance for the singular values cutoff. | ||
Default is 1e-12. | ||
""" | ||
def __init__(self, V: FunctionSpace, mesh: Mesh, polyMesh: Function, | ||
dim: Optional[int] = None, tol: Optional[float]=1e-12): | ||
Check failure on line 131 in firedrake/trefftz.py GitHub Actions / Run linterE252
|
||
# Relabel facets that are inside an aggregated region | ||
offset = 1 + mesh.topology_dm.getLabelSize(FACE_SETS_LABEL) | ||
offset += mesh.topology_dm.getLabelSize(CELL_SETS_LABEL) | ||
nPoly = int(max(polyMesh.dat.data[:])) # Number of aggregates | ||
getIdx = mesh._cell_numbering.getOffset | ||
plex = mesh.topology_dm | ||
pStart, pEnd = plex.getDepthStratum(2) | ||
self.facet_index = [] | ||
for poly in range(nPoly+1): | ||
facets = [] | ||
for i in range(pStart, pEnd): | ||
if polyMesh.dat.data[getIdx(i)] == poly: | ||
for f in plex.getCone(i): | ||
if f in facets: | ||
plex.setLabelValue(FACE_SETS_LABEL, f, offset+poly) | ||
if offset+poly not in self.facet_index: | ||
self.facet_index = self.facet_index + [offset+poly] | ||
facets = facets + list(plex.getCone(i)) | ||
self.mesh = Mesh(plex) | ||
h = CellDiameter(self.mesh) | ||
n = FacetNormal(self.mesh) | ||
W = FunctionSpace(self.mesh, V.ufl_element()) | ||
u = TrialFunction(W) | ||
v = TestFunction(W) | ||
self.b = Constant(0)*inner(u, v)*dx | ||
for i in self.facet_index: | ||
self.b += inner(jump(u), jump(v))*dS(i) | ||
for k in range(1, V.ufl_element().degree()+1): | ||
for i in self.facet_index: | ||
self.b += ((0.5 * h("+") + 0.5 * h("-"))**(2*k)) *\ | ||
inner(jump_normal(u, n("+"), k), jump_normal(v, n("+"), k))*dS(i) | ||
super().__init__(W, self.b, dim, tol) | ||
|
||
|
||
def jump_normal(u: Function, n: FacetNormal, k: int): | ||
""" | ||
Compute the jump of the normal derivative of a function u | ||
Parameters | ||
---------- | ||
u : The function. | ||
n : The normal vector. | ||
k : The order of the normal derivative we aim to compute. | ||
""" | ||
j = 0.5*dot(n, (grad(u)("+")-grad(u)("-"))) | ||
for _ in range(1, k): | ||
j = 0.5*dot(n, (grad(j)-grad(j))) | ||
return j | ||
|
||
|
||
def dumb_aggregation(mesh: Mesh) -> Function: | ||
""" | ||
Compute a dumb aggregation of the mesh | ||
Parameters | ||
---------- | ||
mesh : The mesh we aim to aggregate. | ||
""" | ||
if mesh.comm.size > 1: | ||
raise NotImplementedError("Parallel mesh aggregation not supported") | ||
plex = mesh.topology_dm | ||
pStart, pEnd = plex.getDepthStratum(2) | ||
_, eEnd = plex.getDepthStratum(1) | ||
adjacency = [] | ||
for i in range(pStart, pEnd): | ||
ad = plex.getAdjacency(i) | ||
local = [] | ||
for a in ad: | ||
supp = plex.getSupport(a) | ||
supp = supp[supp < eEnd] | ||
for s in supp: | ||
if s < pEnd and s != ad[0]: | ||
local = local + [s] | ||
adjacency = adjacency + [(i, local)] | ||
adjacency = sorted(adjacency, key=lambda x: len(x[1]))[::-1] | ||
u = Function(FunctionSpace(mesh, "DG", 0)) | ||
|
||
getIdx = mesh._cell_numbering.getOffset | ||
av = list(range(pStart, pEnd)) | ||
col = 0 | ||
for a in adjacency: | ||
if a[0] in av: | ||
for k in a[1]: | ||
if k in av: | ||
av.remove(k) | ||
u.dat.data[getIdx(k)] = col | ||
av.remove(a[0]) | ||
u.dat.data[getIdx(a[0])] = col | ||
col = col + 1 | ||
return u |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from firedrake import * | ||
from firedrake.trefftz import TrefftzEmbedding, AggregationEmbedding, dumb_aggregation | ||
import pytest | ||
|
||
|
||
@pytest.mark.skipcomplex | ||
def test_trefftz_laplace(): | ||
order = 6 | ||
mesh = UnitSquareMesh(2, 2) | ||
x, y = SpatialCoordinate(mesh) | ||
h = CellDiameter(mesh) | ||
n = FacetNormal(mesh) | ||
V = FunctionSpace(mesh, "DG", order) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
|
||
def delta(u): | ||
return div(grad(u)) | ||
|
||
a = inner(delta(u), delta(v)) * dx | ||
alpha = 4 | ||
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+")) | ||
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+")) | ||
aDG = inner(grad(u), grad(v)) * dx | ||
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS | ||
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS | ||
aDG += alpha*order**2/h*inner(u, v)*ds | ||
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds | ||
f = Function(V).interpolate(exp(x)*sin(y)) | ||
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds | ||
# Solve the problem | ||
uDG = Function(V) | ||
uDG.rename("uDG") | ||
embd = TrefftzEmbedding(V, a, tol=1e-8) | ||
appctx = {"trefftz_embedding": embd} | ||
uDG = Function(V) | ||
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python", | ||
"ksp_python_type": "firedrake.TrefftzKSP"}, | ||
appctx=appctx) | ||
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-6) | ||
assert (embd.dimT < V.dim()/2) | ||
|
||
|
||
@pytest.mark.skipcomplex | ||
def test_trefftz_aggregation(): | ||
try: | ||
from netgen.occ import WorkPlane, OCCGeometry | ||
except ImportError: | ||
# Netgen is not installed | ||
pytest.skip("Netgen/ngsPETSc not installed", allow_module_level=True) | ||
|
||
Rectangle = WorkPlane().Rectangle(1, 1).Face() | ||
geo = OCCGeometry(Rectangle, dim=2) | ||
ngmesh = geo.GenerateMesh(maxh=0.3) | ||
mesh = Mesh(ngmesh) | ||
|
||
polymesh = dumb_aggregation(mesh) | ||
|
||
order = 3 | ||
x, y = SpatialCoordinate(mesh) | ||
h = CellDiameter(mesh) | ||
n = FacetNormal(mesh) | ||
V = FunctionSpace(mesh, "DG", order) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
|
||
alpha = 1e3 | ||
mean_dudn = 0.5 * dot(grad(u("+"))+grad(u("-")), n("+")) | ||
mean_dvdn = 0.5 * dot(grad(v("+"))+grad(v("-")), n("+")) | ||
aDG = inner(grad(u), grad(v)) * dx | ||
aDG += inner((alpha*order**2/(h("+")+h("-")))*jump(u), jump(v))*dS | ||
aDG += inner(-mean_dudn, jump(v))*dS-inner(mean_dvdn, jump(u))*dS | ||
aDG += alpha*order**2/h*inner(u, v)*ds | ||
aDG += -inner(dot(n, grad(u)), v)*ds - inner(dot(n, grad(v)), u)*ds | ||
f = Function(V).interpolate(exp(x)*sin(y)) | ||
L = alpha*order**2/h*inner(f, v)*ds - inner(dot(n, grad(v)), f)*ds | ||
agg_embd = AggregationEmbedding(V, mesh, polymesh) | ||
appctx = {"trefftz_embedding": agg_embd} | ||
|
||
uDG = Function(V) | ||
solve(aDG == L, uDG, solver_parameters={"ksp_type": "python", | ||
"ksp_python_type": "firedrake.TrefftzKSP"}, | ||
appctx=appctx) | ||
|
||
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-9) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please define
__all__
to avoid pollution of the namespaceThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or equivalently you can replace
from firedrake.trefftz import *
withfrom firedrake.trefftz import OnlyWhatIWant
inside__init__.py
.