Skip to content

Commit

Permalink
Merge pull request #312 from Neuroblox/combined_callbacks_for_spike_w…
Browse files Browse the repository at this point in the history
…indow

add calbacks for spike_window resetting
  • Loading branch information
anandpathak31 authored Dec 1, 2023
2 parents 24cc6e5 + 63c7945 commit fa4c57b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector; name, t_affect=mis
blox_syss = get_sys(g)

connection_eqs = get_equations_with_state_lhs(bc)
cbs = get_callbacks(bc, t_affect)
cbs = get_callbacks(g, bc, t_affect)

return compose(ODESystem(connection_eqs, t, [], params(bc); name, discrete_events = cbs), blox_syss)
end
Expand All @@ -207,7 +207,7 @@ function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}; na

connection_eqs = get_equations_with_state_lhs(bc)

cbs = get_callbacks(bc, t_affect)
cbs = get_callbacks(g, bc, t_affect)

return compose(ODESystem(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = cbs), blox_syss)
end
Expand Down
4 changes: 4 additions & 0 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ function compileparameterlist(;kwargs...)
return paramlist
end

function get_exci_neurons(g::MetaDiGraph)
mapreduce(x -> get_exci_neurons(x), vcat, get_blox(g))
end

function get_exci_neurons(b::AbstractComponent)
mapreduce(x -> get_exci_neurons(x), vcat, b.parts)
end
Expand Down
13 changes: 11 additions & 2 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@ get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs)

get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs)

function get_callbacks(bc, t_affect=missing)
function get_callbacks(g, bc, t_affect=missing)
if !ismissing(t_affect)
cbs_params = t_affect => get_equations_with_parameter_lhs(bc)

return vcat(cbs_params, bc.events)
neurons_exci = get_exci_neurons(g)
eqs = Equation[]

for neurons in neurons_exci
nn = get_namespaced_sys(neurons)
push!(eqs,nn.spikes_window ~ 0)

end
cb = (t_affect + eps(float(t_affect))) => eqs
return vcat(cbs_params, bc.events, cb)
else
return bc.events
end
Expand Down
18 changes: 9 additions & 9 deletions src/blox/neuron_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ struct HHNeuronExciBlox <: AbstractExciNeuronBlox
D(spikes_window) ~ spk_const*G_asymp(V,G_syn)
]

spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]
# spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]

sys = ODESystem(
eqs, t, sts, ps;
name = Symbol(name),discrete_events = spike_reset_cb
#name = Symbol(name)
#name = Symbol(name),discrete_events = spike_reset_cb
name = Symbol(name)
)

new(sys, spikes, namespace)
Expand Down Expand Up @@ -306,17 +306,17 @@ struct HHNeuronInhibBlox <: AbstractInhNeuronBlox
D(m)~ϕ*(αₘ(V)*(1-m)-βₘ(V)*m),
D(h)~ϕ*(αₕ(V)*(1-h)-βₕ(V)*h),
D(G)~(-1/τ₂)*G + z,
D(z)~(-1/τ₁)*z + G_asymp(V,G_syn),
D(spikes_cumulative) ~ spk_const*G_asymp(V,G_syn),
D(spikes_window) ~ spk_const*G_asymp(V,G_syn)
D(z)~(-1/τ₁)*z + G_asymp(V,G_syn)
#D(spikes_cumulative) ~ spk_const*G_asymp(V,G_syn),
#D(spikes_window) ~ spk_const*G_asymp(V,G_syn)
]

spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]
# spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]

sys = ODESystem(
eqs, t, sts, ps;
name = Symbol(name), discrete_events = spike_reset_cb
#name = Symbol(name)
# name = Symbol(name), discrete_events = spike_reset_cb
name = Symbol(name)
)

new(sys, namespace)
Expand Down
40 changes: 40 additions & 0 deletions src/blox/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,43 @@ function run_experiment!(agent::Agent, env::ClassificationEnvironment; kwargs...

agent.problem = prob
end

function run_experiment_open_loop!(agent::Agent, env::ClassificationEnvironment; kwargs...)
N_trials = env.N_trials
t_trial = env.t_trial
tspan = (0, t_trial)

sys = get_sys(agent)
prob = agent.problem
prob = remake(prob; tspan)

action_selection = agent.action_selection
learning_rules = agent.learning_rules

defs = ModelingToolkit.get_defaults(sys)
weights = Dict{Num, Float64}()
for w in keys(learning_rules)
weights[w] = defs[w]
end

for _ in Base.OneTo(N_trials)
if haskey(kwargs, :alg)
sol = solve(prob, kwargs[:alg]; kwargs...)
else
sol = solve(prob; alg_hints = [:stiff], kwargs...)
end
feedback = 1

for (w, rule) in learning_rules
w_val = weights[w]
Δw = weight_gradient(rule, sol, w_val, feedback)
weights[w] += Δw
end

increment_trial!(env)
stim_params = get_trial_stimulus(env)
prob = remake(prob; p = weights, u0 = stim_params)
end

agent.problem = prob
end

0 comments on commit fa4c57b

Please sign in to comment.