From 8cb59a4bdbd8b5f33499beba39453aabf008794d Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Thu, 15 Jun 2023 23:10:42 +0200 Subject: [PATCH 1/8] Add TranslateVarArray for simplified indexing --- src/SparseVariables.jl | 2 + src/translate.jl | 122 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 28 ++++++++++ 3 files changed, 152 insertions(+) create mode 100644 src/translate.jl diff --git a/src/SparseVariables.jl b/src/SparseVariables.jl index 96b2e63..e88dd92 100644 --- a/src/SparseVariables.jl +++ b/src/SparseVariables.jl @@ -10,6 +10,8 @@ include("dictionaries.jl") include("indexedarray.jl") include("tables.jl") +include("translate.jl") + export SparseArray export IndexedVarArray export insertvar! diff --git a/src/translate.jl b/src/translate.jl new file mode 100644 index 0000000..2e82266 --- /dev/null +++ b/src/translate.jl @@ -0,0 +1,122 @@ +""" + TranslateVarArray{V,N,T} + + Structure for holding an optimization variable with a sparse structure with extra indexing + Translate from abstract types to type stable id for performance in dictionaries +""" +struct TranslateVarArray{V<:AbstractVariableRef,N,T} <: AbstractSparseArray{V,N} + f::Function + data::Dictionary{T,V} + index_names::NamedTuple + index_cache::Vector{Dictionary} +end + +""" + translate + Function to translate from type instance to type stable (e.g. Int or String) for performance in dictionaries. + Extend this for types (e.g. Node or Link types) to improve performance +""" +function translate(x) + return x +end +function translate(::Type{T}) where {T} + return T +end + +_data(sa::TranslateVarArray) = sa.data + +""" + insertvar!(var::IndexedVarArray{V,N,T}, index...) + +Insert a new variable with the given index only after checking if keys are valid and not already defined. +""" +function insertvar!(var::TranslateVarArray{V,N,T}, index...) where {V,N,T} + return insertvar!(var, SafeInsert(), index...) +end +function insertvar!( + var::TranslateVarArray{V,N,T}, + ::SafeInsert = SafeInsert(), + index..., +) where {V,N,T} + tindex = translate.(index) + !valid_index(var, index) && throw(BoundsError(var, index))# "Not a valid index for $(var.name): $index"g + already_defined(var, tindex) && error("$index already defined for array") + var[tindex] = var.f(tindex...) + clear_cache!(var) + return var[tindex] +end + +# Extension for standard JuMP macros +function Containers.container( + f::Function, + indices, + D::Type{TranslateVarArray}, + names, +) + iva_names = NamedTuple{tuple(names...)}(indices.prod.iterators) + T = Tuple{translate.(eltype.(indices.prod.iterators))...} + N = length(names) + V = first(Base.return_types(f)) + return TranslateVarArray{V,N,T}( + f, + Dictionary{T,V}(), + iva_names, + Vector{Dictionary}(undef, 2^N), + ) +end + +@generated function _getindex( + sa::TranslateVarArray{T,N}, + tpl::Tuple, +) where {T,N} + lookup = true + slice = true + for t in fieldtypes(tpl) + if !isfixed(t) + lookup = false + if !iscolon(t) + slice = false + end + end + end + + if lookup + return :(get(_data(sa), translate.(tpl), zero(T))) + elseif !slice + return :(retval = select(_data(sa), translate.(tpl)); + length(retval) > 0 ? retval : zero(T)) + else # Return selection or zero if empty to avoid reduction of empty iterate + return :(retval = _select_var(sa, translate.(tpl)); + length(retval) > 0 ? retval : zero(T)) + end +end + +function Base.firstindex(sa::TranslateVarArray, d) + return first(sort(sa.index_names[d])) +end +function Base.lastindex(sa::TranslateVarArray, d) + return last(sort(sa.index_names[d])) +end + +function build_cache!(cache, pat, sa::TranslateVarArray{V,N,T}) where {V,N,T} + if isempty(cache) + for v in keys(sa) + vred = _active(v, translate.(pat)) + nv = get!(cache, vred, T[]) + push!(nv, v) + end + end + return cache +end + +function _select_cached(sa::TranslateVarArray{V,N,T}, pat) where {V,N,T} + # TODO: Benchmark to find good cutoff-value for caching + # TODO: Return same type for type stability + tpat = translate.(pat) + length(_data(sa)) < 100 && return _select_gen(keys(_data(sa)), tpat) + cache = + _getcache(sa, tpat)::Dictionary{_decode_nonslices(sa, tpat),Vector{T}} + build_cache!(cache, tpat, sa) + vals = _dropslices_gen(tpat) + return get!(cache, vals, T[]) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9b346a4..58de37e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -324,3 +324,31 @@ end @test sum(x) == sum(x[:, :]) @test typeof(sum(x)) <: GenericAffExpr{Float64,MockVariableRef} end + +@testset "TranslateVarArray" begin + # Demo custom types + abstract type Node end + struct Source <: Node + id::Int + end + struct Sink <: Node + id::Int + end + + # Translation to type-stable index + SV.translate(n::Node) = n.id + SV.translate(::Type{<:Node}) = Int + + # Test variable construction + m = Model() + @variable( + m, + x[i = Source.(1:10), j = Sink.(1:10)], + container = SV.TranslateVarArray + ) + for i in Source.(1:10), j in Sink.(1:10) + insertvar!(x, i, j) + end + @test length(x) == 100 + @test typeof(x) == SV.TranslateVarArray{VariableRef,2,Tuple{Int,Int}} +end From 4277efac29e3cab1ca71f96981b890e87648e0c7 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Sat, 17 Jun 2023 14:15:39 +0200 Subject: [PATCH 2/8] Fix cache for TranslateVarArray --- src/translate.jl | 34 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 25 ++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/translate.jl b/src/translate.jl index 2e82266..287f9ae 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -120,3 +120,37 @@ function _select_cached(sa::TranslateVarArray{V,N,T}, pat) where {V,N,T} vals = _dropslices_gen(tpat) return get!(cache, vals, T[]) end + +function _getcache(sa::TranslateVarArray{V,N,T}, pat::P) where {V,N,T,P} + t = _get_cache_index(pat) + if isassigned(sa.index_cache, t) + return sa.index_cache[t] + else + sa.index_cache[t] = Dictionary{_decode_nonslices(sa, t),Vector{T}}() + end + return sa.index_cache[t] +end + +""" + _decode_nonslices(::IndexedVarArray{V,N,T}, ::P) + +Reconstruct types of a pattern from the array types and the pattern type +""" +@generated function _decode_nonslices( + ::TranslateVarArray{V,N,T}, + ::P, +) where {V,N,T,P} + fts = fieldtypes(T) + fts2 = fieldtypes(P) + t = Tuple{ + (translate(fts[i]) for (i, v) in enumerate(fts2) if v != Colon)..., + } + return :($t) +end + +function _decode_nonslices(::TranslateVarArray{V,N,T}, v::Integer) where {V,N,T} + fts = fieldtypes(T) + return Tuple{ + (fts[i] for (i, c) in enumerate(last(bitstring(v), N)) if c == '1')..., + } +end diff --git a/test/runtests.jl b/test/runtests.jl index 58de37e..621d04a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -341,14 +341,33 @@ end # Test variable construction m = Model() + # length < 100 -> filtering (no cache) @variable( m, - x[i = Source.(1:10), j = Sink.(1:10)], + x[i = Source.(1:10), j = Sink.(1:5)], container = SV.TranslateVarArray ) - for i in Source.(1:10), j in Sink.(1:10) + for i in Source.(1:10), j in Sink.(1:5) insertvar!(x, i, j) end - @test length(x) == 100 + @test length(x) == 50 @test typeof(x) == SV.TranslateVarArray{VariableRef,2,Tuple{Int,Int}} + @test length(x[:, 1]) == 10 + @test length(x[1, :]) == 5 + @test length(x[1:2, :]) == 10 + + # length > 100 -> cache + @variable( + m, + y[i = Source.(1:10), j = Sink.(1:15)], + container = SV.TranslateVarArray + ) + for i in Source.(1:10), j in Sink.(1:15) + insertvar!(y, i, j) + end + @test length(y) == 150 + @test typeof(y) == SV.TranslateVarArray{VariableRef,2,Tuple{Int,Int}} + @test length(y[:, 1]) == 10 + @test length(y[1, :]) == 15 + @test length(y[1:2, :]) == 30 end From e3c0dbfd7800b43b29a4e88dc2708051b4fc0d50 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:06:09 +0200 Subject: [PATCH 3/8] Performance improvement from JuMP v 1.12 https://github.com/jump-dev/JuMP.jl/pull/3410 --- src/tables.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/tables.jl b/src/tables.jl index 4372049..f4852c5 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -1,4 +1,4 @@ -function _rows(x::Union{SparseArray,IndexedVarArray}) +function Containers._rows(x::Union{SparseArray,IndexedVarArray}) return zip(eachindex(x.data), keys(x.data)) end @@ -6,7 +6,7 @@ function JuMP.Containers.rowtable( f::Function, x::AbstractSparseArray; header::Vector{Symbol} = Symbol[], -) +)::Vector{<:NamedTuple} if isempty(header) header = Symbol[Symbol("x$i") for i in 1:ndims(x)] push!(header, :y) @@ -17,8 +17,8 @@ function JuMP.Containers.rowtable( "Invalid number of column names provided: Got $got, expected $want.", ) end - names = tuple(header...) - return [NamedTuple{names}((args..., f(x[i]))) for (i, args) in _rows(x)] + elements = [(args..., f(x[i])) for (i, args) in Containers._rows(x)] + return NamedTuple{tuple(header...)}.(elements) end function JuMP.Containers.rowtable( From 7a9fe4de59c5250cf962b31360a711a7bf4d3a44 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:12:51 +0200 Subject: [PATCH 4/8] Tables support for TranslateVarArray --- src/SparseVariables.jl | 3 ++- src/tables.jl | 4 +++- src/translate.jl | 16 ++++++++++++++++ test/runtests.jl | 4 ++++ 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/SparseVariables.jl b/src/SparseVariables.jl index e88dd92..43cadb0 100644 --- a/src/SparseVariables.jl +++ b/src/SparseVariables.jl @@ -8,12 +8,13 @@ using PrecompileTools include("sparsearray.jl") include("dictionaries.jl") include("indexedarray.jl") +include("translate.jl") include("tables.jl") -include("translate.jl") export SparseArray export IndexedVarArray +export TranslateVarArray export insertvar! export unsafe_insertvar! export SafeInsert, UnsafeInsert diff --git a/src/tables.jl b/src/tables.jl index f4852c5..9531f1a 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -1,4 +1,6 @@ -function Containers._rows(x::Union{SparseArray,IndexedVarArray}) +function Containers._rows( + x::Union{SparseArray,IndexedVarArray,TranslateVarArray}, +) return zip(eachindex(x.data), keys(x.data)) end diff --git a/src/translate.jl b/src/translate.jl index 287f9ae..665d4cf 100644 --- a/src/translate.jl +++ b/src/translate.jl @@ -154,3 +154,19 @@ function _decode_nonslices(::TranslateVarArray{V,N,T}, v::Integer) where {V,N,T} (fts[i] for (i, c) in enumerate(last(bitstring(v), N)) if c == '1')..., } end + +function JuMP.Containers.rowtable( + f::Function, + x::TranslateVarArray, + col_header::Symbol, +) + header = Symbol[k for k in keys(x.index_names)] + push!(header, col_header) + return JuMP.Containers.rowtable(f, x; header = header) +end + +function JuMP.Containers.rowtable(f::Function, x::TranslateVarArray) + header = Symbol[k for k in keys(x.index_names)] + push!(header, Symbol(f)) + return JuMP.Containers.rowtable(f, x; header = header) +end diff --git a/test/runtests.jl b/test/runtests.jl index 621d04a..dc1a2a7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -370,4 +370,8 @@ end @test length(y[:, 1]) == 10 @test length(y[1, :]) == 15 @test length(y[1:2, :]) == 30 + + # Tables + @test length(Containers.rowtable(x)) == 50 + @test length(Containers.rowtable(y)) == 150 end From 4a96474dad8e94ccaa335429333156a805901066 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:18:01 +0200 Subject: [PATCH 5/8] Format fix --- src/SparseVariables.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SparseVariables.jl b/src/SparseVariables.jl index 43cadb0..a3590d0 100644 --- a/src/SparseVariables.jl +++ b/src/SparseVariables.jl @@ -11,7 +11,6 @@ include("indexedarray.jl") include("translate.jl") include("tables.jl") - export SparseArray export IndexedVarArray export TranslateVarArray From 96375bb742fc6b8c93ce6a570a9ce40480cd7d73 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:29:06 +0200 Subject: [PATCH 6/8] Test with custom header for improved coverage --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index dc1a2a7..f61839b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -373,5 +373,7 @@ end # Tables @test length(Containers.rowtable(x)) == 50 + @test length(Containers.rowtable(x; header = [:x, :y, :z])) == 50 @test length(Containers.rowtable(y)) == 150 + @test length(Containers.rowtable(y; header = [:x, :y, :z])) == 150 end From eb695f088c82e2cf8b26515a0c4e244f89fdec18 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:46:09 +0200 Subject: [PATCH 7/8] Test lookup in TranslateVarArray --- test/runtests.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index f61839b..5f1ba9a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -370,7 +370,8 @@ end @test length(y[:, 1]) == 10 @test length(y[1, :]) == 15 @test length(y[1:2, :]) == 30 - + @test typeof(x[Source(1), Sink(2)]) == VariableRef + # Tables @test length(Containers.rowtable(x)) == 50 @test length(Containers.rowtable(x; header = [:x, :y, :z])) == 50 From 220d31fe10f98f7d7644e5ffbbd1e782a7b7b111 Mon Sep 17 00:00:00 2001 From: Lars Hellemo Date: Tue, 20 Jun 2023 14:52:46 +0200 Subject: [PATCH 8/8] More rowtable tests --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5f1ba9a..e977c1d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -371,10 +371,12 @@ end @test length(y[1, :]) == 15 @test length(y[1:2, :]) == 30 @test typeof(x[Source(1), Sink(2)]) == VariableRef - + # Tables @test length(Containers.rowtable(x)) == 50 + @test length(Containers.rowtable(identity, x, :a)) == 50 @test length(Containers.rowtable(x; header = [:x, :y, :z])) == 50 @test length(Containers.rowtable(y)) == 150 + @test length(Containers.rowtable(identity, y, :a)) == 150 @test length(Containers.rowtable(y; header = [:x, :y, :z])) == 150 end