-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
47 lines (40 loc) · 1.25 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# for parallel computing
import ray
# the agent to train the model
from es import Master
# number of cpus to be used
num_cpus = 3
# start the redis server
ray.init(num_cpus=num_cpus)
# config to run
config = {'sigma':0.1,'pop_size':50,'l2_coeff':0.00005,'num_workers':num_cpus,'max_steps':200,'learning_rate':0.001,
'envName':'CartPole-v0','early_stop_reward':200}
# run the training procedure
def train():
# define the agent
agent = Master(config)
# run the agent for a total of num_iters iterations
agent.train(num_iters=1000,print_step=10)
# run the agent on the enviroment and render
avg_reward = agent.play(episodes=10)
# filename to save the weights in
filename = config['envName']+'_weights_'+str(avg_reward)
# save the weights
agent.save(filename)
def trainR():
avg_reward = 0.0
agent = None
while avg_reward != config['early_stop_reward']:
del agent
# define the agent
agent = Master(config)
# run the agent for a total of num_iters iterations
agent.train(num_iters=1000,print_step=10)
# run the agent on the enviroment and render
avg_reward = agent.play(episodes=10)
# filename to save the weights in
filename = config['envName']+'_weights_'+str(avg_reward)
# save the weights
agent.save(filename)
if __name__ == '__main__':
trainR()