diff --git a/Project.toml b/Project.toml index bd9ee556..e060d1ff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqJump" uuid = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12" authors = ["Chris Rackauckas "] -version = "6.6.3" +version = "6.7.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/src/jumps.jl b/src/jumps.jl index 4d6b4c2e..e3454c6f 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -24,14 +24,19 @@ VariableRateJump(rate,affect!; rootfind,interp_points, save_positions,abstol,reltol) -struct RegularJump{R,C,MD} +struct RegularJump{iip,R,C,MD} rate::R c::C numjumps::Int mark_dist::MD + function RegularJump{iip}(rate,c,numjumps::Int; mark_dist = nothing) where iip + new{iip,typeof(rate),typeof(c),typeof(mark_dist)}(rate,c,numjumps,mark_dist) + end end -RegularJump(rate,c,numjumps::Int; mark_dist = nothing) = RegularJump(rate,c,numjumps,mark_dist) +DiffEqBase.isinplace(::RegularJump{iip,R,C,MD}) where {iip,R,C,MD} = iip + +RegularJump(rate,args...; kwargs...) = RegularJump{DiffEqBase.isinplace(rate,4)}(rate,args...;kwargs...) # deprecate old call function RegularJump(rate,c,dc::AbstractMatrix; constant_c=false, mark_dist = nothing) diff --git a/src/problem.jl b/src/problem.jl index 06f06850..ee25118a 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -1,4 +1,4 @@ -mutable struct JumpProblem{P,A,C,J<:Union{Nothing,AbstractJumpAggregator},J2,J3,J4} <: DiffEqBase.AbstractJumpProblem{P,J} +mutable struct JumpProblem{iip,P,A,C,J<:Union{Nothing,AbstractJumpAggregator},J2,J3,J4} <: DiffEqBase.AbstractJumpProblem{P,J} prob::P aggregator::A discrete_jump_aggregation::J @@ -7,6 +7,8 @@ mutable struct JumpProblem{P,A,C,J<:Union{Nothing,AbstractJumpAggregator},J2,J3, regular_jump::J3 massaction_jump::J4 end + +DiffEqBase.isinplace(::JumpProblem{iip}) where {iip} = iip JumpProblem(prob::JumpProblem) = prob JumpProblem(prob,jumps::ConstantRateJump;kwargs...) = JumpProblem(prob,JumpSet(jumps);kwargs...) @@ -36,6 +38,13 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS constant_jump_callback = DiscreteCallback(disc) end + iip = if prob isa DiscreteProblem && prob.f === DiffEqBase.DISCRETE_INPLACE_DEFAULT + # Just a default discrete problem f, so don't use it for iip + DiffEqBase.isinplace(jumps.regular_jump) + else + DiffEqBase.isinplace(prob) + end + ## Variable Rate Handling if typeof(jumps.variable_jumps) <: Tuple{} new_prob = prob @@ -45,7 +54,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS variable_jump_callback = build_variable_callback(CallbackSet(),0,jumps.variable_jumps...) end callbacks = CallbackSet(constant_jump_callback,variable_jump_callback) - JumpProblem{typeof(new_prob),typeof(aggregator),typeof(callbacks), + JumpProblem{iip,typeof(new_prob),typeof(aggregator),typeof(callbacks), typeof(disc),typeof(jumps.variable_jumps), typeof(jumps.regular_jump),typeof(jumps.massaction_jump)}( new_prob,aggregator,disc,