Skip to content

Commit

Permalink
Changes to support tensorflow 2.12 (#652)
Browse files Browse the repository at this point in the history
* changes to support tensorflow 2.12

* format change

* updagrade protobuf version for tf 212
  • Loading branch information
yl-to authored Mar 11, 2023
1 parent 56fabe5 commit 6cb0d55
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
FRAMEWORKS = ["tensorflow", "pytorch", "mxnet", "xgboost"]
TESTS_PACKAGES = ["pytest", "torchvision", "pandas"]
INSTALL_REQUIRES = [
"protobuf>=3.20.0,<=3.20.2",
"protobuf>=3.20.0,<=3.20.3",
"numpy>=1.16.0",
"packaging",
"boto3>=1.10.32",
Expand Down
5 changes: 5 additions & 0 deletions smdebug/core/tfevent/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ enum DataType {
DT_VARIANT = 21; // Arbitrary C++ data types
DT_UINT32 = 22;
DT_UINT64 = 23;
DT_FLOAT8_E5M2 = 24; // 5 exponent bits, 2 mantissa bits.
DT_FLOAT8_E4M3FN = 25; // 4 exponent bits, 3 mantissa bits, finite-only, with
// 2 NaNs (0bS1111111).

// TODO(josh11b): DT_GENERIC_PROTO = ??;
// TODO(jeff,josh11b): DT_UINT64? DT_UINT32?
Expand Down Expand Up @@ -67,5 +70,7 @@ enum DataType {
DT_VARIANT_REF = 121;
DT_UINT32_REF = 122;
DT_UINT64_REF = 123;
DT_FLOAT8_E5M2_REF = 124;
DT_FLOAT8_E4M3FN_REF = 125;
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
5 changes: 5 additions & 0 deletions smdebug/core/tfevent/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Third Party
import numpy as np
from tensorflow.python.lib.core import _pywrap_float8

# First Party
from smdebug.core.logger import get_logger
Expand All @@ -9,6 +10,7 @@
from .proto.tensor_pb2 import TensorProto
from .proto.tensor_shape_pb2 import TensorShapeProto


logger = get_logger()

# hash value of ndarray.dtype is not the same as np.float class
Expand All @@ -33,13 +35,16 @@
np.dtype([("qint16", "<i2")]): "DT_QINT16",
np.dtype([("quint16", "<u2")]): "DT_UINT16",
np.dtype([("qint32", "<i4")]): "DT_INT32",
np.dtype(_pywrap_float8.TF_float8_e5m2_type()): "DT_FLOAT8_E5M2",
np.dtype(_pywrap_float8.TF_float8_e4m3fn_type()): "DT_FLOAT8_E4M3FN",
}


def _get_proto_dtype(npdtype):
if hasattr(npdtype, "kind"):
if npdtype.kind == "U" or npdtype.kind == "O" or npdtype.kind == "S":
return False, "DT_STRING"

try:
return True, _NP_DATATYPE_TO_PROTO_DATATYPE[npdtype]
except KeyError:
Expand Down
9 changes: 3 additions & 6 deletions tests/tensorflow2/test_embedding_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,19 @@ def train(out_dir):
encoder.fit(y_train)
encoded_Y = encoder.transform(y_train)
sub_classing_model = ModelSubClassing(hook)

hook.register_model(sub_classing_model)

optimizer = Adam()
optimizer = hook.wrap_optimizer(optimizer)

sub_classing_model.compile(optimizer=optimizer, loss=tf.keras.losses.BinaryCrossentropy(),
run_eagerly=True)

sub_classing_model.fit(x_train, encoded_Y, batch_size=128, epochs=1, callbacks=[hook])


def test_embedding_grad(out_dir):
train(out_dir)
trial = smd.create_trial(path=out_dir)
output = ['gradients/model_sub_classing/dense/biasGrad',
'gradients/model_sub_classing/dense/kernelGrad',
'gradients/model_sub_classing/dense_1/biasGrad',
'gradients/model_sub_classing/dense_1/kernelGrad',
'gradients/model_sub_classing/embedding/embeddingsGrad']
assert trial.tensor_names(collection="gradients") == output
assert len(trial.tensor_names(collection="gradients")) == 5
10 changes: 7 additions & 3 deletions tests/tensorflow2/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
# Third Party
import pytest
import tensorflow.compat.v2 as tf
from tests.tensorflow2.utils import is_tf_version_greater_than_2_4_x
from tests.tensorflow2.utils import is_tf_version_greater_than_2_4_x, is_greater_than_tf_2_11
from tests.zero_code_change.tf_utils import get_estimator, get_input_fns

# First Party
import smdebug.tensorflow as smd
from smdebug.core.collection import CollectionKeys


@pytest.mark.skipif(
is_greater_than_tf_2_11(), reason="Unsupported with TF 2.12 due to breaking changes"
)
@pytest.mark.parametrize("saveall", [True, False])
def test_estimator(out_dir, tf_eager_mode, saveall):
""" Works as intended. """
Expand All @@ -19,15 +22,13 @@ def test_estimator(out_dir, tf_eager_mode, saveall):
tf.keras.backend.clear_session()
mnist_classifier = get_estimator()
train_input_fn, eval_input_fn = get_input_fns()

# Train and evaluate
train_steps, eval_steps = 8, 2
hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall)
hook.set_mode(mode=smd.modes.TRAIN)
mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook])
hook.set_mode(mode=smd.modes.EVAL)
mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook])

# Check that hook created and tensors saved
trial = smd.create_trial(path=out_dir)
tnames = trial.tensor_names()
Expand All @@ -48,6 +49,9 @@ def test_estimator(out_dir, tf_eager_mode, saveall):
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1


@pytest.mark.skipif(
is_greater_than_tf_2_11(), reason="Unsupported with TF 2.12 due to breaking changes"
)
@pytest.mark.parametrize("saveall", [True, False])
def test_linear_classifier(out_dir, tf_eager_mode, saveall):
""" Works as intended. """
Expand Down
10 changes: 2 additions & 8 deletions tests/tensorflow2/test_grad_tape_tf_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,8 @@ def train_step(images, labels):

trial = smd.create_trial(out_dir)
assert trial.tensor_names(collection=CollectionKeys.LOSSES) == ["loss"]
assert trial.tensor_names(collection=CollectionKeys.WEIGHTS) == [
"weights/dense/kernel:0",
"weights/dense_1/kernel:0",
]
assert trial.tensor_names(collection=CollectionKeys.BIASES) == [
"weights/dense/bias:0",
"weights/dense_1/bias:0",
]
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
assert trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES) == [
"Adam/beta_1:0",
"Adam/beta_2:0",
Expand Down
5 changes: 4 additions & 1 deletion tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tests.tensorflow2.utils import is_greater_than_tf_2_2, is_tf_2_3, is_tf_2_6
from tests.tensorflow2.utils import is_greater_than_tf_2_2, is_tf_2_3, is_tf_2_6, is_greater_than_tf_2_11
from tests.tensorflow.utils import create_trial_fast_refresh
from tests.utils import verify_shapes
from packaging import version
Expand Down Expand Up @@ -890,6 +890,9 @@ def test_save_tensors(out_dir, tf_eager_mode):
assert trial.tensor(tname).value(0) is not None


@pytest.mark.skipif(
is_greater_than_tf_2_11(), reason="Unsupported with TF 2.12 due to breaking changes"
)
def test_keras_to_estimator(out_dir, tf_eager_mode):
if not tf_eager_mode:
tf.compat.v1.disable_eager_execution()
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow2/test_tensorflow2_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_tensorflow2_datatypes():
# _NP_TO_TF contains all the mappings
# of numpy to tf types
try:
from tensorflow.python import _pywrap_bfloat16
from tensorflow.python.lib.core import _pywrap_bfloat16

# TF 2.x.x Implements a Custom Numpy Datatype for Brain Floating Type
# Which is currently only supported on TPUs
Expand Down
7 changes: 7 additions & 0 deletions tests/tensorflow2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def is_greater_than_tf_2_2():
return True
return False

def is_greater_than_tf_2_11():
"""
Session hook is deprecated since 2.12, so we do skipping all session hook related tests
"""
if TF_VERSION >= version.parse("2.12.0") or 'rc' in tf.__version__:
return True
return False

def is_tf_2_6():
if TF_VERSION >= version.parse("2.6.0"):
Expand Down

0 comments on commit 6cb0d55

Please sign in to comment.