Skip to content

Commit

Permalink
Merge branch 'master' into metabolic-module
Browse files Browse the repository at this point in the history
  • Loading branch information
bbantal authored Sep 6, 2024
2 parents 43c84ff + e6265ab commit 391c321
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 45 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
Expand All @@ -44,6 +43,7 @@ Peaks = "18e31ff7-3703-566c-8e60-38913d67486b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand All @@ -60,6 +60,12 @@ ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[weakdeps]
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"

[extensions]
MakieExtension = "Makie"

[compat]
AbstractFFTs = "1"
Combinatorics = "1"
Expand Down Expand Up @@ -110,6 +116,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand Down
93 changes: 93 additions & 0 deletions ext/MakieExtension.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
module MakieExtension

isdefined(Base, :get_extension) ? using Makie : using ..Makie

using Neuroblox
using Neuroblox: AbstractNeuronBlox, CompositeBlox
using Neuroblox: meanfield_timeseries, voltage_timeseries, detect_spikes, get_neurons
using SciMLBase: AbstractSolution

import Neuroblox: meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, voltage_stack

@recipe(MeanField, blox, sol) do scene
Theme()
end

argument_names(::Type{<: MeanField}) = (:blox, :sol)

function Makie.plot!(p::MeanField)
sol = p.sol[]
blox = p.blox[]

V = meanfield_timeseries(blox, sol)

lines!(p, sol.t, vec(V))

return p
end

@recipe(RasterPlot, blox, sol) do scene
Theme(
color = :black
)
end

argument_names(::Type{<: RasterPlot}) = (:blox, :sol)

function Makie.plot!(p::RasterPlot)
sol = p.sol[]
t = sol.t
blox = p.blox[]
neurons = get_neurons(blox)

for (i, n) in enumerate(neurons)
spike_idxs = detect_spikes(n, sol)
scatter!(p, t[spike_idxs], fill(i, length(spike_idxs)); color=p.color[])
end

return p
end

@recipe(StackPlot, blox, sol) do scene
Theme(
color = :black
)
end

argument_names(::Type{<: StackPlot}) = (:blox, :sol)

function Makie.plot!(p::StackPlot)
sol = p.sol[]
blox = p.blox[]

V = voltage_timeseries(blox, sol)

offset = 20
for (i,V_neuron) in enumerate(eachcol(V))
lines!(p, sol.t, (i-1)*offset .+ V_neuron; color=p.color[])
end

return p
end

function Makie.convert_arguments(::Makie.PointBased, blox::AbstractNeuronBlox, sol::AbstractSolution)
V = voltage_timeseries(blox, sol)

return (sol.t, V)
end

function voltage_stack(blox::CompositeBlox, sol::AbstractSolution; N_neurons=10, fontsize=8, color=:black)
neurons = get_neurons(blox)
N_ax = min(length(neurons), N_neurons)

fig = Figure()
ax = Axis(fig[1,1], xlabel="Time", ylabel="Neurons")

hideydecorations!(ax)

stackplot!(ax, blox, sol)

display(fig)
end

end
28 changes: 23 additions & 5 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
module Neuroblox

if !isdefined(Base, :get_extension)
using Requires
end

using Reexport
@reexport using ModelingToolkit
const t = ModelingToolkit.t_nounits
Expand Down Expand Up @@ -53,8 +57,6 @@ using Peaks: argmaxima, peakproms!, peakheights!

using LogExpFunctions: logistic

using MakieCore

# define abstract types for Neuroblox
abstract type AbstractBlox end # Blox is the abstract type for Blox that are displayed in the GUI
abstract type AbstractComponent end
Expand Down Expand Up @@ -119,7 +121,6 @@ include("gui/GUI.jl")
include("blox/connections.jl")
include("blox/blox_utilities.jl")
include("Neurographs.jl")
include("./plot_recipes/composite_recipes.jl")

function simulate(sys::ODESystem, u0, timespan, p, solver = AutoVern7(Rodas4()); kwargs...)
prob = ODEProblem(sys, u0, timespan, p)
Expand Down Expand Up @@ -186,17 +187,34 @@ https://github.com/Neuroblox/NeurobloxIssues.
""")
end

function meanfield end
function meanfield! end

function rasterplot end
function rasterplot! end

function stackplot end
function stackplot! end

function voltage_stack end

function __init__()
#if Preferences.@load_preference("PrintLicense", true)
print_license()
#end

@static if !isdefined(Base, :get_extension)
@require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
include("../ext/MakieExtension.jl")
end
end
end

export JansenRitSPM12, next_generation, qif_neuron, if_neuron, hh_neuron_excitatory,
hh_neuron_inhibitory, van_der_pol, Generic2dOscillator
export HHNeuronExciBlox, HHNeuronInhibBlox, IFNeuron, LIFNeuron, QIFNeuron, IzhikevichNeuron, LIFExciNeuron, LIFInhNeuron,
CanonicalMicroCircuitBlox, WinnerTakeAllBlox, CorticalBlox, SuperCortical, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox,
HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox
HHNeuronInhib_GPe_Adam_Blox, Striatum_MSN_Adam, Striatum_FSI_Adam, GPe_Adam, STN_Adam, LIFExciCircuitBlox, LIFInhCircuitBlox, MetabolicHHNeuron
export LinearNeuralMass, HarmonicOscillator, JansenRit, WilsonCowan, LarterBreakspear, NextGenerationBlox, NextGenerationResolvedBlox, NextGenerationEIBlox, KuramotoOscillator
export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
Expand All @@ -219,6 +237,6 @@ export run_experiment!, run_trial!
export addnontunableparams
export get_weights, get_dynamic_states, get_idx_tagged_vars, get_eqidx_tagged_vars
export BalloonModel,LeadField, boldsignal_endo_balloon
export MetabolicHHNeuron
export meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, voltage_stack

end
73 changes: 50 additions & 23 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ function get_weights(agent::Agent, blox_out, blox_in)
end

function find_spikes(x::AbstractVector{T}; minprom=zero(T), maxprom=nothing, minheight=zero(T), maxheight=nothing) where {T}
spikes, _ = argmaxima(x)
peakproms!(spikes, x; minprom, maxheight)
peakheights!(spikes, xx[spikes]; minheight, maxheight)
spikes = argmaxima(x)
peakproms!(spikes, x; minprom, maxprom)
peakheights!(spikes, x[spikes]; minheight, maxheight)

return spikes
end
Expand All @@ -264,6 +264,27 @@ function count_spikes(x::AbstractVector{T}; minprom=zero(T), maxprom=nothing, mi
return length(spikes)
end

function detect_spikes(blox::AbstractNeuronBlox, sol::SciMLBase.AbstractSolution; tolerance = 1e-3)
namespaced_name = namespaced_nameof(blox)
reset_param_name = Symbol(namespaced_name, "₊V_reset")
threshold_param_name = Symbol(namespaced_name, "₊θ")

reset = only(@parameters $(reset_param_name))
thrs = only(@parameters $(threshold_param_name))

get_reset = getp(sol, reset)
reset_value = get_reset(sol)

get_thrs = getp(sol, thrs)
thrs_value = get_thrs(sol)

V = voltage_timeseries(blox, sol)

spikes = find_spikes(V; minheight = thrs_value - tolerance)

return spikes
end

"""
function get_dynamic_states(sys)
Expand Down Expand Up @@ -375,42 +396,48 @@ to_vector(v) = [v]

nanmean(x) = mean(filter(!isnan,x))

function voltage_timeseries(sol::SciMLBase.AbstractSolution, blox::AbstractNeuronBlox)
namespaced_name = string(namespaceof(blox), nameof(blox))
state_name = Symbol(namespaced_name, "₊V")
function replace_refractory!(V, blox::Union{LIFExciNeuron, LIFInhNeuron}, sol::SciMLBase.AbstractSolution)
namespaced_name = namespaced_nameof(blox)
reset_param_name = Symbol(namespaced_name, "₊V_reset")
p = only(@parameters $(reset_param_name))

s = only(@variables $(state_name)(t))
get_reset = getp(sol, p)
reset_value = get_reset(sol)

return sol[s]
V[V .== reset_value] .= NaN

return V
end

function replace_refractory!(V, blox::CompositeBlox, sol::SciMLBase.AbstractSolution)
neurons = get_neurons(blox)

for (i, n) in enumerate(neurons)
V[:, i] = replace_refractory!(V[:,i], n, sol)
end
end

function voltage_timeseries(sol::SciMLBase.AbstractSolution, blox::Union{LIFExciNeuron, LIFInhNeuron})
replace_refractory!(V, blox, sol::SciMLBase.AbstractSolution) = V

function voltage_timeseries(blox::AbstractNeuronBlox, sol::SciMLBase.AbstractSolution)
namespaced_name = namespaced_nameof(blox)
state_name = Symbol(namespaced_name, "₊V")
reset_param_name = Symbol(namespaced_name, "₊V_reset")


s = only(@variables $(state_name)(t))
p = only(@parameters $(reset_param_name))

get_reset = getp(sol, p)
reset_value = get_reset(sol)
V = sol[s]

V[V .== reset_value] .= NaN

return V
return sol[s]
end

function voltage_timeseries(sol::SciMLBase.AbstractSolution, cb::CompositeBlox)
function voltage_timeseries(cb::CompositeBlox, sol::SciMLBase.AbstractSolution)

return mapreduce(hcat, get_neurons(cb)) do neuron
voltage_timeseries(sol, neuron)
voltage_timeseries(neuron, sol)
end
end

function average_voltage_timeseries(sol::SciMLBase.AbstractSolution, cb::CompositeBlox)
V = voltage_timeseries(sol, cb)
function meanfield_timeseries(cb::CompositeBlox, sol::SciMLBase.AbstractSolution)
V = voltage_timeseries(cb, sol)
replace_refractory!(V, cb, sol)

return vec(mapslices(nanmean, V; dims = 2))
end
5 changes: 0 additions & 5 deletions src/plot_recipes/composite_recipes.jl

This file was deleted.

25 changes: 14 additions & 11 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,41 @@ using Statistics
prob = ODEProblem(ss, [], (0, 200.0))
sol = solve(prob, Tsit5())

@test all(sol[ss.n.V] .== Neuroblox.voltage_timeseries(sol, n))
@test all(sol[ss.n.V] .== Neuroblox.voltage_timeseries(n, sol))
end

@testset "Voltage timeseries + Composite average [LIFExciCircuitBloxz]" begin
global_ns = :g
tspan = (0, 200)
V_reset = -55

@named s = PoissonSpikeTrain(3, tspan; namespace = global_ns)
@named n = LIFExciCircuitBlox(; V_reset, namespace = global_ns, N_neurons = 3, weight=1)

neurons = [s, n]

g = MetaDiGraph()
add_blox!.(Ref(g), neurons)

add_edge!(g, 1, 2, Dict(:weight => 1))

sys = system_from_graph(g; name = global_ns)
ss = structural_simplify(sys)
prob = ODEProblem(ss, [], (0, 200.0))
sol = solve(prob, Tsit5())

V = hcat(sol[ss.n.neuron1.V], sol[ss.n.neuron2.V], sol[ss.n.neuron3.V])
V[V .== V_reset] .= NaN

@test all(isequal(V, Neuroblox.voltage_timeseries(sol, n)))


V_nb = Neuroblox.voltage_timeseries(n, sol)
Neuroblox.replace_refractory!(V_nb, n, sol)
@test all(isequal(V, V_nb))

V_filtered = map(eachrow(V)) do V_t
v = filter(!isnan, V_t)
mean(v)
end

@test all(isequal(V_filtered, Neuroblox.average_voltage_timeseries(sol, n)))
@test all(isequal(V_filtered, Neuroblox.meanfield_timeseries(n, sol)))
end

0 comments on commit 391c321

Please sign in to comment.