-
Notifications
You must be signed in to change notification settings - Fork 0
/
decision_transformer.py
232 lines (191 loc) · 10.2 KB
/
decision_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# Decision Transformer module adopted from official code of Decision Transformer
# https://github.com/kzl/decision-transformer
# @article{chen2021decisiontransformer,
# title={Decision Transformer: Reinforcement Learning via Sequence Modeling},
# author={Lili Chen and Kevin Lu and Aravind Rajeswaran and Kimin Lee and Aditya Grover and Michael Laskin and Pieter Abbeel and Aravind Srinivas and Igor Mordatch},
# journal={arXiv preprint arXiv:2106.01345},
# year={2021}
# }
# MIT License
#
# Copyright (c) 2021 Decision Transformer (Decision Transformer: Reinforcement Learning via Sequence Modeling) Authors (https://arxiv.org/abs/2106.01345)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np
import torch
import torch.nn as nn
import transformers
from modules.model import TrajectoryModel
from modules.trajectory_gpt2 import GPT2Model
class DecisionTransformer(TrajectoryModel):
"""
This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
"""
def __init__(
self,
state_dim,
act_dim,
hidden_size,
max_length=None,
max_ep_len=4096,
action_tanh=True,
state_encoder=None,
in_shape=None,
**kwargs
):
super().__init__(state_dim, act_dim, max_length=max_length)
self.hidden_size = hidden_size
config = transformers.GPT2Config(
vocab_size=1, # doesn't matter -- we don't use the vocab
n_embd=hidden_size,
**kwargs
)
# note: the only difference between this GPT2Model and the default Huggingface version
# is that the positional embeddings are removed (since we'll add those ourselves)
self.transformer = GPT2Model(config)
self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
self.embed_return = torch.nn.Linear(1, hidden_size)
if state_encoder is None:
self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
else:
if state_encoder == 'mlp':
self.embed_state = MLP_encoder(state_dim, hidden_size, hidden_size)
elif state_encoder == 'cnn':
self.embed_state = CNN_encoder(in_shape, hidden_size, dropout=0.1)
else:
raise NotImplementedError("encoder not implemented")
self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
self.embed_ln = nn.LayerNorm(hidden_size)
# note: we don't predict states or returns for the paper
# self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
self.predict_action = nn.Sequential(
*([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
)
# self.predict_return = torch.nn.Linear(hidden_size, 1)
def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):
batch_size, seq_length = states.shape[0], states.shape[1]
if attention_mask is None:
# attention mask for GPT: 1 if can be attended to, 0 if not
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
# embed each modality with a different head
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns_to_go)
time_embeddings = self.embed_timestep(timesteps)
# time embeddings are treated similar to positional embeddings
state_embeddings = state_embeddings + time_embeddings
action_embeddings = action_embeddings + time_embeddings
returns_embeddings = returns_embeddings + time_embeddings
# this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
# which works nice in an autoregressive sense since states predict actions
stacked_inputs = torch.stack(
(returns_embeddings, state_embeddings, action_embeddings), dim=1
).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
stacked_inputs = self.embed_ln(stacked_inputs)
# to make the attention mask fit the stacked inputs, have to stack it as well
stacked_attention_mask = torch.stack(
(attention_mask, attention_mask, attention_mask), dim=1
).permute(0, 2, 1).reshape(batch_size, 3*seq_length)
# we feed in the input embeddings (not word indices as in NLP) to the model
transformer_outputs = self.transformer(
inputs_embeds=stacked_inputs,
attention_mask=stacked_attention_mask,
)
x = transformer_outputs['last_hidden_state']
# reshape x so that the second dimension corresponds to the original
# returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
# get predictions
# return_preds = self.predict_return(x[:,2]) # predict next return given state and action
# state_preds = self.predict_state(x[:,2]) # predict next state given state and action
action_preds = self.predict_action(x[:,1]) # predict next action given state
#return state_preds, action_preds, return_preds
return None, action_preds, None
def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
# we don't care about the past rewards in this model
states = states.reshape(1, -1, self.state_dim)
actions = actions.reshape(1, -1, self.act_dim)
returns_to_go = returns_to_go.reshape(1, -1, 1)
timesteps = timesteps.reshape(1, -1)
if self.max_length is not None:
states = states[:,-self.max_length:]
actions = actions[:,-self.max_length:]
returns_to_go = returns_to_go[:,-self.max_length:]
timesteps = timesteps[:,-self.max_length:]
# pad all tokens to sequence length
attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
states = torch.cat(
[torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
dim=1).to(dtype=torch.float32)
actions = torch.cat(
[torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
device=actions.device), actions],
dim=1).to(dtype=torch.float32)
returns_to_go = torch.cat(
[torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
dim=1).to(dtype=torch.float32)
timesteps = torch.cat(
[torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
dim=1
).to(dtype=torch.long)
else:
attention_mask = None
_, action_preds, return_preds = self.forward(
states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)
return action_preds[0,-1]
class MLP_encoder(torch.nn.Module):
def __init__(self, in_dim, out_dim, hidden_size=256, dropout=0.1):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_size = hidden_size
self.predict_state = nn.Sequential(
nn.Linear(in_dim, hidden_size),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_size, out_dim),
)
def forward(self, x):
return self.predict_state(x)
class CNN_encoder(torch.nn.Module):
def __init__(self, in_shape, out_dim, dropout=0.1):
super().__init__()
"""
take input as 1D vector and reshape to in_shape image shape
e.g. in_shape = (4, 128, 64) for 4-stacked grayscale images
output is a 1D vector of length out_dim
"""
self.in_shape = in_shape
self.out_dim = out_dim
channel, width, height = in_shape
self.cnn = nn.Sequential(
nn.Conv2d(channel, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Dropout(dropout),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Dropout(dropout),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
#nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
)
dummy_cnn_output = self.cnn(torch.zeros(1, *in_shape))
print("CNN output shape: ", dummy_cnn_output.shape, "linearly transformed to", out_dim)
self.linear = nn.Linear(dummy_cnn_output.shape[1], out_dim)
def forward(self, x):
batch_size = x.shape[0]
context_length = x.shape[1]
x = x.reshape(-1, *self.in_shape)
x = self.cnn(x)
x = self.linear(x)
return x.view(batch_size, context_length, -1)