From 90ea5b7a9f1eeb7ed75d0bdece658ef6730956bb Mon Sep 17 00:00:00 2001 From: Haris Orgn Date: Thu, 12 Oct 2023 15:35:04 +0300 Subject: [PATCH] improve RL tests --- test/reinforcement_learning.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/reinforcement_learning.jl b/test/reinforcement_learning.jl index 840c8c33..5ca61dde 100644 --- a/test/reinforcement_learning.jl +++ b/test/reinforcement_learning.jl @@ -44,10 +44,17 @@ add_edge!(g, d[STR_L], d[AS]) add_edge!(g, d[STR_R], d[AS]) agent = Agent(g; name=:ag) +ps = parameters(agent.odesystem) init_params = agent.problem.p +map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) +idxs_weight = findall(x -> occursin("w_", String(Symbol(x))), ps) +idxs_other_params = setdiff(eachindex(ps), idxs_weight) env = ClassificationEnvironment(stim; name=:env, namespace=global_ns) run_experiment!(agent, env; alg=QNDF(), reltol=1e-9,abstol=1e-9) final_params = agent.problem.p -@test any(init_params .!= final_params) \ No newline at end of file +# At least some weights need to be different. +@test any(init_params[map_idxs[idxs_weight]] .!= final_params[map_idxs[idxs_weight]]) +# All non-weight parameters need to be the same. +@test all(init_params[map_idxs[idxs_other_params]] .== final_params[map_idxs[idxs_other_params]])