diff --git a/src/LinearSolvers/block_gmres.jl b/src/LinearSolvers/block_gmres.jl index 8f2b587e..a26bf070 100644 --- a/src/LinearSolvers/block_gmres.jl +++ b/src/LinearSolvers/block_gmres.jl @@ -66,14 +66,14 @@ function block_gmres(A, B::AbstractMatrix{FC}, X0::AbstractMatrix{FC}; memory::I atol::T = √eps(T), rtol::T=√eps(T), itmax::Int=0, timemax::Float64=Inf, verbose::Int=0, history::Bool=false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} - start_time = time_ns() - solver = BlockGmresSolver(A, B; memory) - warm_start!(solver, X0) - elapsed_time = ktimer(start_time) - timemax -= elapsed_time - block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) - solver.stats.timer += elapsed_time - return solver.X, solver.stats + start_time = time_ns() + solver = BlockGmresSolver(A, B; memory) + warm_start!(solver, X0) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) + solver.stats.timer += elapsed_time + return solver.X, solver.stats end function block_gmres(A, B::AbstractMatrix{FC}; memory::Int=20, M=I, N=I, @@ -81,13 +81,13 @@ function block_gmres(A, B::AbstractMatrix{FC}; memory::Int=20, M=I, N=I, atol::T = √eps(T), rtol::T=√eps(T), itmax::Int=0, timemax::Float64=Inf, verbose::Int=0, history::Bool=false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}} - start_time = time_ns() - solver = BlockGmresSolver(A, B; memory) - elapsed_time = ktimer(start_time) - timemax -= elapsed_time - block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) - solver.stats.timer += elapsed_time - return solver.X, solver.stats + start_time = time_ns() + solver = BlockGmresSolver(A, B; memory) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) + solver.stats.timer += elapsed_time + return solver.X, solver.stats end function block_gmres!(solver :: BlockGmresSolver{T,FC,SV,SM}, A, B::AbstractMatrix{FC}, X0::AbstractMatrix{FC}; M=I, N=I, @@ -95,13 +95,13 @@ function block_gmres!(solver :: BlockGmresSolver{T,FC,SV,SM}, A, B::AbstractMatr atol::T = √eps(T), rtol::T=√eps(T), itmax::Int=0, timemax::Float64=Inf, verbose::Int=0, history::Bool=false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}} - start_time = time_ns() - warm_start!(solver, X0) - elapsed_time = ktimer(start_time) - timemax -= elapsed_time - block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) - solver.stats.timer += elapsed_time - return solver + start_time = time_ns() + warm_start!(solver, X0) + elapsed_time = ktimer(start_time) + timemax -= elapsed_time + block_gmres!(solver, A, B; M, N, ldiv, restart, reorthogonalization, atol, rtol, itmax, timemax, verbose, history) + solver.stats.timer += elapsed_time + return solver end function block_gmres!(solver :: BlockGmresSolver{T,FC,SV,SM}, A, B::AbstractMatrix{FC}; M=I, N=I, @@ -109,240 +109,240 @@ function block_gmres!(solver :: BlockGmresSolver{T,FC,SV,SM}, A, B::AbstractMatr atol::T = √eps(T), rtol::T=√eps(T), itmax::Int=0, timemax::Float64=Inf, verbose::Int=0, history::Bool=false) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, SV <: AbstractVector{FC}, SM <: AbstractMatrix{FC}} - # Timer - start_time = time_ns() - timemax_ns = 1e9 * timemax - - n, m = size(A) - s, p = size(B) - m == n || error("System must be square") - n == s || error("Inconsistent problem size") - (verbose > 0) && @printf("BLOCK-GMRES: system of size %d with %d right-hand sides\n", n, p) - - # Check M = Iₙ and N = Iₙ - MisI = (M === I) - NisI = (N === I) - - # Check type consistency - eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-matrix products." - typeof(B) <: SM || error("typeof(B) is not a subtype of $SM") - - # Set up workspace. - allocate_if(!MisI , solver, :Q , SM, n, p) - allocate_if(!NisI , solver, :P , SM, n, p) - allocate_if(restart, solver, :ΔX, SM, n, p) - ΔX, X, W, V, Z = solver.ΔX, solver.X, solver.W, solver.V, solver.Z - C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats - warm_start = solver.warm_start - RNorms = stats.residuals - reset!(stats) - Q = MisI ? W : solver.Q - R₀ = MisI ? W : solver.Q - Xr = restart ? ΔX : X - - # Define the blocks D1 and D2 - D1 = view(D, 1:p, :) - D2 = view(D, p+1:2p, :) - - # Coefficients for mul! - α = -one(FC) - β = one(FC) - γ = one(FC) - - # Initial solution X₀. - fill!(X, zero(FC)) - - # Initial residual R₀. - if warm_start - mul!(W, A, ΔX) - W .= B .- W - restart && (X .+= ΔX) - else - copyto!(W, B) - end - MisI || mulorldiv!(R₀, M, W, ldiv) # R₀ = M(B - AX₀) - RNorm = norm(R₀) # ‖R₀‖_F - - history && push!(RNorms, RNorm) - ε = atol + rtol * RNorm - - mem = length(V) # Memory - npass = 0 # Number of pass - - iter = 0 # Cumulative number of iterations - inner_iter = 0 # Number of iterations in a pass - - itmax == 0 && (itmax = 2*div(n,p)) - inner_itmax = itmax - - (verbose > 0) && @printf("%5s %5s %7s %5s\n", "pass", "k", "‖Rₖ‖", "timer") - kdisplay(iter, verbose) && @printf("%5d %5d %7.1e %.2fs\n", npass, iter, RNorm, ktimer(start_time)) - - # Stopping criterion - solved = RNorm ≤ ε - tired = iter ≥ itmax - inner_tired = inner_iter ≥ inner_itmax - status = "unknown" - overtimed = false - - while !(solved || tired || overtimed) - - # Initialize workspace. - nr = 0 # Number of blocks Ψᵢⱼ stored in Rₖ. - for i = 1 : mem - fill!(V[i], zero(FC)) # Orthogonal basis of Kₖ(MAN, MR₀). - end - for Ψ in R - fill!(Ψ, zero(FC)) # Upper triangular matrix Rₖ. - end - for block in Z - fill!(block, zero(FC)) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖYₖ - ΓE₁‖₂. - end - - if restart - fill!(Xr, zero(FC)) # Xr === ΔX when restart is set to true - if npass ≥ 1 - mul!(W, A, X) + # Timer + start_time = time_ns() + timemax_ns = 1e9 * timemax + + n, m = size(A) + s, p = size(B) + m == n || error("System must be square") + n == s || error("Inconsistent problem size") + (verbose > 0) && @printf("BLOCK-GMRES: system of size %d with %d right-hand sides\n", n, p) + + # Check M = Iₙ and N = Iₙ + MisI = (M === I) + NisI = (N === I) + + # Check type consistency + eltype(A) == FC || @warn "eltype(A) ≠ $FC. This could lead to errors or additional allocations in operator-matrix products." + typeof(B) <: SM || error("typeof(B) is not a subtype of $SM") + + # Set up workspace. + allocate_if(!MisI , solver, :Q , SM, n, p) + allocate_if(!NisI , solver, :P , SM, n, p) + allocate_if(restart, solver, :ΔX, SM, n, p) + ΔX, X, W, V, Z = solver.ΔX, solver.X, solver.W, solver.V, solver.Z + C, D, R, H, τ, stats = solver.C, solver.D, solver.R, solver.H, solver.τ, solver.stats + warm_start = solver.warm_start + RNorms = stats.residuals + reset!(stats) + Q = MisI ? W : solver.Q + R₀ = MisI ? W : solver.Q + Xr = restart ? ΔX : X + + # Define the blocks D1 and D2 + D1 = view(D, 1:p, :) + D2 = view(D, p+1:2p, :) + + # Coefficients for mul! + α = -one(FC) + β = one(FC) + γ = one(FC) + + # Initial solution X₀. + fill!(X, zero(FC)) + + # Initial residual R₀. + if warm_start + mul!(W, A, ΔX) W .= B .- W - MisI || mulorldiv!(R₀, M, W, ldiv) - end + restart && (X .+= ΔX) + else + copyto!(W, B) end - - # Initial Γ and V₁ - copyto!(V[1], R₀) - householder!(V[1], Z[1], τ[1]) + MisI || mulorldiv!(R₀, M, W, ldiv) # R₀ = M(B - AX₀) + RNorm = norm(R₀) # ‖R₀‖_F - npass = npass + 1 - inner_iter = 0 - inner_tired = false + history && push!(RNorms, RNorm) + ε = atol + rtol * RNorm - while !(solved || inner_tired || overtimed) + mem = length(V) # Memory + npass = 0 # Number of pass - # Update iteration index - inner_iter = inner_iter + 1 + iter = 0 # Cumulative number of iterations + inner_iter = 0 # Number of iterations in a pass - # Update workspace if more storage is required and restart is set to false - if !restart && (inner_iter > mem) - for i = 1 : inner_iter - push!(R, SM(undef, p, p)) + itmax == 0 && (itmax = 2*div(n,p)) + inner_itmax = itmax + + (verbose > 0) && @printf("%5s %5s %7s %5s\n", "pass", "k", "‖Rₖ‖", "timer") + kdisplay(iter, verbose) && @printf("%5d %5d %7.1e %.2fs\n", npass, iter, RNorm, ktimer(start_time)) + + # Stopping criterion + solved = RNorm ≤ ε + tired = iter ≥ itmax + inner_tired = inner_iter ≥ inner_itmax + status = "unknown" + overtimed = false + + while !(solved || tired || overtimed) + + # Initialize workspace. + nr = 0 # Number of blocks Ψᵢⱼ stored in Rₖ. + for i = 1 : mem + fill!(V[i], zero(FC)) # Orthogonal basis of Kₖ(MAN, MR₀). end - push!(H, SM(undef, 2p, p)) - push!(τ, SV(undef, p)) - end - - # Continue the block-Arnoldi process. - P = NisI ? V[inner_iter] : solver.P - NisI || mulorldiv!(P, N, V[inner_iter], ldiv) # P ← NVₖ - mul!(W, A, P) # W ← ANVₖ - MisI || mulorldiv!(Q, M, W, ldiv) # Q ← MANVₖ - for i = 1 : inner_iter - mul!(R[nr+i], V[i]', Q) # Ψᵢₖ = Vᵢᴴ * Q - mul!(Q, V[i], R[nr+i], α, β) # Q = Q - Vᵢ * Ψᵢₖ - end - - # Reorthogonalization of the block-Krylov basis. - if reorthogonalization - for i = 1 : inner_iter - mul!(Ψtmp, V[i]', Q) # Ψtmp = Vᵢᴴ * Q - mul!(Q, V[i], Ψtmp, α, β) # Q = Q - Vᵢ * Ψtmp - R[nr+i] .+= Ψtmp + for Ψ in R + fill!(Ψ, zero(FC)) # Upper triangular matrix Rₖ. end - end - - # Vₖ₊₁ and Ψₖ₊₁.ₖ are stored in Q and C. - householder!(Q, C, τ[inner_iter]) - - # Update the QR factorization of Hₖ₊₁.ₖ. - # Apply previous Householder reflections Ωᵢ. - for i = 1 : inner_iter-1 - D1 .= R[nr+i] - D2 .= R[nr+i+1] - LAPACK.ormqr!('L', 'T', H[i], τ[i], D) - R[nr+i] .= D1 - R[nr+i+1] .= D2 - end - - # Compute and apply current Householder reflection Ωₖ. - H[inner_iter][1:p,:] .= R[nr+inner_iter] - H[inner_iter][p+1:2p,:] .= C - householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true) - - # Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁) - D1 .= Z[inner_iter] - D2 .= zero(FC) - LAPACK.ormqr!('L', 'T', H[inner_iter], τ[inner_iter], D) - Z[inner_iter] .= D1 - - # Update residual norm estimate. - # ‖ M(B - AXₖ) ‖_F = ‖Λbarₖ₊₁‖_F - C .= D2 - RNorm = norm(C) - history && push!(RNorms, RNorm) - - # Update the number of coefficients in Rₖ - nr = nr + inner_iter - - # Update stopping criterion. - solved = RNorm ≤ ε - inner_tired = restart ? inner_iter ≥ min(mem, inner_itmax) : inner_iter ≥ inner_itmax - timer = time_ns() - start_time - overtimed = timer > timemax_ns - kdisplay(iter+inner_iter, verbose) && @printf("%5d %5d %7.1e %.2fs\n", npass, iter+inner_iter, RNorm, ktimer(start_time)) - - # Compute Vₖ₊₁. - if !(solved || inner_tired || overtimed) - if !restart && (inner_iter ≥ mem) - push!(V, SM(undef, n, p)) - push!(Z, SM(undef, p, p)) + for block in Z + fill!(block, zero(FC)) # Right-hand of the least squares problem min ‖Hₖ₊₁.ₖYₖ - ΓE₁‖₂. end - copyto!(V[inner_iter+1], Q) - Z[inner_iter+1] .= D2 - end - end - # Compute Yₖ by solving RₖYₖ = Zₖ with a backward substitution by block. - Y = Z # Yᵢ = Zᵢ - for i = inner_iter : -1 : 1 - pos = nr + i - inner_iter # position of Ψᵢ.ₖ - for j = inner_iter : -1 : i+1 - mul!(Y[i], R[pos], Y[j], α, β) # Yᵢ ← Yᵢ - ΨᵢⱼYⱼ - pos = pos - j + 1 # position of Ψᵢ.ⱼ₋₁ - end - ldiv!(UpperTriangular(R[pos]), Y[i]) # Yᵢ ← Yᵢ \ Ψᵢᵢ - end + if restart + fill!(Xr, zero(FC)) # Xr === ΔX when restart is set to true + if npass ≥ 1 + mul!(W, A, X) + W .= B .- W + MisI || mulorldiv!(R₀, M, W, ldiv) + end + end - # Form Xₖ = NVₖYₖ - for i = 1 : inner_iter - mul!(Xr, V[i], Y[i], γ, β) - end - if !NisI - copyto!(solver.P, Xr) - mulorldiv!(Xr, N, solver.P, ldiv) - end - restart && (X .+= Xr) + # Initial Γ and V₁ + copyto!(V[1], R₀) + householder!(V[1], Z[1], τ[1]) + + npass = npass + 1 + inner_iter = 0 + inner_tired = false + + while !(solved || inner_tired || overtimed) + + # Update iteration index + inner_iter = inner_iter + 1 + + # Update workspace if more storage is required and restart is set to false + if !restart && (inner_iter > mem) + for i = 1 : inner_iter + push!(R, SM(undef, p, p)) + end + push!(H, SM(undef, 2p, p)) + push!(τ, SV(undef, p)) + end + + # Continue the block-Arnoldi process. + P = NisI ? V[inner_iter] : solver.P + NisI || mulorldiv!(P, N, V[inner_iter], ldiv) # P ← NVₖ + mul!(W, A, P) # W ← ANVₖ + MisI || mulorldiv!(Q, M, W, ldiv) # Q ← MANVₖ + for i = 1 : inner_iter + mul!(R[nr+i], V[i]', Q) # Ψᵢₖ = Vᵢᴴ * Q + mul!(Q, V[i], R[nr+i], α, β) # Q = Q - Vᵢ * Ψᵢₖ + end + + # Reorthogonalization of the block-Krylov basis. + if reorthogonalization + for i = 1 : inner_iter + mul!(Ψtmp, V[i]', Q) # Ψtmp = Vᵢᴴ * Q + mul!(Q, V[i], Ψtmp, α, β) # Q = Q - Vᵢ * Ψtmp + R[nr+i] .+= Ψtmp + end + end + + # Vₖ₊₁ and Ψₖ₊₁.ₖ are stored in Q and C. + householder!(Q, C, τ[inner_iter]) + + # Update the QR factorization of Hₖ₊₁.ₖ. + # Apply previous Householder reflections Ωᵢ. + for i = 1 : inner_iter-1 + D1 .= R[nr+i] + D2 .= R[nr+i+1] + LAPACK.ormqr!('L', 'T', H[i], τ[i], D) + R[nr+i] .= D1 + R[nr+i+1] .= D2 + end + + # Compute and apply current Householder reflection Ωₖ. + H[inner_iter][1:p,:] .= R[nr+inner_iter] + H[inner_iter][p+1:2p,:] .= C + householder!(H[inner_iter], R[nr+inner_iter], τ[inner_iter], compact=true) + + # Update Zₖ = (Qₖ)ᴴΓE₁ = (Λ₁, ..., Λₖ, Λbarₖ₊₁) + D1 .= Z[inner_iter] + D2 .= zero(FC) + LAPACK.ormqr!('L', 'T', H[inner_iter], τ[inner_iter], D) + Z[inner_iter] .= D1 + + # Update residual norm estimate. + # ‖ M(B - AXₖ) ‖_F = ‖Λbarₖ₊₁‖_F + C .= D2 + RNorm = norm(C) + history && push!(RNorms, RNorm) + + # Update the number of coefficients in Rₖ + nr = nr + inner_iter + + # Update stopping criterion. + solved = RNorm ≤ ε + inner_tired = restart ? inner_iter ≥ min(mem, inner_itmax) : inner_iter ≥ inner_itmax + timer = time_ns() - start_time + overtimed = timer > timemax_ns + kdisplay(iter+inner_iter, verbose) && @printf("%5d %5d %7.1e %.2fs\n", npass, iter+inner_iter, RNorm, ktimer(start_time)) + + # Compute Vₖ₊₁. + if !(solved || inner_tired || overtimed) + if !restart && (inner_iter ≥ mem) + push!(V, SM(undef, n, p)) + push!(Z, SM(undef, p, p)) + end + copyto!(V[inner_iter+1], Q) + Z[inner_iter+1] .= D2 + end + end - # Update inner_itmax, iter, tired and overtimed variables. - inner_itmax = inner_itmax - inner_iter - iter = iter + inner_iter - tired = iter ≥ itmax - timer = time_ns() - start_time - overtimed = timer > timemax_ns - end - (verbose > 0) && @printf("\n") - - # Termination status - tired && (status = "maximum number of iterations exceeded") - solved && (status = "solution good enough given atol and rtol") - overtimed && (status = "time limit exceeded") - - # Update Xₖ - warm_start && !restart && (X .+= ΔX) - solver.warm_start = false - - # Update stats - stats.niter = iter - stats.solved = solved - stats.timer = ktimer(start_time) - stats.status = status - return solver + # Compute Yₖ by solving RₖYₖ = Zₖ with a backward substitution by block. + Y = Z # Yᵢ = Zᵢ + for i = inner_iter : -1 : 1 + pos = nr + i - inner_iter # position of Ψᵢ.ₖ + for j = inner_iter : -1 : i+1 + mul!(Y[i], R[pos], Y[j], α, β) # Yᵢ ← Yᵢ - ΨᵢⱼYⱼ + pos = pos - j + 1 # position of Ψᵢ.ⱼ₋₁ + end + ldiv!(UpperTriangular(R[pos]), Y[i]) # Yᵢ ← Yᵢ \ Ψᵢᵢ + end + + # Form Xₖ = NVₖYₖ + for i = 1 : inner_iter + mul!(Xr, V[i], Y[i], γ, β) + end + if !NisI + copyto!(solver.P, Xr) + mulorldiv!(Xr, N, solver.P, ldiv) + end + restart && (X .+= Xr) + + # Update inner_itmax, iter, tired and overtimed variables. + inner_itmax = inner_itmax - inner_iter + iter = iter + inner_iter + tired = iter ≥ itmax + timer = time_ns() - start_time + overtimed = timer > timemax_ns + end + (verbose > 0) && @printf("\n") + + # Termination status + tired && (status = "maximum number of iterations exceeded") + solved && (status = "solution good enough given atol and rtol") + overtimed && (status = "time limit exceeded") + + # Update Xₖ + warm_start && !restart && (X .+= ΔX) + solver.warm_start = false + + # Update stats + stats.niter = iter + stats.solved = solved + stats.timer = ktimer(start_time) + stats.status = status + return solver end diff --git a/src/LinearSolvers/utils.jl b/src/LinearSolvers/utils.jl index 1ccb5bab..bbbd2334 100644 --- a/src/LinearSolvers/utils.jl +++ b/src/LinearSolvers/utils.jl @@ -60,91 +60,91 @@ may be used in order to create these vectors. `memory` is set to `div(n,p)` if the value given is larger than `div(n,p)`. """ mutable struct BlockGmresSolver{T,FC,SV,SM} <: BlockKrylovSolver{T,FC,SV,SM} - m :: Int - n :: Int - p :: Int - ΔX :: SM - X :: SM - W :: SM - P :: SM - Q :: SM - C :: SM - D :: SM - V :: Vector{SM} - Z :: Vector{SM} - R :: Vector{SM} - H :: Vector{SM} - τ :: Vector{SV} - warm_start :: Bool - stats :: BlockGmresStats{T} + m :: Int + n :: Int + p :: Int + ΔX :: SM + X :: SM + W :: SM + P :: SM + Q :: SM + C :: SM + D :: SM + V :: Vector{SM} + Z :: Vector{SM} + R :: Vector{SM} + H :: Vector{SM} + τ :: Vector{SV} + warm_start :: Bool + stats :: BlockGmresStats{T} end function BlockGmresSolver(m, n, p, memory, SV, SM) - memory = min(div(n,p), memory) - FC = eltype(SV) - T = real(FC) - ΔX = SM(undef, 0, 0) - X = SM(undef, n, p) - W = SM(undef, n, p) - P = SM(undef, 0, 0) - Q = SM(undef, 0, 0) - C = SM(undef, p, p) - D = SM(undef, 2p, p) - V = SM[SM(undef, n, p) for i = 1 : memory] - Z = SM[SM(undef, p, p) for i = 1 : memory] - R = SM[SM(undef, p, p) for i = 1 : div(memory * (memory+1), 2)] - H = SM[SM(undef, 2p, p) for i = 1 : memory] - τ = SV[SV(undef, p) for i = 1 : memory] - stats = BlockGmresStats(0, false, T[], 0.0, "unknown") - solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, false, stats) - return solver + memory = min(div(n,p), memory) + FC = eltype(SV) + T = real(FC) + ΔX = SM(undef, 0, 0) + X = SM(undef, n, p) + W = SM(undef, n, p) + P = SM(undef, 0, 0) + Q = SM(undef, 0, 0) + C = SM(undef, p, p) + D = SM(undef, 2p, p) + V = SM[SM(undef, n, p) for i = 1 : memory] + Z = SM[SM(undef, p, p) for i = 1 : memory] + R = SM[SM(undef, p, p) for i = 1 : div(memory * (memory+1), 2)] + H = SM[SM(undef, 2p, p) for i = 1 : memory] + τ = SV[SV(undef, p) for i = 1 : memory] + stats = BlockGmresStats(0, false, T[], 0.0, "unknown") + solver = BlockGmresSolver{T,FC,SV,SM}(m, n, p, ΔX, X, W, P, Q, C, D, V, Z, R, H, τ, false, stats) + return solver end function BlockGmresSolver(A, B; memory::Int=5) - m, n = size(A) - s, p = size(B) - SM = typeof(B) - SV = matrix_to_vector(SM) - BlockGmresSolver(m, n, p, memory, SV, SM) + m, n = size(A) + s, p = size(B) + SM = typeof(B) + SV = matrix_to_vector(SM) + BlockGmresSolver(m, n, p, memory, SV, SM) end for (KS, fun, nsol, nA, nAt) in ((:BlockGmresSolver, :block_gmres!, 1, 1, 0),) - @eval begin - size(solver :: $KS) = solver.m, solver.n - nrhs(solver :: $KS) = solver.p - statistics(solver :: $KS) = solver.stats - niterations(solver :: $KS) = solver.stats.niter - Aprod(solver :: $KS) = $nA * solver.stats.niter - Atprod(solver :: $KS) = $nAt * solver.stats.niter - nsolution(solver :: $KS) = $nsol - if $nsol == 1 - solution(solver :: $KS) = solver.X - solution(solver :: $KS, p :: Integer) = (p == 1) ? solution(solver) : error("solution(solver) has only one output.") - end - issolved(solver :: $KS) = solver.stats.solved - function warm_start!(solver :: $KS, X0) - n, p = size(solver. X) - n2, p2 = size(X0) - SM = typeof(solver.X) - (n == n2 && p == p2) || error("X0 should have size ($n, $p)") - allocate_if(true, solver, :ΔX, SM, n, p) - copyto!(solver.ΔX, X0) - solver.warm_start = true - return solver + @eval begin + size(solver :: $KS) = solver.m, solver.n + nrhs(solver :: $KS) = solver.p + statistics(solver :: $KS) = solver.stats + niterations(solver :: $KS) = solver.stats.niter + Aprod(solver :: $KS) = $nA * solver.stats.niter + Atprod(solver :: $KS) = $nAt * solver.stats.niter + nsolution(solver :: $KS) = $nsol + if $nsol == 1 + solution(solver :: $KS) = solver.X + solution(solver :: $KS, p :: Integer) = (p == 1) ? solution(solver) : error("solution(solver) has only one output.") + end + issolved(solver :: $KS) = solver.stats.solved + function warm_start!(solver :: $KS, X0) + n, p = size(solver. X) + n2, p2 = size(X0) + SM = typeof(solver.X) + (n == n2 && p == p2) || error("X0 should have size ($n, $p)") + allocate_if(true, solver, :ΔX, SM, n, p) + copyto!(solver.ΔX, X0) + solver.warm_start = true + return solver + end end - end end function sizeof(stats_solver :: BlockKrylovSolver) - type = typeof(stats_solver) - nfields = fieldcount(type) - storage = 0 - for i = 1:nfields - field_i = getfield(stats_solver, i) - size_i = ksizeof(field_i) - storage += size_i - end - return storage + type = typeof(stats_solver) + nfields = fieldcount(type) + storage = 0 + for i = 1:nfields + field_i = getfield(stats_solver, i) + size_i = ksizeof(field_i) + storage += size_i + end + return storage end """ @@ -153,41 +153,41 @@ end Statistics of `solver` are displayed if `show_stats` is set to true. """ function show(io :: IO, solver :: BlockKrylovSolver{T,FC,S}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: AbstractVector{FC}} - workspace = typeof(solver) - name_solver = string(workspace.name.name) - name_stats = string(typeof(solver.stats).name.name) - nbytes = sizeof(solver) - storage = format_bytes(nbytes) - architecture = S <: Vector ? "CPU" : "GPU" - l1 = max(length(name_solver), length(string(FC)) + 11) # length("Precision: ") = 11 - nchar = workspace <: BlockGmresSolver ? 8 : 0 # length("Vector{}") = 8 - l2 = max(ndigits(solver.m) + 7, length(architecture) + 14, length(string(S)) + nchar) # length("nrows: ") = 7 and length("Architecture: ") = 14 - l2 = max(l2, length(name_stats) + 2 + length(string(T))) # length("{}") = 2 - l3 = max(ndigits(solver.n) + 7, length(storage) + 9) # length("Storage: ") = 9 and length("cols: ") = 7 - format = Printf.Format("│%$(l1)s│%$(l2)s│%$(l3)s│\n") - format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%$(l3)s│\n") - @printf(io, "┌%s┬%s┬%s┐\n", "─"^l1, "─"^l2, "─"^l3) - Printf.format(io, format, "$(name_solver)", "nrows: $(solver.m)", "ncols: $(solver.n)") - @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) - Printf.format(io, format, "Precision: $FC", "Architecture: $architecture","Storage: $storage") - @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) - Printf.format(io, format, "Attribute", "Type", "Size") - @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) - for i=1:fieldcount(workspace) - name_i = fieldname(workspace, i) - type_i = fieldtype(workspace, i) - field_i = getfield(solver, name_i) - size_i = ksizeof(field_i) - if (name_i::Symbol in [:w̅, :w̄, :d̅]) && (VERSION < v"1.8.0-DEV") - (size_i ≠ 0) && Printf.format(io, format2, string(name_i), type_i, format_bytes(size_i)) - else - (size_i ≠ 0) && Printf.format(io, format, string(name_i), type_i, format_bytes(size_i)) + workspace = typeof(solver) + name_solver = string(workspace.name.name) + name_stats = string(typeof(solver.stats).name.name) + nbytes = sizeof(solver) + storage = format_bytes(nbytes) + architecture = S <: Vector ? "CPU" : "GPU" + l1 = max(length(name_solver), length(string(FC)) + 11) # length("Precision: ") = 11 + nchar = workspace <: BlockGmresSolver ? 8 : 0 # length("Vector{}") = 8 + l2 = max(ndigits(solver.m) + 7, length(architecture) + 14, length(string(S)) + nchar) # length("nrows: ") = 7 and length("Architecture: ") = 14 + l2 = max(l2, length(name_stats) + 2 + length(string(T))) # length("{}") = 2 + l3 = max(ndigits(solver.n) + 7, length(storage) + 9) # length("Storage: ") = 9 and length("cols: ") = 7 + format = Printf.Format("│%$(l1)s│%$(l2)s│%$(l3)s│\n") + format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%$(l3)s│\n") + @printf(io, "┌%s┬%s┬%s┐\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "$(name_solver)", "nrows: $(solver.m)", "ncols: $(solver.n)") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "Precision: $FC", "Architecture: $architecture","Storage: $storage") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "Attribute", "Type", "Size") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + for i=1:fieldcount(workspace) + name_i = fieldname(workspace, i) + type_i = fieldtype(workspace, i) + field_i = getfield(solver, name_i) + size_i = ksizeof(field_i) + if (name_i::Symbol in [:w̅, :w̄, :d̅]) && (VERSION < v"1.8.0-DEV") + (size_i ≠ 0) && Printf.format(io, format2, string(name_i), type_i, format_bytes(size_i)) + else + (size_i ≠ 0) && Printf.format(io, format, string(name_i), type_i, format_bytes(size_i)) + end + end + @printf(io, "└%s┴%s┴%s┘\n","─"^l1,"─"^l2,"─"^l3) + if show_stats + @printf(io, "\n") + show(io, solver.stats) end - end - @printf(io, "└%s┴%s┴%s┘\n","─"^l1,"─"^l2,"─"^l3) - if show_stats - @printf(io, "\n") - show(io, solver.stats) - end - return nothing + return nothing end