Skip to content

Commit

Permalink
Workspace allocation for eigenvalue solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg committed Feb 19, 2024
1 parent 0bc99ce commit dc8fbe8
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 164 deletions.
96 changes: 47 additions & 49 deletions palace/linalg/arpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// clang-format on
#include "linalg/divfree.hpp"
#include "utils/communication.hpp"
#include "utils/workspace.hpp"

namespace
{
Expand Down Expand Up @@ -475,13 +476,15 @@ double ArpackEigenvalueSolver::GetError(int i, EigenvalueSolver::ErrorType type)

void ArpackEigenvalueSolver::RescaleEigenvectors(int num_eig)
{
auto x = workspace::NewVector<ComplexVector>(n);
auto r = workspace::NewVector<ComplexVector>(n);
res = std::make_unique<double[]>(num_eig);
xscale = std::make_unique<double[]>(num_eig);
for (int i = 0; i < num_eig; i++)
{
x1.Set(V.get() + i * n, n, false);
xscale.get()[i] = 1.0 / GetEigenvectorNorm(x1, y1);
res.get()[i] = GetResidualNorm(eig.get()[i], x1, y1) / linalg::Norml2(comm, x1);
x.Set(V.get() + i * n, n, false);
xscale.get()[i] = 1.0 / GetEigenvectorNorm(x, r);
res.get()[i] = GetResidualNorm(eig.get()[i], x, r) / linalg::Norml2(comm, x);
}
}

Expand Down Expand Up @@ -512,22 +515,15 @@ void ArpackEPSSolver::SetOperators(const ComplexOperator &K, const ComplexOperat
delta = 2.0 / normK;
}
}

// Set up workspace.
x1.SetSize(opK->Height());
y1.SetSize(opK->Height());
z1.SetSize(opK->Height());
x1.UseDevice(true);
y1.UseDevice(true);
z1.UseDevice(true);
n = opK->Height();
}

int ArpackEPSSolver::Solve()
{
// Set some defaults (default maximum iterations from SLEPc).
CheckParameters();
HYPRE_BigInt N = linalg::GlobalSize(comm, z1);
HYPRE_BigInt N = n;
Mpi::GlobalSum(1, &N, comm);
if (ncv > N)
{
ncv = mfem::internal::to_int(N);
Expand Down Expand Up @@ -580,37 +576,42 @@ void ArpackEPSSolver::ApplyOp(const std::complex<double> *px,
// Case 2: Shift-and-invert spectral transformation (opInv = (K - σ M)⁻¹)
// y = (K - σ M)⁻¹ M x .
// The input pointers are always to host memory (ARPACK runs on host).
x1.Set(px, n, false);
auto x = workspace::NewVector<ComplexVector>(n);
auto y = workspace::NewVector<ComplexVector>(n);
auto z = workspace::NewVector<ComplexVector>(n);
x.Set(px, n, false);
if (!sinvert)
{
opK->Mult(x1, z1);
opInv->Mult(z1, y1);
y1 *= 1.0 / gamma;
opK->Mult(x, z);
opInv->Mult(z, y);
y *= 1.0 / gamma;
}
else
{
opM->Mult(x1, z1);
opInv->Mult(z1, y1);
y1 *= gamma;
opM->Mult(x, z);
opInv->Mult(z, y);
y *= gamma;
}
if (opProj)
{
// Mpi::Print(" Before projection: {:e}\n", linalg::Norml2(comm, y1));
opProj->Mult(y1);
// Mpi::Print(" After projection: {:e}\n", linalg::Norml2(comm, y1));
// Mpi::Print(" Before projection: {:e}\n", linalg::Norml2(comm, y));
opProj->Mult(y);
// Mpi::Print(" After projection: {:e}\n", linalg::Norml2(comm, y));
}
y1.Get(py, n, false);
y.Get(py, n, false);
}

void ArpackEPSSolver::ApplyOpB(const std::complex<double> *px,
std::complex<double> *py) const
{
MFEM_VERIFY(opB, "No B operator for weighted inner product in ARPACK solve!");
x1.Set(px, n, false);
opB->Mult(x1.Real(), y1.Real());
opB->Mult(x1.Imag(), y1.Imag());
y1 *= delta * gamma;
y1.Get(py, n, false);
auto x = workspace::NewVector<ComplexVector>(n);
auto y = workspace::NewVector<ComplexVector>(n);
x.Set(px, n, false);
opB->Mult(x.Real(), y.Real());
opB->Mult(x.Imag(), y.Imag());
y *= delta * gamma;
y.Get(py, n, false);
}

double ArpackEPSSolver::GetResidualNorm(std::complex<double> l, const ComplexVector &x,
Expand Down Expand Up @@ -668,27 +669,16 @@ void ArpackPEPSolver::SetOperators(const ComplexOperator &K, const ComplexOperat
delta = 2.0 / (normK + gamma * normC);
}
}

// Set up workspace.
x1.SetSize(opK->Height());
x2.SetSize(opK->Height());
y1.SetSize(opK->Height());
y2.SetSize(opK->Height());
z1.SetSize(opK->Height());
x1.UseDevice(true);
x2.UseDevice(true);
y1.UseDevice(true);
y2.UseDevice(true);
z1.UseDevice(true);
n = opK->Height();
}

int ArpackPEPSolver::Solve()
{
// Set some defaults (from SLEPc ARPACK interface). The problem size is the size of the
// 2x2 block linearized problem.
// 2$ block linearized problem.
CheckParameters();
HYPRE_BigInt N = linalg::GlobalSize(comm, z1);
HYPRE_BigInt N = n;
Mpi::GlobalSum(1, &N, comm);
if (ncv > 2 * N)
{
ncv = mfem::internal::to_int(2 * N);
Expand Down Expand Up @@ -754,6 +744,11 @@ void ArpackPEPSolver::ApplyOp(const std::complex<double> *px,
// L₀ = [ -K 0 ] L₁ = [ C M ]
// [ 0 M ] , [ M 0 ] .
// The input pointers are always to host memory (ARPACK runs on host).
auto x1 = workspace::NewVector<ComplexVector>(n);
auto x2 = workspace::NewVector<ComplexVector>(n);
auto y1 = workspace::NewVector<ComplexVector>(n);
auto y2 = workspace::NewVector<ComplexVector>(n);
auto z = workspace::NewVector<ComplexVector>(n);
x1.Set(px, n, false);
x2.Set(px + n, n, false);
if (!sinvert)
Expand All @@ -765,10 +760,9 @@ void ArpackPEPSolver::ApplyOp(const std::complex<double> *px,
opProj->Mult(y1);
// Mpi::Print(" Before projection: {:e}\n", linalg::Norml2(comm, y1));
}

opK->Mult(x1, z1);
opC->AddMult(x2, z1, std::complex<double>(gamma, 0.0));
opInv->Mult(z1, y2);
opK->Mult(x1, z);
opC->AddMult(x2, z, std::complex<double>(gamma, 0.0));
opInv->Mult(z, y2);
y2 *= -1.0 / (gamma * gamma);
if (opProj)
{
Expand All @@ -780,9 +774,9 @@ void ArpackPEPSolver::ApplyOp(const std::complex<double> *px,
else
{
y2.AXPBYPCZ(sigma, x1, gamma, x2, 0.0); // Just temporarily
opM->Mult(y2, z1);
opC->AddMult(x1, z1, std::complex<double>(1.0, 0.0));
opInv->Mult(z1, y1);
opM->Mult(y2, z);
opC->AddMult(x1, z, std::complex<double>(1.0, 0.0));
opInv->Mult(z, y1);
y1 *= -gamma;
if (opProj)
{
Expand All @@ -807,6 +801,10 @@ void ArpackPEPSolver::ApplyOpB(const std::complex<double> *px,
std::complex<double> *py) const
{
MFEM_VERIFY(opB, "No B operator for weighted inner product in ARPACK solve!");
auto x1 = workspace::NewVector<ComplexVector>(n);
auto x2 = workspace::NewVector<ComplexVector>(n);
auto y1 = workspace::NewVector<ComplexVector>(n);
auto y2 = workspace::NewVector<ComplexVector>(n);
x1.Set(px, n, false);
x2.Set(px + n, n, false);
opB->Mult(x1.Real(), y1.Real());
Expand Down
6 changes: 0 additions & 6 deletions palace/linalg/arpack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ class ArpackEigenvalueSolver : public EigenvalueSolver
// which case identity is used.
const Operator *opB;

// Workspace vector for operator applications.
mutable ComplexVector x1, y1, z1;

// Perform the ARPACK RCI loop.
int SolveInternal(int n, std::complex<double> *r, std::complex<double> *V,
std::complex<double> *eig, int *perm);
Expand Down Expand Up @@ -218,9 +215,6 @@ class ArpackPEPSolver : public ArpackEigenvalueSolver
// Operator norms for scaling.
mutable double normK, normC, normM;

// Workspace vectors for operator applications.
mutable ComplexVector x2, y2;

protected:
void ApplyOp(const std::complex<double> *px, std::complex<double> *py) const override;
void ApplyOpB(const std::complex<double> *px, std::complex<double> *py) const override;
Expand Down
Loading

0 comments on commit dc8fbe8

Please sign in to comment.