Skip to content

Commit

Permalink
Minor tweaks for supporting lower-dimensional fields in DirectProducts
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed Mar 30, 2024
1 parent 1cd4424 commit 191879c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion dedalus/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def local_elements(self):
# return reduce()

def load_from_hdf5(self, file, index, task=None):
"""Load grid data from an hdf5 file. Task correpsonds to field name by default."""
"""Load grid data from an hdf5 file. Task corresponds to field name by default."""
if task is None:
task = self.name
dset = file['tasks'][task]
Expand Down
11 changes: 10 additions & 1 deletion dedalus/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,7 +2894,7 @@ def __init__(self, operand, coordsys):
def operate(self, out):
"""Perform operation."""
operand = self.args[0]
if self.input_basis is None:
if hasattr(self.output_basis, "m_maps"):
basis = self.output_basis
else:
basis = self.input_basis
Expand Down Expand Up @@ -3406,6 +3406,13 @@ class CartesianDivergence(Divergence):

cs_type = (coords.CartesianCoordinates, coords.Coordinate)

@classmethod
def _preprocess_args(cls, operand, index=0, out=None):
coordsys = operand.tensorsig[index]
if operand.domain.get_basis(coordsys) is None:
raise SkipDispatchException(output=0)
return [operand], {'index': index, 'out': out}

def __init__(self, operand, index=0, out=None):
coordsys = operand.tensorsig[index]
# Wrap to handle gradient wrt single coordinate
Expand Down Expand Up @@ -3951,6 +3958,8 @@ def _preprocess_args(cls, operand, coordsys=None, out=None):
coordsys = operand.dist.single_coordsys
if coordsys is False:
raise ValueError("coordsys must be specified.")
elif not isinstance(coordsys, coords.DirectProduct) and operand.domain.get_basis(coordsys) is None:
raise SkipDispatchException(output=0)
return [operand, coordsys], {'out': out}

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions dedalus/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,13 @@ def add_equation(self, equation, condition="True"):
# Build basic equation dictionary
# Note: domain determined after NCC reinitialization
expr = LHS - RHS
eqn = {'LHS': LHS,
eqn = {'eqn': expr,
'LHS': LHS,
'RHS': RHS,
'condition': condition,
'tensorsig': expr.tensorsig,
'dtype': expr.dtype,
'valid_modes': LHS.valid_modes.copy()}
'valid_modes': expr.valid_modes.copy()}
self._check_equation_conditions(eqn)
self._build_matrix_expressions(eqn)
self.equations.append(eqn)
Expand Down Expand Up @@ -227,6 +228,7 @@ def __init__(self, *args, **kw):
# Build perturbation variables
self.perturbations = [var.copy() for var in self.variables]
for pert, var in zip(self.perturbations, self.variables):
pert.preset_scales(1)
pert['c'] = 0
if var.name:
pert.name = 'δ'+var.name
Expand Down
10 changes: 5 additions & 5 deletions dedalus/core/subsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def build_matrices(self, names):
eqns = self.problem.equations
vars = self.problem.LHS_variables
eqn_conditions = [self.check_condition(eqn) for eqn in eqns] # HACK
eqn_sizes = [self.field_size(eqn['LHS']) for eqn in eqns]
eqn_sizes = [self.field_size(eqn['eqn']) for eqn in eqns]
var_sizes = [self.field_size(var) for var in vars]
I = sum(eqn_sizes)
J = sum(var_sizes)
Expand Down Expand Up @@ -533,7 +533,7 @@ def build_matrices(self, names):
matrices[name] = sparse.coo_matrix((data, (rows, cols)), shape=(I, J), dtype=dtype).tocsr()

# Valid modes
valid_eqn = [self.valid_modes(eqn['LHS'], eqn['valid_modes']) for eqn in eqns]
valid_eqn = [self.valid_modes(eqn['eqn'], eqn['valid_modes']) for eqn in eqns]
valid_var = [self.valid_modes(var, var.valid_modes) for var in vars]
# Invalidate equations that fail condition test
for n, eqn_cond in enumerate(eqn_conditions):
Expand Down Expand Up @@ -587,7 +587,7 @@ def build_matrices(self, names):
eqn_dofs_by_dim = defaultdict(int)
eqn_pass_cond = [eqn for eqn, cond in zip(eqns, eqn_conditions) if cond]
for eqn in eqn_pass_cond:
eqn_dofs_by_dim[eqn['domain'].dim] += self.field_size(eqn['LHS'])
eqn_dofs_by_dim[eqn['domain'].dim] += self.field_size(eqn['eqn'])
self.update_rank = sum(eqn_dofs_by_dim.values()) - eqn_dofs_by_dim[max(eqn_dofs_by_dim.keys())]

# Store RHS conversion matrix
Expand Down Expand Up @@ -624,8 +624,8 @@ def left_permutation(subproblem, equations, bc_top, interleave_components):
L0 = []
for eqn in equations:
L1 = []
vfshape = subproblem.field_shape(eqn['LHS'])
rank = len(eqn['LHS'].tensorsig)
vfshape = subproblem.field_shape(eqn['eqn'])
rank = len(eqn['tensorsig'])
vfshape = (prod(vfshape[:rank]),) + vfshape[rank:]
if vfshape[0] == 0:
L1.append([])
Expand Down

0 comments on commit 191879c

Please sign in to comment.