Skip to content

Commit

Permalink
Fix device-host aliasing issue on GPU for wave ports
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastiangrimberg committed Feb 27, 2024
1 parent a6eb633 commit a228960
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 55 deletions.
87 changes: 56 additions & 31 deletions palace/linalg/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,20 @@ std::complex<double> ComplexVector::TransposeDot(const ComplexVector &y) const

void ComplexVector::AXPY(std::complex<double> alpha, const ComplexVector &x)
{
const bool use_dev = UseDevice() || x.UseDevice();
const int N = Size();
AXPY(alpha, x.Real(), x.Imag(), Real(), Imag());
}

void ComplexVector::AXPY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
Vector &yr, Vector &yi)
{
const bool use_dev = yr.UseDevice() || xr.UseDevice();
const int N = yr.Size();
const double ar = alpha.real();
const double ai = alpha.imag();
const auto *XR = x.Real().Read(use_dev);
const auto *XI = x.Imag().Read(use_dev);
auto *YR = Real().ReadWrite(use_dev);
auto *YI = Imag().ReadWrite(use_dev);
const auto *XR = xr.Read(use_dev);
const auto *XI = xi.Read(use_dev);
auto *YR = yr.ReadWrite(use_dev);
auto *YI = yi.ReadWrite(use_dev);
if (ai == 0.0)
{
mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { YR[i] += ar * XR[i]; });
Expand All @@ -222,25 +228,32 @@ void ComplexVector::AXPY(std::complex<double> alpha, const ComplexVector &x)
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto t = ai * XR[i] + ar * XI[i];
YR[i] += ar * XR[i] - ai * XI[i];
YI[i] += ai * XR[i] + ar * XI[i];
YI[i] += t;
});
}
}

void ComplexVector::AXPBY(std::complex<double> alpha, const ComplexVector &x,
std::complex<double> beta)
{
const bool use_dev = UseDevice() || x.UseDevice();
const int N = Size();
AXPBY(alpha, x.Real(), x.Imag(), beta, Real(), Imag());
}

void ComplexVector::AXPBY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
std::complex<double> beta, Vector &yr, Vector &yi)
{
const bool use_dev = yr.UseDevice() || xr.UseDevice();
const int N = yr.Size();
const double ar = alpha.real();
const double ai = alpha.imag();
const auto *XR = x.Real().Read(use_dev);
const auto *XI = x.Imag().Read(use_dev);
const auto *XR = xr.Read(use_dev);
const auto *XI = xi.Read(use_dev);
if (beta == 0.0)
{
auto *YR = Real().Write(use_dev);
auto *YI = Imag().Write(use_dev);
auto *YR = yr.Write(use_dev);
auto *YI = yi.Write(use_dev);
if (ai == 0.0)
{
mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { YR[i] = ar * XR[i]; });
Expand All @@ -251,17 +264,18 @@ void ComplexVector::AXPBY(std::complex<double> alpha, const ComplexVector &x,
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto t = ai * XR[i] + ar * XI[i];
YR[i] = ar * XR[i] - ai * XI[i];
YI[i] = ai * XR[i] + ar * XI[i];
YI[i] = t;
});
}
}
else
{
const double br = beta.real();
const double bi = beta.imag();
auto *YR = Real().ReadWrite(use_dev);
auto *YI = Imag().ReadWrite(use_dev);
auto *YR = yr.ReadWrite(use_dev);
auto *YI = yi.ReadWrite(use_dev);
if (ai == 0.0 && bi == 0.0)
{
mfem::forall_switch(use_dev, N,
Expand All @@ -274,9 +288,10 @@ void ComplexVector::AXPBY(std::complex<double> alpha, const ComplexVector &x,
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto t = bi * YR[i] + br * YI[i];
const auto t =
ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i];
YR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i];
YI[i] = ai * XR[i] + ar * XI[i] + t;
YI[i] = t;
});
}
}
Expand All @@ -286,20 +301,27 @@ void ComplexVector::AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
std::complex<double> beta, const ComplexVector &y,
std::complex<double> gamma)
{
const bool use_dev = UseDevice() || x.UseDevice() || y.UseDevice();
const int N = Size();
AXPBYPCZ(alpha, x.Real(), x.Imag(), beta, y.Real(), y.Imag(), gamma, Real(), Imag());
}

void ComplexVector::AXPBYPCZ(std::complex<double> alpha, const Vector &xr, const Vector &xi,
std::complex<double> beta, const Vector &yr, const Vector &yi,
std::complex<double> gamma, Vector &zr, Vector &zi)
{
const bool use_dev = zr.UseDevice() || xr.UseDevice() || yr.UseDevice();
const int N = zr.Size();
const double ar = alpha.real();
const double ai = alpha.imag();
const double br = beta.real();
const double bi = beta.imag();
const auto *XR = x.Real().Read(use_dev);
const auto *XI = x.Imag().Read(use_dev);
const auto *YR = y.Real().Read(use_dev);
const auto *YI = y.Imag().Read(use_dev);
const auto *XR = xr.Read(use_dev);
const auto *XI = xi.Read(use_dev);
const auto *YR = yr.Read(use_dev);
const auto *YI = yi.Read(use_dev);
if (gamma == 0.0)
{
auto *ZR = Real().Write(use_dev);
auto *ZI = Imag().Write(use_dev);
auto *ZR = zr.Write(use_dev);
auto *ZI = zi.Write(use_dev);
if (ai == 0.0 && bi == 0.0)
{
mfem::forall_switch(use_dev, N,
Expand All @@ -312,17 +334,19 @@ void ComplexVector::AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto t =
ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i];
ZR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i];
ZI[i] = ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i];
ZI[i] = t;
});
}
}
else
{
const double gr = gamma.real();
const double gi = gamma.imag();
auto *ZR = Real().ReadWrite(use_dev);
auto *ZI = Imag().ReadWrite(use_dev);
auto *ZR = zr.ReadWrite(use_dev);
auto *ZI = zi.ReadWrite(use_dev);
if (ai == 0.0 && bi == 0.0 && gi == 0.0)
{
mfem::forall_switch(use_dev, N,
Expand All @@ -337,10 +361,11 @@ void ComplexVector::AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
mfem::forall_switch(use_dev, N,
[=] MFEM_HOST_DEVICE(int i)
{
const auto t = gi * ZR[i] + gr * ZI[i];
const auto t = ai * XR[i] + ar * XI[i] + bi * YR[i] +
br * YI[i] + gi * ZR[i] + gr * ZI[i];
ZR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i] +
gr * ZR[i] - gi * ZI[i];
ZI[i] = ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i] + t;
ZI[i] = t;
});
}
}
Expand Down
10 changes: 10 additions & 0 deletions palace/linalg/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ class ComplexVector
void AXPBYPCZ(std::complex<double> alpha, const ComplexVector &x,
std::complex<double> beta, const ComplexVector &y,
std::complex<double> gamma);

static void AXPY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
Vector &yr, Vector &yi);

static void AXPBY(std::complex<double> alpha, const Vector &xr, const Vector &xi,
std::complex<double> beta, Vector &yr, Vector &yi);

static void AXPBYPCZ(std::complex<double> alpha, const Vector &xr, const Vector &xi,
std::complex<double> beta, const Vector &yr, const Vector &yi,
std::complex<double> gamma, Vector &zr, Vector &zi);
};

namespace linalg
Expand Down
38 changes: 15 additions & 23 deletions palace/models/waveportoperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,26 +294,16 @@ void Normalize(const mfem::ParGridFunction &S0t, mfem::ParComplexGridFunction &E
// functions. We choose a (rather arbitrary) phase constraint to at least make results for
// the same port consistent between frequencies/meshes.

// |E x H⋆| ⋅ n = |E ⋅ (-n x H⋆)|
// |E x H⋆| ⋅ n = |E ⋅ (-n x H⋆)|. This also updates the n x H coefficients depending on
// Et, En. Update linear forms for postprocessing too.
std::complex<double> dot[2] = {
{sr * S0t, si * S0t},
{-(sr * E0t.real()) - (si * E0t.imag()), -(sr * E0t.imag()) + (si * E0t.real())}};
Mpi::GlobalSum(2, dot, S0t.ParFESpace()->GetComm());
auto scale = std::abs(dot[0]) / (dot[0] * std::sqrt(std::abs(dot[1])));

// This also updates the n x H coefficients depending on Et, En.
Vector tmp = E0t.real();
add(scale.real(), E0t.real(), -scale.imag(), E0t.imag(), E0t.real());
add(scale.imag(), tmp, scale.real(), E0t.imag(), E0t.imag());

tmp = E0n.real();
add(scale.real(), E0n.real(), -scale.imag(), E0n.imag(), E0n.real());
add(scale.imag(), tmp, scale.real(), E0n.imag(), E0n.imag());

// Update linear forms for postprocessing too.
tmp = sr;
add(scale.real(), sr, -scale.imag(), si, sr);
add(scale.imag(), tmp, scale.real(), si, si);
ComplexVector::AXPBY(scale, E0t.real(), E0t.imag(), 0.0, E0t.real(), E0t.imag());
ComplexVector::AXPBY(scale, E0n.real(), E0n.imag(), 0.0, E0n.real(), E0n.imag());
ComplexVector::AXPBY(scale, sr, si, 0.0, sr, si);

// This parallel communication is not required since wave port boundaries are true one-
// sided boundaries.
Expand Down Expand Up @@ -596,12 +586,10 @@ WavePortData::WavePortData(const config::WavePortData &data,
}
}

// Create vector for initial space for eigenvalue solves.
// Create vector for initial space for eigenvalue solves and eigenmode solution.
GetInitialSpace(*port_nd_fespace, *port_h1_fespace, port_dbc_tdof_list, v0);
e0.SetSize(port_nd_fespace->GetTrueVSize() + port_h1_fespace->GetTrueVSize());
e0n.SetSize(port_h1_fespace->GetTrueVSize());
e0.UseDevice(true);
e0n.UseDevice(true);

// The operators for the generalized eigenvalue problem are:
// [Aₜₜ Aₜₙ] [eₜ] = -kₙ² [Bₜₜ 0ₜₙ] [eₜ]
Expand Down Expand Up @@ -907,6 +895,8 @@ void WavePortData::Initialize(double omega)
MFEM_ASSERT(e0.Size() == 0,
"Unexpected non-empty port FE space in wave port boundary mode solve!");
}
e0.Real().ReadWrite(); // Ensure memory is allocated on device before aliasing
e0.Imag().ReadWrite();
Vector e0tr(e0.Real(), 0, port_nd_fespace->GetTrueVSize());
Vector e0nr(e0.Real(), port_nd_fespace->GetTrueVSize(),
port_h1_fespace->GetTrueVSize());
Expand All @@ -917,13 +907,15 @@ void WavePortData::Initialize(double omega)
e0nr.UseDevice(true);
e0ti.UseDevice(true);
e0ni.UseDevice(true);
e0n.Real() = e0nr;
e0n.Imag() = e0ni;
e0n *= 1.0 / (1i * kn0);
ComplexVector::AXPBY(1.0 / (1i * kn0), e0nr, e0ni, 0.0, e0nr, e0ni);
port_E0t->real().SetFromTrueDofs(e0tr); // Parallel distribute
port_E0t->imag().SetFromTrueDofs(e0ti);
port_E0n->real().SetFromTrueDofs(e0n.Real());
port_E0n->imag().SetFromTrueDofs(e0n.Imag());
port_E0n->real().SetFromTrueDofs(e0nr);
port_E0n->imag().SetFromTrueDofs(e0ni);
port_E0t->real().HostRead(); // Read on host for linear form assembly
port_E0t->imag().HostRead();
port_E0n->real().HostRead();
port_E0n->imag().HostRead();
}

// Configure the linear forms for computing S-parameters (projection of the field onto the
Expand Down
2 changes: 1 addition & 1 deletion palace/models/waveportoperator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class WavePortData

// Operator storage for repeated boundary mode eigenvalue problem solves.
std::unique_ptr<mfem::HypreParMatrix> Atnr, Atni, Antr, Anti, Annr, Anni, Br, Bi;
ComplexVector v0, e0, e0n;
ComplexVector v0, e0;

// Eigenvalue solver for boundary modes.
MPI_Comm port_comm;
Expand Down

0 comments on commit a228960

Please sign in to comment.