-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_buffer.py
48 lines (42 loc) · 2.16 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
class ReplayBuffer:
def __init__(self, args):
self.N = args.N
self.obs_dim = args.obs_dim
self.state_dim = args.state_dim
self.episode_limit = args.episode_limit
self.batch_size = args.batch_size
self.episode_num = 0
self.buffer = None
self.reset_buffer()
# create a buffer (dictionary)
def reset_buffer(self):
self.buffer = {'obs_n': np.empty([self.batch_size, self.episode_limit, self.N, self.obs_dim]),
's': np.empty([self.batch_size, self.episode_limit, self.state_dim]),
'v_n': np.empty([self.batch_size, self.episode_limit + 1, self.N]),
'a_n': np.empty([self.batch_size, self.episode_limit, self.N]),
'a_logprob_n': np.empty([self.batch_size, self.episode_limit, self.N]),
'r_n': np.empty([self.batch_size, self.episode_limit, self.N]),
'done_n': np.empty([self.batch_size, self.episode_limit, self.N])
}
self.episode_num = 0
def store_transition(self, episode_step, obs_n, s, v_n, a_n, a_logprob_n, r_n, done_n):
self.buffer['obs_n'][self.episode_num][episode_step] = obs_n
self.buffer['s'][self.episode_num][episode_step] = s
self.buffer['v_n'][self.episode_num][episode_step] = v_n
self.buffer['a_n'][self.episode_num][episode_step] = a_n
self.buffer['a_logprob_n'][self.episode_num][episode_step] = a_logprob_n
self.buffer['r_n'][self.episode_num][episode_step] = r_n
self.buffer['done_n'][self.episode_num][episode_step] = done_n
def store_last_value(self, episode_step, v_n):
self.buffer['v_n'][self.episode_num][episode_step] = v_n
self.episode_num += 1
def get_training_data(self):
batch = {}
for key in self.buffer.keys():
if key == 'a_n':
batch[key] = torch.tensor(self.buffer[key], dtype=torch.long)
else:
batch[key] = torch.tensor(self.buffer[key], dtype=torch.float32)
return batch