-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
106 lines (77 loc) · 2.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import random
import gym
import numpy as np
import cv2
from collections import deque
from tqdm import tqdm
from agent import Agent
def preprocess(frame):
# Got some ideas from https://github.com/ageron/tiny-dqn
mspacman_color = np.array([210, 164, 74]).mean()
kernel = np.ones((3,3),np.uint8)
dilation = cv2.dilate(frame,kernel,iterations = 2)
img = frame[1:176:2,::2]
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img[img==mspacman_color] = 0
img = cv2.normalize(img, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
return img.reshape(88, 80, 1)
env = gym.make('MsPacman-v4') # Skip 4 Frames
state_size = (88, 80, 1)
action_size = 5
agent = Agent(state_size, action_size, 5) # 5-step return
episodes = 20000
batch_size = 32
total_time = 0
all_rewards = 0
done = False
# Initializing Buffer
while len(agent.buffer) < 10000:
state = preprocess(env.reset())
frame_stack = deque(maxlen=4) # Deque for getting mean of 4 frames instead of stacking
frame_stack.append(state)
for skip in range(90): # Skip first 3 seconds of the game
env.step(0)
for time in range(10000):
state = sum(frame_stack)/len(frame_stack)
action = agent.act(np.expand_dims(state.reshape(88, 80, 1), axis=0))
next_state, reward, done, _ = env.step(action)
next_state = preprocess(next_state)
frame_stack.append(next_state)
next_state = sum(frame_stack)/len(frame_stack)
td_error = agent.calculate_td_error(state, action, reward, next_state, done)
agent.store(state, action, reward, next_state, done, td_error)
state = next_state
if done:
break
print("buffer initialized")
for e in tqdm(range(0, episodes)):
total_reward = 0
game_score = 0
state = preprocess(env.reset())
frame_stack = deque(maxlen=4)
frame_stack.append(state)
for skip in range(90):
env.step(0)
for time in range(20000):
total_time += 1
if total_time % agent.update_rate == 0:
agent.update_target_model()
state = sum(frame_stack)/len(frame_stack)
action = agent.act(np.expand_dims(state.reshape(88, 80, 1), axis=0))
next_state, reward, done, _ = env.step(action)
next_state = preprocess(next_state)
frame_stack.append(next_state)
next_state = sum(frame_stack)/len(frame_stack)
td_error = agent.calculate_td_error(state, action, reward, next_state, done)
agent.store(state, action, reward, next_state, done, td_error)
state = next_state
total_reward += reward
if done:
all_rewards += total_reward
print("episode: {}/{}, reward: {}, avg reward: {}"
.format(e+1, episodes, total_reward, all_rewards/(e+1)))
break
agent.replay(batch_size)
if (e+1) % 500 == 0:
print("model saved on epoch", e)
# agent.save("")