diff --git a/ext/MakieExtension.jl b/ext/MakieExtension.jl index 1e898a62..f8fa3cd8 100644 --- a/ext/MakieExtension.jl +++ b/ext/MakieExtension.jl @@ -241,8 +241,6 @@ function Makie.plot!(p::FRPlot) ax.xlabel = p.xlabel[] ax.ylabel = p.ylabel[] ax.title = p.title[] - - hideydecorations!(ax) fr = firing_rate(blox, sol; win_size = p.win_size[], overlap = p.overlap[], transient = p.transient[], threshold = p.threshold[]) diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index c43bd4c0..093808fa 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -67,8 +67,6 @@ abstract type StimulusBlox <: AbstractBlox end abstract type ObserverBlox end # not AbstractBlox since it should not show up in the GUI abstract type AbstractPINGNeuron <: AbstractNeuronBlox end -const Neuron = AbstractNeuronBlox - # we define these in neural_mass.jl # abstract type HarmonicOscillatorBlox <: NeuralMassBlox end # abstract type JansenRitCBlox <: NeuralMassBlox end @@ -124,6 +122,9 @@ include("GraphDynamicsInterop/GraphDynamicsInterop.jl") include("Neurographs.jl") include("adjacency.jl") +const Neuron = AbstractNeuronBlox +const SpikeSource = AbstractSpikeSource + function simulate(sys::ODESystem, u0, timespan, p, solver = AutoVern7(Rodas4()); kwargs...) prob = ODEProblem(sys, u0, timespan, p) sol = solve(prob, solver; kwargs...) #pass keyword arguments to solver @@ -233,7 +234,8 @@ export Matrisome, Striosome, Striatum, GPi, GPe, Thalamus, STN, TAN, SNc export HebbianPlasticity, HebbianModulationPlasticity export Agent, ClassificationEnvironment, GreedyPolicy, reset! export LearningBlox -export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ExternalInput, PoissonSpikeTrain, DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration +export CosineSource, CosineBlox, NoisyCosineBlox, PhaseBlox, ImageStimulus, ConstantInput, ExternalInput, SpikeSource, PoissonSpikeTrain, generate_spike_times +export DBS, ProtocolDBS, detect_transitions, compute_transition_times, compute_transition_values, get_protocol_duration export BandPassFilterBlox export OUBlox, OUCouplingBlox export phase_inter, phase_sin_blox, phase_cos_blox diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 81c7634f..5d38a806 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -127,10 +127,10 @@ end generate_discrete_callbacks(blox, ::Connector; t_block = missing) = [] -function generate_discrete_callbacks(blox::PoissonSpikeTrain, bc::Connector; t_block = missing) +function generate_discrete_callbacks(blox::AbstractSpikeSource, bc::Connector; t_block = missing) sa = spike_affects(bc) name_blox = namespaced_nameof(blox) - + if haskey(sa, name_blox) eqs = sa[name_blox] @@ -143,7 +143,7 @@ function generate_discrete_callbacks(blox::PoissonSpikeTrain, bc::Connector; t_b t_spikes = generate_spike_times(blox) t_spikes => to_vector(eq) end - + return cb end end diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index d907226b..a00fcd4c 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -92,9 +92,9 @@ end get_system(blox) = blox.system get_system(sys::AbstractODESystem) = sys -get_system(stim::PoissonSpikeTrain) = System(Equation[], t, [], []; name=stim.name) +get_system(stim::AbstractSpikeSource) = System(Equation[], t, [], []; name=stim.name) -function system(blox::AbstractBlox; simplify=true) +function system(blox::CompositeBlox; simplify=true) sys = get_system(blox) eqs = get_input_equations(blox; namespaced=false) @@ -103,6 +103,15 @@ function system(blox::AbstractBlox; simplify=true) return simplify ? structural_simplify(csys) : csys end + +function system(blox::AbstractBlox; simplify=true, kwargs...) + sys = get_system(blox) + eqs = get_input_equations(blox; namespaced=true) + csys = compose(System(eqs, t, [], []; name=namespaced_nameof(blox), kwargs...), sys) + + return simplify ? structural_simplify(csys) : csys +end + function get_namespaced_sys(blox) sys = get_system(blox) @@ -457,7 +466,7 @@ replace_refractory!(V, blox, sol::SciMLBase.AbstractSolution) = V function find_spikes(x::AbstractVector{T}; threshold=zero(T)) where {T} spike_idxs = argmaxima(x) peakheights!(spike_idxs, x[spike_idxs]; minheight = threshold) - + spikes = sparsevec(spike_idxs, ones(length(spike_idxs)), length(x)) return spikes @@ -486,7 +495,6 @@ function detect_spikes( end V = voltage_timeseries(blox, sol; ts) - spikes = find_spikes(V; threshold = thrs_value - tolerance) return spikes diff --git a/src/blox/neuron_models.jl b/src/blox/neuron_models.jl index f89db1f4..3785574f 100644 --- a/src/blox/neuron_models.jl +++ b/src/blox/neuron_models.jl @@ -795,7 +795,8 @@ struct QIFNeuron <: AbstractNeuronBlox D(G)~(-1/τ₂)*G + z, D(z)~(-1/τ₁)*z ] - ev = [V~θ] => [V~Vᵣₑₛ,z~G_syn] + + ev = [V~θ] => [V~Vᵣₑₛ,z~G_syn] sys = ODESystem(eqs, t, sts, p, continuous_events=[ev]; name=name) new(p, sys, namespace) diff --git a/src/blox/sources.jl b/src/blox/sources.jl index ed169d25..24496ee4 100644 --- a/src/blox/sources.jl +++ b/src/blox/sources.jl @@ -1,3 +1,19 @@ +abstract type AbstractSpikeSource <: StimulusBlox end + +struct ConstantInput <: StimulusBlox + namespace + system + + function ConstantInput(; name, namespace=nothing, I=1) + @variables u(t) [output=true, description="ext_input"] + @parameters I=I + eqs = [u ~ I] + sys = System(eqs, t, [u], [I]; name=name) + + new(namespace, sys) + end +end + # Simple input blox mutable struct ExternalInput <: StimulusBlox namespace @@ -156,7 +172,7 @@ end increment_pixel!(stim::ImageStimulus) = stim.current_pixel = mod(stim.current_pixel, stim.N_pixels) + 1 -struct PoissonSpikeTrain{N} <: StimulusBlox +struct PoissonSpikeTrain{N} <: AbstractSpikeSource name namespace N_trains