diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 4d768b2f..789bf2f8 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -139,6 +139,7 @@ end DiffEqBase.add_tstop!(integrator::SSAIntegrator,tstop) = integrator.tstop = tstop function DiffEqBase.step!(integrator::SSAIntegrator) + doaffect = false if !isempty(integrator.tstops) && integrator.tstops_idx <= length(integrator.tstops) && integrator.tstops[integrator.tstops_idx] < integrator.tstop @@ -147,9 +148,23 @@ function DiffEqBase.step!(integrator::SSAIntegrator) integrator.tstops_idx += 1 else integrator.t = integrator.tstop - integrator.cb.affect!(integrator) + doaffect = true # delay effect until after saveat end + @inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat) + # Split to help prediction + while integrator.cur_saveat < length(integrator.saveat) && + integrator.saveat[integrator.cur_saveat] < integrator.t + + saved = true + push!(integrator.sol.t,integrator.saveat[integrator.cur_saveat]) + push!(integrator.sol.u,copy(integrator.u)) + integrator.cur_saveat += 1 + end + end + + doaffect && integrator.cb.affect!(integrator) + if !(typeof(integrator.opts.callback.discrete_callbacks)<:Tuple{}) discrete_modified,saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,integrator.opts.callback.discrete_callbacks...) else @@ -162,23 +177,16 @@ end function DiffEqBase.savevalues!(integrator::SSAIntegrator,force=false) saved, savedexactly = false, false + + # No saveat in here since it would only use previous values, + # so in the specific case of SSAStepper it's already handled + if integrator.save_everystep || force savedexactly = true push!(integrator.sol.t,integrator.t) push!(integrator.sol.u,copy(integrator.u)) end - @inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat) - # Split to help prediction - while integrator.cur_saveat < length(integrator.saveat) && - integrator.saveat[integrator.cur_saveat] < integrator.t - saved = true - push!(integrator.sol.t,integrator.saveat[integrator.cur_saveat]) - push!(integrator.sol.u,copy(integrator.u)) - integrator.cur_saveat += 1 - - end - end saved, savedexactly end diff --git a/test/runtests.jl b/test/runtests.jl index 2bf7fb7a..cd00ee55 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ using DiffEqJump, DiffEqBase, Test @time @testset "Mass Action Jump Tests; Special Cases" begin include("degenerate_rx_cases.jl") end @time @testset "Composition-Rejection Table Tests" begin include("table_test.jl") end @time @testset "Extinction test" begin include("extinction_test.jl") end + @time @testset "Saveat Regression test" begin include("saveat_regression.jl") end @time @testset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end @time @testset "Thread Safety test" begin include("thread_safety.jl") end end diff --git a/test/saveat_regression.jl b/test/saveat_regression.jl new file mode 100644 index 00000000..1055de52 --- /dev/null +++ b/test/saveat_regression.jl @@ -0,0 +1,37 @@ +using DiffEqBase, DiffEqJump, Test +rate_consts = [10.0] +reactant_stoich = [[1 => 1, 2 => 1]] +net_stoich = [[1 => -1, 2 => -1, 3 => 1]] +maj = MassActionJump(rate_consts, reactant_stoich, net_stoich) + +n0 = [1,1,0] +tspan = (0,.2) +dprob = DiscreteProblem(n0, tspan) +jprob = JumpProblem(dprob, Direct(), maj, save_positions=(false,false)) +ts = collect(0:.002:tspan[2]) +NA = zeros(length(ts)) +Nsims = 10_000 +sol = DiffEqJump.solve(EnsembleProblem(jprob), SSAStepper(), saveat=ts, trajectories=Nsims) + +for i in 1:length(sol) + NA .+= sol[i][1,:] +end + +for i in 1:length(ts) + @test NA[i] / Nsims ≈ exp(-10*ts[i]) rtol=1e-1 +end + +NA = zeros(length(ts)) +jprob = JumpProblem(dprob, Direct(), maj) +sol = nothing; GC.gc() +sol = DiffEqJump.solve(EnsembleProblem(jprob), SSAStepper(), trajectories=Nsims) + +for i = 1:Nsims + for n = 1:length(ts) + NA[n] += sol[i](ts[n])[1] + end +end + +for i in 1:length(ts) + @test NA[i] / Nsims ≈ exp(-10*ts[i]) rtol=1e-1 +end