From 2ef6709ab41405bedf89eb83e9372461ce3ac450 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Tue, 30 Jul 2024 12:00:25 +0200 Subject: [PATCH] Add test for bridge --- src/Bridges/Variable/bridges/set_dot.jl | 57 +++++++++----------- src/Bridges/Variable/set_map.jl | 2 +- src/sets.jl | 4 ++ test/Bridges/Variable/set_dot.jl | 72 +++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 32 deletions(-) create mode 100644 test/Bridges/Variable/set_dot.jl diff --git a/src/Bridges/Variable/bridges/set_dot.jl b/src/Bridges/Variable/bridges/set_dot.jl index ba593a717a..d8b5dd1df8 100644 --- a/src/Bridges/Variable/bridges/set_dot.jl +++ b/src/Bridges/Variable/bridges/set_dot.jl @@ -1,7 +1,7 @@ -struct DotProductsBridge{T,S,V} <: SetMapBridge{T,S,MOI.SetWithDotProducts{S,V}} +struct DotProductsBridge{T,S,A,V} <: SetMapBridge{T,S,MOI.SetWithDotProducts{S,A,V}} variables::Vector{MOI.VariableIndex} constraint::MOI.ConstraintIndex{MOI.VectorOfVariables,S} - set::MOI.SetWithDotProducts{S,V} + set::MOI.SetWithDotProducts{S,A,V} end function supports_constrained_variable( @@ -13,23 +13,23 @@ end function concrete_bridge_type( ::Type{<:DotProductsBridge{T}}, - ::Type{MOI.SetWithDotProducts{S,V}}, -) where {T,S,V} - return DotProductsBridge{T,S,V} + ::Type{MOI.SetWithDotProducts{S,A,V}}, +) where {T,S,A,V} + return DotProductsBridge{T,S,A,V} end function bridge_constrained_variable( - BT::Type{DotProductsBridge{T,S,V}}, + BT::Type{DotProductsBridge{T,S,A,V}}, model::MOI.ModelLike, - set::MOI.SetWithDotProducts{S,V}, -) where {T,S,V} + set::MOI.SetWithDotProducts{S,A,V}, +) where {T,S,A,V} variables, constraint = _add_constrained_var(model, MOI.Bridges.inverse_map_set(BT, set)) return BT(variables, constraint, set) end function MOI.Bridges.map_set(bridge::DotProductsBridge{T,S}, set::S) where {T,S} - return MOI.SetWithDotProducts(set, bridge.vectors) + return bridge.set end function MOI.Bridges.inverse_map_set( @@ -45,28 +45,23 @@ function MOI.Bridges.map_function( i::MOI.Bridges.IndexInVector, ) where {T} scalars = MOI.Utilities.eachscalar(func) - if i.value in eachindex(bridge.set.vectors) - return MOI.Utilities.set_dot( - bridge.set.vectors[i.value], - scalars, - bridge.set.set, - ) - else - return convert( - MOI.ScalarAffineFunction{T}, - scalars[i.value-length(bridge.vectors)], - ) - end + return MOI.Utilities.set_dot( + bridge.set.vectors[i.value], + scalars, + bridge.set.set, + ) end -function MOI.Bridges.inverse_map_function( - bridge::DotProductsBridge{T}, - func, -) where {T} - m = length(bridge.set.vectors) - return MOI.Utilities.operate( - vcat, - T, - MOI.Utilities.eachscalar(func)[(m+1):end], - ) +# This returns `true` by default for `SetMapBridge` +# but is is not supported for this bridge because `inverse_map_function` +# is not implemented +function MOI.supports(::MOI.ModelLike, ::MOI.VariablePrimalStart, ::Type{<:DotProductsBridge}) + return false +end + +function unbridged_map( + ::DotProductsBridge, + ::Vector{MOI.VariableIndex}, +) + return nothing end diff --git a/src/Bridges/Variable/set_map.jl b/src/Bridges/Variable/set_map.jl index 574ef805cf..6d0e39091e 100644 --- a/src/Bridges/Variable/set_map.jl +++ b/src/Bridges/Variable/set_map.jl @@ -129,7 +129,7 @@ function MOI.get( bridge::SetMapBridge, ) set = MOI.get(model, attr, bridge.constraint) - return MOI.Bridges.map_set(typeof(bridge), set) + return MOI.Bridges.map_set(bridge, set) end function MOI.set( diff --git a/src/sets.jl b/src/sets.jl index 2af3b3e4cd..ed49de0d6b 100644 --- a/src/sets.jl +++ b/src/sets.jl @@ -1814,6 +1814,10 @@ struct SetWithDotProducts{S,A,V<:AbstractVector{A}} <: AbstractVectorSet vectors::V end +function Base.:(==)(s1::SetWithDotProducts, s2::SetWithDotProducts) + return s1.set == s2.set && s1.vectors == s2.vectors +end + function Base.copy(s::SetWithDotProducts) return SetWithDotProducts(copy(s.set), copy(s.vectors)) end diff --git a/test/Bridges/Variable/set_dot.jl b/test/Bridges/Variable/set_dot.jl new file mode 100644 index 0000000000..e77d775b64 --- /dev/null +++ b/test/Bridges/Variable/set_dot.jl @@ -0,0 +1,72 @@ +# Copyright (c) 2017: Miles Lubin and contributors +# Copyright (c) 2017: Google Inc. +# +# Use of this source code is governed by an MIT-style license that can be found +# in the LICENSE.md file or at https://opensource.org/licenses/MIT. + +module TestVariableDotProducts + +using Test + +import MathOptInterface as MOI + +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +include("../utilities.jl") + +function test_psd() + MOI.Bridges.runtests( + MOI.Bridges.Variable.DotProductsBridge, + model -> begin + x, _ = MOI.add_constrained_variables(model, + MOI.SetWithDotProducts( + MOI.PositiveSemidefiniteConeTriangle(2), + MOI.TriangleVectorization.([ + [1 2.0 + 2 3], + [4 5.0 + 5 6], + ]), + ) + ) + MOI.add_constraint( + model, + 1.0x[1], + MOI.EqualTo(0.0), + ) + MOI.add_constraint( + model, + 1.0x[2], + MOI.LessThan(0.0), + ) + end, + model -> begin + Q, _ = MOI.add_constrained_variables(model, MOI.PositiveSemidefiniteConeTriangle(2)) + MOI.add_constraint( + model, + 1.0 * Q[1] + 4.0 * Q[2] + 3.0 * Q[3], + MOI.EqualTo(0.0), + ) + MOI.add_constraint( + model, + 4.0 * Q[1] + 10.0 * Q[2] + 6.0 * Q[3], + MOI.LessThan(0.0), + ) + end; + cannot_unbridge = true, + ) + return +end + +end # module + +TestVariableDotProducts.runtests()