-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrapper of options #471
Wrapper of options #471
Changes from 17 commits
573ad52
e04c73a
4a13df3
6165057
775d3b2
793768e
5f05367
5c13db1
806a49e
2be1354
6d2fc9b
25ed243
f410d2c
d166a22
aa3d7c0
f1e958a
51f6f7b
55e0856
164041c
edcc929
b317f39
0700381
2f517e1
5c2980d
222c396
66ce750
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,7 +10,7 @@ | |||||
import numpy as np | ||||||
|
||||||
from firedrake import ( | ||||||
Function, TestFunction, NonlinearVariationalProblem, | ||||||
Function, TestFunction, TestFunctions, NonlinearVariationalProblem, | ||||||
NonlinearVariationalSolver, DirichletBC, split, Constant | ||||||
) | ||||||
from firedrake.fml import ( | ||||||
|
@@ -87,16 +87,32 @@ def __init__(self, domain, field_name=None, solver_parameters=None, | |||||
self.courant_max = None | ||||||
|
||||||
if options is not None: | ||||||
self.wrapper_name = options.name | ||||||
if self.wrapper_name == "embedded_dg": | ||||||
self.wrapper = EmbeddedDGWrapper(self, options) | ||||||
elif self.wrapper_name == "recovered": | ||||||
self.wrapper = RecoveryWrapper(self, options) | ||||||
elif self.wrapper_name == "supg": | ||||||
self.wrapper = SUPGWrapper(self, options) | ||||||
if type(options) == MixedOptions: | ||||||
self.wrapper = options | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not that what you've done isn't a sensible way of implementing this, but it's occurred to me that it isn't quite consistent with how the existing wrappers work. At the moment we have two broad types of objects: So now I'm thinking the best way of mirroring this structure for mixed function spaces would be to have:
|
||||||
|
||||||
for field, suboption in self.wrapper.suboptions.items(): | ||||||
if suboption.name == 'embedded_dg': | ||||||
self.wrapper.subwrappers.update({field: EmbeddedDGWrapper(self, suboption)}) | ||||||
elif suboption.name == "recovered": | ||||||
self.wrapper.subwrappers.update({field: RecoveryWrapper(self, suboption)}) | ||||||
elif suboption.name == "supg": | ||||||
raise RuntimeError( | ||||||
'Time discretisation: suboption SUPG is currently not implemented within MixedOptions') | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f'Time discretisation: suboption wrapper {wrapper_name} not implemented') | ||||||
|
||||||
else: | ||||||
raise RuntimeError( | ||||||
f'Time discretisation: wrapper {self.wrapper_name} not implemented') | ||||||
self.wrapper_name = options.name | ||||||
if self.wrapper_name == "embedded_dg": | ||||||
self.wrapper = EmbeddedDGWrapper(self, options) | ||||||
elif self.wrapper_name == "recovered": | ||||||
self.wrapper = RecoveryWrapper(self, options) | ||||||
elif self.wrapper_name == "supg": | ||||||
self.wrapper = SUPGWrapper(self, options) | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f'Time discretisation: wrapper {self.wrapper_name} not implemented') | ||||||
else: | ||||||
self.wrapper = None | ||||||
self.wrapper_name = None | ||||||
|
@@ -159,21 +175,48 @@ def setup(self, equation, apply_bcs=True, *active_labels): | |||||
# -------------------------------------------------------------------- # | ||||||
|
||||||
if self.wrapper is not None: | ||||||
self.wrapper.setup() | ||||||
self.fs = self.wrapper.function_space | ||||||
if self.solver_parameters is None: | ||||||
self.solver_parameters = self.wrapper.solver_parameters | ||||||
new_test = TestFunction(self.wrapper.test_space) | ||||||
# SUPG has a special wrapper | ||||||
if self.wrapper_name == "supg": | ||||||
new_test = self.wrapper.test | ||||||
|
||||||
# Replace the original test function with the one from the wrapper | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test)) | ||||||
if type(self.wrapper) == MixedOptions: | ||||||
|
||||||
for field, subwrapper in self.wrapper.subwrappers.items(): | ||||||
|
||||||
subwrapper.mixed_options = True | ||||||
|
||||||
field_idx = equation.field_names.index(field) | ||||||
|
||||||
# Store the original space of the tracer | ||||||
subwrapper.tracer_fs = self.equation.spaces[field_idx] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
self.residual = self.wrapper.label_terms(self.residual) | ||||||
subwrapper.setup() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# Update the function space to that needed by the wrapper | ||||||
self.wrapper.wrapper_spaces[field_idx] = subwrapper.function_space | ||||||
|
||||||
self.wrapper.setup() | ||||||
self.fs = self.wrapper.function_space | ||||||
new_test_mixed = TestFunctions(self.fs) | ||||||
|
||||||
# Replace the original test function with one from the new | ||||||
# function space defined by the subwrappers | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test_mixed)) | ||||||
|
||||||
else: | ||||||
self.wrapper.setup() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think this should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we would want this to be |
||||||
self.fs = self.wrapper.function_space | ||||||
if self.solver_parameters is None: | ||||||
self.solver_parameters = self.wrapper.solver_parameters | ||||||
new_test = TestFunction(self.wrapper.test_space) | ||||||
# SUPG has a special wrapper | ||||||
if self.wrapper_name == "supg": | ||||||
new_test = self.wrapper.test | ||||||
|
||||||
# Replace the original test function with the one from the wrapper | ||||||
self.residual = self.residual.label_map( | ||||||
all_terms, | ||||||
map_if_true=replace_test_function(new_test)) | ||||||
|
||||||
self.residual = self.wrapper.label_terms(self.residual) | ||||||
|
||||||
# -------------------------------------------------------------------- # | ||||||
# Make boundary conditions | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7,15 +7,15 @@ | |||||
from abc import ABCMeta, abstractmethod | ||||||
from firedrake import ( | ||||||
FunctionSpace, Function, BrokenElement, Projector, Interpolator, | ||||||
VectorElement, Constant, as_ufl, dot, grad, TestFunction | ||||||
VectorElement, Constant, as_ufl, dot, grad, TestFunction, MixedFunctionSpace | ||||||
) | ||||||
from firedrake.fml import Term | ||||||
from gusto.configuration import EmbeddedDGOptions, RecoveryOptions, SUPGOptions | ||||||
from gusto.recovery import Recoverer, ReversibleRecoverer | ||||||
from gusto.labels import transporting_velocity | ||||||
import ufl | ||||||
|
||||||
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper"] | ||||||
__all__ = ["EmbeddedDGWrapper", "RecoveryWrapper", "SUPGWrapper", "MixedOptions"] | ||||||
|
||||||
|
||||||
class Wrapper(object, metaclass=ABCMeta): | ||||||
|
@@ -33,6 +33,9 @@ def __init__(self, time_discretisation, wrapper_options): | |||||
self.time_discretisation = time_discretisation | ||||||
self.options = wrapper_options | ||||||
self.solver_parameters = None | ||||||
self.idx = None | ||||||
self.mixed_options = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a little bit uncertain about having a general It feels like the inheritance of properties could be the wrong way around, as a Looking below at why you've had to do this, it's to set the function space. I'm just wondering if we can find a cleaner way to set |
||||||
self.tracer_fs = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
@abstractmethod | ||||||
def setup(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||||||
|
@@ -82,10 +85,14 @@ def setup(self): | |||||
assert isinstance(self.options, EmbeddedDGOptions), \ | ||||||
'Embedded DG wrapper can only be used with Embedded DG Options' | ||||||
|
||||||
original_space = self.time_discretisation.fs | ||||||
domain = self.time_discretisation.domain | ||||||
equation = self.time_discretisation.equation | ||||||
|
||||||
if self.mixed_options: | ||||||
original_space = self.tracer_fs | ||||||
else: | ||||||
original_space = self.time_discretisation.fs | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this snippet changes to be: |
||||||
# -------------------------------------------------------------------- # | ||||||
# Set up spaces to be used with wrapper | ||||||
# -------------------------------------------------------------------- # | ||||||
|
@@ -104,7 +111,10 @@ def setup(self): | |||||
|
||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
if self.time_discretisation.idx is None: | ||||||
|
||||||
if self.mixed_options: | ||||||
self.x_projected = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
? |
||||||
elif self.time_discretisation.idx is None: | ||||||
self.x_projected = Function(equation.function_space) | ||||||
else: | ||||||
self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) | ||||||
|
@@ -161,13 +171,19 @@ class RecoveryWrapper(Wrapper): | |||||
def setup(self): | ||||||
"""Sets up function spaces and fields needed for this wrapper.""" | ||||||
|
||||||
print(self.options) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just needs removing! |
||||||
|
||||||
assert isinstance(self.options, RecoveryOptions), \ | ||||||
'Embedded DG wrapper can only be used with Recovery Options' | ||||||
'Recovery wrapper can only be used with Recovery Options' | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for correcting this! |
||||||
|
||||||
original_space = self.time_discretisation.fs | ||||||
domain = self.time_discretisation.domain | ||||||
equation = self.time_discretisation.equation | ||||||
|
||||||
if self.mixed_options: | ||||||
original_space = self.tracer_fs | ||||||
else: | ||||||
original_space = self.time_discretisation.fs | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this snippet changes to
|
||||||
# -------------------------------------------------------------------- # | ||||||
# Set up spaces to be used with wrapper | ||||||
# -------------------------------------------------------------------- # | ||||||
|
@@ -184,10 +200,17 @@ def setup(self): | |||||
# Internal variables to be used | ||||||
# -------------------------------------------------------------------- # | ||||||
|
||||||
self.x_in_tmp = Function(self.time_discretisation.fs) | ||||||
if self.mixed_options: | ||||||
self.x_in_tmp = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
else: | ||||||
self.x_in_tmp = Function(self.time_discretisation.fs) | ||||||
|
||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
if self.time_discretisation.idx is None: | ||||||
|
||||||
if self.mixed_options: | ||||||
self.x_projected = Function(self.tracer_fs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
elif self.time_discretisation.idx is None: | ||||||
self.x_projected = Function(equation.function_space) | ||||||
else: | ||||||
self.x_projected = Function(equation.spaces[self.time_discretisation.idx]) | ||||||
|
@@ -361,3 +384,73 @@ def label_terms(self, residual): | |||||
new_residual = transporting_velocity.update_value(new_residual, self.transporting_velocity) | ||||||
|
||||||
return new_residual | ||||||
|
||||||
|
||||||
class MixedOptions(object): | ||||||
""" | ||||||
An object to hold a dictionary with different options for different | ||||||
tracers. This means that different tracers can be solved simultaneously | ||||||
using a CoupledTransportEquation, whilst being in different spaces | ||||||
and needing different implementation options. | ||||||
""" | ||||||
|
||||||
def __init__(self, equation, suboptions): | ||||||
""" | ||||||
Args: | ||||||
equation (:class: `PrognosticEquationSet`): the prognostic equation(s) | ||||||
suboptions (dict): A dictionary holding options defined for individual prognostic variables | ||||||
Raises: | ||||||
ValueError: If an option is defined for a field that is not in the prognostic variable set | ||||||
""" | ||||||
self.wrapper_spaces = equation.spaces | ||||||
self.field_names = equation.field_names | ||||||
self.suboptions = suboptions | ||||||
self.subwrappers = {} | ||||||
|
||||||
for field, suboption in suboptions.items(): | ||||||
# Check that the field is in the prognostic variable set: | ||||||
if field not in equation.field_names: | ||||||
raise ValueError(f"The limiter defined for {field} is for a field that does not exist in the equation set") | ||||||
|
||||||
def setup(self): | ||||||
""" Compute the new mixed function space from the subwrappers """ | ||||||
|
||||||
self.function_space = MixedFunctionSpace(self.wrapper_spaces) | ||||||
self.x_in = Function(self.function_space) | ||||||
self.x_out = Function(self.function_space) | ||||||
|
||||||
def pre_apply(self, x_in): | ||||||
""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method looks exactly right to me |
||||||
Perform the pre-applications for all fields | ||||||
with an associated subwrapper. | ||||||
""" | ||||||
|
||||||
for field_name in self.field_names: | ||||||
field_idx = self.field_names.index(field_name) | ||||||
field = x_in.subfunctions[field_idx] | ||||||
x_in_sub = self.x_in.subfunctions[field_idx] | ||||||
|
||||||
if field_name in self.subwrappers: | ||||||
subwrapper = self.subwrappers[field_name] | ||||||
subwrapper.pre_apply(field) | ||||||
x_in_sub.assign(subwrapper.x_in) | ||||||
else: | ||||||
x_in_sub.assign(field) | ||||||
|
||||||
def post_apply(self, x_out): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method looks exactly right to me |
||||||
""" | ||||||
Perform the post-applications for all fields | ||||||
with an associated subwrapper. | ||||||
""" | ||||||
|
||||||
for field_name in self.field_names: | ||||||
field_idx = self.field_names.index(field_name) | ||||||
field = self.x_out.subfunctions[field_idx] | ||||||
x_out_sub = x_out.subfunctions[field_idx] | ||||||
|
||||||
if field_name in self.subwrappers: | ||||||
subwrapper = self.subwrappers[field_name] | ||||||
subwrapper.x_out.assign(field) | ||||||
subwrapper.post_apply(x_out_sub) | ||||||
else: | ||||||
x_out_sub.assign(field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think when type checking, it's slightly better to use
is
rather than==