diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 948b6883..b70ccd76 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -184,10 +184,11 @@ function DiffEqBase.__init(jump_prob::JumpProblem, end else cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end]) + rng = cb.condition.rng if seed === nothing - Random.seed!(cb.condition.rng, rand(UInt64)) + Random.seed!(rng, rand(UInt64)) else - Random.seed!(cb.condition.rng, seed) + Random.seed!(rng, seed) end end opts = (callback = CallbackSet(callback),) diff --git a/src/problem.jl b/src/problem.jl index 6fccf951..911aa2bc 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -44,8 +44,8 @@ then be passed within a single [`JumpSet`](@ref) or as subsequent sequential arg $(FIELDS) ## Keyword Arguments -- `rng`, the random number generator to use. On 1.7 and up defaults to Julia's built-in - generator, below 1.7 uses RandomNumbers.jl's `Xorshifts.Xoroshiro128Star(rand(UInt64))`. +- `rng`, the random number generator to use. Defaults to Julia's built-in + generator. - `save_positions=(true,true)`, specifies whether to save the system's state (before, after) the jump occurs. - `spatial_system`, for spatial problems the underlying spatial structure. @@ -430,14 +430,14 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL remake(prob; f, u0) end -function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) - idx += 1 - condition = function (u, t, integrator) +function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) + condition = function(u, t, integrator) u.jump_u[idx] end - affect! = function (integrator) + affect! = function(integrator) jump.affect!(integrator) integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t)) + nothing end new_cb = ContinuousCallback(condition, affect!; idxs = jump.idxs, @@ -446,26 +446,18 @@ function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) save_positions = jump.save_positions, abstol = jump.abstol, reltol = jump.reltol) + return new_cb +end + +function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) + idx += 1 + new_cb = wrap_jump_in_callback(idx, jump; rng) build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) end function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) idx += 1 - condition = function (u, t, integrator) - u.jump_u[idx] - end - affect! = function (integrator) - jump.affect!(integrator) - integrator.u.jump_u[idx] = -randexp(rng, typeof(integrator.t)) - end - new_cb = ContinuousCallback(condition, affect!; - idxs = jump.idxs, - rootfind = jump.rootfind, - interp_points = jump.interp_points, - save_positions = jump.save_positions, - abstol = jump.abstol, - reltol = jump.reltol) - CallbackSet(cb, new_cb) + CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) end aggregator(jp::JumpProblem{iip, P, A, C, J}) where {iip, P, A, C, J} = A diff --git a/src/solve.jl b/src/solve.jl index ee51ce98..a0b739a9 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -47,12 +47,12 @@ end function resetted_jump_problem(_jump_prob, seed) jump_prob = deepcopy(_jump_prob) + rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng if !isempty(jump_prob.jump_callback.discrete_callbacks) if seed === nothing - Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, - rand(UInt64)) + Random.seed!(rng, rand(UInt64)) else - Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, seed) + Random.seed!(rng, seed) end end diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 436ba6ba..39c8d55b 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -275,18 +275,20 @@ end # https://github.com/SciML/JumpProcesses.jl/issues/320 # note that even with the seeded StableRNG this test is not # deterministic for some reason. -function getmean(Nsims, prob, alg, dt, tsave) +function getmean(Nsims, prob, alg, dt, tsave, seed) umean = zeros(length(tsave)) for i in 1:Nsims - sol = solve(prob, alg; saveat = dt) + sol = solve(prob, alg; saveat = dt, seed) umean .+= Array(sol(tsave; idxs = 1)) + seed += 1 end umean ./= Nsims return umean end let - rng = StableRNG(12345) + seed = 12345 + rng = StableRNG(seed) b = 2.0 d = 1.0 n0 = 1 @@ -320,7 +322,8 @@ let dt = 0.1 tsave = range(tspan[1], tspan[2]; step = dt) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) - umean = getmean(Nsims, sjm_prob, alg, dt, tsave) + umean = getmean(Nsims, sjm_prob, alg, dt, tsave, seed) @test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave)) + seed += Nsims end end