From 29568d2ad0f20fe56a46c2e4b4434c84920f9c05 Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Thu, 5 Jan 2023 12:37:53 +0100 Subject: [PATCH] MutableArithmetics for IPM/HSD Use the `BigFloat` dot product from MutableArithmetics in HSD code. Helps with the performance of the `BigFloat` arithmetic. The change shouldn't affect other arithmetics, and it's coded so it'd be easy to extend it to another mutable arithmetic apart from just `BigFloat`, if necessary, and if such a type will support MutableArithmetics. Apart from improving performance, this change could possibly also benefit LP problems with numerical issues (when using `BigFloat`), because the MA dot product uses a summation algorithm that's more accurate than naive summation. A performance experiment is presented in the commit message of the following commit. The conclusion is that this commit improves performance only by a tiny bit, likewise with allocation. --- Project.toml | 1 + src/IPM/HSD/HSD.jl | 55 ++++++++++++++---- src/IPM/HSD/dot_for_mutable.jl | 102 +++++++++++++++++++++++++++++++++ src/IPM/HSD/step.jl | 66 ++++++++++++++------- src/Tulip.jl | 1 + 5 files changed, 193 insertions(+), 32 deletions(-) create mode 100644 src/IPM/HSD/dot_for_mutable.jl diff --git a/Project.toml b/Project.toml index 62531f7b..b74b6d5e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee" +MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" QPSReader = "10f199a5-22af-520b-b891-7ce84a7b1bd0" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/IPM/HSD/HSD.jl b/src/IPM/HSD/HSD.jl index e2681a11..f08dd902 100644 --- a/src/IPM/HSD/HSD.jl +++ b/src/IPM/HSD/HSD.jl @@ -64,6 +64,8 @@ mutable struct HSD{T, Tv, Tb, Ta, Tk} <: AbstractIPMOptimizer{T} end +include("dot_for_mutable.jl") + include("step.jl") @@ -101,13 +103,22 @@ function compute_residuals!(hsd::HSD{T} mul!(res.rd, transpose(dat.A), pt.y, -one(T), one(T)) @. res.rd += pt.zu .* dat.uflag - pt.zl .* dat.lflag + dot_buf = buffer_for_dot_weighted_sum(T) + # Gap residual # rg = c'x - (b'y + l'zl - u'zu) + k - res.rg = pt.κ + (dot(dat.c, pt.x) - ( - dot(dat.b, pt.y) - + dot(dat.l .* dat.lflag, pt.zl) - - dot(dat.u .* dat.uflag, pt.zu) - )) + res.rg = pt.κ + buffered_dot_weighted_sum!!( + dot_buf, + ( + (dat.c, pt.x), + (dat.b, pt.y), + (dat.l .* dat.lflag, pt.zl), + (dat.u .* dat.uflag, pt.zu), + ), + ( + 1, -1, -1, 1, + ), + ) # Residuals norm res.rp_nrm = norm(res.rp, Inf) @@ -117,11 +128,17 @@ function compute_residuals!(hsd::HSD{T} res.rg_nrm = norm(res.rg, Inf) # Compute primal and dual bounds - hsd.primal_objective = dot(dat.c, pt.x) / pt.τ + dat.c0 - hsd.dual_objective = ( - dot(dat.b, pt.y) - + dot(dat.l .* dat.lflag, pt.zl) - - dot(dat.u .* dat.uflag, pt.zu) + hsd.primal_objective = buffered_dot_product!!(dot_buf.dot, dat.c, pt.x) / pt.τ + dat.c0 + hsd.dual_objective = buffered_dot_weighted_sum!!( + dot_buf, + ( + (dat.b, pt.y), + (dat.l .* dat.lflag, pt.zl), + (dat.u .* dat.uflag, pt.zu), + ), + ( + 1, 1, -1, + ), ) / pt.τ + dat.c0 return nothing @@ -168,12 +185,15 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher return nothing end + dot_buf = buffer_for_dot_weighted_sum(T) + # Check for infeasibility certificates if max( norm(dat.A * pt.x, Inf), norm((pt.x .- pt.xl) .* dat.lflag, Inf), norm((pt.x .+ pt.xu) .* dat.uflag, Inf) - ) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) < - ϵi * dot(dat.c, pt.x) + ) * (norm(dat.c, Inf) / max(1, norm(dat.b, Inf))) < + -ϵi * buffered_dot_product!!(dot_buf.dot, dat.c, pt.x) # Dual infeasible, i.e., primal unbounded hsd.primal_status = Sln_InfeasibilityCertificate hsd.solver_status = Trm_DualInfeasible @@ -185,7 +205,18 @@ function update_solver_status!(hsd::HSD{T}, ϵp::T, ϵd::T, ϵg::T, ϵi::T) wher norm(dat.l .* dat.lflag, Inf), norm(dat.u .* dat.uflag, Inf), norm(dat.b, Inf) - ) / (max(one(T), norm(dat.c, Inf))) < (dot(dat.b, pt.y) + dot(dat.l .* dat.lflag, pt.zl)- dot(dat.u .* dat.uflag, pt.zu)) * ϵi + ) / (max(one(T), norm(dat.c, Inf))) < buffered_dot_weighted_sum!!( + dot_buf, + ( + (dat.b, pt.y), + (dat.l .* dat.lflag, pt.zl), + (dat.u .* dat.uflag, pt.zu), + ), + ( + 1, 1, -1, + ), + ) * ϵi + # Primal infeasible hsd.dual_status = Sln_InfeasibilityCertificate hsd.solver_status = Trm_PrimalInfeasible diff --git a/src/IPM/HSD/dot_for_mutable.jl b/src/IPM/HSD/dot_for_mutable.jl new file mode 100644 index 00000000..0bece3d4 --- /dev/null +++ b/src/IPM/HSD/dot_for_mutable.jl @@ -0,0 +1,102 @@ +# Right now this is just `BigFloat`, but in principle it could be expanded to a whitelist +# that would include other mutable types. +const SupportedMutableArithmetics = BigFloat + +buffer_for_dot_product(::Type{V}) where {V <: AbstractVector{<:Real}} = + buffer_for(LinearAlgebra.dot, V, V) + +buffer_for_dot_product(::Type{F}) where {F <: Real} = + buffer_for_dot_product(Vector{F}) + +buffered_dot_product_to!( + buf::B, + result::F, + x::V, + y::V, +) where {B <: Any, F <: SupportedMutableArithmetics, V <: AbstractVector{F}} = + buffered_operate_to!(buf, result, LinearAlgebra.dot, x, y) + +function buffered_dot_product!!( + buf::B, + x::V, + y::V, +) where {B <: Any, F <: SupportedMutableArithmetics, V <: AbstractVector{F}} + ret = zero(F) + ret = buffered_dot_product_to!(buf, ret, x, y) + return ret +end + +buffered_dot_product!!(::Nothing, x::V, y::V) where {F <: Real, V <: AbstractVector{F}} = + dot(x, y) + +struct DotWeightedSumBuffer{F <: Real, DotBuffer <: Any} + tmp::F + dot::DotBuffer + + function DotWeightedSumBuffer{F}() where {F <: Real} + dot_buffer = buffer_for_dot_product(F) + return new{F, typeof(dot_buffer)}(zero(F), dot_buffer) + end +end + +struct DotWeightedSumBufferDummy + dot::Nothing + + DotWeightedSumBufferDummy() = new(nothing) +end + +buffer_for_dot_weighted_sum(::Type{F}) where {F <: SupportedMutableArithmetics} = + DotWeightedSumBuffer{F}() + +buffer_for_dot_weighted_sum(::Type{F}) where {F <: Real} = + DotWeightedSumBufferDummy() + +function buffered_dot_weighted_sum_to_inner!( + buf::DotWeightedSumBuffer{F}, + sum::F, + vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}}, + weights::NTuple{n, <:Real}, +) where {n, F <: SupportedMutableArithmetics} + sum = zero!!(sum) + + for i in 1:n + weight = weights[i] + (x, y) = vecs[i] + + buffered_dot_product_to!(buf.dot, buf.tmp, x, y) + mul!!(buf.tmp, weight) + + sum = add!!(sum, buf.tmp) + end + + return sum +end + +buffered_dot_weighted_sum_to!( + buf::DotWeightedSumBuffer{F}, + sum::F, + vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}}, + weights::NTuple{n, Int}) where {n, F <: SupportedMutableArithmetics} = + # It seems like the specialization + # *(x::BigFloat, c::Int8) + # could be more efficient than + # *(x::BigFloat, c::Int) + # MPFR has separate functions for those, and Julia uses them, + # there must be a good (performance) reason for that. + buffered_dot_weighted_sum_to_inner!(buf, sum, vecs, map(Int8, weights)) + +function buffered_dot_weighted_sum!!( + buf::DotWeightedSumBuffer{F}, + vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}}, + weights::NTuple{n, Int}, +) where {n, F <: SupportedMutableArithmetics} + ret = zero(F) + ret = buffered_dot_weighted_sum_to!(buf, ret, vecs, weights) + return ret +end + +buffered_dot_weighted_sum!!( + buf::DotWeightedSumBufferDummy, + vecs::NTuple{n, NTuple{2, <:AbstractVector{F}}}, + weights::NTuple{n, Int}) where {n, F <: Real} = + mapreduce((vec2, weight) -> weight*dot(vec2...), +, vecs, weights, init = zero(F)) diff --git a/src/IPM/HSD/step.jl b/src/IPM/HSD/step.jl index cc7b7abc..f37ee2ee 100644 --- a/src/IPM/HSD/step.jl +++ b/src/IPM/HSD/step.jl @@ -61,17 +61,23 @@ function compute_step!(hsd::HSD{T, Tv}, params::IPMOptions{T}) where{T, Tv<:Abst ξ_ = @. (dat.c - ((pt.zl / pt.xl) * dat.l) * dat.lflag - ((pt.zu / pt.xu) * dat.u) * dat.uflag) KKT.solve!(hx, hy, hsd.kkt, dat.b, ξ_) + dot_buf = buffer_for_dot_weighted_sum(T) + # Recover h0 = ρg + κ / τ - c'hx + b'hy - u'hz # Some of the summands may take large values, # so care must be taken for numerical stability - h0 = ( - dot(dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag) - + dot(dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag) - - dot((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx) - + dot(b, hy) - + pt.κ / pt.τ - + hsd.regG - ) + h0 = buffered_dot_weighted_sum!!( + dot_buf, + ( + (dat.l .* dat.lflag, (dat.l .* θl) .* dat.lflag), + (dat.u .* dat.uflag, (dat.u .* θu) .* dat.uflag), + ((@. (c + (θl * dat.l) * dat.lflag + (θu * dat.u) * dat.uflag)), hx), + (b, hy), + ), + ( + 1, 1, -1, 1, + ), + ) + pt.κ / pt.τ + hsd.regG # Affine-scaling direction @timeit hsd.timer "Newton" solve_newton_system!(Δ, hsd, hx, hy, h0, @@ -211,22 +217,42 @@ function solve_newton_system!(Δ::Point{T, Tv}, end @timeit hsd.timer "KKT" KKT.solve!(Δ.x, Δ.y, hsd.kkt, ξp, ξd_) + dot_buf = buffer_for_dot_weighted_sum(T) + # II. Recover Δτ, Δx, Δy # Compute Δτ - @timeit hsd.timer "ξg_" ξg_ = (ξg + ξtk / pt.τ - - dot((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag) # l'(Xl)^-1 * ξxzl - + dot((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag) - - dot(((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag) - - dot(((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag) # - ) + @timeit hsd.timer "ξg_" ξg_ = ξg + ξtk / pt.τ + + buffered_dot_weighted_sum!!( + dot_buf, + ( + ((ξxzl ./ pt.xl) .* dat.lflag, dat.l .* dat.lflag), # l'(Xl)^-1 * ξxzl + ((ξxzu ./ pt.xu) .* dat.uflag, dat.u .* dat.uflag), + (((pt.zl ./ pt.xl) .* ξl) .* dat.lflag, dat.l .* dat.lflag), + (((pt.zu ./ pt.xu) .* ξu) .* dat.uflag, dat.u .* dat.uflag), + ), + ( + -1, 1, -1, -1, + ), + ) @timeit hsd.timer "Δτ" Δ.τ = ( - ξg_ - + dot((@. (dat.c - + ((pt.zl / pt.xl) * dat.l) * dat.lflag - + ((pt.zu / pt.xu) * dat.u) * dat.uflag)) - , Δ.x) - - dot(dat.b, Δ.y) + ξg_ + + buffered_dot_weighted_sum!!( + dot_buf, + ( + ( + (@. ( + dat.c + + ((pt.zl / pt.xl) * dat.l) * dat.lflag + + ((pt.zu / pt.xu) * dat.u) * dat.uflag)), + Δ.x, + ), + (dat.b, Δ.y), + ), + ( + 1, -1, + ), + ) ) / h0 diff --git a/src/Tulip.jl b/src/Tulip.jl index ac52b5cb..224640e9 100644 --- a/src/Tulip.jl +++ b/src/Tulip.jl @@ -2,6 +2,7 @@ module Tulip using LinearAlgebra using Logging +using MutableArithmetics using Printf using SparseArrays using TOML