Skip to content

Commit

Permalink
Add missing AddMult(Transpose) overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg committed Feb 2, 2024
1 parent 63ad43c commit b0e8f29
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
61 changes: 61 additions & 0 deletions palace/linalg/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,41 @@ void BaseDiagonalOperator<ComplexOperator>::Mult(const ComplexVector &x,
});
}

template <>
void BaseDiagonalOperator<Operator>::AddMult(const Vector &x, Vector &y,
const double a) const
{
const int N = this->height;
const auto *D = d.Read();
const auto *X = x.Read();
auto *Y = y.Write();
mfem::forall(N, [=] MFEM_HOST_DEVICE(int i) { Y[i] += a * D[i] * X[i]; });
}

template <>
void BaseDiagonalOperator<ComplexOperator>::AddMult(const ComplexVector &x,
ComplexVector &y,
const std::complex<double> a) const
{
const int N = this->height;
const double ar = a.real();
const double ai = a.imag();
const auto *DR = d.Real().Read();
const auto *DI = d.Imag().Read();
const auto *XR = x.Real().Read();
const auto *XI = x.Imag().Read();
auto *YR = y.Real().Write();
auto *YI = y.Imag().Write();
mfem::forall(N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto tr = DR[i] * XR[i] - DI[i] * XI[i];
const auto ti = DI[i] * XR[i] + DR[i] * XI[i];
YR[i] += ar * tr - ai * ti;
YI[i] += ai * ti + ar * ti;
});
}

template <>
void DiagonalOperatorHelper<BaseDiagonalOperator<ComplexOperator>,
ComplexOperator>::MultHermitianTranspose(const ComplexVector &x,
Expand All @@ -523,6 +558,32 @@ void DiagonalOperatorHelper<BaseDiagonalOperator<ComplexOperator>,
});
}

template <>
void DiagonalOperatorHelper<BaseDiagonalOperator<ComplexOperator>, ComplexOperator>::
AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
const std::complex<double> a) const
{
const ComplexVector &d =
static_cast<const BaseDiagonalOperator<ComplexOperator> *>(this)->d;
const int N = this->height;
const double ar = a.real();
const double ai = a.imag();
const auto *DR = d.Real().Read();
const auto *DI = d.Imag().Read();
const auto *XR = x.Real().Read();
const auto *XI = x.Imag().Read();
auto *YR = y.Real().Write();
auto *YI = y.Imag().Write();
mfem::forall(N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto tr = DR[i] * XR[i] + DI[i] * XI[i];
const auto ti = -DI[i] * XR[i] + DR[i] * XI[i];
YR[i] += ar * tr - ai * ti;
YI[i] += ai * ti + ar * ti;
});
}

namespace linalg
{

Expand Down
55 changes: 55 additions & 0 deletions palace/linalg/operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ class ProductOperatorHelper<ProductOperator, ComplexOperator> : public ComplexOp
A.MultHermitianTranspose(x, z);
B.MultHermitianTranspose(z, y);
}

void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
const std::complex<double> a = 1.0) const override
{
const ComplexOperator &A = static_cast<const ProductOperator *>(this)->A;
const ComplexOperator &B = static_cast<const ProductOperator *>(this)->B;
ComplexVector &z = static_cast<const ProductOperator *>(this)->z;
A.MultHermitianTranspose(x, z);
B.AddMultHermitianTranspose(z, y, a);
}
};

template <typename OperType>
Expand All @@ -171,6 +181,9 @@ class BaseProductOperator

using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
ComplexVector, Vector>::type;
using ScalarType =
typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
std::complex<double>, double>::type;

private:
const OperType &A, &B;
Expand All @@ -194,6 +207,19 @@ class BaseProductOperator
A.MultTranspose(x, z);
B.MultTranspose(z, y);
}

void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
{
B.Mult(x, z);
A.AddMult(z, y, a);
}

void AddMultTranspose(const VecType &x, VecType &y,
const ScalarType a = 1.0) const override
{
A.MultTranspose(x, z);
B.AddMultTranspose(z, y, a);
}
};

using ProductOperator = BaseProductOperator<Operator>;
Expand All @@ -219,6 +245,9 @@ class DiagonalOperatorHelper<DiagonalOperator, ComplexOperator> : public Complex
DiagonalOperatorHelper(int s) : ComplexOperator(s) {}

void MultHermitianTranspose(const ComplexVector &x, ComplexVector &y) const override;

void AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
const std::complex<double> a = 1.0) const override;
};

template <typename OperType>
Expand All @@ -229,6 +258,9 @@ class BaseDiagonalOperator

using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
ComplexVector, Vector>::type;
using ScalarType =
typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
std::complex<double>, double>::type;

private:
const VecType &d;
Expand All @@ -242,6 +274,14 @@ class BaseDiagonalOperator
void Mult(const VecType &x, VecType &y) const override;

void MultTranspose(const VecType &x, VecType &y) const override { Mult(x, y); }

void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override;

void AddMultTranspose(const VecType &x, VecType &y,
const ScalarType a = 1.0) const override
{
AddMult(x, y, a);
}
};

using DiagonalOperator = BaseDiagonalOperator<Operator>;
Expand All @@ -256,6 +296,9 @@ class BaseMultigridOperator : public OperType
{
using VecType = typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
ComplexVector, Vector>::type;
using ScalarType =
typename std::conditional<std::is_same<OperType, ComplexOperator>::value,
std::complex<double>, double>::type;

private:
std::vector<std::unique_ptr<OperType>> ops, aux_ops;
Expand Down Expand Up @@ -300,10 +343,22 @@ class BaseMultigridOperator : public OperType
}

void Mult(const VecType &x, VecType &y) const override { GetFinestOperator().Mult(x, y); }

void MultTranspose(const VecType &x, VecType &y) const override
{
GetFinestOperator().MultTranspose(x, y);
}

void AddMult(const VecType &x, VecType &y, const ScalarType a = 1.0) const override
{
GetFinestOperator().AddMult(x, y, a);
}

void AddMultTranspose(const VecType &x, VecType &y,
const ScalarType a = 1.0) const override
{
GetFinestOperator().AddMultTranspose(x, y, a);
}
};

using MultigridOperator = BaseMultigridOperator<Operator>;
Expand Down

0 comments on commit b0e8f29

Please sign in to comment.