Skip to content

Commit

Permalink
Numpy doc style
Browse files Browse the repository at this point in the history
Signed-off-by: Umberto Zerbinati <[email protected]>
  • Loading branch information
Umberto Zerbinati committed Sep 18, 2024
1 parent b84833c commit fe9643f
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 14 deletions.
71 changes: 57 additions & 14 deletions firedrake/trefftz.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@
class TrefftzEmbedding(object):
"""
This class computes the Trefftz embedding of a given function space
:arg V: the function space
:arg b: the bilinear form defining the embedding
:arg dim: the dimension of the embedding
:arg tol: the tolerance for the singular values
:arg backend: the backend to use for the computation of the SVD
Parameters
----------
V : :class:`.FunctionSpace`
Ambient function space.
b : :class:`.ufl.form.Form`
Bilinear form defining the Trefftz operator.
dim : int, optional
Dimension of the embedding.
Default is the dimension of the function space.
tol : float, optional
Tolerance for the singular values cutoff.
Default is 1e-12.
backend : str, optional
Backend to use for the computation of the SVD.
Default is "scipy".
"""
def __init__(self, V, b, dim=None, tol=1e-12, backend="scipy"):
self.V = V
Expand Down Expand Up @@ -60,20 +70,36 @@ def __init__(self):
def get_appctx(ksp):
"""
Get the application context from the KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
"""
from firedrake.dmhooks import get_appctx
return get_appctx(ksp.getDM()).appctx

def setUp(self, ksp):
"""
Set up the Trefftz KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
"""
appctx = self.get_appctx(ksp)
self.QT, _ = appctx["trefftz_embedding"].assemble()

def solve(self, ksp, b, x):
"""
Solve the Trefftz KSP
Parameters
----------
ksp : :class:`PETSc.KSP`
The KSP object
b : :class:`PETSc.Vec`
The right-hand side
x : :class:`PETSc.Vec`
The solution
"""
A, P = ksp.getOperators()
self.Q = PETSc.Mat().createTranspose(self.QT)
Expand All @@ -95,11 +121,20 @@ def solve(self, ksp, b, x):
class AggregationEmbedding(TrefftzEmbedding):
"""
This class computes the aggregation embedding of a given function space.
:arg V: the function space
:arg mesh: the mesh
:arg polyMesh: the aggregation mesh
:arg dim: the dimension of the embedding
:arg tol: the tolerance for the singular values
Parameters
----------
V : :class:`.FunctionSpace`
Ambient function space.
mesh : :class:`.Mesh`
The mesh on which the aggregation is defined.
polyMesh : :class:`.Function`
The function defining the aggregation.
dim : int
Dimension of the embedding.
Default is the dimension of the function space.
tol : float
Tolerance for the singular values cutoff.
Default is 1e-12.
"""
def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12):
# Relabel facets that are inside an aggregated region
Expand Down Expand Up @@ -139,9 +174,14 @@ def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12):
def jump_normal(u, n, k):
"""
Compute the jump of the normal derivative of a function u
:arg u: the function
:arg n: the normal vector
:arg k: the degree of the normal derivative
Parameters
----------
u : :class:`.Function`
The function.
n : :class:`.ufc.Normal`
The normal vector.
k : int
The order of the normal derivative we aim to compute.
"""
j = 0.5*dot(n, (grad(u)("+")-grad(u)("-")))
for _ in range(1, k):
Expand All @@ -152,7 +192,10 @@ def jump_normal(u, n, k):
def dumb_aggregation(mesh):
"""
Compute a dumb aggregation of the mesh
:arg mesh: the mesh
Parameters
----------
mesh : :class:`.Mesh`
The mesh we aim to aggregate.
"""
if mesh.comm.size > 1:
raise NotImplementedError("Parallel mesh aggregation not supported")
Expand Down
2 changes: 2 additions & 0 deletions tests/regression/test_trefftz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from firedrake import *

Check failure on line 1 in tests/regression/test_trefftz.py

View workflow job for this annotation

GitHub Actions / Firedrake complex

test_trefftz.tests.regression.test_trefftz

tests.regression.test_trefftz

Check failure on line 1 in tests/regression/test_trefftz.py

View workflow job for this annotation

GitHub Actions / Firedrake real

test_trefftz.tests.regression.test_trefftz

tests.regression.test_trefftz
from firedrake.trefftz import TrefftzEmbedding, AggregationEmbedding, dumb_aggregation


@pytest.mark.skipcomplex
def test_trefftz_laplace():
order = 6
Expand Down Expand Up @@ -38,6 +39,7 @@ def delta(u):
assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-6)
assert (embd.dimT < V.dim()/2)


@pytest.mark.skipcomplex
def test_trefftz_aggregation():
from netgen.occ import WorkPlane, OCCGeometry
Expand Down

0 comments on commit fe9643f

Please sign in to comment.