Skip to content

Commit

Permalink
add tdl-small policy in evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Oct 24, 2023
1 parent 2b21b1a commit 9eb4cf0
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NTupleNetworkBasePolicy,
NTupleNetworkQLearningPolicy,
NTupleNetworkTDPolicy,
NTupleNetworkTDPolicySmall,
)


Expand All @@ -20,7 +21,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--algo",
help="RL Algorithm",
choices=["ql", "tdl"],
choices=["ql", "tdl", "tdl-small"],
)
parser.add_argument(
"--env",
Expand Down Expand Up @@ -59,6 +60,7 @@ def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy:
algo_policy_map = {
"ql": NTupleNetworkQLearningPolicy,
"tdl": NTupleNetworkTDPolicy,
"tdl-small": NTupleNetworkTDPolicySmall,
}
policy = algo_policy_map[algo]
return policy.load(trained_agent)
Expand Down

0 comments on commit 9eb4cf0

Please sign in to comment.