diff --git a/tests/tensorflow2/test_grad_tape_tf_function.py b/tests/tensorflow2/test_grad_tape_tf_function.py index 681287b14..f35c8d50d 100644 --- a/tests/tensorflow2/test_grad_tape_tf_function.py +++ b/tests/tensorflow2/test_grad_tape_tf_function.py @@ -1,5 +1,7 @@ # Third Party import tensorflow as tf +import keras +from packaging import version # First Party import smdebug.tensorflow as smd @@ -41,7 +43,11 @@ def train_step(images, labels): dataset = dataset.shuffle(1000).batch(64) model = create_model() hook = create_hook(out_dir) - opt = tf.keras.optimizers.Adam() + if version.parse(tf.__version__) >= version.parse("2.11.0") or "rc" in tf.__version__: + opt = keras.optimizers.optimizer_v2.adam.Adam() + else: + opt = tf.keras.optimizers.Adam() + hook.wrap_optimizer(opt) n_epochs = 1