Skip to content

Commit

Permalink
More UI fixes & improvements for IAP course (#513)
Browse files Browse the repository at this point in the history
* don't hide y axis in firing rate plot

* add `SPikeSource` abstract type alias and export relevant functions for custom sources

* generalize `generate_discrete_callbacks` for any spike source

* dispatch `system` separately for composite and non-composite bloxs to account for namespacing

* switch QIF spiking to discrete callback

* add `ContantInput` source

* revert back to continuous event for QIF spiking
  • Loading branch information
harisorgn authored Dec 30, 2024
1 parent 77c0bca commit 48c0a68
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 14 deletions.
2 changes: 0 additions & 2 deletions ext/MakieExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ function Makie.plot!(p::FRPlot)
ax.xlabel = p.xlabel[]
ax.ylabel = p.ylabel[]
ax.title = p.title[]

hideydecorations!(ax)

fr = firing_rate(blox, sol; win_size = p.win_size[], overlap = p.overlap[], transient = p.transient[], threshold = p.threshold[])

Expand Down
8 changes: 5 additions & 3 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ abstract type StimulusBlox <: AbstractBlox end
abstract type ObserverBlox end # not AbstractBlox since it should not show up in the GUI
abstract type AbstractPINGNeuron <: AbstractNeuronBlox end

const Neuron = AbstractNeuronBlox

# we define these in neural_mass.jl
# abstract type HarmonicOscillatorBlox <: NeuralMassBlox end
# abstract type JansenRitCBlox <: NeuralMassBlox end
Expand Down Expand Up @@ -124,6 +122,9 @@ include("GraphDynamicsInterop/GraphDynamicsInterop.jl")
include("Neurographs.jl")
include("adjacency.jl")

const Neuron = AbstractNeuronBlox
const SpikeSource = AbstractSpikeSource

function simulate(sys::ODESystem, u0, timespan, p, solver = AutoVern7(Rodas4()); kwargs...)
prob = ODEProblem(sys, u0, timespan, p)
sol = solve(prob, solver; kwargs...) #pass keyword arguments to solver
Expand Down Expand Up @@ -233,7 +234,8 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc
export HebbianPlasticity, HebbianModulationPlasticity
export Agent, ClassificationEnvironment, GreedyPolicy, reset!
export LearningBlox
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration
export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ConstantInput, ExternalInput, SpikeSource, PoissonSpikeTrain, generate_spike_times
export DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration
export BandPassFilterBlox
export OUBlox, OUCouplingBlox
export phase_inter, phase_sin_blox, phase_cos_blox
Expand Down
6 changes: 3 additions & 3 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ end

generate_discrete_callbacks(blox, ::Connector; t_block = missing) = []

function generate_discrete_callbacks(blox::PoissonSpikeTrain, bc::Connector; t_block = missing)
function generate_discrete_callbacks(blox::AbstractSpikeSource, bc::Connector; t_block = missing)
sa = spike_affects(bc)
name_blox = namespaced_nameof(blox)

if haskey(sa, name_blox)
eqs = sa[name_blox]

Expand All @@ -143,7 +143,7 @@ function generate_discrete_callbacks(blox::PoissonSpikeTrain, bc::Connector; t_b
t_spikes = generate_spike_times(blox)
t_spikes => to_vector(eq)
end

return cb
end
end
Expand Down
16 changes: 12 additions & 4 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ end

get_system(blox) = blox.system
get_system(sys::AbstractODESystem) = sys
get_system(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name)
get_system(stim::AbstractSpikeSource) = System(Equation[], t, [], []; name=stim.name)

function system(blox::AbstractBlox; simplify=true)
function system(blox::CompositeBlox; simplify=true)
sys = get_system(blox)
eqs = get_input_equations(blox; namespaced=false)

Expand All @@ -103,6 +103,15 @@ function system(blox::AbstractBlox; simplify=true)
return simplify ? structural_simplify(csys) : csys
end


function system(blox::AbstractBlox; simplify=true, kwargs...)
sys = get_system(blox)
eqs = get_input_equations(blox; namespaced=true)
csys = compose(System(eqs, t, [], []; name=namespaced_nameof(blox), kwargs...), sys)

return simplify ? structural_simplify(csys) : csys
end

function get_namespaced_sys(blox)
sys = get_system(blox)

Expand Down Expand Up @@ -457,7 +466,7 @@ replace_refractory!(V, blox, sol::SciMLBase.AbstractSolution) = V
function find_spikes(x::AbstractVector{T}; threshold=zero(T)) where {T}
spike_idxs = argmaxima(x)
peakheights!(spike_idxs, x[spike_idxs]; minheight = threshold)

spikes = sparsevec(spike_idxs, ones(length(spike_idxs)), length(x))

return spikes
Expand Down Expand Up @@ -486,7 +495,6 @@ function detect_spikes(
end

V = voltage_timeseries(blox, sol; ts)

spikes = find_spikes(V; threshold = thrs_value - tolerance)

return spikes
Expand Down
3 changes: 2 additions & 1 deletion src/blox/neuron_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,8 @@ struct QIFNeuron <: AbstractNeuronBlox
D(G)~(-1/τ₂)*G + z,
D(z)~(-1/τ₁)*z
]
ev = [V~θ] => [V~Vᵣₑₛ,z~G_syn]

ev = [V~θ] => [V~Vᵣₑₛ,z~G_syn]
sys = ODESystem(eqs, t, sts, p, continuous_events=[ev]; name=name)

new(p, sys, namespace)
Expand Down
18 changes: 17 additions & 1 deletion src/blox/sources.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
abstract type AbstractSpikeSource <: StimulusBlox end

struct ConstantInput <: StimulusBlox
namespace
system

function ConstantInput(; name, namespace=nothing, I=1)
@variables u(t) [output=true, description="ext_input"]
@parameters I=I
eqs = [u ~ I]
sys = System(eqs, t, [u], [I]; name=name)

new(namespace, sys)
end
end

# Simple input blox
mutable struct ExternalInput <: StimulusBlox
namespace
Expand Down Expand Up @@ -156,7 +172,7 @@ end

increment_pixel!(stim::ImageStimulus) = stim.current_pixel = mod(stim.current_pixel, stim.N_pixels) + 1

struct PoissonSpikeTrain{N} <: StimulusBlox
struct PoissonSpikeTrain{N} <: AbstractSpikeSource
name
namespace
N_trains
Expand Down

0 comments on commit 48c0a68

Please sign in to comment.