diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 4df55f13..08151586 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -63,6 +63,7 @@ include("spatial/bracketing.jl") include("spatial/nsm.jl") include("spatial/directcrdirect.jl") +include("spatial/directcrrssa.jl") include("aggregators/aggregated_api.jl") @@ -99,6 +100,6 @@ export ExtendedJumpArray export CartesianGrid, CartesianGridRej export SpatialMassActionJump export outdegree, num_sites, neighbors -export NSM, DirectCRDirect +export NSM, DirectCRDirect, DirectCRRSSA end # module diff --git a/src/aggregators/aggregators.jl b/src/aggregators/aggregators.jl index e64501d7..6fda0b49 100644 --- a/src/aggregators/aggregators.jl +++ b/src/aggregators/aggregators.jl @@ -163,6 +163,8 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108 """ struct DirectCRDirect <: AbstractAggregatorAlgorithm end +struct DirectCRRSSA <: AbstractAggregatorAlgorithm end + const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(), FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve()) @@ -192,6 +194,7 @@ supports_variablerates(aggregator::Coevolve) = true is_spatial(aggregator::AbstractAggregatorAlgorithm) = false is_spatial(aggregator::NSM) = true is_spatial(aggregator::DirectCRDirect) = true +is_spatial(aggregator::DirectCRRSSA) = true # return the fastest aggregator out of the available ones function select_aggregator(jumps::JumpSet; vartojumps_map = nothing, diff --git a/src/spatial/bracketing.jl b/src/spatial/bracketing.jl index 6346fa4b..431913a5 100644 --- a/src/spatial/bracketing.jl +++ b/src/spatial/bracketing.jl @@ -5,9 +5,15 @@ struct LowHigh{T} low::T high::T - LowHigh(low::T, high::T) where {T} = new{T}(deepcopy(low), deepcopy(high)) - LowHigh(pair::Tuple{T, T}) where {T} = new{T}(pair[1], pair[2]) - LowHigh(low_and_high::T) where {T} = new{T}(low_and_high, deepcopy(low_and_high)) + function LowHigh(low::T, high::T; do_copy = true) where {T} + if do_copy + return new{T}(deepcopy(low), deepcopy(high)) + else + return new{T}(low, high) + end + end + LowHigh(pair::Tuple{T,T}; kwargs...) where {T} = LowHigh(pair[1], pair[2]; kwargs...) + LowHigh(low_and_high::T; kwargs...) where {T} = LowHigh(low_and_high, low_and_high; kwargs...) end function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) @@ -16,22 +22,32 @@ function Base.show(io::IO, ::MIME"text/plain", low_high::LowHigh) end @inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) - @inbounds for (i, uval) in enumerate(u) - u_low_high[i] = LowHigh(get_spec_brackets(bracket_data, i, uval)) + num_species, num_sites = size(u) + update_u_brackets!(u_low_high, bracket_data, u, 1:num_species, 1:num_sites) +end + +@inline function update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix, species_vec, sites) + @inbounds for site in sites + for species in species_vec + u_low_high[species, site] = LowHigh(get_spec_brackets(bracket_data, species, u[species, site])) + end end nothing end +function is_inside_brackets(u_low_high::LowHigh{M}, u::M, species, site) where {M} + return u_low_high.low[species, site] < u[species, site] < u_low_high.high[species, site] +end + ### convenience functions for LowHigh ### -function setindex!(low_high::LowHigh, val::LowHigh, i) - low_high.low[i] = val.low - low_high.high[i] = val.high +function setindex!(low_high::LowHigh{A}, val::LowHigh, i...) where {A <: AbstractArray} + low_high.low[i...] = val.low + low_high.high[i...] = val.high val end +getindex(low_high::LowHigh{A}, i) where {A <: AbstractArray} = LowHigh(low_high.low[i], low_high.high[i]) -function getindex(low_high::LowHigh, i) - return LowHigh(low_high.low[i], low_high.high[i]) -end +get_majumps(rx_rates::LowHigh{R}) where {R <: RxRates} = get_majumps(rx_rates.low) function total_site_rate(rx_rates::LowHigh, hop_rates::LowHigh, site) return LowHigh( @@ -48,3 +64,8 @@ function update_hop_rates!(hop_rates::LowHigh, species, u_low_high, site, spatia update_hop_rates!(hop_rates.low, species, u_low_high.low, site, spatial_system) update_hop_rates!(hop_rates.high, species, u_low_high.high, site, spatial_system) end + +function reset!(low_high::LowHigh) + reset!(low_high.low) + reset!(low_high.high) +end \ No newline at end of file diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index bec547c4..bf9441fe 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -4,7 +4,6 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: @@ -107,12 +106,12 @@ end function initialize!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, u, params, t) p.next_jump_time = t + randexp(p.rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing site = sample(p.rt, p.site_rates, p.rng) diff --git a/src/spatial/directcrrssa.jl b/src/spatial/directcrrssa.jl new file mode 100644 index 00000000..500a6f60 --- /dev/null +++ b/src/spatial/directcrrssa.jl @@ -0,0 +1,260 @@ +# site chosen with DirectCR, rx or hop chosen with RSSA + +############################ DirectCRRSSA ################################### +const MINJUMPRATE = 2.0^exponent(1e-12) + +#NOTE state vector u is a matrix. u[i,j] is species i, site j +mutable struct DirectCRRSSAJumpAggregation{T, BD, M, RNG, J, RX, HOP, DEPGR, + VJMAP, JVMAP, SS, U <: PriorityTable, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + next_jump::SpatialJump{J} + prev_jump::SpatialJump{J} + next_jump_time::T + end_time::T + bracket_data::BD + u_low_high::LowHigh{M} # species bracketing + rx_rates::LowHigh{RX} + hop_rates::LowHigh{HOP} + site_rates_high::Vector{T} # we do not need site_rates_low + save_positions::Tuple{Bool, Bool} + rng::RNG + dep_gr::DEPGR #dep graph is same for each site + vartojumps_map::VJMAP #vartojumps_map is same for each site + jumptovars_map::JVMAP #jumptovars_map is same for each site + spatial_system::SS + numspecies::Int #number of species + rt::U + rates::F1 # legacy, not used + affects!::F2 # legacy, not used +end + +function DirectCRRSSAJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, bd::BD, + u_low_high::LowHigh{M}, rx_rates::LowHigh{RX}, + hop_rates::LowHigh{HOP}, site_rates_high::Vector{T}, + sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + num_specs, minrate = convert(T, MINJUMPRATE), + vartojumps_map = nothing, jumptovars_map = nothing, + dep_graph = nothing, + kwargs...) where {J, T, BD, RX, HOP, RNG, SS, M} + + # a dependency graph is needed + if dep_graph === nothing + dg = make_dependency_graph(num_specs, get_majumps(rx_rates)) + else + dg = dep_graph + # make sure each jump depends on itself + add_self_dependencies!(dg) + end + + # a species-to-reactions graph is needed + if vartojumps_map === nothing + vtoj_map = var_to_jumps_map(num_specs, get_majumps(rx_rates)) + else + vtoj_map = vartojumps_map + end + + if jumptovars_map === nothing + jtov_map = jump_to_vars_map(get_majumps(rx_rates)) + else + jtov_map = jumptovars_map + end + + # mapping from jump rate to group id + minexponent = exponent(minrate) + + # use the largest power of two that is <= the passed in minrate + minrate = 2.0^minexponent + ratetogroup = rate -> priortogid(rate, minexponent) + + # construct an empty initial priority table -- we'll reset this in init + rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) + + DirectCRRSSAJumpAggregation{ + T, + BD, + M, + RNG, + J, + RX, + HOP, + typeof(dg), + typeof(vtoj_map), + typeof(jtov_map), + SS, + typeof(rt), + Nothing, + Nothing, + Nothing, + }(nj, nj, njt, et, bd, u_low_high, rx_rates, hop_rates, site_rates_high, sps, rng, dg, + vtoj_map, jtov_map, spatial_system, num_specs, rt, nothing, nothing) +end + +############################# Required Functions ############################## +# creating the JumpAggregation structure (function wrapper-based constant jumps) +function aggregate(aggregator::DirectCRRSSA, starting_state, p, t, end_time, + constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + spatial_system, bracket_data = nothing, kwargs...) + T = typeof(end_time) + num_species = size(starting_state, 1) + majumps = ma_jumps + if majumps === nothing + majumps = MassActionJump(Vector{T}(), + Vector{Vector{Pair{Int, Int}}}(), + Vector{Vector{Pair{Int, Int}}}()) + end + + next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder + next_jump_time = typemax(T) + rx_rates = LowHigh(RxRates(num_sites(spatial_system), majumps), + RxRates(num_sites(spatial_system), majumps); + do_copy = false) # do not copy ma_jumps + hop_rates = LowHigh(HopRates(hopping_constants, spatial_system), + HopRates(hopping_constants, spatial_system); + do_copy = false) # do not copy hopping_constants + site_rates_high = zeros(T, num_sites(spatial_system)) + bd = (bracket_data === nothing) ? BracketData{T, eltype(starting_state)}() : + bracket_data + u_low_high = LowHigh(starting_state) + + DirectCRRSSAJumpAggregation(next_jump, next_jump_time, end_time, bd, u_low_high, + rx_rates, hop_rates, + site_rates_high, save_positions, rng, spatial_system; + num_specs = num_species, kwargs...) +end + +# set up a new simulation and calculate the first jump / jump time +function initialize!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) + p.end_time = integrator.sol.prob.tspan[2] + fill_rates_and_get_times!(p, integrator, t) + generate_jumps!(p, integrator, u, params, t) + nothing +end + +# calculate the next jump / jump time +function generate_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t) + @unpack rng, rt, site_rates_high, rx_rates, hop_rates, spatial_system = p + time_delta = zero(t) + while true + site = sample(rt, site_rates_high, rng) + jump = sample_jump_direct(rx_rates.high, hop_rates.high, site, spatial_system, rng) + time_delta += randexp(rng) + if accept_jump(p, u, jump) + p.next_jump_time = t + time_delta / groupsum(rt) + p.next_jump = jump + break + end + end + nothing +end + +# execute one jump, changing the system state +function execute_jumps!(p::DirectCRRSSAJumpAggregation, integrator, u, params, t, + affects!) + update_state!(p, integrator) + update_dependent_rates!(p, integrator, t) + nothing +end + +######################## SSA specific helper routines ######################## +# Return true if site is accepted. +@inline accept_jump(p, u, jump) = is_hop(p, jump) ? accept_hop(p, u, jump) : accept_rx(p, u, jump) + + +function accept_hop(p, u, jump) + @unpack hop_rates, spatial_system, rng = p + species, site = jump.jidx, jump.src + acceptance_threshold = rand(rng) * hop_rate(hop_rates.high, species, site) + if hop_rate(hop_rates.low, species, site) > acceptance_threshold + return true + else + # compute the real rate. Could have used hop_rates.high as well. + real_rate = evalhoprate(hop_rates.low, u, species, site, spatial_system) + return real_rate > acceptance_threshold + end +end + +function accept_rx(p, u, jump) + @unpack rx_rates, rng = p + rx, site = reaction_id_from_jump(p, jump), jump.src + acceptance_threshold = rand(rng) * rx_rate(rx_rates.high, rx, site) + if rx_rate(rx_rates.low, rx, site) > acceptance_threshold + return true + else + # compute the real rate. Could have used rx_rates.high as well. + real_rate = evalrxrate(rx_rates.low, u, rx, site) + return real_rate > acceptance_threshold + end +end + +""" + fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, u, t) + +reset all stucts, reevaluate all rates, repopulate the priority table +""" +function fill_rates_and_get_times!(aggregation::DirectCRRSSAJumpAggregation, integrator, t) + @unpack bracket_data, u_low_high, spatial_system, rx_rates, hop_rates, site_rates_high, rt = aggregation + u = integrator.u + update_u_brackets!(u_low_high::LowHigh, bracket_data, u::AbstractMatrix) + + reset!(rx_rates) + reset!(hop_rates) + fill!(site_rates_high, zero(eltype(site_rates_high))) + + rxs = 1:num_rxs(rx_rates.low) + species = 1:(aggregation.numspecies) + + for site in 1:num_sites(spatial_system) + update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) + end + + # setup PriorityTable + reset!(rt) + for (pid, priority) in enumerate(site_rates_high) + insert!(rt, pid, priority) + end + nothing +end + +""" + update_dependent_rates!(p, integrator, t) + +recalculate jump rates for jumps that depend on the just executed jump (p.prev_jump) +""" +function update_dependent_rates!(p::DirectCRRSSAJumpAggregation, integrator, t) + jump = p.prev_jump + if is_hop(p, jump) + update_brackets!(p, integrator, jump.jidx, (jump.src, jump.dst)) + else + update_brackets!(p, integrator, p.jumptovars_map[reaction_id_from_jump(p, jump)], jump.src) + end +end + +function update_brackets!(p, integrator, species_to_update, sites_to_update) + @unpack rx_rates, hop_rates, site_rates_high, u_low_high, bracket_data, vartojumps_map, spatial_system = p + u = integrator.u + for site in sites_to_update, species in species_to_update + if !is_inside_brackets(u_low_high, u, species, site) + update_u_brackets!(u_low_high, bracket_data, u, species, site) + update_rx_rates!(rx_rates, + vartojumps_map[species], + u_low_high, + integrator, + site) + update_hop_rates!(hop_rates, species, u_low_high, site, spatial_system) + + oldrate = site_rates_high[site] + site_rates_high[site] = total_site_rate(rx_rates.high, hop_rates.high, site) + update!(p.rt, site, oldrate, site_rates_high[site]) + end + end + nothing +end + +""" + num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) + +number of constant rate jumps +""" +num_constant_rate_jumps(aggregator::DirectCRRSSAJumpAggregation) = 0 \ No newline at end of file diff --git a/src/spatial/hop_rates.jl b/src/spatial/hop_rates.jl index bdb70f0a..fdab78aa 100644 --- a/src/spatial/hop_rates.jl +++ b/src/spatial/hop_rates.jl @@ -68,6 +68,8 @@ function update_hop_rates!(hop_rates::AbstractHopRates, species::AbstractArray, end end +hop_rate(hop_rates, species, site) = @inbounds hop_rates.rates[species, site] + """ update_hop_rate!(hop_rates::HopRatesGraphDsi, species, u, site, spatial_system) diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 85e5cc42..dfed44b5 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -2,7 +2,6 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j -#NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, PQ, SS} <: AbstractSSAJumpAggregator{T, S, F1, F2, RNG} @@ -96,12 +95,12 @@ end function initialize!(p::NSMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] fill_rates_and_get_times!(p, integrator, t) - generate_jumps!(p, integrator, params, u, t) + generate_jumps!(p, integrator, u, params, t) nothing end # calculate the next jump / jump time -function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) +function generate_jumps!(p::NSMJumpAggregation, integrator, u, params, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing p.next_jump = sample_jump_direct(p, site) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 737cc5c9..0ccd908f 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -26,6 +26,7 @@ function RxRates(num_sites::Int, ma_jumps::M) where {M} end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) +get_majumps(rx_rates::RxRates) = rx_rates.ma_jumps """ reset!(rx_rates::RxRates) @@ -38,6 +39,9 @@ function reset!(rx_rates::RxRates) nothing end +rx_rate(rx_rates, rx, site) = rx_rates.rates[rx, site] +evalrxrate(rx_rates, u, rx, site) = eval_massaction_rate(u, rx, rx_rates.ma_jumps, site) + """ total_site_rx_rate(rx_rates::RxRates, site) diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index ad3c2cb3..42de1308 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -27,14 +27,17 @@ end sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < - total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) - return SpatialJump(site, rx + p.numspecies, site) +sample_jump_direct(p, site) = sample_jump_direct(p.rx_rates, p.hop_rates, site, p.spatial_system, p.rng) + +function sample_jump_direct(rx_rates, hop_rates, site, spatial_system, rng) + numspecies = size(hop_rates.rates, 1) + if rand(rng) * (total_site_rate(rx_rates, hop_rates, site)) < + total_site_rx_rate(rx_rates, site) + rx = sample_rx_at_site(rx_rates, site, rng) + return SpatialJump(site, rx + numspecies, site) else - species_to_diffuse, target_site = sample_hop_at_site(p.hop_rates, site, p.rng, - p.spatial_system) + species_to_diffuse, target_site = sample_hop_at_site(hop_rates, site, rng, + spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end end @@ -52,10 +55,10 @@ end function update_rates_after_hop!(p, integrator, source_site, target_site, species) u = integrator.u update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, source_site) - update_hop_rate!(p.hop_rates, species, u, source_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, source_site, p.spatial_system) update_rx_rates!(p.rx_rates, p.vartojumps_map[species], integrator, target_site) - update_hop_rate!(p.hop_rates, species, u, target_site, p.spatial_system) + update_hop_rates!(p.hop_rates, species, u, target_site, p.spatial_system) end """ @@ -70,7 +73,7 @@ function update_state!(p, integrator) else rx_index = reaction_id_from_jump(p, jump) @inbounds executerx!((@view integrator.u[:, jump.src]), rx_index, - p.rx_rates.ma_jumps) + get_majumps(p.rx_rates)) end # save jump that was just executed p.prev_jump = jump diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 8d7230a7..61c5a1dd 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -50,13 +50,15 @@ end # testing grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, - hopping_constants = hopping_constants, - spatial_system = grid, - save_positions = (false, false), rng = rng) - for grid in grids] -push!(jump_problems, - JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + hopping_constants = hopping_constants, + spatial_system = grid, + save_positions = (false, false), rng = rng) for grid in grids] + +# SSAs +for alg in [DirectCRDirect(), DirectCRRSSA()] + push!(jump_problems, JumpProblem(prob, alg, majumps; hopping_constants, spatial_system = grids[1], save_positions = (false, false), rng)) +end + # setup flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, diff --git a/test/spatial/bracketing.jl b/test/spatial/bracketing.jl index 31c1e23b..89006f2e 100644 --- a/test/spatial/bracketing.jl +++ b/test/spatial/bracketing.jl @@ -10,7 +10,6 @@ n = 3 # number of sites # set up spatial system spatial_system = CartesianGrid((n,)) # n sites -site_rates = JP.LowHigh(zeros(n), zeros(n)) # set up reaction rates majump_rates = [0.1] # death at rate 0.1 @@ -36,7 +35,6 @@ integrator = Nothing # only needed for constant rate jumps for site in 1:num_sites(spatial_system) JP.update_rx_rates!(rx_rates, rxs, u_low_high, integrator, site) JP.update_hop_rates!(hop_rates, species_vec, u_low_high, site, spatial_system) - site_rates[site] = JP.total_site_rate(rx_rates, hop_rates, site) end # test species brackets diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index c014b5c5..db58e1e7 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -61,7 +61,7 @@ Nsims = 10000 rel_tol = 0.02 times = 0.0:(tf / num_time_points):tf -algs = [NSM(), DirectCRDirect()] +algs = [NSM(), DirectCRDirect(), DirectCRRSSA()] grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, hopping_constants = hopping_constants,