-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
72 lines (63 loc) · 2.39 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gym, helpers, argparse, os
from models import fully_connected
import tensorflow as tf
import environments
import matplotlib.pyplot as plt
import datetime
def main(args):
BATCH_SIZE = args.batch_size
MAX_EPSILON = args.max_epsilon
MIN_EPSILON = args.min_epsilon
decay = args.decay
gamma = args.gamma
env_name = args.env_name
if env_name in ['MountainCar-v0']:
env = gym.make(env_name)
num_states = env.env.observation_space.shape[0]
num_actions = env.env.action_space.n
else:
env = environments.make(env_name)
num_states = env.get_num_states()
num_actions = env.get_num_actions()
model = fully_connected.Model(num_states, num_actions, BATCH_SIZE, layer_sizes=[10,10])
mem = helpers.Memory(1000)
config = tf.ConfigProto(
device_count = {'GPU': 0}
)
saver = tf.train.Saver()
model_save_dir = os.path.join('.', 'saved_models', datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
os.makedirs(model_save_dir, exist_ok=True)
with tf.Session(config=config) as sess:
sess.run(model.var_init)
gr = helpers.GameRunner(sess, model, env, mem, MAX_EPSILON, MIN_EPSILON,
decay, gamma)
num_episodes = 300
cnt = 0
while cnt < num_episodes:
if cnt % 50 == 0:
print('Episode {} of {}'.format(cnt+1, num_episodes))
gr._render = True
gr.run()
save_path = saver.save(sess, os.path.join(model_save_dir,"model_{:05d}.ckpt".format(cnt)))
print("Model saved in path: %s" % save_path)
else:
gr._render = True
gr.run()
cnt += 1
# plt.plot(gr.reward_store)
# plt.show()
# plt.close("all")
# plt.plot(gr.max_x_store)
# plt.show()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default = 3, type=int)
parser.add_argument('--max_epsilon', default = 1, type=float)
parser.add_argument('--min_epsilon', default = 0.01, type=float)
parser.add_argument('--decay', default = 0.001, type = float)
parser.add_argument('--gamma', default = 0.99, type = float)
parser.add_argument('--env_name', default = 'MountainCar-v0', type = str)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)