Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amesos2: Update the STRUMPACK interface #12227

Merged
merged 3 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions packages/amesos2/src/Amesos2_STRUMPACK_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class STRUMPACK : public SolverCore<Amesos2::STRUMPACK, Matrix, Vector>
typedef typename strum_type::global_size_type global_size_type;
typedef typename strum_type::node_type node_type;

typedef Kokkos::DefaultHostExecutionSpace HostExecSpaceType;
typedef Kokkos::View<global_ordinal_type*, HostExecSpaceType> host_ordinal_type_array;
typedef Kokkos::View<scalar_type*, HostExecSpaceType> host_value_type_array;

/// \name Constructor/Destructor methods
//@{

Expand Down Expand Up @@ -233,11 +237,11 @@ class STRUMPACK : public SolverCore<Amesos2::STRUMPACK, Matrix, Vector>

// The following Arrays are persisting storage arrays for A, X, and B
/// Stores the values of the nonzero entries for STRUMPACK
Teuchos::Array<scalar_type> nzvals_;
// /// Stores the row indices of the nonzero entries
Teuchos::Array<global_ordinal_type> colind_;
// /// Stores the location in \c Ai_ and Aval_ that starts row j
Teuchos::Array<global_ordinal_type> rowptr_;
host_value_type_array nzvals_view_;
/// Stores the row indices of the nonzero entries
host_ordinal_type_array colind_view_;
/// Stores the location in \c Ai_ and Aval_ that starts row j
host_ordinal_type_array rowptr_view_;
// /// 1D store for B values
mutable Teuchos::Array<scalar_type> bvals_;
// /// 1D store for X values
Expand Down
66 changes: 34 additions & 32 deletions packages/amesos2/src/Amesos2_STRUMPACK_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ namespace Amesos2 {

{
using Teuchos::Comm;
#ifdef HAVE_MPI
#ifdef HAVE_MPI
using Teuchos::MpiComm;
#endif
using Teuchos::ParameterList;
Expand All @@ -79,7 +79,7 @@ namespace Amesos2 {
using Teuchos::rcp;
using Teuchos::rcp_dynamic_cast;
typedef global_ordinal_type GO;
#ifdef HAVE_MPI
#ifdef HAVE_MPI
typedef Tpetra::Map<local_ordinal_type, GO, node_type> map_type;
#endif
RCP<const Comm<int> > comm = this->getComm ();
Expand All @@ -93,16 +93,17 @@ namespace Amesos2 {
MPI_Comm rawMpiComm = (* (mpiComm->getRawMpiComm ())) ();

sp_ = Teuchos::RCP<strumpack::StrumpackSparseSolverMPIDist<scalar_type,GO>>
(new strumpack::StrumpackSparseSolverMPIDist<scalar_type,GO>(rawMpiComm, this->control_.verbose_));
// (new strumpack::StrumpackSparseSolverMPIDist<scalar_type,GO>(rawMpiComm, this->control_.verbose_));
(new strumpack::StrumpackSparseSolverMPIDist<scalar_type,GO>(rawMpiComm, true));
#else
sp_ = Teuchos::RCP<strumpack::StrumpackSparseSolver<scalar_type,GO>>
(new strumpack::StrumpackSparseSolver<scalar_type,GO>(this->control_.verbose_, this->root_));

#endif

/*
Do we need this?
(What parameters do we set here that are not already provided?)
/*
Do we need this?
(What parameters do we set here that are not already provided?)
*/
RCP<ParameterList> default_params =
parameterList (* (this->getValidParameters ()));
Expand Down Expand Up @@ -208,11 +209,11 @@ namespace Amesos2 {
// this processor
const size_t local_len_rhs = strumpack_rowmap_->getLocalNumElements();
const global_size_type nrhs = X->getGlobalNumVectors();

// make sure our multivector storage is sized appropriately
bvals_.resize(nrhs * local_len_rhs);
xvals_.resize(nrhs * local_len_rhs);

{

#ifdef HAVE_AMESOS2_TIMERS
Expand All @@ -239,7 +240,7 @@ namespace Amesos2 {
strumpack::DenseMatrixWrapper<scalar_type>
Bsp(local_len_rhs, nrhs, bvals_().getRawPtr(), local_len_rhs),
Xsp(local_len_rhs, nrhs, xvals_().getRawPtr(), local_len_rhs);
strumpack::ReturnCode ret =sp_->solve(Bsp, Xsp);
strumpack::ReturnCode ret =sp_->solve(Bsp, Xsp);

TEUCHOS_TEST_FOR_EXCEPTION( ret != strumpack::ReturnCode::SUCCESS,
std::runtime_error,
Expand Down Expand Up @@ -429,7 +430,7 @@ namespace Amesos2 {
pl->set("ReplaceTinyPivot", true, "Specifies whether to replace tiny diagonals during LU factorization");


// There are multiple options available for an iterative refinement,
// There are multiple options available for an iterative refinement,
// however we recommend the use of "DIRECT" within the Amesos2 interface
setStringToIntegralParameter<strumpack::KrylovSolver>("IterRefine", "DIRECT",
"Type of iterative refinement to use",
Expand All @@ -450,7 +451,7 @@ namespace Amesos2 {
strumpack::KrylovSolver::BICGSTAB),
pl.getRawPtr());

// There are multiple options available for the compression of the matrix,
// There are multiple options available for the compression of the matrix,
// we recommend the use of "NONE" within the Amesos2 interface
setStringToIntegralParameter<strumpack::CompressionType>("Compression", "NONE",
"Type of compression to use",
Expand Down Expand Up @@ -521,21 +522,21 @@ namespace Amesos2 {

MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, dist.data()+1, sizeof(GO), MPI_BYTE,
mpiComm->getRawMpiComm()->operator()());
nzvals_.resize(l_nnz);
colind_.resize(l_nnz);
rowptr_.resize(l_rows + 1);
Kokkos::resize(nzvals_view_, l_nnz);
Kokkos::resize(colind_view_, l_nnz);
Kokkos::resize(rowptr_view_, l_rows + 1);



GO nnz_ret = 0;
{
#ifdef HAVE_AMESOS2_TIMERS
Teuchos::TimeMonitor mtxRedistTimer( this->timers_.mtxRedistTime_ );
#endif

Util::get_crs_helper<
MatrixAdapter<Matrix>,
scalar_type, GO, GO >::do_get(redist_mat.ptr(),
nzvals_(), colind_(), rowptr_(),
Util::get_crs_helper_kokkos_view<MatrixAdapter<Matrix>,
host_value_type_array, host_ordinal_type_array, host_ordinal_type_array >::do_get(
redist_mat.ptr(),
nzvals_view_, colind_view_, rowptr_view_,
nnz_ret,
ptrInArg(*strumpack_rowmap_),
ROOTED,
Expand All @@ -549,8 +550,9 @@ namespace Amesos2 {

// Get the csr data type for this type of matrix
sp_->set_distributed_csr_matrix
(l_rows, rowptr_.getRawPtr(), colind_.getRawPtr(),
nzvals_.getRawPtr(), dist.getRawPtr(), false);
(l_rows, rowptr_view_.data(), colind_view_.data(),
nzvals_view_.data(), dist.getRawPtr(), false);

#else
#ifdef HAVE_AMESOS2_TIMERS
Teuchos::TimeMonitor convTimer(this->timers_.mtxConvTime_);
Expand All @@ -560,19 +562,19 @@ namespace Amesos2 {
GO nnz_ret = 0;

if( this->root_ ){
nzvals_.resize(this->globalNumNonZeros_);
colind_.resize(this->globalNumNonZeros_);
rowptr_.resize(this->globalNumRows_ + 1);
}
Kokkos::resize(nzvals_view_, this->globalNumNonZeros_);
Kokkos::resize(colind_view_, this->globalNumNonZeros_);
Kokkos::resize(rowptr_view_, this->globalNumRows_ + 1);
}
{
#ifdef HAVE_AMESOS2_TIMERS
Teuchos::TimeMonitor mtxRedistTimer( this->timers_.mtxRedistTime_ );
#endif

Util::get_crs_helper<
MatrixAdapter<Matrix>,
scalar_type, GO, GO >::do_get(this->matrixA_.ptr(),
nzvals_(), colind_(), rowptr_(),
Util::get_crs_helper_kokkos_view<MatrixAdapter<Matrix>,
host_value_type_array, host_ordinal_type_array, host_ordinal_type_array >::do_get(
this->matrixA_.ptr(),
nzvals_view_, colind_view_, rowptr_view_,
nnz_ret,
ROOTED,
ARBITRARY, this->rowIndexBase_);
Expand All @@ -583,10 +585,10 @@ namespace Amesos2 {
"Did not get the expected number of non-zero vals");

// Get the csr data type for this type of matrix
sp_->set_csr_matrix(this->globalNumRows_, rowptr_.getRawPtr(), colind_.getRawPtr(),
nzvals_.getRawPtr(), false);
sp_->set_csr_matrix(this->globalNumRows_, rowptr_view_.data(), colind_view_.data(),
nzvals_view_.data(), false);

#endif
#endif
return true;
}

Expand Down
23 changes: 22 additions & 1 deletion packages/amesos2/test/solvers/STRUMPACK_UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,22 @@ namespace {
using Teuchos::TRANS;
using Teuchos::NO_TRANS;


using Tpetra::global_size_t;
using Tpetra::CrsMatrix;
using Tpetra::MultiVector;
using Tpetra::Map;
using Tpetra::createContigMap;
using Tpetra::createUniformContigMap;



using Amesos2::STRUMPACK;

using Amesos2::Meta::is_same;

typedef Tpetra::Map<>::node_type Node;


bool testMpi = true;

// Where to look for input files
Expand Down Expand Up @@ -700,7 +703,25 @@ namespace {
UNIT_TEST_GROUP_ORDINAL_DOUBLE(LO, GO) \
UNIT_TEST_GROUP_ORDINAL_COMPLEX_DOUBLE(LO,GO)

// UNIT_TEST_GROUP_ORDINAL(int)

#ifndef HAVE_AMESOS2_EXPLICIT_INSTANTIATION
UNIT_TEST_GROUP_ORDINAL(int)
typedef long int LongInt;
UNIT_TEST_GROUP_ORDINAL_ORDINAL( int, LongInt )
#ifdef HAVE_TPETRA_INT_LONG_LONG
typedef long long int LongLongInt;
UNIT_TEST_GROUP_ORDINAL_ORDINAL( int, LongLongInt )
#endif
#else //ETI
#ifdef HAVE_TPETRA_INST_INT_INT
UNIT_TEST_GROUP_ORDINAL(int)
#endif
#ifdef HAVE_TPETRA_INST_INT_LONG
typedef long int LongInt;
UNIT_TEST_GROUP_ORDINAL_ORDINAL(int,LongInt)
#endif
#endif // EXPL-INST

//# ifndef HAVE_AMESOS2_EXPLICIT_INSTANTIATION
//typedef long int LongInt;
Expand Down