diff --git a/firedrake/trefftz.py b/firedrake/trefftz.py index dae0e94ebb..aa40759fb8 100644 --- a/firedrake/trefftz.py +++ b/firedrake/trefftz.py @@ -17,11 +17,21 @@ class TrefftzEmbedding(object): """ This class computes the Trefftz embedding of a given function space - :arg V: the function space - :arg b: the bilinear form defining the embedding - :arg dim: the dimension of the embedding - :arg tol: the tolerance for the singular values - :arg backend: the backend to use for the computation of the SVD + Parameters + ---------- + V : :class:`.FunctionSpace` + Ambient function space. + b : :class:`.ufl.form.Form` + Bilinear form defining the Trefftz operator. + dim : int, optional + Dimension of the embedding. + Default is the dimension of the function space. + tol : float, optional + Tolerance for the singular values cutoff. + Default is 1e-12. + backend : str, optional + Backend to use for the computation of the SVD. + Default is "scipy". """ def __init__(self, V, b, dim=None, tol=1e-12, backend="scipy"): self.V = V @@ -60,6 +70,10 @@ def __init__(self): def get_appctx(ksp): """ Get the application context from the KSP + Parameters + ---------- + ksp : :class:`PETSc.KSP` + The KSP object """ from firedrake.dmhooks import get_appctx return get_appctx(ksp.getDM()).appctx @@ -67,6 +81,10 @@ def get_appctx(ksp): def setUp(self, ksp): """ Set up the Trefftz KSP + Parameters + ---------- + ksp : :class:`PETSc.KSP` + The KSP object """ appctx = self.get_appctx(ksp) self.QT, _ = appctx["trefftz_embedding"].assemble() @@ -74,6 +92,14 @@ def setUp(self, ksp): def solve(self, ksp, b, x): """ Solve the Trefftz KSP + Parameters + ---------- + ksp : :class:`PETSc.KSP` + The KSP object + b : :class:`PETSc.Vec` + The right-hand side + x : :class:`PETSc.Vec` + The solution """ A, P = ksp.getOperators() self.Q = PETSc.Mat().createTranspose(self.QT) @@ -95,11 +121,20 @@ def solve(self, ksp, b, x): class AggregationEmbedding(TrefftzEmbedding): """ This class computes the aggregation embedding of a given function space. - :arg V: the function space - :arg mesh: the mesh - :arg polyMesh: the aggregation mesh - :arg dim: the dimension of the embedding - :arg tol: the tolerance for the singular values + Parameters + ---------- + V : :class:`.FunctionSpace` + Ambient function space. + mesh : :class:`.Mesh` + The mesh on which the aggregation is defined. + polyMesh : :class:`.Function` + The function defining the aggregation. + dim : int + Dimension of the embedding. + Default is the dimension of the function space. + tol : float + Tolerance for the singular values cutoff. + Default is 1e-12. """ def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12): # Relabel facets that are inside an aggregated region @@ -139,9 +174,14 @@ def __init__(self, V, mesh, polyMesh, dim=None, tol=1e-12): def jump_normal(u, n, k): """ Compute the jump of the normal derivative of a function u - :arg u: the function - :arg n: the normal vector - :arg k: the degree of the normal derivative + Parameters + ---------- + u : :class:`.Function` + The function. + n : :class:`.ufc.Normal` + The normal vector. + k : int + The order of the normal derivative we aim to compute. """ j = 0.5*dot(n, (grad(u)("+")-grad(u)("-"))) for _ in range(1, k): @@ -152,7 +192,10 @@ def jump_normal(u, n, k): def dumb_aggregation(mesh): """ Compute a dumb aggregation of the mesh - :arg mesh: the mesh + Parameters + ---------- + mesh : :class:`.Mesh` + The mesh we aim to aggregate. """ if mesh.comm.size > 1: raise NotImplementedError("Parallel mesh aggregation not supported") diff --git a/tests/regression/test_trefftz.py b/tests/regression/test_trefftz.py index bbbe6823cf..c23566be08 100644 --- a/tests/regression/test_trefftz.py +++ b/tests/regression/test_trefftz.py @@ -1,6 +1,7 @@ from firedrake import * from firedrake.trefftz import TrefftzEmbedding, AggregationEmbedding, dumb_aggregation + @pytest.mark.skipcomplex def test_trefftz_laplace(): order = 6 @@ -38,6 +39,7 @@ def delta(u): assert (assemble(inner(uDG-f, uDG-f)*dx) < 1e-6) assert (embd.dimT < V.dim()/2) + @pytest.mark.skipcomplex def test_trefftz_aggregation(): from netgen.occ import WorkPlane, OCCGeometry