Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix vrj reinitialization bug #450

Merged
merged 8 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ function resetted_jump_problem(_jump_prob, seed)

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
jump_prob.prob.u0.jump_u .= -randexp.(_jump_prob.rng, eltype(_jump_prob.prob.tspan))
randexp!(_jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
end
jump_prob
end
Expand All @@ -70,6 +71,7 @@ function reset_jump_problem!(jump_prob, seed)

if !isempty(jump_prob.variable_jumps)
@assert jump_prob.prob.u0 isa ExtendedJumpArray
jump_prob.prob.u0.jump_u .= -randexp.(jump_prob.rng, eltype(jump_prob.prob.tspan))
randexp!(jump_prob.rng, jump_prob.prob.u0.jump_u)
jump_prob.prob.u0.jump_u .*= -1
end
end
88 changes: 88 additions & 0 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DiffEqBase, JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test
using Random
using StableRNGs
rng = StableRNG(12345)

Expand Down Expand Up @@ -229,3 +230,90 @@ let
solve(jprob, SSAStepper())
end
end

# test u0 resets correctly
let
b = 2.0
d = 1.0
n0 = 1
tspan = (0.0, 4.0)
Nsims = 10
u0 = [n0]
p = [b, d]

function ode_fxn(du, u, p, t)
du .= 0
nothing
end
b_rate(u, p, t) = (u[1] * p[1])
function birth!(integrator)
integrator.u[1] += 1
nothing
end
b_jump = VariableRateJump(b_rate, birth!)

d_rate(u, p, t) = (u[1] * p[2])
function death!(integrator)
integrator.u[1] -= 1
nothing
end
d_jump = VariableRateJump(d_rate, death!)

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng)
@test allunique(sjm_prob.prob.u0.jump_u)
u0old = copy(sjm_prob.prob.u0.jump_u)
for i in 1:Nsims
sol = solve(sjm_prob, Tsit5(); saveat = tspan[2])
@test allunique(sjm_prob.prob.u0.jump_u)
@test all(u0old != sjm_prob.prob.u0.jump_u)
u0old .= sjm_prob.prob.u0.jump_u
end
end

# accuracy test based on
# https://github.com/SciML/JumpProcesses.jl/issues/320
# note that even with the seeded StableRNG this test is not
# deterministic for some reason.
let
rng = StableRNG(12345)
b = 2.0
d = 1.0
n0 = 1
tspan = (0.0, 4.0)
Nsims = 10000
n(t) = n0 * exp((b - d) * t)
u0 = [n0]
p = [b, d]

function ode_fxn(du, u, p, t)
du .= 0
nothing
end

b_rate(u, p, t) = (u[1] * p[1])
function birth!(integrator)
integrator.u[1] += 1
nothing
end
b_jump = VariableRateJump(b_rate, birth!)

d_rate(u, p, t) = (u[1] * p[2])
function death!(integrator)
integrator.u[1] -= 1
nothing
end
d_jump = VariableRateJump(d_rate, death!)

ode_prob = ODEProblem(ode_fxn, u0, tspan, p)
sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; rng)
dt = 0.1
tsave = range(tspan[1], tspan[2]; step = dt)
umean = zeros(length(tsave))
for i in 1:Nsims
sol = solve(sjm_prob, Tsit5(); saveat = dt)
umean .+= Array(sol(tsave; idxs = 1))
end
umean ./= Nsims
@test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave))
end
Loading