Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Choose an SSA if no SSA is passed in JumpProblem. #351

Merged
merged 23 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,38 @@
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true

# true if aggregator supports hops, e.g. diffusion
is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
is_spatial(aggregator::DirectCRDirect) = true

# return the fastest aggregator out of the available ones
function select_aggregator(jumps::JumpSet; vartojumps_map=nothing, jumptovars_map=nothing, dep_graph=nothing, spatial_system=nothing, hopping_constants=nothing)

# detect if a spatial SSA should be used
!isnothing(spatial_system) && !isnothing(hopping_constants) && return DirectCRDirect

Check warning on line 197 in src/aggregators/aggregators.jl

View check run for this annotation

Codecov / codecov/patch

src/aggregators/aggregators.jl#L197

Added line #L197 was not covered by tests
# if variable rate jumps are present, return one of the two SSAs that support them
if num_vrjs(jumps) != 0
any(isbounded, vrjs) && return Coevolve
return Direct
end

# if the number of jumps is small, return the Direct
num_jumps(jumps) < 10 && return Direct

Check warning on line 205 in src/aggregators/aggregators.jl

View check run for this annotation

Codecov / codecov/patch

src/aggregators/aggregators.jl#L204-L205

Added lines #L204 - L205 were not covered by tests

# if there are only massaction jumps, we can any build dependency graph
TorkelE marked this conversation as resolved.
Show resolved Hide resolved
can_build_dependency_graphs = num_crjs(jumps) == 0 && num_vrjs(jumps) == 0
have_dependency_graphs = !isnothing(vartojumps_map) && !isnothing(jumptovars_map)

# if we have the species-jumps dependency graphs or can build them, use one of the Rejection-based methods
if can_build_dependency_graphs || have_dependency_graphs
num_jumps(jumps) < 100 && return RSSA
return RSSACR
# if we have the jumps-jumps dependency graph, use the Composition-Rejection Direct method
elseif !isnothing(dep_graph)
TorkelE marked this conversation as resolved.
Show resolved Hide resolved
return DirectCR
else

Check warning on line 218 in src/aggregators/aggregators.jl

View check run for this annotation

Codecov / codecov/patch

src/aggregators/aggregators.jl#L218

Added line #L218 was not covered by tests
return Direct
end
end

Check warning on line 221 in src/aggregators/aggregators.jl

View check run for this annotation

Codecov / codecov/patch

src/aggregators/aggregators.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
7 changes: 5 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,12 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::Abstr
kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps...); kwargs...)
end
function JumpProblem(prob, jumps::JumpSet; kwargs...)
JumpProblem(prob, NullAggregator(), jumps; kwargs...)
function JumpProblem(prob, jumps::JumpSet; vartojumps_map=nothing, jumptovars_map=nothing, dep_graph=nothing, spatial_system=nothing, hopping_constants=nothing, kwargs...)
isaacsas marked this conversation as resolved.
Show resolved Hide resolved
aggregator = select_aggregator(jumps::JumpSet; vartojumps_map=vartojumps_map, jumptovars_map=jumptovars_map, dep_graph=dep_graph, spatial_system=spatial_system, hopping_constants=hopping_constants)
return JumpProblem(prob, aggregator(), jumps; vartojumps_map=vartojumps_map, jumptovars_map=jumptovars_map, dep_graph=dep_graph, spatial_system=spatial_system, hopping_constants=hopping_constants, kwargs...)
end
# this makes it easier to test the aggregator selection
JumpProblem(prob, aggregator::NullAggregator, jumps::JumpSet; kwargs...) = JumpProblem(prob, jumps; kwargs...)

make_kwarg(; kwargs...) = kwargs

Expand Down
2 changes: 1 addition & 1 deletion test/geneexpr_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dotestmean = true
doprintmeans = false

# SSAs to test
SSAalgs = (RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
SSAalgs = (JumpProcesses.NullAggregator(), RDirect(), RSSACR(), Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(),
TorkelE marked this conversation as resolved.
Show resolved Hide resolved
NRM(), RSSA(), DirectCR(), Coevolve())

# numerical parameters
Expand Down
2 changes: 1 addition & 1 deletion test/linearreaction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ tf = 0.1
baserate = 0.1
A0 = 100
exactmean = (t, ratevec) -> A0 * exp(-sum(ratevec) * t)
SSAalgs = [RSSACR(), Direct(), RSSA()]
SSAalgs = [RSSACR(), Direct(), RSSA(), JumpProcesses.NullAggregator()]

spec_to_dep_jumps = [collect(1:Nrxs), []]
jump_to_dep_specs = [[1, 2] for i in 1:Nrxs]
Expand Down
3 changes: 3 additions & 0 deletions test/spatial/ABC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps,
push!(jump_problems,
JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants,
spatial_system = grids[1], save_positions = (false, false), rng = rng))
push!(jump_problems,
JumpProblem(prob, majumps, hopping_constants = hopping_constants,
spatial_system = grids[1], save_positions = (false, false), rng = rng))
# setup flattenned jump prob
push!(jump_problems,
JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants,
Expand Down
Loading