-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed of GraphDynamics compared to DiffEq #471
Comments
This comment was marked as outdated.
This comment was marked as outdated.
Here's a new implementation based on the same technique as #484, and requiring the new GraphDynamics v0.2.0: using GraphDynamics, Neuroblox, OrdinaryDiffEq, Distributions
using Neuroblox: AbstractNeuronBlox
using Neuroblox.GraphDynamicsInterop: BasicConnection
#I'm making this a <:AbstractNeuronBlox even though I didn't actually bother doing
#any of the MTK-side stuff we'd normally do with a AbstractNeuronBlox
struct Izh2 <: AbstractNeuronBlox
vₚ::Float64
vᵣ::Float64
τᵣ::Float64
η::Float64
α::Float64
a::Float64
b::Float64
wⱼ::Float64
sⱼ::Float64
eᵣ::Float64
gₛ::Float64
τₛ::Float64
Iₑ::Float64
namespace::Symbol
name::Symbol
function Izh2(;name,
vₚ=200,
vᵣ=-200,
ηₘ=0.12,
ηₛ=0.02,
α=0.6215,
a=0.0077,
b=-0.0062,
wⱼ=0.0189,
sⱼ=1.2308,
eᵣ=1.0,
gₛ=1.2308,
τₛ=2.6,
Iₑ=0.0,
τᵣ=nothing,
namespace=Symbol(""))
τᵣ = isnothing(τᵣ) ? 2.0/vₚ : τᵣ
η = rand(Cauchy(ηₘ, ηₛ))
new(vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ, namespace, name)
end
end
Neuroblox.nameof((;name)::Izh2) = name
GraphDynamicsInterop.issupported(::Izh2) = true
# Teach GraphDynamics how to create a GraphDynamics.Subsystem out of a Izh2 struct
function GraphDynamicsInterop.to_subsystem((;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ)::Izh2)
states = SubsystemStates{Izh2}((; V=0.0, w=0.0, z=0.0))
params = SubsystemParams{Izh2}((;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ))
Subsystem(states, params)
end
GraphDynamics.initialize_input(::Subsystem{Izh2}) = 0.0
function GraphDynamics.subsystem_differential(s::Subsystem{Izh2}, _, t)
(;V, w, z) = s
(;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ) = s
dV = V * (V - α) - w + η + Iₑ + gₛ * z * (eᵣ - V)
dw = a * (b * V - w)
dz = -z/τₛ
SubsystemStates{Izh2}((;V=dV, w=dw, z=dz))
end
# This lets GraphDynamics do some optimizations by promising that the differential never depends on the connections
GraphDynamics.subsystem_differential_requires_inputs(::Type{<:Izh2}) = false
# Connections between Izh2 neurons shouldn't do anything (during solving). They only exist for events.
function (c::BasicConnection)(::Subsystem{Izh2}, ::Subsystem{Izh2})
0.0
end
# Tell GraphDynamics that Izh2 has events
GraphDynamics.has_discrete_events(::Type{Izh2}) = true
# The events trigger when V exceeds some threshold
function GraphDynamics.discrete_event_condition(n::Subsystem{Izh2}, t, _)
n.V >= n.vₚ
end
# When the event triggers, we want to update the neuron which fired *and* every destination neuron it has a
# connection to.
function GraphDynamics.apply_discrete_event!(integrator, sview, _, neuron_src::Subsystem{Izh2}, foreach_connected_neuron::F) where {F}
# Set the neuron to the post-firing state
sview[:V] = neuron_src.vᵣ
sview[:w] += neuron_src.wⱼ
# For each neuron which has a connection leading to it from neuron_src
# we increment it's `z` state
foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
states_view_dst[:z] += conn.weight * neuron_src.sⱼ
end
end julia> let tspan=(0.0, 20.0), N = 1000
neurons = map(1:N) do i
Izh2(;name = Symbol(:n, i))
end
g = MetaDiGraph()
add_blox!.((g,), neurons)
for i ∈ 1:N, j ∈ 1:N
add_edge!(g, i, j, Dict(:weight => 1/N))
end
@named sys = system_from_graph(g; graphdynamics=true)
prob = ODEProblem(sys, [], tspan, [])
sol = @btime solve($prob, Tsit5(), saveat=1)
end;
387.575 ms (42130 allocations: 66.82 MiB) which is still outperforming the pure diff-eq version by about the same margin as before: julia> let tspan=(0.0, 20.0), N=1000
p=izh_parameters_v2(N; adj=(ones(N, N)))
threshold = DiscreteCallback(spike_discrete_izh, reset_izh_cc!)
u₀ = zeros(3*N, 1)
prob = ODEProblem(izh_cc_network!, u₀, tspan, p, callback=threshold)
sol = @btime solve($prob, Tsit5(), saveat=1)
end;
435.293 ms (757448 allocations: 1.81 GiB) @agchesebro should this version of the Izh neuron be added to Neuroblox itself? |
Yes definitely! The other implementation is the default only to match the other kinds of neurons. Since this is so much more performant we might want to restructure IF/LIF/QIF internals along the same lines. We can discuss that on Wednesday though. Do you want to make this block or should I? |
Yeah, discussing on wednesday would be good. I've got a few things I want to to get to first, but if you want to open a PR with the MTK code, I could write up the GraphDynamics equivalent. |
@MasonProtter here's the MWE we were talking about earlier (sorry for the length of the DiffEq section). A couple notes:
IzhikevichNeuronCC()
in this branch is the exact same neuron as the DiffEq code. I've also updated GraphDynamicsInterop to support it in that branch. I wasn't sure if my implementation was what was slowing things down which is why I'd switched back to the original one for these benchmarks, but overall I've seen similar times for either Neuroblox version.The text was updated successfully, but these errors were encountered: