Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper of options #471

Merged
merged 26 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
573ad52
started MixedOptions and an integration test for this
ta440 Dec 20, 2023
e04c73a
MixedOptions dictionary added
ta440 Dec 21, 2023
4a13df3
improvements to MixedOptions
ta440 Dec 22, 2023
6165057
more mixed options changes
ta440 Jan 9, 2024
775d3b2
more changes to mixed options
ta440 Jan 10, 2024
793768e
working on new test functions with mixed options
ta440 Jan 10, 2024
5f05367
more changes to test functions for multiple wrappers
ta440 Jan 11, 2024
5c13db1
more changes to mixed options with replacing test functions
ta440 Jan 16, 2024
806a49e
mixed options works for embedded dg and recovery
ta440 Jan 17, 2024
2be1354
moved DG1-DG1 equispaced test from test_limiters to test_mixed_fs_opt…
ta440 Jan 22, 2024
6d2fc9b
finalising mixed options test script
ta440 Jan 22, 2024
25ed243
lint
ta440 Jan 22, 2024
f410d2c
lint
ta440 Jan 22, 2024
d166a22
lint
ta440 Jan 22, 2024
aa3d7c0
lint
ta440 Jan 22, 2024
f1e958a
lint
ta440 Jan 22, 2024
51f6f7b
remove dubugging statements
ta440 Jan 23, 2024
55e0856
separate MixedFSOptions and MixedFSWrapper to align with exising code
ta440 Feb 20, 2024
164041c
lint
ta440 Feb 20, 2024
edcc929
lint
ta440 Feb 20, 2024
b317f39
lint
ta440 Feb 20, 2024
0700381
set up original spaces for MixedFSWrapper within the subwrappers
ta440 Feb 28, 2024
2f517e1
make original_space a base wrapper property
ta440 Feb 28, 2024
5c2980d
make original_space a required argument for wrapper.setup()
ta440 Feb 29, 2024
222c396
lint
ta440 Feb 29, 2024
66ce750
Merge branch 'main' into wrapper_of_options
tommbendall Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 67 additions & 24 deletions gusto/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from firedrake import (
Function, TestFunction, NonlinearVariationalProblem,
Function, TestFunction, TestFunctions, NonlinearVariationalProblem,
NonlinearVariationalSolver, DirichletBC, split, Constant
)
from firedrake.fml import (
Expand Down Expand Up @@ -87,16 +87,32 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
self.courant_max = None

if options is not None:
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
self.wrapper = EmbeddedDGWrapper(self, options)
elif self.wrapper_name == "recovered":
self.wrapper = RecoveryWrapper(self, options)
elif self.wrapper_name == "supg":
self.wrapper = SUPGWrapper(self, options)
if type(options) == MixedOptions:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if type(options) == MixedOptions:
if type(options) is MixedOptions:

I think when type checking, it's slightly better to use is rather than ==

self.wrapper = options
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that what you've done isn't a sensible way of implementing this, but it's occurred to me that it isn't quite consistent with how the existing wrappers work.

At the moment we have two broad types of objects: Options and Wrappers, and I think what we've done here is mix up the two a bit.

So now I'm thinking the best way of mirroring this structure for mixed function spaces would be to have:

  1. MixedFSOptions which is defined in configuration.py and is essentially a glorified dictionary
  2. MixedFSWrapper which is defined in wrappers.py and has the same methods as other Wrapper objects


for field, suboption in self.wrapper.suboptions.items():
if suboption.name == 'embedded_dg':
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)})
elif suboption.name == "recovered":
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)})
elif suboption.name == "supg":
raise RuntimeError(
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions')
else:
raise RuntimeError(
f'Time discretisation: suboption wrapper {wrapper_name} not implemented')

else:
raise RuntimeError(
f'Time discretisation: wrapper {self.wrapper_name} not implemented')
self.wrapper_name = options.name
if self.wrapper_name == "embedded_dg":
self.wrapper = EmbeddedDGWrapper(self, options)
elif self.wrapper_name == "recovered":
self.wrapper = RecoveryWrapper(self, options)
elif self.wrapper_name == "supg":
self.wrapper = SUPGWrapper(self, options)
else:
raise RuntimeError(
f'Time discretisation: wrapper {self.wrapper_name} not implemented')
else:
self.wrapper = None
self.wrapper_name = None
Expand Down Expand Up @@ -159,21 +175,48 @@ def setup(self, equation, apply_bcs=True, *active_labels):
# -------------------------------------------------------------------- #

if self.wrapper is not None:
self.wrapper.setup()
self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))
if type(self.wrapper) == MixedOptions:

for field, subwrapper in self.wrapper.subwrappers.items():

subwrapper.mixed_options = True

field_idx = equation.field_names.index(field)

# Store the original space of the tracer
subwrapper.tracer_fs = self.equation.spaces[field_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subwrapper.tracer_fs = self.equation.spaces[field_idx]


self.residual = self.wrapper.label_terms(self.residual)
subwrapper.setup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
subwrapper.setup()
subwrapper.setup(self.equation.spaces[field_idx])


# Update the function space to that needed by the wrapper
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space

self.wrapper.setup()
self.fs = self.wrapper.function_space
new_test_mixed = TestFunctions(self.fs)

# Replace the original test function with one from the new
# function space defined by the subwrappers
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test_mixed))

else:
self.wrapper.setup()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.wrapper.setup()
self.wrapper.setup(self.fs)

I think this should be self.fs but I'm not sure...? Maybe it is just None

Copy link
Collaborator Author

@ta440 ta440 Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would want this to be None, as we want self.original_space=None when not using the mixed wrapper.

self.fs = self.wrapper.function_space
if self.solver_parameters is None:
self.solver_parameters = self.wrapper.solver_parameters
new_test = TestFunction(self.wrapper.test_space)
# SUPG has a special wrapper
if self.wrapper_name == "supg":
new_test = self.wrapper.test

# Replace the original test function with the one from the wrapper
self.residual = self.residual.label_map(
all_terms,
map_if_true=replace_test_function(new_test))

self.residual = self.wrapper.label_terms(self.residual)

# -------------------------------------------------------------------- #
# Make boundary conditions
Expand Down
109 changes: 101 additions & 8 deletions gusto/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
from abc import ABCMeta, abstractmethod
from firedrake import (
FunctionSpace, Function, BrokenElement, Projector, Interpolator,
VectorElement, Constant, as_ufl, dot, grad, TestFunction
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace
)
from firedrake.fml import Term
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions
from gusto.recovery import Recoverer, ReversibleRecoverer
from gusto.labels import transporting_velocity
import ufl

__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"]
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedOptions"]


class Wrapper(object, metaclass=ABCMeta):
Expand All @@ -33,6 +33,9 @@ def __init__(self, time_discretisation, wrapper_options):
self.time_discretisation = time_discretisation
self.options = wrapper_options
self.solver_parameters = None
self.idx = None
self.mixed_options = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit uncertain about having a general Wrapper object know about mixed_options...

It feels like the inheritance of properties could be the wrong way around, as a MixedFSWrapper has several subwrappers, but the subwrappers then need to know that they are subwrapper. Ideally it would be better if they don't know that they are subwrappers!

Looking below at why you've had to do this, it's to set the function space. I'm just wondering if we can find a cleaner way to set original_space? Maybe idx could be an argument to the __init__ or setup method of the base Wrapper class?

self.tracer_fs = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.tracer_fs = None
self.original_space = None


@abstractmethod
def setup(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add original_space as an argument here

Expand Down Expand Up @@ -82,10 +85,14 @@ def setup(self):
assert isinstance(self.options, EmbeddedDGOptions), \
'Embedded DG wrapper can only be used with Embedded DG Options'

original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

if self.mixed_options:
original_space = self.tracer_fs
else:
original_space = self.time_discretisation.fs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this snippet changes to be:
self.original_space = original_space

# -------------------------------------------------------------------- #
# Set up spaces to be used with wrapper
# -------------------------------------------------------------------- #
Expand All @@ -104,7 +111,10 @@ def setup(self):

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)
if self.time_discretisation.idx is None:

if self.mixed_options:
self.x_projected = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_projected = Function(self.tracer_fs)
self.x_projected = Function(self.original_space)

?

elif self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])
Expand Down Expand Up @@ -161,13 +171,19 @@ class RecoveryWrapper(Wrapper):
def setup(self):
"""Sets up function spaces and fields needed for this wrapper."""

print(self.options)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just needs removing!


assert isinstance(self.options, RecoveryOptions), \
'Embedded DG wrapper can only be used with Recovery Options'
'Recovery wrapper can only be used with Recovery Options'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for correcting this!


original_space = self.time_discretisation.fs
domain = self.time_discretisation.domain
equation = self.time_discretisation.equation

if self.mixed_options:
original_space = self.tracer_fs
else:
original_space = self.time_discretisation.fs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this snippet changes to

self.original_space = original_space

# -------------------------------------------------------------------- #
# Set up spaces to be used with wrapper
# -------------------------------------------------------------------- #
Expand All @@ -184,10 +200,17 @@ def setup(self):
# Internal variables to be used
# -------------------------------------------------------------------- #

self.x_in_tmp = Function(self.time_discretisation.fs)
if self.mixed_options:
self.x_in_tmp = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_in_tmp = Function(self.tracer_fs)
self.x_in_tmp = Function(self.original_space)

else:
self.x_in_tmp = Function(self.time_discretisation.fs)

self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)
if self.time_discretisation.idx is None:

if self.mixed_options:
self.x_projected = Function(self.tracer_fs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.x_projected = Function(self.tracer_fs)
self.x_projected = Function(self.original_space)

elif self.time_discretisation.idx is None:
self.x_projected = Function(equation.function_space)
else:
self.x_projected = Function(equation.spaces[self.time_discretisation.idx])
Expand Down Expand Up @@ -361,3 +384,73 @@ def label_terms(self, residual):
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity)

return new_residual


class MixedOptions(object):
"""
An object to hold a dictionary with different options for different
tracers. This means that different tracers can be solved simultaneously
using a CoupledTransportEquation, whilst being in different spaces
and needing different implementation options.
"""

def __init__(self, equation, suboptions):
"""
Args:
equation (:class: `PrognosticEquationSet`): the prognostic equation(s)
suboptions (dict): A dictionary holding options defined for individual prognostic variables
Raises:
ValueError: If an option is defined for a field that is not in the prognostic variable set
"""
self.wrapper_spaces = equation.spaces
self.field_names = equation.field_names
self.suboptions = suboptions
self.subwrappers = {}

for field, suboption in suboptions.items():
# Check that the field is in the prognostic variable set:
if field not in equation.field_names:
raise ValueError(f"The limiter defined for {field} is for a field that does not exist in the equation set")

def setup(self):
""" Compute the new mixed function space from the subwrappers """

self.function_space = MixedFunctionSpace(self.wrapper_spaces)
self.x_in = Function(self.function_space)
self.x_out = Function(self.function_space)

def pre_apply(self, x_in):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

Perform the pre-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = x_in.subfunctions[field_idx]
x_in_sub = self.x_in.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.pre_apply(field)
x_in_sub.assign(subwrapper.x_in)
else:
x_in_sub.assign(field)

def post_apply(self, x_out):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks exactly right to me

"""
Perform the post-applications for all fields
with an associated subwrapper.
"""

for field_name in self.field_names:
field_idx = self.field_names.index(field_name)
field = self.x_out.subfunctions[field_idx]
x_out_sub = x_out.subfunctions[field_idx]

if field_name in self.subwrappers:
subwrapper = self.subwrappers[field_name]
subwrapper.x_out.assign(field)
subwrapper.post_apply(x_out_sub)
else:
x_out_sub.assign(field)
Loading
Loading