forked from google-research/robotics_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sequence_agent.py
171 lines (152 loc) · 6.52 KB
/
sequence_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence policy and agent that directly output actions via actor network.
These classes are not intended to change as they are generic enough for any
all-neural actor based agent+policy. All new features are intended to be
implemented in `actor_network` and `loss_fn`.
"""
from typing import Optional, Type
from absl import logging
import tensorflow as tf
from tf_agents.agents import data_converter
from tf_agents.agents import tf_agent
from tf_agents.networks import network
from tf_agents.policies import actor_policy
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step as ts
from tf_agents.typing import types
from tf_agents.utils import nest_utils
class SequencePolicy(actor_policy.ActorPolicy):
"""A policy that directly outputs actions via an actor network."""
def __init__(self, **kwargs):
self._actions = None
super().__init__(**kwargs)
def set_actions(self, actions):
self._actor_network.set_actions(actions)
def get_actor_loss(self):
return self._actor_network.get_actor_loss()
def get_aux_info(self):
return self._actor_network.get_aux_info()
def set_training(self, training):
self._training = training
def _action(self,
time_step: ts.TimeStep,
policy_state: types.NestedTensor,
seed: Optional[types.Seed] = None) -> policy_step.PolicyStep:
del seed
action, policy_state = self._apply_actor_network(
time_step.observation,
step_type=time_step.step_type,
policy_state=policy_state)
info = ()
return policy_step.PolicyStep(action, policy_state, info)
def _distribution(self, time_step, policy_state):
current_step = super()._distribution(time_step, policy_state)
return current_step
class SequenceAgent(tf_agent.TFAgent):
"""A sequence agent that directly outputs actions via an actor network."""
def __init__(self,
time_step_spec: ts.TimeStep,
action_spec: types.NestedTensorSpec,
actor_network: Type[network.Network],
actor_optimizer: tf.keras.optimizers.Optimizer,
policy_cls: Type[actor_policy.ActorPolicy] = SequencePolicy,
time_sequence_length: int = 6,
debug_summaries: bool = False,
**kwargs):
self._info_spec = ()
self._actor_network = actor_network( # pytype: disable=missing-parameter # dynamic-method-lookup
input_tensor_spec=time_step_spec.observation,
output_tensor_spec=action_spec,
policy_info_spec=self._info_spec,
train_step_counter=kwargs['train_step_counter'],
time_sequence_length=time_sequence_length)
self._actor_optimizer = actor_optimizer
# Train policy is only used for loss and never exported as saved_model.
self._train_policy = policy_cls(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=self._info_spec,
actor_network=self._actor_network,
training=True)
collect_policy = policy_cls(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=self._info_spec,
actor_network=self._actor_network,
training=False)
super(SequenceAgent, self).__init__(
time_step_spec,
action_spec,
collect_policy, # We use the collect_policy as the eval policy.
collect_policy,
train_sequence_length=time_sequence_length,
**kwargs)
self._data_context = data_converter.DataContext(
time_step_spec=time_step_spec,
action_spec=action_spec,
info_spec=collect_policy.info_spec,
use_half_transition=True)
self.as_transition = data_converter.AsHalfTransition(
self._data_context, squeeze_time_dim=False)
self._debug_summaries = debug_summaries
num_params = 0
for weight in self._actor_network.trainable_weights:
weight_params = 1
for dim in weight.shape:
weight_params *= dim
logging.info('%s has %s params.', weight.name, weight_params)
num_params += weight_params
logging.info('Actor network has %sM params.', round(num_params / 1000000.,
2))
def _train(self, experience: types.NestedTensor,
weights: types.Tensor) -> tf_agent.LossInfo:
self.train_step_counter.assign_add(1)
loss_info = self._loss(experience, weights, training=True)
self._apply_gradients(loss_info.loss)
return loss_info
def _apply_gradients(self, loss: types.Tensor):
variables = self._actor_network.trainable_weights
gradients = tf.gradients(loss, variables)
# Skip nan and inf gradients.
new_gradients = []
for g in gradients:
if g is not None:
new_g = tf.where(
tf.math.logical_or(tf.math.is_inf(g), tf.math.is_nan(g)),
tf.zeros_like(g), g)
new_gradients.append(new_g)
else:
new_gradients.append(g)
grads_and_vars = list(zip(new_gradients, variables))
self._actor_optimizer.apply_gradients(grads_and_vars)
def _loss(self, experience: types.NestedTensor, weights: types.Tensor,
training: bool) -> tf_agent.LossInfo:
transition = self.as_transition(experience)
time_steps, policy_steps, _ = transition
batch_size = nest_utils.get_outer_shape(time_steps, self._time_step_spec)[0]
policy = self._train_policy
policy.set_actions(policy_steps.action)
policy.set_training(training=training)
with tf.name_scope('actor_loss'):
policy_state = policy.get_initial_state(batch_size)
policy.action(time_steps, policy_state=policy_state)
valid_mask = tf.cast(~time_steps.is_last(), tf.float32)
loss = valid_mask * policy.get_actor_loss()
loss = tf.reduce_mean(loss)
policy.set_actions(None)
self._actor_network.add_summaries(time_steps.observation,
policy.get_aux_info(),
self._debug_summaries, training)
return tf_agent.LossInfo(loss=loss, extra=loss)