-
Notifications
You must be signed in to change notification settings - Fork 4
/
traj_buffer.py
147 lines (129 loc) · 5.67 KB
/
traj_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import random
class TrajBuffer(object):
def __init__(self, env, n_steps, size=50000, raw_pixels=False):
"""
A buffer for observations, actions, rewards, mu's, states, masks and dones values
:param env: (Gym environment) The environment to learn from
:param n_steps: (int) The number of steps to run for each environment
:param size: (int) The buffer size in number of steps
"""
self.n_env = env.num_envs
self.n_steps = n_steps
self.n_batch = self.n_env * self.n_steps
# Each loc contains n_env * n_steps frames, thus total buffer is n_env * size frames
self.size = size // self.n_steps
self.raw_pixels = raw_pixels
if self.raw_pixels:
self.height, self.width, self.n_channels = env.observation_space.shape
self.obs_dtype = np.uint8
else:
self.obs_dim = env.observation_space.shape[-1]
self.obs_dtype = np.float32
# Memory
self.enc_obs = None
self.actions = None
self.rewards = None
self.mus = None
self.dones = None
# Size indexes
self.next_idx = 0
self.num_in_buffer = 0
def has_atleast(self, frames):
"""
Check to see if the buffer has at least the asked number of frames
:param frames: (int) The number of frames checked
:return: (bool) number of frames in buffer >= number asked
"""
# Frames per env, so total (n_env * frames) Frames needed
# Each buffer loc has n_env * n_steps frames
return self.num_in_buffer >= (frames // self.n_steps)
def can_sample(self):
"""
Check if the buffer has at least one frame
:return: (bool) if the buffer has at least one frame
"""
return self.num_in_buffer > 0
def decode(self, enc_obs):
"""
Get the stacked frames of an observation
:param enc_obs: ([float]) the encoded observation
:return: ([float]) the decoded observation
"""
# enc_obs has shape [n_envs, n_steps + 1, nh, nw, nc]
# dones has shape [n_envs, n_steps, nh, nw, nc]
# returns stacked obs of shape [n_env, (n_steps + 1), nh, nw, nc]
n_env, n_steps = self.n_env, self.n_steps
if self.raw_pixels:
obs_dim = [self.height, self.width, self.n_channels]
else:
obs_dim = [self.obs_dim]
obs = np.zeros([1, n_steps + 1, n_env] + obs_dim, dtype=self.obs_dtype)
# [n_steps + nstack, n_env, nh, nw, nc]
x_var = np.reshape(enc_obs, [n_env, n_steps + 1] + obs_dim).swapaxes(1, 0)
obs[-1, :] = x_var
if self.raw_pixels:
obs = obs.transpose((2, 1, 3, 4, 0, 5))
else:
obs = obs.transpose((2, 1, 3, 0))
return np.reshape(obs, [n_env, (n_steps + 1)] + obs_dim[:-1] + [obs_dim[-1]])
def put(self, enc_obs, actions, rewards, mus, dones):
"""
Adds a frame to the buffer
:param enc_obs: ([float]) the encoded observation
:param actions: ([float]) the actions
:param rewards: ([float]) the rewards
:param mus: ([float]) the policy probability for the actions
:param dones: ([bool])
:param masks: ([bool])
"""
# enc_obs [n_env, (n_steps + n_stack), nh, nw, nc]
# actions, rewards, dones [n_env, n_steps]
# mus [n_env, n_steps, n_act]
if self.enc_obs is None:
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
self.actions = np.empty([self.size] + list(actions.shape), dtype=np.int32)
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
self.enc_obs[self.next_idx] = enc_obs
self.actions[self.next_idx] = actions
self.rewards[self.next_idx] = rewards
self.mus[self.next_idx] = mus
self.dones[self.next_idx] = dones
self.next_idx = (self.next_idx + 1) % self.size
self.num_in_buffer = min(self.size, self.num_in_buffer + 1)
def take(self, arr, idx, envx):
"""
Reads a frame from a list and index for the asked environment ids
:param arr: (np.ndarray) the array that is read
:param idx: ([int]) the idx that are read
:param envx: ([int]) the idx for the environments
:return: ([float]) the askes frames from the list
"""
n_env = self.n_env
out = np.empty(arr.shape[1:], dtype=arr.dtype)
for i in range(n_env):
out[:, i] = arr[idx[i], :, envx[i]]
return out
def get(self):
"""
randomly read a frame from the buffer
:return: ([float], [float], [float], [float], [bool], [float])
observations, actions, rewards, mus, dones, maskes
"""
# returns
# obs [n_env, (n_steps + 1), nh, nw, n_stack*nc]
# actions, rewards, dones [n_env, n_steps]
# mus [n_env, n_steps, n_act]
n_env = self.n_env
assert self.can_sample()
# Sample exactly one id per env. If you sample across envs, then higher correlation in samples from same env.
idx = np.random.randint(0, self.num_in_buffer, n_env)
envx = np.arange(n_env)
dones = self.take(self.dones, idx, envx)
obs = self.take(self.enc_obs, idx, envx)
actions = self.take(self.actions, idx, envx)
rewards = self.take(self.rewards, idx, envx)
mus = self.take(self.mus, idx, envx)
return obs, actions, rewards, mus, dones