Skip to content

Commit

Permalink
improve RL tests
Browse files Browse the repository at this point in the history
  • Loading branch information
harisorgn committed Oct 12, 2023
1 parent 87f62bc commit 90ea5b7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion test/reinforcement_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ add_edge!(g, d[STR_L], d[AS])
add_edge!(g, d[STR_R], d[AS])

agent = Agent(g; name=:ag)
ps = parameters(agent.odesystem)
init_params = agent.problem.p
map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps))
idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps)
idxs_other_params = setdiff(eachindex(ps), idxs_weight)

env = ClassificationEnvironment(stim; name=:env, namespace=global_ns)
run_experiment!(agent, env; alg=QNDF(), reltol=1e-9,abstol=1e-9)

final_params = agent.problem.p
@test any(init_params .!= final_params)
# At least some weights need to be different.
@test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]])
# All non-weight parameters need to be the same.
@test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]])

0 comments on commit 90ea5b7

Please sign in to comment.