From 1c4d99500c297f97b0477cdfea8a9501f805d81d Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Tue, 29 Oct 2024 14:26:45 +0200 Subject: [PATCH 1/3] Add tests for AdjacencyMatrix (#477) * add tests for AdjacencyMatrix * fix import and typo * fix names test --- test/graphs.jl | 61 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/test/graphs.jl b/test/graphs.jl index 304b8ab1..eec7f0cb 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -1,8 +1,65 @@ using Neuroblox -using Test -using SparseArrays +using Neuroblox: get_adjacency using Graphs using MetaGraphs +using Test +using SparseArrays + +@testset "AdjacencyMatrix [HH Neurons]" begin + @named n1 = HHNeuronExciBlox() + @named n2 = HHNeuronExciBlox() + @named n3 = HHNeuronInhibBlox() + + g = MetaDiGraph() + add_edge!(g, n1 => n2 , weight = 1) + add_edge!(g, n1 => n3 , weight = 1) + add_edge!(g, n3 => n2 , weight = 1) + add_edge!(g, n2 => n2 , weight = 1) + + adj = get_adjacency(g) + + A = [0 1 1 ; 0 1 0; 0 1 0] + + @test all(A .== adj.matrix) + @test all([:n1, :n2, :n3] .== adj.names) +end + +@testset "AdjacencyMatrix [CorticalBlox]" begin + global_ns = :g + + A = Matrix{Matrix{Bool}}(undef, 2, 2) + A[2,1] = [0 1 ; 1 1] + A[1,2] = [0 1 ; 1 1] + + @named cb1 = CorticalBlox(namespace = global_ns, N_wta=2, N_exci=2, connection_matrices=A, weight=1) + + adj = get_adjacency(cb1) + + adj_wta_11 = [0 1 1; 1 0 0; 1 0 0] + adj_wta_12 = [[0 0 0]; hcat([0, 0], A[1,2])] + adj_wta_21 = [[0 0 0]; hcat([0, 0], A[2,1])] + + A_wta = [adj_wta_11 adj_wta_12 ; adj_wta_21 adj_wta_11] + + A = [ + hcat(A_wta, [0, 0, 0, 0, 0, 0]); + [0 1 1 0 1 1 0] + ] + + @test all(A .== adj.matrix) + + nms = [ + :cb1₊wta1₊inh, + :cb1₊wta1₊exci1, + :cb1₊wta1₊exci2, + :cb1₊wta2₊inh, + :cb1₊wta2₊exci1, + :cb1₊wta2₊exci2, + :cb1₊ff_inh + ] + + @test all(nms .== adj.names) +end @testset "Graph to adjacency matrix" begin # testing whether creating a simple graph results in the correct adjacency matrix From 1eba98e2fcb474b010122f8f43f53006027b274c Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Wed, 30 Oct 2024 20:07:00 +0200 Subject: [PATCH 2/3] More LIF fixes for the decision making tutorial (#472) * more terms from connections to neurons * accumulate both states and parameters (values) for spike affects * match state with the correct parameter value in spike affect * allow for duplicate parameters to be passed in functional affect using Pairs * update comment & fix typo * rename variable for clarity * synchronize GraphDynamicsInterop with changes to the LIFExci / LIFInh neurons --------- Co-authored-by: Mason Protter --- Project.toml | 2 +- .../GraphDynamicsInterop.jl | 3 +- .../connection_interop.jl | 55 ++++++++++++------- src/Neurographs.jl | 32 +++++++++-- src/blox/blox_utilities.jl | 6 +- src/blox/connections.jl | 40 ++++++-------- src/blox/neuron_models.jl | 8 ++- 7 files changed, 91 insertions(+), 55 deletions(-) diff --git a/Project.toml b/Project.toml index f0159e2c..1658f078 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ DataFrames = "1.3" Distributions = "0.25.102" ExponentialUtilities = "1" ForwardDiff = "0.10" -GraphDynamics = "0.1.4" +GraphDynamics = "0.1.5" Graphs = "1" Interpolations = "0.14, 0.15" MetaGraphs = "0.7" diff --git a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl index 473c518a..66cbcd53 100644 --- a/src/GraphDynamicsInterop/GraphDynamicsInterop.jl +++ b/src/GraphDynamicsInterop/GraphDynamicsInterop.jl @@ -72,7 +72,8 @@ using GraphDynamics: StateIndex, ParamIndex, event_times, - calculate_inputs + calculate_inputs, + connection_index using Random: Random, diff --git a/src/GraphDynamicsInterop/connection_interop.jl b/src/GraphDynamicsInterop/connection_interop.jl index 460b482c..9b8e2af2 100644 --- a/src/GraphDynamicsInterop/connection_interop.jl +++ b/src/GraphDynamicsInterop/connection_interop.jl @@ -260,17 +260,12 @@ end function (c::BasicConnection)(sys_src::Subsystem{LIFExciNeuron}, sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) - w = c.weight - (; S_AMPA, g_AMPA, V, V_E, g_NMDA, Mg) = sys_dst - (; S_NMDA) = sys_src - (; jcn = w * (S_AMPA * g_AMPA * (V - V_E) + S_NMDA * g_NMDA * (V - V_E) / (1 + Mg * exp(-0.062 * V) / 3.57))) + (; jcn = 0.0) end -function (c::BasicConnection)(sys_src::Subsystem{LIFInhNeuron}, - sys_dst::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) - w = c.weight - (; S_GABA, g_GABA, V, V_I) = sys_dst - (;jcn = w * S_GABA * g_GABA * (V - V_I)) +function (c::BasicConnection)(::Subsystem{LIFInhNeuron}, + ::Union{Subsystem{LIFExciNeuron}, Subsystem{LIFInhNeuron}}) + (; jcn = 0.0) end struct SpikeAffectEventBuilder @@ -284,6 +279,7 @@ struct SpikeAffectEvent{i_src, i_LIFInh, i_LIFExci} j_dsts_inh::Vector{Int} j_dsts_exci::Vector{Int} end + function (ev::SpikeAffectEventBuilder)(index_map) (i_src, j_src) = index_map[ev.idx_src] i_inh, j_dsts_inh = let v = ev.idx_dsts_inh @@ -315,14 +311,18 @@ end + function GraphDynamics.apply_discrete_event!(integrator, states::NTuple{Len, Any}, params::NTuple{Len, Any}, - _, + connection_matrices, t, ev::SpikeAffectEvent{i_src, i_dst_inh, i_dst_exci} ) where {i_src, i_dst_inh, i_dst_exci, Len} (; j_src, j_dsts_inh, j_dsts_exci) = ev + + nc = connection_index(BasicConnection, connection_matrices) + params_src = params[i_src][j_src] @reset params_src.t_refract_end = t + params_src.t_refract_duration @reset params_src.is_refractory = 1 @@ -334,22 +334,39 @@ function GraphDynamics.apply_discrete_event!(integrator, states[i_src][:V, j_src] = params_src.V_reset if (states_src isa SubsystemStates{LIFExciNeuron}) && (j_src ∈ j_dsts_exci) # x is the rise variable for NMDA synapses and it only applies to self-recurrent connections - states[i_src][:x, j_src] += 1 + w = connection_matrices[nc][i_src, i_src][j_src, j_src].weight + states[i_src][:x, j_src] += w end if states_src isa SubsystemStates{LIFExciNeuron} - !isnothing(i_dst_inh) && for j_dst ∈ j_dsts_inh - states[i_dst_inh][:S_AMPA, j_dst] += 1 + if !isnothing(i_dst_inh) + M = connection_matrices[nc][i_src, i_dst_inh] + for j_dst ∈ j_dsts_inh + w = M[j_src, j_dst].weight + states[i_dst_inh][:S_AMPA, j_dst] += w + end end - !isnothing(i_dst_exci) && for j_dst ∈ j_dsts_exci - states[i_dst_exci][:S_AMPA, j_dst] += 1 + if !isnothing(i_dst_exci) + M = connection_matrices[nc][i_src, i_dst_exci] + for j_dst ∈ j_dsts_exci + w = M[j_src, j_dst].weight + states[i_dst_exci][:S_AMPA, j_dst] += w + end end elseif states_src isa SubsystemStates{LIFInhNeuron} - !isnothing(i_dst_inh) && for j_dst ∈ j_dsts_inh - states[i_dst_inh][:S_GABA, j_dst] += 1 + if !isnothing(i_dst_inh) + M = connection_matrices[nc][i_src, i_dst_inh] + for j_dst ∈ j_dsts_inh + w = M[j_src, j_dst].weight + states[i_dst_inh][:S_GABA, j_dst] += w + end end - !isnothing(i_dst_exci) && for j_dst ∈ j_dsts_exci - states[i_dst_exci][:S_GABA, j_dst] += 1 + if !isnothing(i_dst_exci) + M = connection_matrices[nc][i_src, i_dst_exci] + for j_dst ∈ j_dsts_exci + w = M[j_src, j_dst].weight + states[i_dst_exci][:S_GABA, j_dst] += w + end end else error("this should be unreachable") diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 56723128..96a1b782 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -76,17 +76,39 @@ end generate_discrete_callbacks(blox, ::BloxConnector; t_block = missing) = [] function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::BloxConnector; t_block = missing) - spike_affect_states = get_spike_affect_states(bc) + spike_affects = get_spike_affects(bc) name_blox = namespaced_nameof(blox) + sys = get_namespaced_sys(blox) - states_dest = get(spike_affect_states, name_blox, Num[]) + states_affect, params_affect = get(spike_affects, name_blox, (Num[], Num[])) - sys = get_namespaced_sys(blox) + # HACK : MTK will complain if the parameter vector passed to a functional affect + # contains non-unique parameters. Here we sometimes need to pass duplicate parameters that + # affect states in the loop in LIF_spike_affect! . + # Passing parameters with Symbol aliases bypasses this issue and allows for duplicates. + affect_pairs = if unique(params_affect) == length(params_affect) + [p => Symbol(p) for p in params_affect] + else + map(params_affect) do p + if count(pi -> Symbol(pi) == Symbol(p), params_affect) > 1 + p => Symbol(p, "_$(rand(1:1000))") + else + p => Symbol(p) + end + end + end + + ps = vcat([ + sys.V_reset => Symbol(sys.V_reset), + sys.t_refract_duration => Symbol(sys.t_refract_duration), + sys.t_refract_end => Symbol(sys.t_refract_end), + sys.is_refractory => Symbol(sys.is_refractory) + ], affect_pairs) cb = (sys.V > sys.θ) => ( LIF_spike_affect!, - vcat(sys.V, states_dest), - [sys.V_reset, sys.t_refract_duration, sys.t_refract_end, sys.is_refractory], + vcat(sys.V, states_affect), + ps, [], nothing ) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index c385d347..97abd899 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -179,9 +179,9 @@ get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks ∘ get_connector)(blox) get_discrete_callbacks(blox) = [] -get_spike_affect_states(bc::BloxConnector) = bc.spike_affect_states -get_spike_affect_states(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affect_states ∘ get_connector)(blox) -get_spike_affect_states(blox) = Dict{Symbol, Vector{Num}}() +get_spike_affects(bc::BloxConnector) = bc.spike_affects +get_spike_affects(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affects ∘ get_connector)(blox) +get_spike_affects(blox) = Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}() get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 71f80f07..059dbb26 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -3,7 +3,7 @@ mutable struct BloxConnector weights::Vector{Num} delays::Vector{Num} discrete_callbacks - spike_affect_states::Dict{Symbol, Vector{Num}} + spike_affects::Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}} learning_rules adjacency @@ -14,16 +14,17 @@ mutable struct BloxConnector weights = mapreduce(get_weight_parameters, vcat, bloxs) delays = mapreduce(get_delay_parameters, vcat, bloxs) discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) - # spike_affect_states holds a Dictionary that maps - # the name of a source Blox to the states of a destination Blox - # that are affected by a continuous callback of the source Blox. + # spike_affects holds a Dictionary that maps + # the name of a source Blox to a Tuple of (states, parameters) of a destination Blox. + # The states are affected by a discrete callback of the source Blox + # and the parameters determine the amount of this affect like `states .+= parameters`. # Typically this is used when a source Blox spikes, so its Voltage state crosses a threshold, # and this spike affects synaptic parameters of every destination Blox that it connects to. - spike_affect_states = mapreduce(get_spike_affect_states, merge, bloxs) + spike_affects = mapreduce(get_spike_affects, merge, bloxs) learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs) adjacency = mapreduce(get_adjacency, merge, bloxs) - new(eqs, weights, delays, discrete_callbacks, spike_affect_states, learning_rules, adjacency) + new(eqs, weights, delays, discrete_callbacks, spike_affects, learning_rules, adjacency) end end @@ -33,11 +34,13 @@ function accumulate_equation!(bc::BloxConnector, eq) bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs end -function accumulate_spike_affect_states!(bc::BloxConnector, name_blox_src, states_dst) - if haskey(bc.spike_affect_states, name_blox_src) - append!(bc.spike_affect_states[name_blox_src], states_dst) +function accumulate_spike_affects!(bc::BloxConnector, name_blox_src, states_affect, params_affect) + if haskey(bc.spike_affects, name_blox_src) + spike_affects = bc.spike_affects[name_blox_src] + append!(spike_affects[1], states_affect) + append!(spike_affects[2], params_affect) else - bc.spike_affect_states[name_blox_src] = states_dst + bc.spike_affects[name_blox_src] = (states_affect, params_affect) end end @@ -891,18 +894,16 @@ function (bc::BloxConnector)( w = generate_weight_param(bloxout, bloxin; kwargs...) push!(bc.weights, w) - eq = sys_in.jcn ~ w * sys_in.S_AMPA * sys_in.g_AMPA * (sys_in.V - sys_in.V_E) + - w * sys_out.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) / + eq = sys_in.jcn ~ w * sys_out.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) / (1 + sys_in.Mg * exp(-0.062 * sys_in.V) / 3.57) - accumulate_equation!(bc, eq) # Compare the unique namespaced names of both systems if nameof(sys_out) == nameof(sys_in) # x is the rise variable for NMDA synapses and it only applies to self-recurrent connections - accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x]) + accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x], [w, w]) else - accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_AMPA]) + accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA], [w]) end end @@ -917,11 +918,7 @@ function (bc::BloxConnector)( w = generate_weight_param(bloxout, bloxin; kwargs...) push!(bc.weights, w) - eq = sys_in.jcn ~ w * sys_in.S_GABA * sys_in.g_GABA * (sys_in.V - sys_in.V_I) - - accumulate_equation!(bc, eq) - - accumulate_spike_affect_states!(bc, nameof(sys_out), [sys_in.S_GABA]) + accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_GABA], [w]) end function (bc::BloxConnector)( @@ -931,9 +928,6 @@ function (bc::BloxConnector)( ) sys_in = get_namespaced_sys(neuron) - w = generate_weight_param(stim, neuron; kwargs...) - push!(bc.weights, w) - t_spikes = generate_spike_times(stim) cb = t_spikes => [sys_in.S_AMPA_ext ~ sys_in.S_AMPA_ext + 1] diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index b4d67ef3..b1162246 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -638,8 +638,10 @@ function LIF_spike_affect!(integ, u, p, ctx) SciMLBase.add_tstop!(integ, t_refract_end) + c = 1 for i in eachindex(u)[2:end] - integ.u[u[i]] += 1 + integ.u[u[i]] += integ.p[p[c + 4]] + c += 1 end end @@ -693,7 +695,7 @@ struct LIFInhNeuron <: AbstractInhNeuronBlox sts = @variables V(t)=-52 S_AMPA(t)=0 S_GABA(t)=0 S_AMPA_ext(t)=0 jcn(t) [input=true] jcn_external(t) [input=true] eqs = [ - D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - jcn) / C, + D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - S_GABA * g_GABA * (V - V_I) - S_AMPA * g_AMPA * (V - V_E) - jcn) / C, D(S_AMPA) ~ - S_AMPA / τ_AMPA, D(S_GABA) ~ - S_GABA / τ_GABA, D(S_AMPA_ext) ~ - S_AMPA_ext / τ_AMPA @@ -761,7 +763,7 @@ struct LIFExciNeuron <: AbstractExciNeuronBlox sts = @variables V(t)=-52 S_AMPA(t)=0 S_GABA(t)=0 S_NMDA(t)=0 x(t)=0 S_AMPA_ext(t)=0 jcn(t) [input=true] eqs = [ - D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - jcn) / C, + D(V) ~ (1 - is_refractory) * (- g_L * (V - V_L) - S_AMPA_ext * g_AMPA_ext * (V - V_E) - S_GABA * g_GABA * (V - V_I) - S_AMPA * g_AMPA * (V - V_E) - jcn) / C, D(S_AMPA) ~ - S_AMPA / τ_AMPA, D(S_GABA) ~ - S_GABA / τ_GABA, D(S_NMDA) ~ - S_NMDA / τ_NMDA_decay + α * x * (1 - S_NMDA), From de84dac625e63a3675beb0c11676c43c9cf28376 Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Thu, 31 Oct 2024 14:51:19 +0200 Subject: [PATCH 3/3] normalize colormap range to data range (#479) --- ext/MakieExtension.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ext/MakieExtension.jl b/ext/MakieExtension.jl index 5cc537e9..bf81ef72 100644 --- a/ext/MakieExtension.jl +++ b/ext/MakieExtension.jl @@ -42,13 +42,7 @@ function Makie.plot!(p::Adjacency) X, Y, D = findnz(adj.matrix) - heatmap!(p, X, Y, D; colormap = p.colormap[]) - - idxs = Tuple.(findall(iszero, adj.matrix)) - x_zero = first.(idxs) - y_zero = last.(idxs) - - #heatmap!(p, x_zero, y_zero, fill(0, length(x_zero)); color=:black) + heatmap!(p, X, Y, D; colormap = p.colormap[], colorrange = (minimum(D), maximum(D))) return p end