Skip to content

Latest commit

 

History

History
62 lines (49 loc) · 2.46 KB

README.md

File metadata and controls

62 lines (49 loc) · 2.46 KB

QRMPI.jl: QR factorisation distributed over MPI

QR factorise your MPI distributed matrix using Householder reflections and then solve your square or least square problem.

Currently this only works for Complex32, Float64, ComplexF32 and ComplexF64 matrix and vector eltypes. These isbitstypes make use of calls to threaded BLAS.gemm! and BLAS.axpy! on each rank for highest performance, which is on a par with off threaded LAPACK getrf speeds (with a comparable number of cores).

There is a tunable parameter blocksize, which enables larger calls to gemm! and reduces the number of MPI communications, albeit at the cost of bigger communcations when they do happen, which themselves should be hidden well by compute.

An example:

using MPIQR
using LinearAlgebra, MPI, Distributed, MPIClusterManagers
using ProgressMeter # optional

MPI.Init(;threadlevel=MPI.THREAD_SERIALIZED)
const rnk = MPI.Comm_rank(MPI.COMM_WORLD)

function run(T=ComplexF64;)
  # increase blocksize to improve usage of BLAS and decrease MPI comms
  blocksize = 2
  m, n = 2048, 1024
  A0 = zeros(T, 0, 0)
  x1 = b0 = zeros(T, 0)
  if rnk == 0 # assemble and solve serially to compare with MPIQR later
    A0 = rand(T, m, n) # the original matrix
    b0 = rand(T, m) # the original lhs
    A1 = deepcopy(A0) # this will get mutated
    b1 = deepcopy(b0) # as will this
    x1 = qr!(A1) \ b1
    y1 = A0 * x1 # this is the matrix vector product, not the least squares solution
  end
  Aall = MPI.bcast(A0, 0, MPI.COMM_WORLD) # lhs matrix on all ranks
  ball = MPI.bcast(b0, 0, MPI.COMM_WORLD) # rhs vector on all ranks
  xall = MPI.bcast(x1, 0, MPI.COMM_WORLD) # solution vector on all ranks

  # get the columns of the matrix that will be local to this rank
  localcols = MPIQR.localcolumns(rnk, n, blocksize, MPI.Comm_size(MPI.COMM_WORLD))
  b = deepcopy(ball)

  # distribute the serial matrix onto the columns local to this rank
  A = MPIQR.MPIQRMatrix(deepcopy(Aall[:, localcols]), size(Aall); blocksize=blocksize)
  y2 = A * xall # make sure matrix vector multiplication works...
  if iszero(rnk) # ... and is correct.
    @assert y2  y1
  end

  # qr! optionally accepts a progress meter
  # qr factorize A in-place and solve
  x2 = qr!(A; progress=Progress(A; showspeed=true)) \ b

  if iszero(rnk) # now see if the answer is right...
    @assert norm(Aall' * Aall * xall .- Aall' * ball) < 1e-8
    @show residual = norm(Aall' * Aall * x2 .- Aall' * ball)
  end
end
run()

MPI.Finalize()