diff --git a/gym/decision_transformer/models/decision_transformer.py b/gym/decision_transformer/models/decision_transformer.py index f76f8f3b..2b63722e 100644 --- a/gym/decision_transformer/models/decision_transformer.py +++ b/gym/decision_transformer/models/decision_transformer.py @@ -134,7 +134,7 @@ def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwarg else: attention_mask = None - _, action_preds, return_preds = self.forward( + _, action_preds, _ = self.forward( states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs) return action_preds[0,-1]