Skip to content

Commit

Permalink
plottable_indices dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Nov 8, 2024
1 parent b4cd99b commit bcab934
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ PoissonRandom = "0.4"
RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.46"
SciMLBase = "2.59"
Setfield = "1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.13"
Expand Down
3 changes: 3 additions & 0 deletions src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ function Base.similar(A::ExtendedJumpArray, ::Type{S},
ExtendedJumpArray(similar(A.u, S), similar(A.jump_u, S))
end

# plotting
SciMLBase.plottable_indices(u::ExtendedJumpArray) = SciMLBase.plottable_indices(u.u)

# ODE norm to prevent type-unstable fallback
@inline function DiffEqBase.ODE_DEFAULT_NORM(u::ExtendedJumpArray, t)
Base.FastMath.sqrt_fast(real(sum(abs2, u)) / max(length(u), 1))
Expand Down
5 changes: 3 additions & 2 deletions test/extended_jump_array.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq
using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq, SciMLBase
using FastBroadcast
using StableRNGs

rng = StableRNG(123)
Expand Down Expand Up @@ -56,7 +57,6 @@ out_result .= bc_dtype_1 .+ bc_dtype_2 .* 2
@test out_result result

# Test that fast broadcasting also gives the correct results
using FastBroadcast
@.. bc_out = 3.14 * bc_eja_1 + 2.7 * bc_eja_2
@test bc_out 3.14 * bc_eja_1 + 2.7 * bc_eja_2

Expand Down Expand Up @@ -118,4 +118,5 @@ let
jprob = JumpProblem(oprob, Direct(), vrj, deathvrj)
sol = solve(jprob, Tsit5())
@test eltype(sol.u) <: ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
@test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀)
end

0 comments on commit bcab934

Please sign in to comment.