From 9eb4cf07cb8114f852b366b1736f228ad1f07cfb Mon Sep 17 00:00:00 2001 From: Quentin18 Date: Tue, 24 Oct 2023 18:59:26 +0200 Subject: [PATCH] add tdl-small policy in evaluate --- scripts/evaluate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 68ffbc9..4d323cb 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -9,6 +9,7 @@ NTupleNetworkBasePolicy, NTupleNetworkQLearningPolicy, NTupleNetworkTDPolicy, + NTupleNetworkTDPolicySmall, ) @@ -20,7 +21,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--algo", help="RL Algorithm", - choices=["ql", "tdl"], + choices=["ql", "tdl", "tdl-small"], ) parser.add_argument( "--env", @@ -59,6 +60,7 @@ def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy: algo_policy_map = { "ql": NTupleNetworkQLearningPolicy, "tdl": NTupleNetworkTDPolicy, + "tdl-small": NTupleNetworkTDPolicySmall, } policy = algo_policy_map[algo] return policy.load(trained_agent)