Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed of GraphDynamics compared to DiffEq #471

Open
agchesebro opened this issue Oct 23, 2024 · 4 comments
Open

Speed of GraphDynamics compared to DiffEq #471

agchesebro opened this issue Oct 23, 2024 · 4 comments

Comments

@agchesebro
Copy link
Member

@MasonProtter here's the MWE we were talking about earlier (sorry for the length of the DiffEq section). A couple notes:

  • It seems like there's some unfavorable scaling in simulation time with GraphDynamics that is different than DiffEq. For example, on my computer:
    • 20ms simulation: GraphDynamics = ~134s; DiffEq = ~9s
    • 40ms simulation: GraphDynamics = ~370s; DiffEq = ~11s
  • The neurons are slightly different in how they handle synapses (Neuroblox neurons have one extra state), so this isn't an exact comparison. If you'd like to use an exact comparison IzhikevichNeuronCC() in this branch is the exact same neuron as the DiffEq code. I've also updated GraphDynamicsInterop to support it in that branch. I wasn't sure if my implementation was what was slowing things down which is why I'd switched back to the original one for these benchmarks, but overall I've seen similar times for either Neuroblox version.
using Neuroblox, Random, Distributions, OrdinaryDiffEq

# Common parameters for both simulations
# --------------------------------------
N = 1000
tspan = (0.0, 20.0)

# GraphDynamics/Neuroblox Version 
# ----------------------

η_dist = Cauchy(0.12, 0.02)
neurons = map(1:N) do i 
    IzhikevichNeuron(name=Symbol(:n,i), η=rand(η_dist))
end

g = MetaDiGraph()
add_blox!.((g,), neurons)
for i ∈ 1:N
    for j ∈ 1:N
        add_edge!(g, i, j, Dict(:weight => 2*randn(), :connection_rule => "basic"))
    end
end
@named sys = system_from_graph(g; graphdynamics=true)
prob = ODEProblem(sys, [], tspan)
println("GraphDynamics ----------------")
print("time-to-first-solve:  "); sol = @time solve(prob, Tsit5())

# DiffEq Version
# ----------------------

function izh_parameters_v2(N;
                           vₚ=200,
                           vᵣ=-200,
                           ηₘ=0.12,
                           ηₛ=0.02,
                           α=0.6215,
                           a=0.0077,
                           b=-0.0062,
                           wⱼ=0.0189,
                           sⱼ=1.2308,
                           eᵣ=1.0,
                           gₛ=1.2308,
                           τₛ=2.6,
                           Iₑ=0.0, 
                           τᵣ=nothing,
                           adj=nothing)
    lst = zeros(N, 1)
    τᵣ = isnothing(τᵣ) ? 2.0/vₚ : τᵣ
    η = rand(Cauchy(ηₘ, ηₛ), N, 1)
    adj = isnothing(adj) ? adj = ones(N, N) : adj
    return (vₚ=vₚ, vᵣ=vᵣ, lst=lst, τᵣ=τᵣ, η=η, α=α, a=a, b=b, wⱼ=wⱼ, sⱼ=sⱼ, eᵣ=eᵣ, gₛ=gₛ, τₛ=τₛ, Iₑ=Iₑ, adj=adj)
end

function spike_discrete_izh(u, t, integrator)
    return sum(integrator.u[1:3:end] .>= integrator.p.vₚ) > 0
end

function izh_cc_network!(du, u, p, t)    
    du[1:3:end] .= u[1:3:end] .* (u[1:3:end] .- p.α) .- u[2:3:end] .+ p.η .+ p.Iₑ .+ p.gₛ .* u[3:3:end] .* (p.eᵣ .- u[1:3:end])
    du[2:3:end] .= p.a .* (p.b .* u[1:3:end] .- u[2:3:end])
    du[3:3:end] .= -u[3:3:end] ./ p.τₛ
end

function reset_izh_cc!(integrator)
    idx = findall(x -> x >= integrator.p.vₚ, integrator.u[1:3:end])
    integrator.u[3 .*idx .- 2] .= integrator.p.vᵣ
    integrator.u[3 .*idx .- 1] .+= integrator.p.wⱼ

    N = Int(length(integrator.u)/3)
    spikes = zeros(N)
    spikes[idx] .= 1
    integrator.u[3:3:end] .+= integrator.p.sⱼ*integrator.p.adj*spikes/N
end

function run_izh_net_v2(tspan, N; p=izh_parameters_v2(N))
    threshold = DiscreteCallback(spike_discrete_izh, reset_izh_cc!)

    u₀ = zeros(3*N, 1)
    prob = ODEProblem(izh_cc_network!, u₀, tspan, p, callback=threshold)
    return solve(prob, Tsit5(), saveat=1)
end

println("DiffEq ----------------")
print("time-to-first-solve:  "); sol = @time run_izh_net_v2(tspan, N)
@MasonProtter

This comment was marked as outdated.

@MasonProtter
Copy link
Contributor

Here's a new implementation based on the same technique as #484, and requiring the new GraphDynamics v0.2.0:

using GraphDynamics, Neuroblox, OrdinaryDiffEq, Distributions
using Neuroblox: AbstractNeuronBlox
using Neuroblox.GraphDynamicsInterop: BasicConnection

#I'm making this a <:AbstractNeuronBlox even though I didn't actually bother doing
#any of the MTK-side stuff we'd normally do with a AbstractNeuronBlox
struct Izh2 <: AbstractNeuronBlox
    vₚ::Float64
    vᵣ::Float64
    τᵣ::Float64
    η::Float64
    α::Float64
    a::Float64
    b::Float64
    wⱼ::Float64
    sⱼ::Float64
    eᵣ::Float64
    gₛ::Float64
    τₛ::Float64
    Iₑ::Float64
    namespace::Symbol
    name::Symbol
    function Izh2(;name,
                  vₚ=200,
                  vᵣ=-200,
                  ηₘ=0.12,
                  ηₛ=0.02,
                  α=0.6215,
                  a=0.0077,
                  b=-0.0062,
                  wⱼ=0.0189,
                  sⱼ=1.2308,
                  eᵣ=1.0,
                  gₛ=1.2308,
                  τₛ=2.6,
                  Iₑ=0.0, 
                  τᵣ=nothing,
                  namespace=Symbol(""))
        τᵣ = isnothing(τᵣ) ? 2.0/vₚ : τᵣ
        η = rand(Cauchy(ηₘ, ηₛ))
        new(vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ, namespace, name)
    end
end
Neuroblox.nameof((;name)::Izh2) = name

GraphDynamicsInterop.issupported(::Izh2) = true

# Teach GraphDynamics how to create a GraphDynamics.Subsystem out of a Izh2 struct
function GraphDynamicsInterop.to_subsystem((;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ)::Izh2)
    states = SubsystemStates{Izh2}((; V=0.0, w=0.0, z=0.0))
    params = SubsystemParams{Izh2}((;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ))
    Subsystem(states, params)
end

GraphDynamics.initialize_input(::Subsystem{Izh2}) = 0.0

function GraphDynamics.subsystem_differential(s::Subsystem{Izh2}, _, t)
    (;V, w, z) = s
    (;vₚ, vᵣ, τᵣ, η, α, a, b, wⱼ, sⱼ, eᵣ, gₛ, τₛ, Iₑ) = s
    dV = V * (V - α) - w + η + Iₑ + gₛ * z * (eᵣ - V)
    dw = a * (b * V - w)
    dz = -z/τₛ
    SubsystemStates{Izh2}((;V=dV, w=dw, z=dz))
end

# This lets GraphDynamics do some optimizations by promising that the differential never depends on the connections
GraphDynamics.subsystem_differential_requires_inputs(::Type{<:Izh2}) = false

# Connections between Izh2 neurons shouldn't do anything (during solving). They only exist for events.
function (c::BasicConnection)(::Subsystem{Izh2}, ::Subsystem{Izh2})
    0.0
end

# Tell GraphDynamics that Izh2 has events
GraphDynamics.has_discrete_events(::Type{Izh2}) = true

# The events trigger when V exceeds some threshold
function GraphDynamics.discrete_event_condition(n::Subsystem{Izh2}, t, _)
    n.V >= n.vₚ
end

# When the event triggers, we want to update the neuron which fired *and* every destination neuron it has a
# connection to.
function GraphDynamics.apply_discrete_event!(integrator, sview, _, neuron_src::Subsystem{Izh2}, foreach_connected_neuron::F) where {F}
    # Set the neuron to the post-firing state
    sview[:V]  = neuron_src.vᵣ
    sview[:w] += neuron_src.wⱼ

    # For each neuron which has a connection leading to it from neuron_src
    # we increment it's `z` state
    foreach_connected_neuron() do conn, neuron_dst, states_view_dst, params_view_dst
        states_view_dst[:z] += conn.weight * neuron_src.sⱼ
    end
end
julia> let tspan=(0.0, 20.0), N = 1000
           neurons = map(1:N) do i
               Izh2(;name = Symbol(:n, i))
           end
           
           g = MetaDiGraph()
           add_blox!.((g,), neurons)
           for i  1:N, j  1:N
               add_edge!(g, i, j, Dict(:weight => 1/N))
           end
           @named sys = system_from_graph(g; graphdynamics=true)
           prob = ODEProblem(sys, [], tspan, [])
           sol = @btime solve($prob, Tsit5(), saveat=1)
       end;
  387.575 ms (42130 allocations: 66.82 MiB)

which is still outperforming the pure diff-eq version by about the same margin as before:

julia> let tspan=(0.0, 20.0), N=1000
           p=izh_parameters_v2(N; adj=(ones(N, N)))
           threshold = DiscreteCallback(spike_discrete_izh, reset_izh_cc!)

           u₀ = zeros(3*N, 1)
           prob = ODEProblem(izh_cc_network!, u₀, tspan, p, callback=threshold)
           sol = @btime solve($prob, Tsit5(), saveat=1)
       end;
  435.293 ms (757448 allocations: 1.81 GiB)

@agchesebro should this version of the Izh neuron be added to Neuroblox itself?

@agchesebro
Copy link
Member Author

Yes definitely! The other implementation is the default only to match the other kinds of neurons. Since this is so much more performant we might want to restructure IF/LIF/QIF internals along the same lines. We can discuss that on Wednesday though. Do you want to make this block or should I?

@MasonProtter
Copy link
Contributor

Yeah, discussing on wednesday would be good. I've got a few things I want to to get to first, but if you want to open a PR with the MTK code, I could write up the GraphDynamics equivalent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants