diff --git a/src/SparseVariables.jl b/src/SparseVariables.jl index 96b2e63..a3590d0 100644 --- a/src/SparseVariables.jl +++ b/src/SparseVariables.jl @@ -8,10 +8,12 @@ using PrecompileTools include("sparsearray.jl") include("dictionaries.jl") include("indexedarray.jl") +include("translate.jl") include("tables.jl") export SparseArray export IndexedVarArray +export TranslateVarArray export insertvar! export unsafe_insertvar! export SafeInsert, UnsafeInsert diff --git a/src/tables.jl b/src/tables.jl index 4372049..9531f1a 100644 --- a/src/tables.jl +++ b/src/tables.jl @@ -1,4 +1,6 @@ -function _rows(x::Union{SparseArray,IndexedVarArray}) +function Containers._rows( + x::Union{SparseArray,IndexedVarArray,TranslateVarArray}, +) return zip(eachindex(x.data), keys(x.data)) end @@ -6,7 +8,7 @@ function JuMP.Containers.rowtable( f::Function, x::AbstractSparseArray; header::Vector{Symbol} = Symbol[], -) +)::Vector{<:NamedTuple} if isempty(header) header = Symbol[Symbol("x$i") for i in 1:ndims(x)] push!(header, :y) @@ -17,8 +19,8 @@ function JuMP.Containers.rowtable( "Invalid number of column names provided: Got $got, expected $want.", ) end - names = tuple(header...) - return [NamedTuple{names}((args..., f(x[i]))) for (i, args) in _rows(x)] + elements = [(args..., f(x[i])) for (i, args) in Containers._rows(x)] + return NamedTuple{tuple(header...)}.(elements) end function JuMP.Containers.rowtable( diff --git a/src/translate.jl b/src/translate.jl new file mode 100644 index 0000000..665d4cf --- /dev/null +++ b/src/translate.jl @@ -0,0 +1,172 @@ +""" + TranslateVarArray{V,N,T} + + Structure for holding an optimization variable with a sparse structure with extra indexing + Translate from abstract types to type stable id for performance in dictionaries +""" +struct TranslateVarArray{V<:AbstractVariableRef,N,T} <: AbstractSparseArray{V,N} + f::Function + data::Dictionary{T,V} + index_names::NamedTuple + index_cache::Vector{Dictionary} +end + +""" + translate + Function to translate from type instance to type stable (e.g. Int or String) for performance in dictionaries. + Extend this for types (e.g. Node or Link types) to improve performance +""" +function translate(x) + return x +end +function translate(::Type{T}) where {T} + return T +end + +_data(sa::TranslateVarArray) = sa.data + +""" + insertvar!(var::IndexedVarArray{V,N,T}, index...) + +Insert a new variable with the given index only after checking if keys are valid and not already defined. +""" +function insertvar!(var::TranslateVarArray{V,N,T}, index...) where {V,N,T} + return insertvar!(var, SafeInsert(), index...) +end +function insertvar!( + var::TranslateVarArray{V,N,T}, + ::SafeInsert = SafeInsert(), + index..., +) where {V,N,T} + tindex = translate.(index) + !valid_index(var, index) && throw(BoundsError(var, index))# "Not a valid index for $(var.name): $index"g + already_defined(var, tindex) && error("$index already defined for array") + var[tindex] = var.f(tindex...) + clear_cache!(var) + return var[tindex] +end + +# Extension for standard JuMP macros +function Containers.container( + f::Function, + indices, + D::Type{TranslateVarArray}, + names, +) + iva_names = NamedTuple{tuple(names...)}(indices.prod.iterators) + T = Tuple{translate.(eltype.(indices.prod.iterators))...} + N = length(names) + V = first(Base.return_types(f)) + return TranslateVarArray{V,N,T}( + f, + Dictionary{T,V}(), + iva_names, + Vector{Dictionary}(undef, 2^N), + ) +end + +@generated function _getindex( + sa::TranslateVarArray{T,N}, + tpl::Tuple, +) where {T,N} + lookup = true + slice = true + for t in fieldtypes(tpl) + if !isfixed(t) + lookup = false + if !iscolon(t) + slice = false + end + end + end + + if lookup + return :(get(_data(sa), translate.(tpl), zero(T))) + elseif !slice + return :(retval = select(_data(sa), translate.(tpl)); + length(retval) > 0 ? retval : zero(T)) + else # Return selection or zero if empty to avoid reduction of empty iterate + return :(retval = _select_var(sa, translate.(tpl)); + length(retval) > 0 ? retval : zero(T)) + end +end + +function Base.firstindex(sa::TranslateVarArray, d) + return first(sort(sa.index_names[d])) +end +function Base.lastindex(sa::TranslateVarArray, d) + return last(sort(sa.index_names[d])) +end + +function build_cache!(cache, pat, sa::TranslateVarArray{V,N,T}) where {V,N,T} + if isempty(cache) + for v in keys(sa) + vred = _active(v, translate.(pat)) + nv = get!(cache, vred, T[]) + push!(nv, v) + end + end + return cache +end + +function _select_cached(sa::TranslateVarArray{V,N,T}, pat) where {V,N,T} + # TODO: Benchmark to find good cutoff-value for caching + # TODO: Return same type for type stability + tpat = translate.(pat) + length(_data(sa)) < 100 && return _select_gen(keys(_data(sa)), tpat) + cache = + _getcache(sa, tpat)::Dictionary{_decode_nonslices(sa, tpat),Vector{T}} + build_cache!(cache, tpat, sa) + vals = _dropslices_gen(tpat) + return get!(cache, vals, T[]) +end + +function _getcache(sa::TranslateVarArray{V,N,T}, pat::P) where {V,N,T,P} + t = _get_cache_index(pat) + if isassigned(sa.index_cache, t) + return sa.index_cache[t] + else + sa.index_cache[t] = Dictionary{_decode_nonslices(sa, t),Vector{T}}() + end + return sa.index_cache[t] +end + +""" + _decode_nonslices(::IndexedVarArray{V,N,T}, ::P) + +Reconstruct types of a pattern from the array types and the pattern type +""" +@generated function _decode_nonslices( + ::TranslateVarArray{V,N,T}, + ::P, +) where {V,N,T,P} + fts = fieldtypes(T) + fts2 = fieldtypes(P) + t = Tuple{ + (translate(fts[i]) for (i, v) in enumerate(fts2) if v != Colon)..., + } + return :($t) +end + +function _decode_nonslices(::TranslateVarArray{V,N,T}, v::Integer) where {V,N,T} + fts = fieldtypes(T) + return Tuple{ + (fts[i] for (i, c) in enumerate(last(bitstring(v), N)) if c == '1')..., + } +end + +function JuMP.Containers.rowtable( + f::Function, + x::TranslateVarArray, + col_header::Symbol, +) + header = Symbol[k for k in keys(x.index_names)] + push!(header, col_header) + return JuMP.Containers.rowtable(f, x; header = header) +end + +function JuMP.Containers.rowtable(f::Function, x::TranslateVarArray) + header = Symbol[k for k in keys(x.index_names)] + push!(header, Symbol(f)) + return JuMP.Containers.rowtable(f, x; header = header) +end diff --git a/test/runtests.jl b/test/runtests.jl index 9b346a4..e977c1d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -324,3 +324,59 @@ end @test sum(x) == sum(x[:, :]) @test typeof(sum(x)) <: GenericAffExpr{Float64,MockVariableRef} end + +@testset "TranslateVarArray" begin + # Demo custom types + abstract type Node end + struct Source <: Node + id::Int + end + struct Sink <: Node + id::Int + end + + # Translation to type-stable index + SV.translate(n::Node) = n.id + SV.translate(::Type{<:Node}) = Int + + # Test variable construction + m = Model() + # length < 100 -> filtering (no cache) + @variable( + m, + x[i = Source.(1:10), j = Sink.(1:5)], + container = SV.TranslateVarArray + ) + for i in Source.(1:10), j in Sink.(1:5) + insertvar!(x, i, j) + end + @test length(x) == 50 + @test typeof(x) == SV.TranslateVarArray{VariableRef,2,Tuple{Int,Int}} + @test length(x[:, 1]) == 10 + @test length(x[1, :]) == 5 + @test length(x[1:2, :]) == 10 + + # length > 100 -> cache + @variable( + m, + y[i = Source.(1:10), j = Sink.(1:15)], + container = SV.TranslateVarArray + ) + for i in Source.(1:10), j in Sink.(1:15) + insertvar!(y, i, j) + end + @test length(y) == 150 + @test typeof(y) == SV.TranslateVarArray{VariableRef,2,Tuple{Int,Int}} + @test length(y[:, 1]) == 10 + @test length(y[1, :]) == 15 + @test length(y[1:2, :]) == 30 + @test typeof(x[Source(1), Sink(2)]) == VariableRef + + # Tables + @test length(Containers.rowtable(x)) == 50 + @test length(Containers.rowtable(identity, x, :a)) == 50 + @test length(Containers.rowtable(x; header = [:x, :y, :z])) == 50 + @test length(Containers.rowtable(y)) == 150 + @test length(Containers.rowtable(identity, y, :a)) == 150 + @test length(Containers.rowtable(y; header = [:x, :y, :z])) == 150 +end