forked from medipixel/rl_algorithms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_reacher_v2.py
130 lines (112 loc) · 3.54 KB
/
run_reacher_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
"""Train or test algorithms on Reacher-v2 of Mujoco.
- Author: Kyunghwan Kim
- Contact: [email protected]
"""
import argparse
import datetime
import gym
from rl_algorithms import build_agent
import rl_algorithms.common.env.utils as env_utils
import rl_algorithms.common.helper_functions as common_utils
from rl_algorithms.utils import YamlConfig
def parse_args() -> argparse.Namespace:
# configurations
parser = argparse.ArgumentParser(description="Pytorch RL rl_algorithms")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility"
)
parser.add_argument("--algo", type=str, default="ddpg", help="choose an algorithm")
parser.add_argument(
"--cfg-path",
type=str,
default="./configs/reacher_v2/ddpg.yaml",
help="config path",
)
parser.add_argument(
"--integration-test",
dest="integration_test",
action="store_true",
help="for integration test",
)
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)"
)
parser.add_argument(
"--load-from",
type=str,
default=None,
help="load the saved model and optimizer at the beginning",
)
parser.add_argument(
"--off-render", dest="render", action="store_false", help="turn off rendering"
)
parser.add_argument(
"--render-after",
type=int,
default=0,
help="start rendering after the input number of episode",
)
parser.add_argument(
"--log", dest="log", action="store_true", help="turn on logging"
)
parser.add_argument(
"--save-period", type=int, default=200, help="save model period"
)
parser.add_argument(
"--episode-num", type=int, default=20000, help="total episode num"
)
parser.add_argument(
"--max-episode-steps", type=int, default=-1, help="max episode step"
)
parser.add_argument(
"--interim-test-num",
type=int,
default=10,
help="number of test during training",
)
return parser.parse_args()
def main():
"""Main."""
args = parse_args()
# env initialization
env_name = "Reacher-v2"
env = gym.make(env_name)
env, max_episode_steps = env_utils.set_env(env, args.max_episode_steps)
# set a random seed
common_utils.set_random_seed(args.seed, env)
# run
NOWTIMES = datetime.datetime.now()
curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")
cfg = YamlConfig(dict(agent=args.cfg_path)).get_config_dict()
# If running integration test, simplify experiment
if args.integration_test:
cfg = common_utils.set_cfg_for_intergration_test(cfg)
env_info = dict(
name=env.spec.id,
observation_space=env.observation_space,
action_space=env.action_space,
is_atari=False,
)
log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time, cfg_path=args.cfg_path)
build_args = dict(
env=env,
env_info=env_info,
log_cfg=log_cfg,
is_test=args.test,
load_from=args.load_from,
is_render=args.render,
render_after=args.render_after,
is_log=args.log,
save_period=args.save_period,
episode_num=args.episode_num,
max_episode_steps=max_episode_steps,
interim_test_num=args.interim_test_num,
)
agent = build_agent(cfg.agent, build_args)
if not args.test:
agent.train()
else:
agent.test()
if __name__ == "__main__":
main()