Skip to content

Commit

Permalink
fix: fix tf2 test for tf 2.11 release (#626)
Browse files Browse the repository at this point in the history
* using v2 optimizer from keras

* remove func

* format change

* add version check
  • Loading branch information
yl-to authored Nov 11, 2022
1 parent 2caacc2 commit 167fa3a
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion tests/tensorflow2/test_grad_tape_tf_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Third Party
import tensorflow as tf
import keras
from packaging import version

# First Party
import smdebug.tensorflow as smd
Expand Down Expand Up @@ -41,7 +43,11 @@ def train_step(images, labels):
dataset = dataset.shuffle(1000).batch(64)
model = create_model()
hook = create_hook(out_dir)
opt = tf.keras.optimizers.Adam()
if version.parse(tf.__version__) >= version.parse("2.11.0") or "rc" in tf.__version__:
opt = keras.optimizers.optimizer_v2.adam.Adam()
else:
opt = tf.keras.optimizers.Adam()

hook.wrap_optimizer(opt)

n_epochs = 1
Expand Down

0 comments on commit 167fa3a

Please sign in to comment.