From b0e8f29adf2265ea8a9c32ded86546084e887524 Mon Sep 17 00:00:00 2001 From: Sebastian Grimberg Date: Tue, 30 Jan 2024 17:51:51 -0800 Subject: [PATCH] Add missing AddMult(Transpose) overrides --- palace/linalg/operator.cpp | 61 ++++++++++++++++++++++++++++++++++++++ palace/linalg/operator.hpp | 55 ++++++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) diff --git a/palace/linalg/operator.cpp b/palace/linalg/operator.cpp index 461f4b2f6..301bdf720 100644 --- a/palace/linalg/operator.cpp +++ b/palace/linalg/operator.cpp @@ -501,6 +501,41 @@ void BaseDiagonalOperator::Mult(const ComplexVector &x, }); } +template <> +void BaseDiagonalOperator::AddMult(const Vector &x, Vector &y, + const double a) const +{ + const int N = this->height; + const auto *D = d.Read(); + const auto *X = x.Read(); + auto *Y = y.Write(); + mfem::forall(N, [=] MFEM_HOST_DEVICE(int i) { Y[i] += a * D[i] * X[i]; }); +} + +template <> +void BaseDiagonalOperator::AddMult(const ComplexVector &x, + ComplexVector &y, + const std::complex a) const +{ + const int N = this->height; + const double ar = a.real(); + const double ai = a.imag(); + const auto *DR = d.Real().Read(); + const auto *DI = d.Imag().Read(); + const auto *XR = x.Real().Read(); + const auto *XI = x.Imag().Read(); + auto *YR = y.Real().Write(); + auto *YI = y.Imag().Write(); + mfem::forall(N, + [=] MFEM_HOST_DEVICE(int i) + { + const auto tr = DR[i] * XR[i] - DI[i] * XI[i]; + const auto ti = DI[i] * XR[i] + DR[i] * XI[i]; + YR[i] += ar * tr - ai * ti; + YI[i] += ai * ti + ar * ti; + }); +} + template <> void DiagonalOperatorHelper, ComplexOperator>::MultHermitianTranspose(const ComplexVector &x, @@ -523,6 +558,32 @@ void DiagonalOperatorHelper, }); } +template <> +void DiagonalOperatorHelper, ComplexOperator>:: + AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y, + const std::complex a) const +{ + const ComplexVector &d = + static_cast *>(this)->d; + const int N = this->height; + const double ar = a.real(); + const double ai = a.imag(); + const auto *DR = d.Real().Read(); + const auto *DI = d.Imag().Read(); + const auto *XR = x.Real().Read(); + const auto *XI = x.Imag().Read(); + auto *YR = y.Real().Write(); + auto *YI = y.Imag().Write(); + mfem::forall(N, + [=] MFEM_HOST_DEVICE(int i) + { + const auto tr = DR[i] * XR[i] + DI[i] * XI[i]; + const auto ti = -DI[i] * XR[i] + DR[i] * XI[i]; + YR[i] += ar * tr - ai * ti; + YI[i] += ai * ti + ar * ti; + }); +} + namespace linalg { diff --git a/palace/linalg/operator.hpp b/palace/linalg/operator.hpp index ab204c076..89f3890b9 100644 --- a/palace/linalg/operator.hpp +++ b/palace/linalg/operator.hpp @@ -161,6 +161,16 @@ class ProductOperatorHelper : public ComplexOp A.MultHermitianTranspose(x, z); B.MultHermitianTranspose(z, y); } + + void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y, + const std::complex a = 1.0) const override + { + const ComplexOperator &A = static_cast(this)->A; + const ComplexOperator &B = static_cast(this)->B; + ComplexVector &z = static_cast(this)->z; + A.MultHermitianTranspose(x, z); + B.AddMultHermitianTranspose(z, y, a); + } }; template @@ -171,6 +181,9 @@ class BaseProductOperator using VecType = typename std::conditional::value, ComplexVector, Vector>::type; + using ScalarType = + typename std::conditional::value, + std::complex, double>::type; private: const OperType &A, &B; @@ -194,6 +207,19 @@ class BaseProductOperator A.MultTranspose(x, z); B.MultTranspose(z, y); } + + void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override + { + B.Mult(x, z); + A.AddMult(z, y, a); + } + + void AddMultTranspose(const VecType &x, VecType &y, + const ScalarType a = 1.0) const override + { + A.MultTranspose(x, z); + B.AddMultTranspose(z, y, a); + } }; using ProductOperator = BaseProductOperator; @@ -219,6 +245,9 @@ class DiagonalOperatorHelper : public Complex DiagonalOperatorHelper(int s) : ComplexOperator(s) {} void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override; + + void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y, + const std::complex a = 1.0) const override; }; template @@ -229,6 +258,9 @@ class BaseDiagonalOperator using VecType = typename std::conditional::value, ComplexVector, Vector>::type; + using ScalarType = + typename std::conditional::value, + std::complex, double>::type; private: const VecType &d; @@ -242,6 +274,14 @@ class BaseDiagonalOperator void Mult(const VecType &x, VecType &y) const override; void MultTranspose(const VecType &x, VecType &y) const override { Mult(x, y); } + + void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override; + + void AddMultTranspose(const VecType &x, VecType &y, + const ScalarType a = 1.0) const override + { + AddMult(x, y, a); + } }; using DiagonalOperator = BaseDiagonalOperator; @@ -256,6 +296,9 @@ class BaseMultigridOperator : public OperType { using VecType = typename std::conditional::value, ComplexVector, Vector>::type; + using ScalarType = + typename std::conditional::value, + std::complex, double>::type; private: std::vector> ops, aux_ops; @@ -300,10 +343,22 @@ class BaseMultigridOperator : public OperType } void Mult(const VecType &x, VecType &y) const override { GetFinestOperator().Mult(x, y); } + void MultTranspose(const VecType &x, VecType &y) const override { GetFinestOperator().MultTranspose(x, y); } + + void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override + { + GetFinestOperator().AddMult(x, y, a); + } + + void AddMultTranspose(const VecType &x, VecType &y, + const ScalarType a = 1.0) const override + { + GetFinestOperator().AddMultTranspose(x, y, a); + } }; using MultigridOperator = BaseMultigridOperator;