Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper of options #471

Merged
merged 26 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
573ad52
started MixedOptions and an integration test for this
ta440 Dec 20, 2023
e04c73a
MixedOptions dictionary added
ta440 Dec 21, 2023
4a13df3
improvements to MixedOptions
ta440 Dec 22, 2023
6165057
more mixed options changes
ta440 Jan 9, 2024
775d3b2
more changes to mixed options
ta440 Jan 10, 2024
793768e
working on new test functions with mixed options
ta440 Jan 10, 2024
5f05367
more changes to test functions for multiple wrappers
ta440 Jan 11, 2024
5c13db1
more changes to mixed options with replacing test functions
ta440 Jan 16, 2024
806a49e
mixed options works for embedded dg and recovery
ta440 Jan 17, 2024
2be1354
moved DG1-DG1 equispaced test from test_limiters to test_mixed_fs_opt…
ta440 Jan 22, 2024
6d2fc9b
finalising mixed options test script
ta440 Jan 22, 2024
25ed243
lint
ta440 Jan 22, 2024
f410d2c
lint
ta440 Jan 22, 2024
d166a22
lint
ta440 Jan 22, 2024
aa3d7c0
lint
ta440 Jan 22, 2024
f1e958a
lint
ta440 Jan 22, 2024
51f6f7b
remove dubugging statements
ta440 Jan 23, 2024
55e0856
separate MixedFSOptions and MixedFSWrapper to align with exising code
ta440 Feb 20, 2024
164041c
lint
ta440 Feb 20, 2024
edcc929
lint
ta440 Feb 20, 2024
b317f39
lint
ta440 Feb 20, 2024
0700381
set up original spaces for MixedFSWrapper within the subwrappers
ta440 Feb 28, 2024
2f517e1
make original_space a base wrapper property
ta440 Feb 28, 2024
5c2980d
make original_space a required argument for wrapper.setup()
ta440 Feb 29, 2024
222c396
lint
ta440 Feb 29, 2024
66ce750
Merge branch 'main' into wrapper_of_options
tommbendall Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion gusto/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__all__ = [
"IntegrateByParts", "TransportEquationType", "OutputParameters",
"CompressibleParameters", "ShallowWaterParameters",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions",
"EmbeddedDGOptions", "RecoveryOptions", "SUPGOptions", "MixedFSOptions",
"SpongeLayerParameters", "DiffusionParameters", "BoundaryLayerParameters"
]

Expand Down Expand Up @@ -172,6 +172,15 @@ class SUPGOptions(WrapperOptions):
ibp = IntegrateByParts.TWICE


class MixedFSOptions(WrapperOptions):
"""Specifies options for a mixed finite element formulation
where different suboptions are applied to different
prognostic variables."""

name = "mixed_options"
suboptions = {}


class SpongeLayerParameters(Configuration):
"""Specifies parameters describing a 'sponge' (damping) layer."""

Expand Down
76 changes: 60 additions & 16 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from firedrake import (
Function, TestFunction, NonlinearVariationalProblem,
Function, TestFunction, TestFunctions, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC, split, Constant
)
from firedrake.fml import (
Expand Down Expand Up @@ -88,7 +88,21 @@ def __init__(self, domain, field_name=None, solver_parameters=None,

if options is not None:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
if self.wrapper_name == "mixed_options":
self.wrapper = MixedFSWrapper()

for field, suboption in options.suboptions.items():
if suboption.name == 'embedded_dg':
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)})
elif suboption.name == "recovered":
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)})
elif suboption.name == "supg":
raise RuntimeError(
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions')
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')
elif self.wrapper_name == "embedded_dg":
self.wrapper = EmbeddedDGWrapper(self, options)
elif self.wrapper_name == "recovered":
self.wrapper = RecoveryWrapper(self, options)
Expand Down Expand Up @@ -159,21 +173,51 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# -------------------------------------------------------------------- #

if self.wrapper is not None:
self.wrapper.setup()
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))
if self.wrapper_name == "mixed_options":

self.wrapper.wrapper_spaces = equation.spaces
self.wrapper.field_names = equation.field_names

for field, subwrapper in self.wrapper.subwrappers.items():

if field not in equation.field_names:
raise ValueError(f"The option defined for {field} is for a field that does not exist in the equation set")

field_idx = equation.field_names.index(field)
subwrapper.setup(equation.spaces[field_idx])

self.residual = self.wrapper.label_terms(self.residual)
# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space

self.wrapper.setup()
self.fs = self.wrapper.function_space
new_test_mixed = TestFunctions(self.fs)

# Replace the original test function with one from the new
# function space defined by the subwrappers
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test_mixed))

else:
if self.wrapper_name == "supg":
self.wrapper.setup()
else:
self.wrapper.setup(self.fs)
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))

self.residual = self.wrapper.label_terms(self.residual)

# -------------------------------------------------------------------- #
# Make boundary conditions
Expand Down
103 changes: 87 additions & 16 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from abc import ABCMeta, abstractmethod
from firedrake import (
FunctionSpace, Function, BrokenElement, Projector, Interpolator,
VectorElement, Constant, as_ufl, dot, grad, TestFunction
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace
)
from firedrake.fml import Term
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions
from gusto.recovery import Recoverer, ReversibleRecoverer
from gusto.labels import transporting_velocity
import ufl

__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"]
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedFSWrapper"]


class Wrapper(object, metaclass=ABCMeta):
Expand All @@ -33,14 +33,23 @@ def __init__(self, time_discretisation, wrapper_options):
self.time_discretisation = time_discretisation
self.options = wrapper_options
self.solver_parameters = None
self.original_space = None

@abstractmethod
def setup(self):
def setup(self, original_space):
"""
Performs standard set up routines, and is to be called by the setup
method of the underlying time discretisation.
Store the original function space of the prognostic variable.

Within each child wrapper, setup performs standard set up routines,
and is to be called by the setup method of the underlying
time discretisation.

Args:
original_space (:class:`FunctionSpace`): the space that the
prognostic variable is defined on. This is a subset space of
a mixed function space when using a MixedFSWrapper.
"""
pass
self.original_space = original_space

@abstractmethod
def pre_apply(self):
Expand Down Expand Up @@ -76,13 +85,14 @@ class EmbeddedDGWrapper(Wrapper):
the original space.
"""

def setup(self):
def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""

assert isinstance(self.options, EmbeddedDGOptions), \
'Embedded DG wrapper can only be used with Embedded DG Options'

original_space = self.time_discretisation.fs
super().setup(original_space)

domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

Expand All @@ -91,7 +101,7 @@ def setup(self):
# -------------------------------------------------------------------- #

if self.options.embedding_space is None:
V_elt = BrokenElement(original_space.ufl_element())
V_elt = BrokenElement(self.original_space.ufl_element())
self.function_space = FunctionSpace(domain.mesh, V_elt)
else:
self.function_space = self.options.embedding_space
Expand All @@ -104,8 +114,9 @@ def setup(self):

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

if self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
self.x_projected = Function(self.original_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

Expand Down Expand Up @@ -158,13 +169,14 @@ class RecoveryWrapper(Wrapper):
field is then returned to the original space.
"""

def setup(self):
def setup(self, original_space):
"""Sets up function spaces and fields needed for this wrapper."""

assert isinstance(self.options, RecoveryOptions), \
'Embedded DG wrapper can only be used with Recovery Options'
'Recovery wrapper can only be used with Recovery Options'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for correcting this!


super().setup(original_space)

original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

Expand All @@ -173,7 +185,7 @@ def setup(self):
# -------------------------------------------------------------------- #

if self.options.embedding_space is None:
V_elt = BrokenElement(original_space.ufl_element())
V_elt = BrokenElement(self.original_space.ufl_element())
self.function_space = FunctionSpace(domain.mesh, V_elt)
else:
self.function_space = self.options.embedding_space
Expand All @@ -184,11 +196,12 @@ def setup(self):
# Internal variables to be used
# -------------------------------------------------------------------- #

self.x_in_tmp = Function(self.time_discretisation.fs)
self.x_in_tmp = Function(self.original_space)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

if self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
self.x_projected = Function(self.original_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])

Expand Down Expand Up @@ -361,3 +374,61 @@ def label_terms(self, residual):
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity)

return new_residual


class MixedFSWrapper(object):
"""
An object to hold a subwrapper dictionary with different wrappers for
different tracers. This means that different tracers can be solved
simultaneously using a CoupledTransportEquation, whilst being in
different spaces and needing different implementation options.
"""

def __init__(self):

self.wrapper_spaces = None
self.field_names = None
self.subwrappers = {}

def setup(self):
""" Compute the new mixed function space from the subwrappers """

self.function_space = MixedFunctionSpace(self.wrapper_spaces)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

def pre_apply(self, x_in):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

Perform the pre-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = x_in.subfunctions[field_idx]
x_in_sub = self.x_in.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.pre_apply(field)
x_in_sub.assign(subwrapper.x_in)
else:
x_in_sub.assign(field)

def post_apply(self, x_out):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

"""
Perform the post-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = self.x_out.subfunctions[field_idx]
x_out_sub = x_out.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.x_out.assign(field)
subwrapper.post_apply(x_out_sub)
else:
x_out_sub.assign(field)
Loading
Loading