Skip to content

Commit

Permalink
Merge pull request #297 from Neuroblox/connecting-ascending-input
Browse files Browse the repository at this point in the history
Connecting ascending input
  • Loading branch information
anandpathak31 authored Oct 25, 2023
2 parents d554b27 + 8fe13fd commit 6ba3d3c
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 16 deletions.
5 changes: 0 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.3.0"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Catalyst = "479239e8-5488-4da2-87a7-35f2df7eef83"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
Expand All @@ -25,7 +24,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -43,8 +41,6 @@ OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Peaks = "18e31ff7-3703-566c-8e60-38913d67486b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PowerModelsDistribution = "d7431456-977f-11e9-2de3-97ff7677985e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -92,7 +88,6 @@ OptimizationOptimJL = "0.1"
OptimizationOptimisers = "0.1"
OrderedCollections = "1.4"
OrdinaryDiffEq = "6"
Plots = "1"
RecursiveArrayTools = "2"
Reexport = "1.0"
SafeTestsets = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ export harmonic_oscillator, jansen_ritC, jansen_ritSC, jansen_rit_spm12,
export IFNeuronBlox, LIFNeuronBlox, QIFNeuronBlox, HHNeuronExciBlox, HHNeuronInhibBlox, LinearNeuralMassBlox,
WilsonCowanBlox, HarmonicOscillatorBlox, JansenRitCBlox, JansenRitSCBlox, LarterBreakspearBlox,
CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical
export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear
export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox
export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy
Expand Down
44 changes: 44 additions & 0 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ function (bc::BloxConnector)(
accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
asc_out::NextGenerationEIBlox,
HH_in::Union{HHNeuronExciBlox, HHNeuronInhibBlox};
kwargs...
)
sys_out = get_namespaced_sys(asc_out)
sys_in = get_namespaced_sys(HH_in)

w = generate_weight_param(asc_out, HH_in; kwargs...)
push!(bc.weights, w)

#Z = sys_out.Z
a = sys_out.aₑ
b = sys_out.bₑ
f = (1/(sys_out.Cₑ*π))*(1-a^2-b^2)/(1+2*a+a^2+b^2)
eq = sys_in.I_asc ~ w*f

accumulate_equation!(bc, eq)
end

function (bc::BloxConnector)(
asc_out::NextGenerationEIBlox,
cb_in::CorticalBlox;
kwargs...
)
neurons_in = get_inh_neurons(cb_in)

bc(asc_out, neurons_in[end]; kwargs...)

end


function (bc::BloxConnector)(
bloxout::NeuralMassBlox,
bloxin::NeuralMassBlox;
Expand Down Expand Up @@ -146,6 +178,18 @@ function (bc::BloxConnector)(
end
end

function (bc::BloxConnector)(
neuron_out::HHNeuronInhibBlox,
wta_in::WinnerTakeAllBlox;
kwargs...
)
neurons_in = get_exci_neurons(wta_in)

for neuron_postsyn in neurons_in
bc(neuron_out, neuron_postsyn; kwargs...)
end
end

function (bc::BloxConnector)(
cb_out::Union{CorticalBlox,STN,Thalamus},
cb_in::Union{CorticalBlox,STN,Thalamus};
Expand Down
24 changes: 18 additions & 6 deletions src/blox/cortical_blox.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct CorticalBlox{P} <: AbstractComponent
struct CorticalBlox <: AbstractComponent
namespace
parts::Vector{P}
parts
odesystem
connector
mean::Vector{Num}
Expand All @@ -13,7 +13,8 @@ struct CorticalBlox{P} <: AbstractComponent
E_syn_exci=0.0,
E_syn_inhib=-70,
G_syn_exci=3.0,
G_syn_inhib=3.0,
G_syn_inhib=4.0,
G_syn_ff_inhib=3.5,
freq=zeros(N_exci),
phase=zeros(N_exci),
τ_exci=5,
Expand All @@ -36,13 +37,24 @@ struct CorticalBlox{P} <: AbstractComponent
τ_inhib
)
end

n_ff_inh = HHNeuronInhibBlox(
name = "ff_inh",
namespace = namespaced_name(namespace, name),
E_syn = E_syn_inhib,
G_syn = G_syn_ff_inhib,
τ = τ_inhib
)



g = MetaDiGraph()
add_blox!.(Ref(g), wtas)
add_blox!.(Ref(g), vcat(wtas, n_ff_inh))

idxs = Base.OneTo(N_wta)
for i in idxs
add_edge!.(Ref(g), i, setdiff(idxs, i), Ref(Dict(kwargs)))
add_edge!(g, N_wta+1, i, Dict(:weight => 1))
end

# Construct a BloxConnector object from the graph
Expand All @@ -53,7 +65,7 @@ struct CorticalBlox{P} <: AbstractComponent
# If there is a higher namespace, construct only a subsystem containing the parts of this level
# and propagate the BloxConnector object `bc` to the higher level
# to potentially add more terms to the same connections.
sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(wtas; name)
sys = isnothing(namespace) ? system_from_graph(g, bc; name) : system_from_parts(vcat(wtas, n_ff_inh); name)

# TO DO : m is a subset of states to be plotted in the GUI.
# This can be moved to NeurobloxGUI, maybe via plotting recipes,
Expand All @@ -68,6 +80,6 @@ struct CorticalBlox{P} <: AbstractComponent
[s for s in states.((sys_namespace,), states(sys)) if contains(string(s), "V(t)")]
end

new{eltype(wtas)}(namespace, wtas, sys, bc, m)
new(namespace, vcat(wtas, n_ff_inh), sys, bc, m)
end
end
61 changes: 58 additions & 3 deletions src/blox/neural_mass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,18 +124,73 @@ mutable struct NextGenerationBlox <: NeuralMassBlox
v_syn::Num
alpha_inv::Num
k::Num
output
connector::Num
odesystem::ODESystem
function NextGenerationBlox(;name, C=30.0, Δ=1.0, η_0=5.0, v_syn=-10.0, alpha_inv=35.0, k=0.105)
namespace
function NextGenerationBlox(;name,namespace=nothing, C=30.0, Δ=1.0, η_0=5.0, v_syn=-10.0, alpha_inv=35.0, k=0.105)
params = @parameters C=C Δ=Δ η_0=η_0 v_syn=v_syn alpha_inv=alpha_inv k=k
sts = @variables Z(t)=0.5 g(t)=1.6
sts = @variables Z(t)=0.5 [output=true] g(t)=1.6
Z = ModelingToolkit.unwrap(Z)
g = ModelingToolkit.unwrap(g)
C, Δ, η_0, v_syn, alpha_inv, k = map(ModelingToolkit.unwrap, [C, Δ, η_0, v_syn, alpha_inv, k])
eqs = [Equation(D(Z), (1/C)*(-im*((Z-1)^2)/2 + (((Z+1)^2)/2)*(-Δ + im*(η_0) + im*v_syn*g) - ((Z^2-1)/2)*g))
D(g) ~ alpha_inv*((k/(C*pi))*(1-abs(Z)^2)/(1+Z+conj(Z)+abs(Z)^2) - g)]
odesys = ODESystem(eqs, t, sts, params; name=name)
new(C, Δ, η_0, v_syn, alpha_inv, k, odesys.Z, odesys)
new(C, Δ, η_0, v_syn, alpha_inv, k, sts[1], odesys.Z, odesys, namespace)
end
end

mutable struct NextGenerationResolvedBlox <: NeuralMassBlox
C::Num
Δ::Num
η_0::Num
v_syn::Num
alpha_inv::Num
k::Num
output
connector::Num
odesystem::ODESystem
namespace
function NextGenerationResolvedBlox(;name,namespace=nothing, C=30.0, Δ=1.0, η_0=5.0, v_syn=-10.0, alpha_inv=35.0, k=0.105)
params = @parameters C=C Δ=Δ η_0=η_0 v_syn=v_syn alpha_inv=alpha_inv k=k
sts = @variables a(t)=0.5 [output=true] b(t)=0.0 [output=true] g(t)=1.6
#Z = a + ib

eqs = [ D(a) ~ (1/C)*(b*(a-1) -/2)*((a+1)^2-b^2) - η_0*b*(a+1) - v_syn*g*b*(a+1) - (g/2)*(a^2-b^2-1)),
D(b) ~ (1/C)*((b^2-(a-1)^2)/2 - Δ*b*(a+1) + (η_0/2)*((a+1)^2-b^2) + v_syn*(g/2)*((a+1)^2-b^2) - a*b*g),
D(g) ~ alpha_inv*((k/(C*pi))*((1-a^2-b^2)/(1+2*a+a^2+b^2)) - g)
]
odesys = ODESystem(eqs, t, sts, params; name=name)
new(C, Δ, η_0, v_syn, alpha_inv, k, sts[1], odesys.a, odesys, namespace)
end
end


mutable struct NextGenerationEIBlox <: NeuralMassBlox
Cₑ::Num
Cᵢ::Num
output
connector::Num
odesystem::ODESystem
namespace
function NextGenerationEIBlox(;name,namespace=nothing, Cₑ=30.0,Cᵢ=30.0, Δₑ=0.5, Δᵢ=0.5, η_0ₑ=10.0, η_0ᵢ=0.0, v_synₑₑ=10.0, v_synₑᵢ=-10.0, v_synᵢₑ=10.0, v_synᵢᵢ=-10.0, alpha_invₑₑ=10.0, alpha_invₑᵢ=0.8, alpha_invᵢₑ=10.0, alpha_invᵢᵢ=0.8, kₑₑ=0, kₑᵢ=0.5, kᵢₑ=0.65, kᵢᵢ=0)
params = @parameters Cₑ=Cₑ Cᵢ=Cᵢ Δₑ=Δₑ Δᵢ=Δᵢ η_0ₑ=η_0ₑ η_0ᵢ=η_0ᵢ v_synₑₑ=v_synₑₑ v_synₑᵢ=v_synₑᵢ v_synᵢₑ=v_synᵢₑ v_synᵢᵢ=v_synᵢᵢ alpha_invₑₑ=alpha_invₑₑ alpha_invₑᵢ=alpha_invₑᵢ alpha_invᵢₑ=alpha_invᵢₑ alpha_invᵢᵢ=alpha_invᵢᵢ kₑₑ=kₑₑ kₑᵢ=kₑᵢ kᵢₑ=kᵢₑ kᵢᵢ=kᵢᵢ
sts = @variables aₑ(t)=-0.6 [output=true] bₑ(t)=0.18 [output=true] aᵢ(t)=0.02 [output=true] bᵢ(t)=0.21 [output=true] gₑₑ(t)=0 gₑᵢ(t)=0.23 gᵢₑ(t)=0.26 gᵢᵢ(t)=0

#Z = a + ib

eqs = [ D(aₑ) ~ (1/Cₑ)*(bₑ*(aₑ-1) - (Δₑ/2)*((aₑ+1)^2-bₑ^2) - η_0ₑ*bₑ*(aₑ+1) - (v_synₑₑ*gₑₑ+v_synₑᵢ*gₑᵢ)*(bₑ*(aₑ+1)) - (gₑₑ/2+gₑᵢ/2)*(aₑ^2-bₑ^2-1)),
D(bₑ) ~ (1/Cₑ)*((bₑ^2-(aₑ-1)^2)/2 - Δₑ*bₑ*(aₑ+1) + (η_0ₑ/2)*((aₑ+1)^2-bₑ^2) + (v_synₑₑ*(gₑₑ/2)+v_synₑᵢ*(gₑᵢ/2))*((aₑ+1)^2-bₑ^2) - aₑ*bₑ*(gₑₑ+gₑᵢ)),
D(aᵢ) ~ (1/Cᵢ)*(bᵢ*(aᵢ-1) - (Δᵢ/2)*((aᵢ+1)^2-bᵢ^2) - η_0ᵢ*bᵢ*(aᵢ+1) - (v_synᵢₑ*gᵢₑ+v_synᵢᵢ*gᵢᵢ)*(bᵢ*(aᵢ+1)) - (gᵢₑ/2+gᵢᵢ/2)*(aᵢ^2-bᵢ^2-1)),
D(bᵢ) ~ (1/Cᵢ)*((bᵢ^2-(aᵢ-1)^2)/2 - Δᵢ*bᵢ*(aᵢ+1) + (η_0ᵢ/2)*((aᵢ+1)^2-bᵢ^2) + (v_synᵢₑ*(gᵢₑ/2)+v_synᵢᵢ*(gᵢᵢ/2))*((aᵢ+1)^2-bᵢ^2) - aᵢ*bᵢ*(gᵢₑ+gᵢᵢ)),
D(gₑₑ) ~ alpha_invₑₑ*((kₑₑ/(Cₑ*pi))*((1-aₑ^2-bₑ^2)/(1+2*aₑ+aₑ^2+bₑ^2)) - gₑₑ),
D(gₑᵢ) ~ alpha_invₑᵢ*((kₑᵢ/(Cᵢ*pi))*((1-aᵢ^2-bᵢ^2)/(1+2*aᵢ+aᵢ^2+bᵢ^2)) - gₑᵢ),
D(gᵢₑ) ~ alpha_invᵢₑ*((kᵢₑ/(Cₑ*pi))*((1-aₑ^2-bₑ^2)/(1+2*aₑ+aₑ^2+bₑ^2)) - gᵢₑ),
D(gᵢᵢ) ~ alpha_invᵢᵢ*((kᵢᵢ/(Cᵢ*pi))*((1-aᵢ^2-bᵢ^2)/(1+2*aᵢ+aᵢ^2+bᵢ^2)) - gᵢᵢ)
]
odesys = ODESystem(eqs, t, sts, params; name=name)
new(Cₑ, Cᵢ, sts[1], odesys.aₑ, odesys, namespace)
end
end
# this assignment is temporary until all the code is changed to the new name
Expand Down
31 changes: 30 additions & 1 deletion test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,35 @@ sol = solve(prob, Vern7())
@test sol.retcode == ReturnCode.Success 
end

@testset "NextGenerationEIBlox connected to neuron" begin
global_ns = :g
@named LC = NextGenerationEIBlox(;namespace=global_ns, Cₑ=2*26,Cᵢ=1*26, Δₑ=0.5, Δᵢ=0.5, η_0ₑ=10.0, η_0ᵢ=0.0, v_synₑₑ=10.0, v_synₑᵢ=-10.0, v_synᵢₑ=10.0, v_synᵢᵢ=-10.0, alpha_invₑₑ=10.0/26, alpha_invₑᵢ=0.8/26, alpha_invᵢₑ=10.0/26, alpha_invᵢᵢ=0.8/26, kₑₑ=0.0*26, kₑᵢ=0.6*26, kᵢₑ=0.6*26, kᵢᵢ=0*26)
@named nn = HHNeuronExciBlox(;namespace=global_ns, t_spike_window=0.1)
assembly = [LC, nn]
g = MetaDiGraph()
add_blox!.(Ref(g), assembly)
add_edge!(g,1,2, :weight, 44)
neuron_net = system_from_graph(g; name=global_ns)
prob = ODEProblem(structural_simplify(neuron_net), [], (0.0, 2), [])
sol = solve(prob, Vern7())
@test neuron_net isa ODESystem
@test sol.retcode == ReturnCode.Success
end

@testset "NextGenerationEIBlox connected to CorticalBlox" begin
global_ns = :g
@named LC = NextGenerationEIBlox(;namespace=global_ns, Cₑ=2*26,Cᵢ=1*26, Δₑ=0.5, Δᵢ=0.5, η_0ₑ=10.0, η_0ᵢ=0.0, v_synₑₑ=10.0, v_synₑᵢ=-10.0, v_synᵢₑ=10.0, v_synᵢᵢ=-10.0, alpha_invₑₑ=10.0/26, alpha_invₑᵢ=0.8/26, alpha_invᵢₑ=10.0/26, alpha_invᵢᵢ=0.8/26, kₑₑ=0.0*26, kₑᵢ=0.6*26, kᵢₑ=0.6*26, kᵢᵢ=0*26)
@named cb = CorticalBlox(N_wta=2, N_exci=2, namespace=global_ns, density=0.1, weight=1)
assembly = [LC, cb]
g = MetaDiGraph()
add_blox!.(Ref(g), assembly)
add_edge!(g,1,2, :weight, 44)
neuron_net = system_from_graph(g; name=global_ns)
prob = ODEProblem(structural_simplify(neuron_net), [], (0.0, 2), [])
sol = solve(prob, Vern7())
@test sol.retcode == ReturnCode.Success
end

@testset "WinnerTakeAll" begin
N_exci = 5
@named wta= WinnerTakeAllBlox(;I_bg=5*rand(N_exci), N_exci)
Expand Down Expand Up @@ -549,7 +578,7 @@ end
g = MetaDiGraph()
add_blox!.(Ref(g), [cb1, cb2])
add_edge!(g, 1, 2, Dict(:weight => 1, :density => 0.1))
sys = system_from_graph(g; name=namespace=global_ns)
sys = system_from_graph(g; name=global_ns)
sys_simpl =structural_simplify(sys)
prob = ODEProblem(sys_simpl, [], (0,2))
sol = solve(prob, Vern7(), saveat=0.1)
Expand Down

0 comments on commit 6ba3d3c

Please sign in to comment.