Skip to content

Commit

Permalink
Modified CartesianTrace to work for tensors with rank >2. Added rank …
Browse files Browse the repository at this point in the history
…3 tests to test_cartesian_operators.py which pass with the changes to operators.py, but did not pass before. (#292)
  • Loading branch information
lecoanet authored Jun 5, 2024
1 parent 70bd15e commit 961ed0b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
7 changes: 5 additions & 2 deletions dedalus/core/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1829,8 +1829,11 @@ class CartesianTrace(Trace):
def subproblem_matrix(self, subproblem):
dim = self.coordsys.dim
trace = np.ravel(np.eye(dim))
# Assume all components have the same n_size
eye = sparse.identity(subproblem.coeff_size(self.domain), self.dtype, format='csr')
# Kronecker up identity for remaining tensor components
n_eye = prod(cs.dim for cs in self.tensorsig)
# Kronecker up identity for coeff size
n_eye *= subproblem.coeff_size(self.domain)
eye = sparse.identity(n_eye, self.dtype, format='csr')
matrix = sparse.kron(trace, eye)
return matrix

Expand Down
40 changes: 40 additions & 0 deletions dedalus/tests/test_cartesian_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,23 @@ def test_trace_explicit(basis, N, dealias, dtype, layout):
assert np.allclose(g[layout], np.trace(f[layout]))


@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
@pytest.mark.parametrize('N', N_range)
@pytest.mark.parametrize('dealias', dealias_range)
@pytest.mark.parametrize('dtype', dtype_range)
@pytest.mark.parametrize('layout', ['c', 'g'])
def test_trace_rank3_explicit(basis, N, dealias, dtype, layout):
"""Test explicit evaluation of trace operator for correctness."""
c, d, b, r = basis(N, dealias, dtype)
# Random tensor field
f = d.TensorField((c,c,c), bases=b)
f.fill_random(layout='g')
# Evaluate trace
f.change_layout(layout)
g = d3.trace(f).evaluate()
assert np.allclose(g[layout], np.trace(f[layout]))


@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
@pytest.mark.parametrize('N', N_range)
@pytest.mark.parametrize('dealias', dealias_range)
Expand All @@ -170,6 +187,29 @@ def test_trace_implicit(basis, N, dealias, dtype):
assert np.allclose(u['c'], f['c'])


@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
@pytest.mark.parametrize('N', N_range)
@pytest.mark.parametrize('dealias', dealias_range)
@pytest.mark.parametrize('dtype', dtype_range)
def test_trace_rank3_implicit(basis, N, dealias, dtype):
"""Test implicit evaluation of trace operator for correctness."""
c, d, b, r = basis(N, dealias, dtype)
# Random scalar field
f = d.VectorField(c, bases=b)
f.fill_random(layout='g')
# Trace LBVP
u = d.VectorField(c, bases=b)
I = d.TensorField((c,c))
dim = len(r)
for i in range(dim):
I['g'][i,i] = 1
problem = d3.LBVP([u], namespace=locals())
problem.add_equation("trace(I*u) = dim*f")
solver = problem.build_solver()
solver.solve()
assert np.allclose(u['c'], f['c'])


@pytest.mark.parametrize('basis', [build_FF, build_FC, build_CC, build_FFF, build_FFC])
@pytest.mark.parametrize('N', N_range)
@pytest.mark.parametrize('dealias', dealias_range)
Expand Down

0 comments on commit 961ed0b

Please sign in to comment.