Skip to content

Commit

Permalink
Parallel prolongation and restriction sometimes do not have appropria…
Browse files Browse the repository at this point in the history
…te AddMult/AddMultTranspose overrides, causing temporary vector allocation

So, fix it by handling this at the ParOperator level.
  • Loading branch information
sebastiangrimberg committed Feb 2, 2024
1 parent b0e8f29 commit d633677
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 94 deletions.
118 changes: 26 additions & 92 deletions palace/linalg/rap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,21 @@ void ParOperator::EliminateRHS(const Vector &x, Vector &b) const
}

MFEM_VERIFY(A, "No local matrix available for ParOperator::EliminateRHS!");
auto &tx = trial_fespace.GetTVector<Vector>();
auto &lx = trial_fespace.GetLVector<Vector>();
auto &ly = GetTestLVector();
tx = 0.0;
linalg::SetSubVector(tx, *dbc_tdof_list, x);
trial_fespace.GetProlongationMatrix()->Mult(tx, lx);
{
auto &tx = trial_fespace.GetTVector<Vector>();
tx = 0.0;
linalg::SetSubVector(tx, *dbc_tdof_list, x);
trial_fespace.GetProlongationMatrix()->Mult(tx, lx);
}

// Apply the unconstrained operator.
A->Mult(lx, ly);
ly *= -1.0;

RestrictionMatrixAddMult(ly, b);
auto &ty = test_fespace.GetTVector<Vector>();
RestrictionMatrixMult(ly, ty);
b.Add(-1.0, ty);
if (diag_policy == DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(b, *dbc_tdof_list, x);
Expand Down Expand Up @@ -292,10 +295,10 @@ void ParOperator::AddMult(const Vector &x, Vector &y, const double a) const
// Apply the operator on the L-vector.
A->Mult(lx, ly);

auto &ty = test_fespace.GetTVector<Vector>();
RestrictionMatrixMult(ly, ty);
if (dbc_tdof_list)
{
auto &ty = test_fespace.GetTVector<Vector>();
RestrictionMatrixMult(ly, ty);
if (diag_policy == DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(ty, *dbc_tdof_list, x);
Expand All @@ -304,16 +307,8 @@ void ParOperator::AddMult(const Vector &x, Vector &y, const double a) const
{
linalg::SetSubVector(ty, *dbc_tdof_list, 0.0);
}
y.Add(a, ty);
}
else
{
if (a != 1.0)
{
ly *= a;
}
RestrictionMatrixAddMult(ly, y);
}
y.Add(a, ty);
}

void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) const
Expand Down Expand Up @@ -343,10 +338,10 @@ void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) c
// Apply the operator on the L-vector.
A->MultTranspose(ly, lx);

auto &tx = trial_fespace.GetTVector<Vector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx, tx);
if (dbc_tdof_list)
{
auto &tx = trial_fespace.GetTVector<Vector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx, tx);
if (diag_policy == DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(tx, *dbc_tdof_list, x);
Expand All @@ -355,16 +350,8 @@ void ParOperator::AddMultTranspose(const Vector &x, Vector &y, const double a) c
{
linalg::SetSubVector(tx, *dbc_tdof_list, 0.0);
}
y.Add(a, tx);
}
else
{
if (a != 1.0)
{
lx *= a;
}
trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx, y);
}
y.Add(a, tx);
}

void ParOperator::RestrictionMatrixMult(const Vector &ly, Vector &ty) const
Expand All @@ -379,18 +366,6 @@ void ParOperator::RestrictionMatrixMult(const Vector &ly, Vector &ty) const
}
}

void ParOperator::RestrictionMatrixAddMult(const Vector &ly, Vector &ty) const
{
if (!use_R)
{
test_fespace.GetProlongationMatrix()->AddMultTranspose(ly, ty);
}
else
{
test_fespace.GetRestrictionMatrix()->AddMult(ly, ty);
}
}

void ParOperator::RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const
{
if (!use_R)
Expand Down Expand Up @@ -623,10 +598,10 @@ void ComplexParOperator::AddMult(const ComplexVector &x, ComplexVector &y,
// Apply the operator on the L-vector.
A->Mult(lx, ly);

auto &ty = test_fespace.GetTVector<ComplexVector>();
RestrictionMatrixMult(ly, ty);
if (dbc_tdof_list)
{
auto &ty = test_fespace.GetTVector<ComplexVector>();
RestrictionMatrixMult(ly, ty);
if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(ty, *dbc_tdof_list, x);
Expand All @@ -635,16 +610,8 @@ void ComplexParOperator::AddMult(const ComplexVector &x, ComplexVector &y,
{
linalg::SetSubVector(ty, *dbc_tdof_list, 0.0);
}
y.AXPY(a, ty);
}
else
{
if (a != 1.0)
{
ly *= a;
}
RestrictionMatrixAddMult(ly, y);
}
y.AXPY(a, ty);
}

void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector &y,
Expand All @@ -669,11 +636,11 @@ void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector
// Apply the operator on the L-vector.
A->MultTranspose(ly, lx);

auto &tx = trial_fespace.GetTVector<ComplexVector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
if (dbc_tdof_list)
{
auto &tx = trial_fespace.GetTVector<ComplexVector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(tx, *dbc_tdof_list, x);
Expand All @@ -682,17 +649,8 @@ void ComplexParOperator::AddMultTranspose(const ComplexVector &x, ComplexVector
{
linalg::SetSubVector(tx, *dbc_tdof_list, 0.0);
}
y.AXPY(a, tx);
}
else
{
if (a != 1.0)
{
lx *= a;
}
trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Real(), y.Real());
trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Imag(), y.Imag());
}
y.AXPY(a, tx);
}

void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, ComplexVector &y,
Expand All @@ -717,11 +675,11 @@ void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, Compl
// Apply the operator on the L-vector.
A->MultHermitianTranspose(ly, lx);

auto &tx = trial_fespace.GetTVector<ComplexVector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
if (dbc_tdof_list)
{
auto &tx = trial_fespace.GetTVector<ComplexVector>();
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Real(), tx.Real());
trial_fespace.GetProlongationMatrix()->MultTranspose(lx.Imag(), tx.Imag());
if (diag_policy == Operator::DiagonalPolicy::DIAG_ONE)
{
linalg::SetSubVector(tx, *dbc_tdof_list, x);
Expand All @@ -730,17 +688,8 @@ void ComplexParOperator::AddMultHermitianTranspose(const ComplexVector &x, Compl
{
linalg::SetSubVector(tx, *dbc_tdof_list, 0.0);
}
y.AXPY(a, tx);
}
else
{
if (a != 1.0)
{
lx *= a;
}
trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Real(), y.Real());
trial_fespace.GetProlongationMatrix()->AddMultTranspose(lx.Imag(), y.Imag());
}
y.AXPY(a, tx);
}

void ComplexParOperator::RestrictionMatrixMult(const ComplexVector &ly,
Expand All @@ -758,21 +707,6 @@ void ComplexParOperator::RestrictionMatrixMult(const ComplexVector &ly,
}
}

void ComplexParOperator::RestrictionMatrixAddMult(const ComplexVector &ly,
ComplexVector &ty) const
{
if (!use_R)
{
test_fespace.GetProlongationMatrix()->AddMultTranspose(ly.Real(), ty.Real());
test_fespace.GetProlongationMatrix()->AddMultTranspose(ly.Imag(), ty.Imag());
}
else
{
test_fespace.GetRestrictionMatrix()->AddMult(ly.Real(), ty.Real());
test_fespace.GetRestrictionMatrix()->AddMult(ly.Imag(), ty.Imag());
}
}

void ComplexParOperator::RestrictionMatrixMultTranspose(const ComplexVector &ty,
ComplexVector &ly) const
{
Expand Down
2 changes: 0 additions & 2 deletions palace/linalg/rap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class ParOperator : public Operator

// Helper methods for operator application.
void RestrictionMatrixMult(const Vector &ly, Vector &ty) const;
void RestrictionMatrixAddMult(const Vector &ly, Vector &ty) const;
void RestrictionMatrixMultTranspose(const Vector &ty, Vector &ly) const;
Vector &GetTestLVector() const;

Expand Down Expand Up @@ -130,7 +129,6 @@ class ComplexParOperator : public ComplexOperator

// Helper methods for operator application.
void RestrictionMatrixMult(const ComplexVector &ly, ComplexVector &ty) const;
void RestrictionMatrixAddMult(const ComplexVector &ly, ComplexVector &ty) const;
void RestrictionMatrixMultTranspose(const ComplexVector &ty, ComplexVector &ly) const;
ComplexVector &GetTestLVector() const;

Expand Down

0 comments on commit d633677

Please sign in to comment.