From a228960d99387daa15088a32901c13c3a2fe592f Mon Sep 17 00:00:00 2001 From: Sebastian Grimberg Date: Tue, 27 Feb 2024 14:15:16 -0800 Subject: [PATCH] Fix device-host aliasing issue on GPU for wave ports --- palace/linalg/vector.cpp | 87 +++++++++++++++++++----------- palace/linalg/vector.hpp | 10 ++++ palace/models/waveportoperator.cpp | 38 ++++++------- palace/models/waveportoperator.hpp | 2 +- 4 files changed, 82 insertions(+), 55 deletions(-) diff --git a/palace/linalg/vector.cpp b/palace/linalg/vector.cpp index 888895af7..0bbbc2783 100644 --- a/palace/linalg/vector.cpp +++ b/palace/linalg/vector.cpp @@ -204,14 +204,20 @@ std::complex ComplexVector::TransposeDot(const ComplexVector &y) const void ComplexVector::AXPY(std::complex alpha, const ComplexVector &x) { - const bool use_dev = UseDevice() || x.UseDevice(); - const int N = Size(); + AXPY(alpha, x.Real(), x.Imag(), Real(), Imag()); +} + +void ComplexVector::AXPY(std::complex alpha, const Vector &xr, const Vector &xi, + Vector &yr, Vector &yi) +{ + const bool use_dev = yr.UseDevice() || xr.UseDevice(); + const int N = yr.Size(); const double ar = alpha.real(); const double ai = alpha.imag(); - const auto *XR = x.Real().Read(use_dev); - const auto *XI = x.Imag().Read(use_dev); - auto *YR = Real().ReadWrite(use_dev); - auto *YI = Imag().ReadWrite(use_dev); + const auto *XR = xr.Read(use_dev); + const auto *XI = xi.Read(use_dev); + auto *YR = yr.ReadWrite(use_dev); + auto *YI = yi.ReadWrite(use_dev); if (ai == 0.0) { mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { YR[i] += ar * XR[i]; }); @@ -222,8 +228,9 @@ void ComplexVector::AXPY(std::complex alpha, const ComplexVector &x) mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { + const auto t = ai * XR[i] + ar * XI[i]; YR[i] += ar * XR[i] - ai * XI[i]; - YI[i] += ai * XR[i] + ar * XI[i]; + YI[i] += t; }); } } @@ -231,16 +238,22 @@ void ComplexVector::AXPY(std::complex alpha, const ComplexVector &x) void ComplexVector::AXPBY(std::complex alpha, const ComplexVector &x, std::complex beta) { - const bool use_dev = UseDevice() || x.UseDevice(); - const int N = Size(); + AXPBY(alpha, x.Real(), x.Imag(), beta, Real(), Imag()); +} + +void ComplexVector::AXPBY(std::complex alpha, const Vector &xr, const Vector &xi, + std::complex beta, Vector &yr, Vector &yi) +{ + const bool use_dev = yr.UseDevice() || xr.UseDevice(); + const int N = yr.Size(); const double ar = alpha.real(); const double ai = alpha.imag(); - const auto *XR = x.Real().Read(use_dev); - const auto *XI = x.Imag().Read(use_dev); + const auto *XR = xr.Read(use_dev); + const auto *XI = xi.Read(use_dev); if (beta == 0.0) { - auto *YR = Real().Write(use_dev); - auto *YI = Imag().Write(use_dev); + auto *YR = yr.Write(use_dev); + auto *YI = yi.Write(use_dev); if (ai == 0.0) { mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { YR[i] = ar * XR[i]; }); @@ -251,8 +264,9 @@ void ComplexVector::AXPBY(std::complex alpha, const ComplexVector &x, mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { + const auto t = ai * XR[i] + ar * XI[i]; YR[i] = ar * XR[i] - ai * XI[i]; - YI[i] = ai * XR[i] + ar * XI[i]; + YI[i] = t; }); } } @@ -260,8 +274,8 @@ void ComplexVector::AXPBY(std::complex alpha, const ComplexVector &x, { const double br = beta.real(); const double bi = beta.imag(); - auto *YR = Real().ReadWrite(use_dev); - auto *YI = Imag().ReadWrite(use_dev); + auto *YR = yr.ReadWrite(use_dev); + auto *YI = yi.ReadWrite(use_dev); if (ai == 0.0 && bi == 0.0) { mfem::forall_switch(use_dev, N, @@ -274,9 +288,10 @@ void ComplexVector::AXPBY(std::complex alpha, const ComplexVector &x, mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { - const auto t = bi * YR[i] + br * YI[i]; + const auto t = + ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i]; YR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i]; - YI[i] = ai * XR[i] + ar * XI[i] + t; + YI[i] = t; }); } } @@ -286,20 +301,27 @@ void ComplexVector::AXPBYPCZ(std::complex alpha, const ComplexVector &x, std::complex beta, const ComplexVector &y, std::complex gamma) { - const bool use_dev = UseDevice() || x.UseDevice() || y.UseDevice(); - const int N = Size(); + AXPBYPCZ(alpha, x.Real(), x.Imag(), beta, y.Real(), y.Imag(), gamma, Real(), Imag()); +} + +void ComplexVector::AXPBYPCZ(std::complex alpha, const Vector &xr, const Vector &xi, + std::complex beta, const Vector &yr, const Vector &yi, + std::complex gamma, Vector &zr, Vector &zi) +{ + const bool use_dev = zr.UseDevice() || xr.UseDevice() || yr.UseDevice(); + const int N = zr.Size(); const double ar = alpha.real(); const double ai = alpha.imag(); const double br = beta.real(); const double bi = beta.imag(); - const auto *XR = x.Real().Read(use_dev); - const auto *XI = x.Imag().Read(use_dev); - const auto *YR = y.Real().Read(use_dev); - const auto *YI = y.Imag().Read(use_dev); + const auto *XR = xr.Read(use_dev); + const auto *XI = xi.Read(use_dev); + const auto *YR = yr.Read(use_dev); + const auto *YI = yi.Read(use_dev); if (gamma == 0.0) { - auto *ZR = Real().Write(use_dev); - auto *ZI = Imag().Write(use_dev); + auto *ZR = zr.Write(use_dev); + auto *ZI = zi.Write(use_dev); if (ai == 0.0 && bi == 0.0) { mfem::forall_switch(use_dev, N, @@ -312,8 +334,10 @@ void ComplexVector::AXPBYPCZ(std::complex alpha, const ComplexVector &x, mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { + const auto t = + ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i]; ZR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i]; - ZI[i] = ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i]; + ZI[i] = t; }); } } @@ -321,8 +345,8 @@ void ComplexVector::AXPBYPCZ(std::complex alpha, const ComplexVector &x, { const double gr = gamma.real(); const double gi = gamma.imag(); - auto *ZR = Real().ReadWrite(use_dev); - auto *ZI = Imag().ReadWrite(use_dev); + auto *ZR = zr.ReadWrite(use_dev); + auto *ZI = zi.ReadWrite(use_dev); if (ai == 0.0 && bi == 0.0 && gi == 0.0) { mfem::forall_switch(use_dev, N, @@ -337,10 +361,11 @@ void ComplexVector::AXPBYPCZ(std::complex alpha, const ComplexVector &x, mfem::forall_switch(use_dev, N, [=] MFEM_HOST_DEVICE(int i) { - const auto t = gi * ZR[i] + gr * ZI[i]; + const auto t = ai * XR[i] + ar * XI[i] + bi * YR[i] + + br * YI[i] + gi * ZR[i] + gr * ZI[i]; ZR[i] = ar * XR[i] - ai * XI[i] + br * YR[i] - bi * YI[i] + gr * ZR[i] - gi * ZI[i]; - ZI[i] = ai * XR[i] + ar * XI[i] + bi * YR[i] + br * YI[i] + t; + ZI[i] = t; }); } } diff --git a/palace/linalg/vector.hpp b/palace/linalg/vector.hpp index 469a9f056..e47e45e93 100644 --- a/palace/linalg/vector.hpp +++ b/palace/linalg/vector.hpp @@ -116,6 +116,16 @@ class ComplexVector void AXPBYPCZ(std::complex alpha, const ComplexVector &x, std::complex beta, const ComplexVector &y, std::complex gamma); + + static void AXPY(std::complex alpha, const Vector &xr, const Vector &xi, + Vector &yr, Vector &yi); + + static void AXPBY(std::complex alpha, const Vector &xr, const Vector &xi, + std::complex beta, Vector &yr, Vector &yi); + + static void AXPBYPCZ(std::complex alpha, const Vector &xr, const Vector &xi, + std::complex beta, const Vector &yr, const Vector &yi, + std::complex gamma, Vector &zr, Vector &zi); }; namespace linalg diff --git a/palace/models/waveportoperator.cpp b/palace/models/waveportoperator.cpp index 3eb6620e1..c56c9a325 100644 --- a/palace/models/waveportoperator.cpp +++ b/palace/models/waveportoperator.cpp @@ -294,26 +294,16 @@ void Normalize(const mfem::ParGridFunction &S0t, mfem::ParComplexGridFunction &E // functions. We choose a (rather arbitrary) phase constraint to at least make results for // the same port consistent between frequencies/meshes. - // |E x H⋆| ⋅ n = |E ⋅ (-n x H⋆)| + // |E x H⋆| ⋅ n = |E ⋅ (-n x H⋆)|. This also updates the n x H coefficients depending on + // Et, En. Update linear forms for postprocessing too. std::complex dot[2] = { {sr * S0t, si * S0t}, {-(sr * E0t.real()) - (si * E0t.imag()), -(sr * E0t.imag()) + (si * E0t.real())}}; Mpi::GlobalSum(2, dot, S0t.ParFESpace()->GetComm()); auto scale = std::abs(dot[0]) / (dot[0] * std::sqrt(std::abs(dot[1]))); - - // This also updates the n x H coefficients depending on Et, En. - Vector tmp = E0t.real(); - add(scale.real(), E0t.real(), -scale.imag(), E0t.imag(), E0t.real()); - add(scale.imag(), tmp, scale.real(), E0t.imag(), E0t.imag()); - - tmp = E0n.real(); - add(scale.real(), E0n.real(), -scale.imag(), E0n.imag(), E0n.real()); - add(scale.imag(), tmp, scale.real(), E0n.imag(), E0n.imag()); - - // Update linear forms for postprocessing too. - tmp = sr; - add(scale.real(), sr, -scale.imag(), si, sr); - add(scale.imag(), tmp, scale.real(), si, si); + ComplexVector::AXPBY(scale, E0t.real(), E0t.imag(), 0.0, E0t.real(), E0t.imag()); + ComplexVector::AXPBY(scale, E0n.real(), E0n.imag(), 0.0, E0n.real(), E0n.imag()); + ComplexVector::AXPBY(scale, sr, si, 0.0, sr, si); // This parallel communication is not required since wave port boundaries are true one- // sided boundaries. @@ -596,12 +586,10 @@ WavePortData::WavePortData(const config::WavePortData &data, } } - // Create vector for initial space for eigenvalue solves. + // Create vector for initial space for eigenvalue solves and eigenmode solution. GetInitialSpace(*port_nd_fespace, *port_h1_fespace, port_dbc_tdof_list, v0); e0.SetSize(port_nd_fespace->GetTrueVSize() + port_h1_fespace->GetTrueVSize()); - e0n.SetSize(port_h1_fespace->GetTrueVSize()); e0.UseDevice(true); - e0n.UseDevice(true); // The operators for the generalized eigenvalue problem are: // [Aₜₜ Aₜₙ] [eₜ] = -kₙ² [Bₜₜ 0ₜₙ] [eₜ] @@ -907,6 +895,8 @@ void WavePortData::Initialize(double omega) MFEM_ASSERT(e0.Size() == 0, "Unexpected non-empty port FE space in wave port boundary mode solve!"); } + e0.Real().ReadWrite(); // Ensure memory is allocated on device before aliasing + e0.Imag().ReadWrite(); Vector e0tr(e0.Real(), 0, port_nd_fespace->GetTrueVSize()); Vector e0nr(e0.Real(), port_nd_fespace->GetTrueVSize(), port_h1_fespace->GetTrueVSize()); @@ -917,13 +907,15 @@ void WavePortData::Initialize(double omega) e0nr.UseDevice(true); e0ti.UseDevice(true); e0ni.UseDevice(true); - e0n.Real() = e0nr; - e0n.Imag() = e0ni; - e0n *= 1.0 / (1i * kn0); + ComplexVector::AXPBY(1.0 / (1i * kn0), e0nr, e0ni, 0.0, e0nr, e0ni); port_E0t->real().SetFromTrueDofs(e0tr); // Parallel distribute port_E0t->imag().SetFromTrueDofs(e0ti); - port_E0n->real().SetFromTrueDofs(e0n.Real()); - port_E0n->imag().SetFromTrueDofs(e0n.Imag()); + port_E0n->real().SetFromTrueDofs(e0nr); + port_E0n->imag().SetFromTrueDofs(e0ni); + port_E0t->real().HostRead(); // Read on host for linear form assembly + port_E0t->imag().HostRead(); + port_E0n->real().HostRead(); + port_E0n->imag().HostRead(); } // Configure the linear forms for computing S-parameters (projection of the field onto the diff --git a/palace/models/waveportoperator.hpp b/palace/models/waveportoperator.hpp index bef39892e..ed736839f 100644 --- a/palace/models/waveportoperator.hpp +++ b/palace/models/waveportoperator.hpp @@ -65,7 +65,7 @@ class WavePortData // Operator storage for repeated boundary mode eigenvalue problem solves. std::unique_ptr Atnr, Atni, Antr, Anti, Annr, Anni, Br, Bi; - ComplexVector v0, e0, e0n; + ComplexVector v0, e0; // Eigenvalue solver for boundary modes. MPI_Comm port_comm;