From a8afb3e729e68b3298cb007a4ee8399ca2505e86 Mon Sep 17 00:00:00 2001 From: hellemo Date: Mon, 5 Jun 2023 21:32:05 +0200 Subject: [PATCH] Parametrize safe/unsafe inserts (#42) * Parametrize safe/unsafe inserts * Fix formatting --------- Co-authored-by: Lars Hellemo --- src/SparseVariables.jl | 1 + src/indexedarray.jl | 22 +++++++++++++++++++--- test/runtests.jl | 3 +++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/SparseVariables.jl b/src/SparseVariables.jl index 9eb187a..96b2e63 100644 --- a/src/SparseVariables.jl +++ b/src/SparseVariables.jl @@ -14,6 +14,7 @@ export SparseArray export IndexedVarArray export insertvar! export unsafe_insertvar! +export SafeInsert, UnsafeInsert @setup_workload begin # Putting some things in `setup` can reduce the size of the diff --git a/src/indexedarray.jl b/src/indexedarray.jl index cb351af..9f52440 100644 --- a/src/indexedarray.jl +++ b/src/indexedarray.jl @@ -10,6 +10,9 @@ struct IndexedVarArray{V<:AbstractVariableRef,N,T} <: AbstractSparseArray{V,N} index_cache::Vector{Dictionary} end +struct SafeInsert end +struct UnsafeInsert end + _data(sa::IndexedVarArray) = sa.data already_defined(var, index) = haskey(_data(var), index) @@ -37,15 +40,28 @@ end Insert a new variable with the given index only after checking if keys are valid and not already defined. """ function insertvar!(var::IndexedVarArray{V,N,T}, index...) where {V,N,T} + return insertvar!(var, SafeInsert(), index...) +end +function insertvar!( + var::IndexedVarArray{V,N,T}, + ::SafeInsert = SafeInsert(), + index..., +) where {V,N,T} !valid_index(var, index) && throw(BoundsError(var, index))# "Not a valid index for $(var.name): $index"g already_defined(var, index) && error("$index already defined for array") - var[index] = var.f(index...) - clear_cache!(var) return var[index] end +function insertvar!( + var::IndexedVarArray{V,N,T}, + ::UnsafeInsert, + index..., +) where {V,N,T} + return var[index] = var.f(index...) +end + """ unsafe_insertvar!(var::indexedVarArray{V,N,T}, index...) @@ -53,7 +69,7 @@ Insert a new variable with the given index withouth checking if the index is val already assigned. """ function unsafe_insertvar!(var::IndexedVarArray{V,N,T}, index...) where {V,N,T} - return var[index] = var.f(index...) + return insertvar!(var, UnsafeInsert(), index...) end joinex(ex1, ex2) = :($ex1..., $ex2...) diff --git a/test/runtests.jl b/test/runtests.jl index 1a8a4cf..9b346a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -273,6 +273,9 @@ end unsafe_insertvar!(x, 2, 102) @test length(x) == 2 + insertvar!(x, UnsafeInsert(), 2, 103) + @test length(x) == 3 + # When no names are provided @variable(m, y[1:3, 100:102] >= 0, container = IndexedVarArray) @test length(y) == 0