Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Support callable decay rates in Adafactor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 421333767
  • Loading branch information
T2T Team authored and copybara-github committed Jan 12, 2022
1 parent 81c2b2e commit 2a33b15
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tensor2tensor/utils/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _resource_apply_dense(self, grad, handle):
grad = tf.to_float(grad)
grad_squared = tf.square(grad) + self._epsilon1
grad_squared_mean = tf.reduce_mean(grad_squared)
decay_rate = self._decay_rate
decay_rate = self._call_if_callable(self._decay_rate)
update_scale = self._call_if_callable(self._learning_rate)
update_scale = tf.convert_to_tensor(update_scale, name="update_scale")
update_scale = tf.cast(update_scale, grad_squared_mean.dtype.base_dtype)
Expand Down

0 comments on commit 2a33b15

Please sign in to comment.