-
Notifications
You must be signed in to change notification settings - Fork 0
/
Train.py
36 lines (29 loc) · 1.45 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from DQN import *
tf.reset_default_graph()
# Where we save our checkpoints and graphs
experiment_dir = os.path.abspath("./experiments/{}".format(env.spec.id))
# Create a glboal step variable
global_step = tf.Variable(0, name='global_step', trainable=False)
# Create estimators
q_estimator = Estimator(scope="q", summaries_dir=experiment_dir)
target_estimator = Estimator(scope="target_q")
# State processor
state_processor = StateProcessor()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
for t, stats in deep_q_learning(sess,
env,
q_estimator=q_estimator,
target_estimator=target_estimator,
state_processor=state_processor,
experiment_dir=experiment_dir,
num_episodes=10000,
replay_memory_size=500000,
replay_memory_init_size=50000,
update_target_estimator_every=10000,
epsilon_start=1.0,
epsilon_end=0.1,
epsilon_decay_steps=50000,
discount_factor=0.99,
batch_size=32):
print("\nEpisode Reward: {}".format(stats.episode_rewards[-1]))