Skip to content

Commit

Permalink
Merge branch 'master' into make-discrete-state-reducible
Browse files Browse the repository at this point in the history
  • Loading branch information
anandpathak31 authored Dec 8, 2023
2 parents 7deb773 + f1fe424 commit 6af1085
Show file tree
Hide file tree
Showing 14 changed files with 263 additions and 103 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SignalAnalysis = "df1fea92-c066-49dd-8b36-eace3378ea47"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand Down
7 changes: 5 additions & 2 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using Distributions
using ModelingToolkit: get_namespace, get_systems, isparameter,
renamespace, namespace_equation, namespace_parameters, namespace_expr,
AbstractODESystem
import ModelingToolkit: inputs, nameof
import ModelingToolkit: inputs, nameof, outputs, getdescription

using Symbolics: @register_symbolic, getdefaultval
using IfElse
Expand All @@ -52,6 +52,7 @@ abstract type AbstractNeuronBlox <: AbstractBlox end
abstract type NeuralMassBlox <: AbstractBlox end
abstract type CompositeBlox <: AbstractBlox end
abstract type StimulusBlox <: AbstractBlox end
abstract type ObserverBlox <: AbstractBlox end

# we define these in neural_mass.jl
# abstract type HarmonicOscillatorBlox <: NeuralMassBlox end
Expand Down Expand Up @@ -172,12 +173,14 @@ export LinearConnections, SynapticConnections, ODEfromGraph, ODEfromGraphNeuron,
export add_blox!
export powerspectrum, complexwavelet, bandpassfilter, hilberttransform, phaseangle, mar2csd, csd2mar, mar_ml
export learningrate, ControlError
export Hemodynamics, LinHemo, boldsignal
export boldsignal, BalloonModel
export vecparam, unvecparam, csd_Q, spectralVI
export simulate, random_initials
export system_from_graph, graph_delays
export create_adjacency_edges!
export get_namespaced_sys, nameof
export run_experiment!, run_experiment_open_loop!
export addnontunableparams, get_hemodynamic_observers


end
74 changes: 51 additions & 23 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
function progress_scope(params; lvl=0)
para_list = []
for p in params
pp = ModelingToolkit.unwrap(p)
if ModelingToolkit.hasdefault(pp)
d = ModelingToolkit.getdefault(pp)
if typeof(d)==SymbolicUtils.BasicSymbolic{Real}
if lvl==0
pp = ParentScope(pp)
else
pp = DelayParentScope(pp,lvl)
end
end
end
push!(para_list,ModelingToolkit.wrap(pp))
end
return para_list
end
# function progress_scope(params; lvl=0)
# para_list = []
# for p in params
# pp = ModelingToolkit.unwrap(p)
# if ModelingToolkit.hasdefault(pp)
# d = ModelingToolkit.getdefault(pp)
# if typeof(d)==SymbolicUtils.BasicSymbolic{Real}
# if lvl==0
# pp = ParentScope(pp)
# else
# pp = DelayParentScope(pp,lvl)
# end
# end
# end
# push!(para_list,ModelingToolkit.wrap(pp))
# end
# return para_list
# end

"""
This function progresses the scope of parameters and leaves floating point values untouched
"""
function progress_scope(args...)
paramlist = []
for p in args
if p isa Float64
push!(paramlist, p)
else
if p isa Num
p = ParentScope(p)
# pp = ModelingToolkit.unwrap(p)
# if ModelingToolkit.hasdefault(pp)
Expand All @@ -36,6 +34,8 @@ function progress_scope(args...)
# end
# push!(para_list,ModelingToolkit.wrap(pp))
push!(paramlist, p)
else
push!(paramlist, p)
end
end
return paramlist
Expand All @@ -48,8 +48,8 @@ end
function compileparameterlist(;kwargs...)
paramlist = []
for (kw, v) in kwargs
if v isa Float64
paramlist = vcat(paramlist, @parameters $kw = v)
if v isa Union{Float64, Int} # note that Num is also subtype of Real. Thus union of types seems to be the solution.
paramlist = vcat(paramlist, @parameters $kw = v [tunable=true])
else
paramlist = vcat(paramlist, v)
end
Expand Down Expand Up @@ -248,3 +248,31 @@ function count_spikes(x::AbstractVector{T}; minprom=zero(T), maxprom=nothing, mi

return length(spikes)
end

function get_hemodynamic_observers(sys_from_graph, nr)
obs_idx = Dict([k => [] for k in 1:nr])
obs_states = Dict([k => [] for k in 1:nr])
for (i, s) in enumerate(states(sys_from_graph))
if isequal(getdescription(s), "hemodynamic_observer")
regionidx = parse(Int64, split(string(s), "")[1][end])
push!(obs_idx[regionidx], i)
push!(obs_states[regionidx], s)
end
end
return (obs_idx, obs_states)
end

function addnontunableparams(param, model)
newparam = []
k = 0
for p in parameters(model)
if istunable(p)
k += 1
push!(newparam, param[k])
else
push!(newparam, Symbolics.getdefaultval(p))
end
end
append!(newparam, param[k+1:end])
return newparam
end
85 changes: 82 additions & 3 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ function generate_weight_param(blox_out, blox_in; kwargs...)

weight = get_weight(kwargs, name_out, name_in)
w_name = Symbol("w_$(name_out)_$(name_in)")
w = only(@parameters $(w_name)=weight)
if typeof(weight) == Num # Symbol
w = weight
else
w = only(@parameters $(w_name)=weight)
end

return w
end
Expand All @@ -98,7 +102,15 @@ end
Helper to merge delays and weights into a single vector
"""
function params(bc::BloxConnector)
return vcat(bc.weights, bc.delays)
weights = []
for w in bc.weights
append!(weights, Symbolics.get_variables(w))
end
if isempty(weights)
return vcat(weights, bc.delays)
else
return vcat(reduce(vcat, weights), bc.delays)
end
end

function (bc::BloxConnector)(
Expand Down Expand Up @@ -157,7 +169,6 @@ function (bc::BloxConnector)(
neurons_in = get_inh_neurons(cb_in)

bc(asc_out, neurons_in[end]; kwargs...)

end

function (bc::BloxConnector)(
Expand Down Expand Up @@ -193,6 +204,74 @@ function (bc::BloxConnector)(
accumulate_equation!(bc, eq)
end

# additional dispatch to connect to hemodynamic observer blox
function (bc::BloxConnector)(
bloxout::NeuralMassBlox,
bloxin::ObserverBlox;
weight=1,
delay=0,
density=0.1
)
# Need t for the delay term
@variables t

sys_out = get_namespaced_sys(bloxout)
sys_in = get_namespaced_sys(bloxin)

if typeof(bloxout.output) == Num
w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))")
if typeof(weight) == Num # Symbol
w = weight
else
w = only(@parameters $(w_name)=weight)
end
push!(bc.weights, w)
x = namespace_expr(bloxout.output, sys_out, nameof(sys_out))
eq = sys_in.jcn ~ x*w
else
# Define & accumulate delay parameter
# Don't accumulate if zero
τ_name = Symbol("τ_$(nameof(sys_out))_$(nameof(sys_in))")
τ = only(@parameters $(τ_name)=delay)
push!(bc.delays, τ)

w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))")
w = only(@parameters $(w_name)=weight)
push!(bc.weights, w)

x = namespace_expr(bloxout.output, sys_out, nameof(sys_out))
eq = sys_in.jcn ~ x(t-τ)*w
end

accumulate_equation!(bc, eq)
end

# # Ok yes this is a bad dispatch but the whole compound blocks implementation is hacky and needs fixing @@
# # Opening an issue to loop back to this during clean up week
# function (bc::BloxConnector)(
# bloxout::CompoundNOBlox,
# bloxin::CompoundNOBlox;
# weight=1,
# delay=0,
# density=0.1
# )

# sys_out = get_namespaced_sys(bloxout)
# sys_in = get_namespaced_sys(bloxin)

# w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))")
# if typeof(weight) == Num # Symbol
# w = weight
# else
# w = only(@parameters $(w_name)=weight)
# end
# push!(bc.weights, w)
# x = namespace_expr(bloxout.output, sys_out, nameof(sys_out))
# eq = sys_in.nmm₊jcn ~ x*w

# accumulate_equation!(bc, eq)
# end

function (bc::BloxConnector)(
wta_out::WinnerTakeAllBlox,
wta_in::WinnerTakeAllBlox;
Expand Down
12 changes: 6 additions & 6 deletions src/blox/cortical_blox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ struct CorticalBlox <: CompositeBlox
τ_inhib
)
end

n_ff_inh = HHNeuronInhibBlox(
name = "ff_inh",
namespace = namespaced_name(namespace, name),
E_syn = E_syn_inhib,
G_syn = G_syn_ff_inhib,
namespace = namespaced_name(namespace, name),
E_syn = E_syn_inhib,
G_syn = G_syn_ff_inhib,
τ = τ_inhib
)
)

g = MetaDiGraph()
add_blox!.(Ref(g), vcat(wtas, n_ff_inh))

Expand Down
24 changes: 15 additions & 9 deletions src/blox/neural_mass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ mutable struct HarmonicOscillatorBlox <: NeuralMassBlox
initial::Dict{Num, Tuple{Float64, Float64}}
odesystem::ODESystem
function HarmonicOscillatorBlox(;name, ω=25*(2*pi), ζ=1.0, k=625*(2*pi), h=35.0)
params = progress_scope(@parameters ω=ω ζ=ζ k=k h=h)
params = progress_scope(ω, ζ, k, h)
params = compileparameterlist=params[1], ζ=params[2], k=params[3], h=params[4])
sts = @variables x(t)=1.0 y(t)=1.0 jcn(t)=0.0
ω, ζ, k, h = params
eqs = [D(x) ~ y-(2*ω*ζ*x)+ k*(2/π)*(atan((jcn)/h))
Expand All @@ -45,7 +46,8 @@ mutable struct JansenRitCBlox <: NeuralMassBlox
initial::Dict{Num, Tuple{Float64, Float64}}
odesystem::ODESystem
function JansenRitCBlox(;name, τ=0.001, H=20.0, λ=5.0, r=0.15)
params = progress_scope(@parameters τ=τ H=H λ=λ r=r)
params = progress_scope(τ, H, λ, r)
params = compileparameterlist=params[1], H=params[2], λ=params[3], r=params[4])
sts = @variables x(t)=1.0 y(t)=1.0 jcn(t)=0.0
τ, H, λ, r = params
eqs = [D(x) ~ y - ((2/τ)*x),
Expand All @@ -66,7 +68,8 @@ mutable struct JansenRitSCBlox <: NeuralMassBlox
initial::Dict{Num, Tuple{Float64, Float64}}
odesystem::ODESystem
function JansenRitSCBlox(;name, τ=0.014, H=20.0, λ=400.0, r=0.1)
params = progress_scope(@parameters τ=τ H=H λ=λ r=r)
params = progress_scope(τ, H, λ, r)
params = compileparameterlist=params[1], H=params[2], λ=params[3], r=params[4])
sts = @variables x(t)=1.0 y(t)=1.0 jcn(t)=0.0
τ, H, λ, r = params
eqs = [D(x) ~ y - ((2/τ)*x),
Expand Down Expand Up @@ -301,7 +304,7 @@ struct LinearNeuralMass <: NeuralMassBlox
odesystem
namespace
function LinearNeuralMass(;name, namespace=nothing)
sts = @variables x(t) [output=true] jcn(t) [input=true]
sts = @variables x(t)=0.0 [output=true] jcn(t)=0.0 [input=true]
eqs = [D(x) ~ jcn]
sys = System(eqs, name=name)
new(sts[1], sts[2], sys, namespace)
Expand All @@ -319,7 +322,8 @@ struct HarmonicOscillator <: NeuralMassBlox
odesystem
namespace
function HarmonicOscillator(;name, namespace=nothing, ω=25*(2*pi)*0.001, ζ=1.0, k=625*(2*pi), h=35.0)
p = progress_scope(@parameters ω=ω ζ=ζ k=k h=h)
p = progress_scope(ω, ζ, k, h)
p = compileparameterlist=p[1], ζ=p[2], k=p[3], h=p[4])
sts = @variables x(t)=1.0 [output=true] y(t)=1.0 jcn(t)=0.0 [input=true]
ω, ζ, k, h = p
eqs = [D(x) ~ y-(2*ω*ζ*x)+ k*(2/π)*(atan((jcn)/h))
Expand Down Expand Up @@ -353,7 +357,8 @@ struct JansenRit <: NeuralMassBlox
λ = isnothing(λ) ? (cortical ? 5.0 : 400.0) : λ
r = isnothing(r) ? (cortical ? 0.15 : 0.1) : r

p = progress_scope(@parameters τ=τ H=H λ=λ r=r)
p = progress_scope(τ, H, λ, r)
p = compileparameterlist=p[1], H=p[2], λ=p[3], r=p[4])
τ, H, λ, r = p
sts = @variables x(..)=1.0 [output=true] y(t)=1.0 jcn(t)=0.0 [input=true]
eqs = [D(x(t)) ~ y - ((2/τ)*x(t)),
Expand Down Expand Up @@ -389,8 +394,8 @@ struct WilsonCowan <: NeuralMassBlox
θ_I=3.5,
η=1.0
)
p = progress_scope(@parameters τ_E=τ_E τ_I=τ_I a_E=a_E a_I=a_I c_EE=c_EE c_IE=c_IE c_EI=c_EI c_II=c_II θ_E=θ_E θ_I=θ_I η=η)

p = progress_scope(τ_E, τ_I, a_E, a_I, c_EE, c_IE, c_EI, c_II, θ_E, θ_I, η)
p = compileparameterlist(τ_E=p[1], τ_I=p[2], a_E=p[3], a_I=p[4], c_EE=p[5], c_IE=p[6], c_EI=p[7], c_II=p[8], θ_E=p[9], θ_I=p[10], η=p[11])
τ_E, τ_I, a_E, a_I, c_EE, c_IE, c_EI, c_II, θ_E, θ_I, η = p
sts = @variables E(t)=1.0 [output=true] I(t)=1.0 jcn(t)=0.0 [input=true] #P(t)=0.0
eqs = [D(E) ~ -E/τ_E + 1/(1 + exp(-a_E*(c_EE*E - c_IE*I - θ_E + η*(jcn)))), #old form: D(E) ~ -E/τ_E + 1/(1 + exp(-a_E*(c_EE*E - c_IE*I - θ_E + P + η*(jcn)))),
Expand Down Expand Up @@ -443,7 +448,8 @@ struct LarterBreakspear <: NeuralMassBlox
r_NMDA=0.25,
C=0.35
)
p = progress_scope(@parameters C=C δ_VZ=δ_VZ T_Ca=T_Ca δ_Ca=δ_Ca g_Ca=g_Ca V_Ca=V_Ca T_K=T_K δ_K=δ_K g_K=g_K V_K=V_K T_Na=T_Na δ_Na=δ_Na g_Na=g_Na V_Na=V_Na V_L=V_L g_L=g_L V_T=V_T Z_T=Z_T Q_Vmax=Q_Vmax Q_Zmax=Q_Zmax IS=IS a_ee=a_ee a_ei=a_ei a_ie=a_ie a_ne=a_ne a_ni=a_ni b=b τ_K=τ_K ϕ=ϕ r_NMDA=r_NMDA)
p = progress_scope(C, δ_VZ, T_Ca, δ_Ca, g_Ca, V_Ca, T_K, δ_K, g_K, V_K, T_Na, δ_Na, g_Na, V_Na, V_L, g_L, V_T, Z_T, Q_Vmax, Q_Zmax, IS, a_ee, a_ei, a_ie, a_ne, a_ni, b, τ_K, ϕ,r_NMDA)
p = compileparameterlist(C=p[1], δ_VZ=p[2], T_Ca=p[3], δ_Ca=p[4], g_Ca=p[5], V_Ca=p[6], T_K=p[7], δ_K=p[8], g_K=p[9], V_K=p[10], T_Na=p[11], δ_Na=p[12], g_Na=p[13],V_Na=p[14], V_L=p[15], g_L=p[16], V_T=p[17], Z_T=p[18], Q_Vmax=p[19], Q_Zmax=p[20], IS=p[21], a_ee=p[22], a_ei=p[23], a_ie=p[24], a_ne=p[25], a_ni=p[26], b=p[27], τ_K=p[28], ϕ=p[29], r_NMDA=p[30])
C, δ_VZ, T_Ca, δ_Ca, g_Ca, V_Ca, T_K, δ_K, g_K, V_K, T_Na, δ_Na, g_Na,V_Na, V_L, g_L, V_T, Z_T, Q_Vmax, Q_Zmax, IS, a_ee, a_ei, a_ie, a_ne, a_ni, b, τ_K, ϕ, r_NMDA = p

sts = @variables V(t)=0.5 Z(t)=0.5 W(t)=0.5 jcn(t)=0.0 [input=true] Q_V(t) [output=true] Q_Z(t) m_Ca(t) m_Na(t) m_K(t)
Expand Down
Loading

0 comments on commit 6af1085

Please sign in to comment.