Skip to content

Commit

Permalink
Merge pull request #125 from SciML/saveat
Browse files Browse the repository at this point in the history
fix SSAStepper saveat
  • Loading branch information
ChrisRackauckas authored Jun 17, 2020
2 parents 76451d5 + 8843ca1 commit e1faf34
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
32 changes: 20 additions & 12 deletions src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ end
DiffEqBase.add_tstop!(integrator::SSAIntegrator,tstop) = integrator.tstop = tstop

function DiffEqBase.step!(integrator::SSAIntegrator)
doaffect = false
if !isempty(integrator.tstops) &&
integrator.tstops_idx <= length(integrator.tstops) &&
integrator.tstops[integrator.tstops_idx] < integrator.tstop
Expand All @@ -147,9 +148,23 @@ function DiffEqBase.step!(integrator::SSAIntegrator)
integrator.tstops_idx += 1
else
integrator.t = integrator.tstop
integrator.cb.affect!(integrator)
doaffect = true # delay effect until after saveat
end

@inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat)
# Split to help prediction
while integrator.cur_saveat < length(integrator.saveat) &&
integrator.saveat[integrator.cur_saveat] < integrator.t

saved = true
push!(integrator.sol.t,integrator.saveat[integrator.cur_saveat])
push!(integrator.sol.u,copy(integrator.u))
integrator.cur_saveat += 1
end
end

doaffect && integrator.cb.affect!(integrator)

if !(typeof(integrator.opts.callback.discrete_callbacks)<:Tuple{})
discrete_modified,saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,integrator.opts.callback.discrete_callbacks...)
else
Expand All @@ -162,23 +177,16 @@ end

function DiffEqBase.savevalues!(integrator::SSAIntegrator,force=false)
saved, savedexactly = false, false

# No saveat in here since it would only use previous values,
# so in the specific case of SSAStepper it's already handled

if integrator.save_everystep || force
savedexactly = true
push!(integrator.sol.t,integrator.t)
push!(integrator.sol.u,copy(integrator.u))
end
@inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat)
# Split to help prediction
while integrator.cur_saveat < length(integrator.saveat) &&
integrator.saveat[integrator.cur_saveat] < integrator.t

saved = true
push!(integrator.sol.t,integrator.saveat[integrator.cur_saveat])
push!(integrator.sol.u,copy(integrator.u))
integrator.cur_saveat += 1

end
end
saved, savedexactly
end

Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using DiffEqJump, DiffEqBase, Test
@time @testset "Mass Action Jump Tests; Special Cases" begin include("degenerate_rx_cases.jl") end
@time @testset "Composition-Rejection Table Tests" begin include("table_test.jl") end
@time @testset "Extinction test" begin include("extinction_test.jl") end
@time @testset "Saveat Regression test" begin include("saveat_regression.jl") end
@time @testset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end
@time @testset "Thread Safety test" begin include("thread_safety.jl") end
end
37 changes: 37 additions & 0 deletions test/saveat_regression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using DiffEqBase, DiffEqJump, Test
rate_consts = [10.0]
reactant_stoich = [[1 => 1, 2 => 1]]
net_stoich = [[1 => -1, 2 => -1, 3 => 1]]
maj = MassActionJump(rate_consts, reactant_stoich, net_stoich)

n0 = [1,1,0]
tspan = (0,.2)
dprob = DiscreteProblem(n0, tspan)
jprob = JumpProblem(dprob, Direct(), maj, save_positions=(false,false))
ts = collect(0:.002:tspan[2])
NA = zeros(length(ts))
Nsims = 10_000
sol = DiffEqJump.solve(EnsembleProblem(jprob), SSAStepper(), saveat=ts, trajectories=Nsims)

for i in 1:length(sol)
NA .+= sol[i][1,:]
end

for i in 1:length(ts)
@test NA[i] / Nsims exp(-10*ts[i]) rtol=1e-1
end

NA = zeros(length(ts))
jprob = JumpProblem(dprob, Direct(), maj)
sol = nothing; GC.gc()
sol = DiffEqJump.solve(EnsembleProblem(jprob), SSAStepper(), trajectories=Nsims)

for i = 1:Nsims
for n = 1:length(ts)
NA[n] += sol[i](ts[n])[1]
end
end

for i in 1:length(ts)
@test NA[i] / Nsims exp(-10*ts[i]) rtol=1e-1
end

0 comments on commit e1faf34

Please sign in to comment.