Skip to content

Commit

Permalink
Merge pull request #408 from isaacsas/indexing_fixes
Browse files Browse the repository at this point in the history
Indexing fixes
  • Loading branch information
ChrisRackauckas authored Mar 25, 2024
2 parents fe5dd7b + 0f14821 commit 53772a0
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 33 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
26 changes: 14 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -28,20 +29,21 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
JumpProcessFastBroadcastExt = "FastBroadcast"

[compat]
ArrayInterface = "6, 7"
DataStructures = "0.17, 0.18"
DiffEqBase = "6.122"
DocStringExtensions = "0.8.6, 0.9"
FunctionWrappers = "1.0"
Graphs = "1.4"
ArrayInterface = "7.9"
DataStructures = "0.18"
DiffEqBase = "6.148"
DocStringExtensions = "0.9"
FunctionWrappers = "1.1"
Graphs = "1.9"
PoissonRandom = "0.4"
RandomNumbers = "1.3"
RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "1.51, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.30.1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.11"
UnPack = "1.0.2"
julia = "1.6"
julia = "1.10"

[extras]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Graphs: neighbors, outdegree

import RecursiveArrayTools: recursivecopy!
using StaticArrays, Base.Threads
import SymbolicIndexingInterface as SII

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down
2 changes: 1 addition & 1 deletion src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end

# executing jump at the next jump time
function (p::CoevolveJumpAggregation)(integrator::I) where {I <:
AbstractSSAIntegrator}
AbstractSSAIntegrator}
if !accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t)
return nothing
end
Expand Down
3 changes: 1 addition & 2 deletions src/aggregators/ssajump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ end
end

@inline function concretize_affects!(p::AbstractSSAJumpAggregator{T, S, F1, F2},
::I) where {T, S, F1, F2 <: Tuple,
I <: DiffEqBase.DEIntegrator}
::I) where {T, S, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
nothing
end

Expand Down
24 changes: 14 additions & 10 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,17 @@ end
# when setindex! is used.
function Base.setindex!(prob::JumpProblem, args...; kwargs...)
SciMLBase.___internal_setindex!(prob.prob, args...; kwargs...)
end

# for updating parameters in JumpProblems to update MassActionJumps
function SII.set_parameter!(prob::JumpProblem, val, idx)
ans = SII.set_parameter!(SII.parameter_values(prob), val, idx)

if using_params(prob.massaction_jump)
update_parameters!(prob.massaction_jump, prob.prob.p)
end

ans
end

# when getindex is used.
Expand All @@ -151,14 +159,12 @@ function JumpProblem(prob, jumps::AbstractJump...; kwargs...)
JumpProblem(prob, JumpSet(jumps...); kwargs...)
end

function JumpProblem(
prob, aggregator::AbstractAggregatorAlgorithm, jumps::ConstantRateJump;
kwargs...)
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm,
jumps::ConstantRateJump; kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...)
end
function JumpProblem(
prob, aggregator::AbstractAggregatorAlgorithm, jumps::VariableRateJump;
kwargs...)
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm,
jumps::VariableRateJump; kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...)
end
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::RegularJump;
Expand Down Expand Up @@ -321,8 +327,7 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = SDEFunction{isinplace(prob)}(jump_f, jump_g), g = jump_g, u0 = u0)
end

Expand All @@ -347,8 +352,7 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = DDEFunction{isinplace(prob)}(jump_f), u0 = u0)
end

Expand Down
31 changes: 24 additions & 7 deletions test/jprob_symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
# prepares the problem
using JumpProcesses, Test
using JumpProcesses, Test, SymbolicIndexingInterface
rate1(u, p, t) = p[1]
rate2(u, p, t) = p[2]
affect1!(integ) = (integ.u[1] += 1)
affect2!(integ) = (integ.u[2] += 1)
crj1 = ConstantRateJump(rate1, affect1!)
crj2 = ConstantRateJump(rate2, affect2!)
g = DiscreteFunction((du, u, p, t) -> nothing; syms = [:a, :b], paramsyms = [:p1, :p2])
maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1,2])
g = DiscreteFunction((du, u, p, t) -> nothing;
sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t))
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2)
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)

# runs the tests
# test basic querying of u0 and p
@test jprob[:a] == 0
@test jprob[:b] == 10
@test jprob[:p1] == 1.0
@test jprob[:p2] == 2.0
@test getp(jprob,:p1)(jprob) == 1.0
@test getp(jprob,:p2)(jprob) == 2.0
@test jprob.ps[:p1] == 1.0
@test jprob.ps[:p2] == 2.0

# tests for setindex (e.g. `jprob[:a] = 10`) not possible, this requires the problem to have a .f.sys filed.,
# test updating u0
jprob[:a] = 20
@test jprob[:a] == 20

# test mass action jumps update with parameter mutation in problems
@test jprob.massaction_jump.scaled_rates[1] == 1.0
jprob.ps[:p1] = 3.0
@test jprob.ps[:p1] == 3.0
@test jprob.massaction_jump.scaled_rates[1] == 3.0
p1setter = setp(jprob, [:p1, :p2])
p1setter(jprob, [4.0, 10.0])
@test jprob.ps[:p1] == 4.0
@test jprob.ps[:p2] == 10.0
@test jprob.massaction_jump.scaled_rates == [4.0, 10.0]

0 comments on commit 53772a0

Please sign in to comment.