Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving spDCM tutorial #500

Merged
merged 63 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
8571333
updated CMC code (#474)
david-hofmann Nov 13, 2024
cac1cfb
removed some wrong lines that resulted from a rebase
david-hofmann Nov 13, 2024
94f9528
Changing edges to standard nomenclature.
david-hofmann Dec 10, 2024
c50a013
change of time scale from seconds to milliseconds
david-hofmann Jan 2, 2025
94cced3
bump GraphDynamics compat (#491)
MasonProtter Nov 15, 2024
b76346f
Add decision making example (#492)
harisorgn Nov 15, 2024
f05b77c
Update README.md
hstrey Nov 20, 2024
a290c65
Update README.md
hstrey Nov 20, 2024
e530554
remove `jcn=0` pattern from discrete blox
MasonProtter Nov 21, 2024
821655d
Fix CI errors from `jcn` variables being set with initial conditions …
agchesebro Nov 21, 2024
50cab67
Update README.md
helmutstrey Nov 21, 2024
91cd7af
added plots to GUI's plotdetail
helmutstrey Dec 5, 2024
2a9e1fc
added more regions
helmutstrey Dec 5, 2024
7e114ad
add plotdetail properly
helmutstrey Dec 5, 2024
014e485
Bump MTK compat (#497)
MasonProtter Dec 6, 2024
34b2cc8
Firing rate cleaning (#496)
gabrevaya Dec 6, 2024
95df12d
address Scott’s feedback on the basal ganglia tutorial (#498)
gabrevaya Dec 6, 2024
87344a3
Plot recipes improvements (#495)
gabrevaya Dec 9, 2024
0b321a9
Change `BloxConnector` to `Connector` (#502)
harisorgn Dec 16, 2024
b8778b4
Improve `show` for `Connector` objects (#503)
harisorgn Dec 17, 2024
c00d823
Switch `Connector` fields from `Vector{Vector{...}}` to `Vector{...}`…
harisorgn Dec 20, 2024
888f692
Generic `Connector` dispatch (#505)
harisorgn Dec 22, 2024
fc3af35
Replace `output` field with `MTK.outputs` function and cleanup (#507)
harisorgn Dec 24, 2024
3461c2f
filter out empty spike affects (#508)
harisorgn Dec 24, 2024
358d4f7
Add `PoissonSpikeTrain` to `Connector.spike_affects` collection and i…
harisorgn Dec 27, 2024
b9d93e2
add simple stimulus example (#511)
gabrevaya Dec 27, 2024
bd035af
More UI fixes & improvements for IAP course (#513)
harisorgn Dec 30, 2024
4d32499
added Plotsettings tab in GUI.jl
helmutstrey Dec 26, 2024
157e57c
moved plot options to Sim
helmutstrey Dec 27, 2024
5b0d781
Switch QIF neurons to use discrete events (#514)
MasonProtter Dec 30, 2024
988c183
print to the passed `IO` explicitly (#515)
harisorgn Dec 30, 2024
65401c1
added adj to special plot options
helmutstrey Dec 30, 2024
2f70c55
Making code consistent with SPM25 (#501)
david-hofmann Jan 2, 2025
37f24ee
updated CMC code (#474)
david-hofmann Nov 13, 2024
cee545e
removed some wrong lines that resulted from a rebase
david-hofmann Nov 13, 2024
ccfa6f6
fixed effective connectivity recipe
david-hofmann Jan 2, 2025
f684484
added some more explanator text
david-hofmann Jan 2, 2025
3c99dfa
fixed readme
david-hofmann Jan 2, 2025
dd54144
Update README.md
hstrey Nov 20, 2024
fb4a2c2
remove `jcn=0` pattern from discrete blox
MasonProtter Nov 21, 2024
f5004c0
Fix CI errors from `jcn` variables being set with initial conditions …
agchesebro Nov 21, 2024
babdf2e
added plots to GUI's plotdetail
helmutstrey Dec 5, 2024
9c5dc36
added more regions
helmutstrey Dec 5, 2024
a25362e
Bump MTK compat (#497)
MasonProtter Dec 6, 2024
e2fab04
Change `BloxConnector` to `Connector` (#502)
harisorgn Dec 16, 2024
25245b3
Improve `show` for `Connector` objects (#503)
harisorgn Dec 17, 2024
9b24eac
Switch `Connector` fields from `Vector{Vector{...}}` to `Vector{...}`…
harisorgn Dec 20, 2024
eb95919
Generic `Connector` dispatch (#505)
harisorgn Dec 22, 2024
a997406
Replace `output` field with `MTK.outputs` function and cleanup (#507)
harisorgn Dec 24, 2024
a77fd99
filter out empty spike affects (#508)
harisorgn Dec 24, 2024
b335382
Add `PoissonSpikeTrain` to `Connector.spike_affects` collection and i…
harisorgn Dec 27, 2024
a09dfa7
More UI fixes & improvements for IAP course (#513)
harisorgn Dec 30, 2024
bb8b537
added Plotsettings tab in GUI.jl
helmutstrey Dec 26, 2024
062545c
moved plot options to Sim
helmutstrey Dec 27, 2024
60de89f
Switch QIF neurons to use discrete events (#514)
MasonProtter Dec 30, 2024
09c7298
print to the passed `IO` explicitly (#515)
harisorgn Dec 30, 2024
86c06db
added adj to special plot options
helmutstrey Dec 30, 2024
69f5644
Making code consistent with SPM25 (#501)
david-hofmann Jan 2, 2025
ce4a67f
start learning tutorial (#480)
anandpathak31 Jan 2, 2025
cc84a44
added an option for Neuron selection
helmutstrey Jan 2, 2025
a0b45b8
updated to 0.5.2
helmutstrey Jan 2, 2025
0d79b0d
Added more comments to the tutorial documentation
david-hofmann Jan 3, 2025
0b5dc39
Merge branch 'master' into spDCM_tutorial_revisions
david-hofmann Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions docs/src/tutorials/spectralDCM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# inferring model parameters, that is solving inverse problems, from time series.
# The method of choice is one of the most widely spread in imaging neuroscience, spectral Dynamic Causal Modeling (spDCM)[1,2].
# In this tutorial we will introduce how to perform a spDCM analysis on simulated data.
# To do so we roughly reproduce the procedure in the [SPM12](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) script `DEM_demo_induced_fMRI.m` in [Neuroblox](https://www.neuroblox.org/).
# To do so we roughly reproduce the procedure in the [SPM](https://www.fil.ion.ucl.ac.uk/spm/software/spm12/) script `DEM_demo_induced_fMRI.m` in [Neuroblox](https://www.neuroblox.org/).
# This work was also presented in Hofmann et al.[2]
#
# In this tutorial we will define a circuit of three linear neuronal mass models, all driven by an Ornstein-Uhlenbeck process.
Expand All @@ -30,6 +30,7 @@ using DataFrames
using OrderedCollections
using CairoMakie
using ModelingToolkit
using Random

# # Model simulation
# ## Define the model
Expand All @@ -41,13 +42,16 @@ using ModelingToolkit
# We want to simulate fMRI signals thus we will need to also add a BalloonModel per region.
# Note that the Ornstein-Uhlenbeck block will feed into the linear neural mass which in turn will feed into the BalloonModel blox.
# This needs to be represented by the way we define the edges.
Random.seed!(17) # set seed for reproducibility

nr = 3 # number of regions
g = MetaDiGraph()
regions = []; # list of neural mass blocks to then connect them to each other with an adjacency matrix `A_true`
# Now add the different blocks to each region and connect the blocks within each region:
regions = []; # list of neural mass blocks to then connect them to each other with an adjacency matrix `A_true`
# Now add the different blocks to each region and connect the blocks within each region.
# For convenience we use a for loop since the type of blocks belonging to a each region repeat over regions but you could also approach building the system the same way as was shown in previous tutorials:
for i = 1:nr
region = LinearNeuralMass(;name=Symbol("r$(i)₊lm"))
push!(regions, region) # store neural mass model for connection of regions
push!(regions, region) # store neural mass model in list. We need this list below. If you haven't seen the Julia command `push!` before [see here](http://jlhub.com/julia/manual/en/function/push-exclamation).

## add Ornstein-Uhlenbeck block as noisy input to the current region
input = OUBlox(;name=Symbol("r$(i)₊ou"), σ=0.1)
Expand All @@ -57,10 +61,7 @@ for i = 1:nr
measurement = BalloonModel(;name=Symbol("r$(i)₊bm"))
add_edge!(g, region => measurement, weight=1.0)
end
# Next we define the between-region connectivity matrix and make sure that it is diagonally dominant to guarantee numerical stability (see Gershgorin theorem).
A_true = 0.1*randn(nr, nr)
A_true -= diagm(map(a -> sum(abs, a), eachrow(A_true))) # ensure diagonal dominance of matrix
# Instead of a random matrix use the same matrix as is defined in [3]
# Next we define the between-region connectivity matrix and connect regions; we use the same matrix as is defined in [3]
A_true = [[-0.5 -2 0]; [0.4 -0.5 -0.3]; [0 0.2 -0.5]]
for idx in CartesianIndices(A_true)
add_edge!(g, regions[idx[1]] => regions[idx[2]], weight=A_true[idx[1], idx[2]])
Expand All @@ -74,15 +75,18 @@ end
# setup simulation of the model, time in seconds
tspan = (0.0, 512.0)
prob = SDEProblem(simmodel, [], tspan)
dt = 2.0 # two seconds as measurement interval for fMRI
dt = 2 # 2 seconds (units are milliseconds) as measurement interval for fMRI
sol = solve(prob, ImplicitRKMil(), saveat=dt);

# plot bold signal time series
# we now want to extract all the variables in our model which carry the tag "measurement". For this purpose we can use the Neuroblox function `get_idx_tagged_vars`
# the observable quantity in our model is the BOLD signal, the variable of the Blox `BalloonModel` that represents the BOLD signal is tagged with "measurement" tag.
# other tags that are defined are "input" which denotes variables representing a stimulus, like for instance an `OUBlox`.
idx_m = get_idx_tagged_vars(simmodel, "measurement") # get index of bold signal
# plot bold signal time series
f = Figure()
ax = Axis(f[1, 1],
title = "fMRI time series",
xlabel = "Time [s]",
xlabel = "Time [ms]",
ylabel = "BOLD",
)
lines!(ax, sol, idxs=idx_m)
Expand Down Expand Up @@ -117,10 +121,14 @@ fig
# Note that parameters are tunable by default.
g = MetaDiGraph()
regions = []; # list of neural mass blocks to then connect them to each other with an adjacency matrix `A`
# The following parameters are shared accross regions, which is why we define them here.
# Note that parameters are typically defined within a Blox and thus not immediately visible to the user.
# Since we want some parameters to be shared across several regions we define them outside of the regions.
# For this purpose use the ModelingToolkit macro `@parameters` which is used to define symbolic parameters for models.
# Note that we can set the tunable flag right away thereby defining whether we will include this parameter in the optimization procedure or rather keep it fixed to its predefined value.
@parameters lnκ=0.0 [tunable=false] lnϵ=0.0 [tunable=false] lnτ=0.0 [tunable=false] # lnκ: decay parameter for hemodynamics; lnϵ: ratio of intra- to extra-vascular components, lnτ: transit time scale
@parameters C=1/16 [tunable=false] # note that C=1/16 is taken from SPM12 and stabilizes the balloon model simulation. See also comment above.

# We now define a similar model as above for the simulation but instead of using an actual stimulus Blox we here add ExternalInput which represents a simple linear external input that is not specified any further.
# We simply say that our model gets some input with a proportional factor $C$. This is mostly only to make sure that our results are consistent with those produced by SPM
for i = 1:nr
region = LinearNeuralMass(;name=Symbol("r$(i)₊lm"))
push!(regions, region)
Expand All @@ -132,6 +140,7 @@ for i = 1:nr
add_edge!(g, region => measurement, weight=1.0)
end

# Here we define the prior expectation values of the effective connectivity matrix we wish to infer:
A_prior = 0.01*randn(nr, nr)
A_prior -= diagm(diag(A_prior)) # remove the diagonal
# Since we want to optimize these weights we turn them into symbolic parameters:
Expand All @@ -149,21 +158,21 @@ for (i, idx) in enumerate(CartesianIndices(A_prior))
add_edge!(g, regions[idx[2]] => regions[idx[1]], weight=A[i])
end
end
# we avoid simplification of the model in order to exclude some parameters from fitting
# Avoid simplification of the model in order to be able to exclude some parameters from fitting
@named fitmodel = system_from_graph(g, simplify=false)
# With the function `changetune`` we can provide a dictionary of parameters whose tunable flag should be changed, for instance set to false to exclude them from the optimizatoin procedure.
# For instance the the effective connections that are set to zero in the simulation:
untune = Dict(A[3] => false, A[7] => false)
fitmodel = changetune(fitmodel, untune) # 3 and 7 are not present in the simulation model
fitmodel = structural_simplify(fitmodel, split=false) # and now simplify the euqations
fitmodel = structural_simplify(fitmodel, split=false) # and now simplify the euqations; the `split` parameter is necessary for some ModelingToolkit peculiarities and will soon be removed. So don't lose time with it ;)

# ## Setup spectral DCM
max_iter = 128; # maximum number of iterations
## attribute initial conditions to states
sts, _ = get_dynamic_states(fitmodel);
# the following step is needed if the model's Jacobian would give degenerate eigenvalues if expanded around 0 (which is the default expansion)
perturbedfp = Dict(sts .=> abs.(0.001*rand(length(sts)))) # slight noise to avoid issues with Automatic Differentiation. TODO: find different solution, this is hacky.
# We can use the default prior function to use standardized prior values as given in SPM12.
# the following step is needed if the model's Jacobian would give degenerate eigenvalues when expanded around the fixed point 0 (which is the default expansion). We simply add small random values to avoid this degeneracy:
perturbedfp = Dict(sts .=> abs.(0.001*rand(length(sts)))) # slight noise to avoid issues with Automatic Differentiation.
# For convenience we can use the default prior function to use standardized prior values as given in SPM:
pmean, pcovariance, indices = defaultprior(fitmodel, nr)

priors = (μθ_pr = pmean,
Expand All @@ -175,16 +184,18 @@ hyperpriors = Dict(:Πλ_pr => 128.0*ones(1, 1), # prior metaparameter precisi
);
# To compute the cross spectral densities we need to provide the sampling interval of the time series, the frequency axis and the order of the multivariate autoregressive model:
csdsetup = (mar_order = p, freq = freq, dt = dt);

# earlier we used the function `get_idx_tagged_vars` to get the indices of tagged variables. Here we don't want to get the indices but rather the symbolic variable names themselves.
# and in particular we need to get the measurement variables in the same ordering as the model equations are defined.
_, s_bold = get_eqidx_tagged_vars(fitmodel, "measurement"); # get bold signal variables
# Prepare the DCM:
# Prepare the DCM. This function will setup the computation of the Dynamic Causal Model. The last parameter specifies that wer are using fMRI time series as opposed to LFPs.
(state, setup) = setup_sDCM(dfsol[:, String.(Symbol.(s_bold))], fitmodel, perturbedfp, csdsetup, priors, hyperpriors, indices, pmean, "fMRI");

## HACK: on machines with very small amounts of RAM, Julia can run out of stack space while compiling the code called in this loop
## this should be rewritten to abuse the compiler less, but for now, an easy solution is just to run it with more allocated stack space.
with_stack(f, n) = fetch(schedule(Task(f, n)));

# We are ready to run the optimization procedure! :)
# We are now ready to run the optimization procedure! :)
# That is we loop over run_sDCM_iteration! which will alter `state` after each optimization iteration. It essentially computes the Variational Laplace estimation of expectation and variance of the tunable parameters.
with_stack(5_000_000) do # 5MB of stack space
for iter in 1:max_iter
state.iter = iter
Expand All @@ -201,7 +212,8 @@ with_stack(5_000_000) do # 5MB of stack space
end

# # Plot Results
# Plot the free energy evolution over optimization iterations:
# Free energy is the objective function of the optimization scheme of spectral DCM. Note that in the machine learning literature this it is called Evidence Lower Bound (ELBO).
# Plot the free energy evolution over optimization iterations to see how the algorithm converges towards a (potentially local) optimum:
freeenergy(state)

# Plot the estimated posterior of the effective connectivity and compare that to the true parameter values.
Expand Down
27 changes: 14 additions & 13 deletions ext/MakieExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,30 +82,31 @@ end
argument_names(::Type{<: ECBarPlot}) = (:spDCMresults, :spDCMsetup, :groundtruth)

function Makie.plot!(p::ECBarPlot)
nr = p.spDCMsetup[].systemnums[1] # number of regions
diagidx = 1:(nr+1):nr^2
modelparam = p.spDCMsetup[].modelparam
idx = collect(1:nr^2)
deleteat!(idx, diagidx)
xlabels = string.(collect(keys(modelparam))[idx])
xlabels = string.(collect(keys(modelparam)))
idx = []
for l in xlabels
if l[1] == 'A'
push!(idx, parse(Int64, l[2:end]))
end
end
np = length(idx)

ax = current_axis()
ax.xticks = (1:(nr^2-nr), xlabels)
ax.xticks = (1:np, xlabels[1:np])
ax.xlabel = p.xlabel[]
ax.ylabel = p.ylabel[]
ax.title = p.title[]

gt = copy(vec(p.groundtruth[])) # get ground truth values
deleteat!(gt, diagidx)
state = p.spDCMresults[]
μA = state.μθ_po[1:nr^2] # get estimated means of effective connectivity
deleteat!(μA, diagidx)
var_A = diag(state.Σθ_po[1:nr^2, 1:nr^2]) # get variance of effective connectivity
deleteat!(var_A, diagidx)
μA = state.μθ_po[1:length(idx)] # get estimated means of effective connectivity
var_A = diag(state.Σθ_po[1:np, 1:np]) # get variance of effective connectivity

x = 1:(nr^2-nr)
x = 1:np
barplot!(p, x, μA)
errorbars!(p, x, μA, sqrt.(var_A), color = :red)
scatter!(p, x, gt)
scatter!(p, x, gt[idx])
return p
end

Expand Down
Loading