From c795a1e9dfa8ac6a17d2016f54a5d0eec3565d73 Mon Sep 17 00:00:00 2001 From: Mathieu Guillame-Bert Date: Mon, 1 Nov 2021 08:38:59 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 406833837 --- CHANGELOG.md | 8 +- WORKSPACE | 6 ++ configure/MANIFEST.in | 1 + configure/setup.py | 4 +- documentation/distributed_training.md | 58 ++++++++++--- documentation/known_issues.md | 1 + tensorflow_decision_forests/BUILD | 11 +++ tensorflow_decision_forests/__init__.py | 2 +- tensorflow_decision_forests/keras/BUILD | 20 ++++- tensorflow_decision_forests/keras/core.py | 1 + .../keras/keras_distributed_test.py | 82 +++++++++++++++++-- .../keras/keras_test.py | 49 +---------- tensorflow_decision_forests/tensorflow/BUILD | 38 +++++---- .../tensorflow/core.py | 18 +++- .../tensorflow/distribute/BUILD | 26 ++++-- .../tensorflow/distribute/api.py | 1 + .../tensorflow/ops/training/BUILD | 2 + .../tensorflow/ops/training/feature_on_file.h | 14 +++- .../tensorflow/ops/training/kernel_on_file.cc | 2 +- .../yggdrasil_decision_forests/workspace.bzl | 20 +++-- tools/build_pip_package.sh | 9 +- tools/test_bazel.sh | 24 ++---- 22 files changed, 267 insertions(+), 130 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 932b2934..d6526ced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Changelog -## 0.2.0 - ???? +## 0.2.0 - 2021-10-29 ### Features @@ -11,8 +11,10 @@ - Add support for permutation variable importance in the GBT learner with the `compute_permutation_variable_importance` parameter. - Support for tf.int8 and tf.int16 values. -- Support for distributed gradient boosted trees learning using the - ParameterServerStrategy distribution strategy. +- Support for distributed gradient boosted trees learning. Currently, the TF + ParameterServerStrategy distribution strategy is only available in + monolithic TF-DF builds. The Yggdrasil Decision Forest GRPC distribute + strategy can be used instead. - Support for training from dataset stored on disk in CSV and RecordIO format (instead of creating a tensorflow dataset). This option is currently more efficient for distributed training (until the ParameterServerStrategy diff --git a/WORKSPACE b/WORKSPACE index 55883824..8c3dd55e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -8,6 +8,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # absl used by tensorflow. http_archive( name = "org_tensorflow", + + # sha256 = "4896b49c4088030f62b98264441475c09569ea6e49cfb270e2e1f3ef0f743a2f", + # strip_prefix = "tensorflow-2.7.0-rc1", + # urls = ["https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.7.0-rc1.zip"], + sha256 = "40d3203ab5f246d83bae328288a24209a2b85794f1b3e2cd0329458d8e7c1985", strip_prefix = "tensorflow-2.6.0", urls = ["https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.6.0.zip"], @@ -58,6 +63,7 @@ ydf_load_deps( "absl", "protobuf", "zlib", + "farmhash", ], repo_name = "@ydf", ) diff --git a/configure/MANIFEST.in b/configure/MANIFEST.in index 05b76a08..3acf4b02 100644 --- a/configure/MANIFEST.in +++ b/configure/MANIFEST.in @@ -5,3 +5,4 @@ recursive-include * *.so recursive-include * *.so.[0-9] recursive-include * *.dylib recursive-include * *.dll +recursive-include * grpc_worker_main diff --git a/configure/setup.py b/configure/setup.py index 3fc30266..ef5a3eff 100644 --- a/configure/setup.py +++ b/configure/setup.py @@ -20,7 +20,7 @@ from setuptools.command.install import install from setuptools.dist import Distribution -_VERSION = "0.1.9" +_VERSION = "0.2.0" with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -28,7 +28,7 @@ REQUIRED_PACKAGES = [ "numpy", "pandas", - "tensorflow~=2.6", + "tensorflow~=2.6", # "tensorflow >= 2.7.0rc0, < 2.8'", "six", "absl_py", "wheel", diff --git a/documentation/distributed_training.md b/documentation/distributed_training.md index e059fbb0..98a2512c 100644 --- a/documentation/distributed_training.md +++ b/documentation/distributed_training.md @@ -16,15 +16,19 @@ Distributed training makes it possible to train models quickly on larger datasets. Distributed training in TF-DF relies on the TensorFlow -ParameterServerV2 distribution strategy. Only some of the TF-DF models support -distributed training. +ParameterServerV2 distribution strategy or the Yggdrasil Decision Forest GRPC +distribute strategy. Only some of the TF-DF models support distributed training. See the [distributed training](https://github.com/google/yggdrasil-decision-forests/documentation/user_manual.md?#distributed-training) section in the Yggdrasil Decision Forests user manual for details about the -available distributed training algorithms. When using distributed training in -TF-DF, Yggdrasil Decision Forests is effectively running the `TF_DIST distribute -implementation`. +available distributed training algorithms. When using distributed training with +TF Parameter Server in TF-DF, Yggdrasil Decision Forests is effectively running +the `TF_DIST` distribute implementation. + +**Note:** Currently (Oct. 2021), the shared (i.e. != monolithic) OSS build of +TF-DF does not support TF ParameterServer distribution strategy. Please use the +Yggdrasil DF GRPC distribute strategy instead. ## Dataset @@ -40,7 +44,8 @@ As of today ( Oct 2021), the following solutions are available for TF-DF: solution is the fastest and the one that gives the best results as it is currently the only one that guarantees that each example is read only once. The downside is that this solution does not support TensorFlow - pre-processing. + pre-processing. The "Yggdrasil DF GRPC distribute strategy" only support + this option for dataset reading. 2. To use **ParameterServerV2 distributed dataset** with dataset file sharding using TF-DF worker index. This solution is the most natural for TF users. @@ -48,13 +53,11 @@ As of today ( Oct 2021), the following solutions are available for TF-DF: Currently, using ParameterServerV2 distributed dataset with context or tf.data.service are not compatible with TF-DF. -Note that in all cases, ParameterServerV2 is used to distribute the computation. - ## Examples Following are some examples of distributed training. -### Distribution with Yggdrasil distributed dataset reading +### Distribution with Yggdrasil distributed dataset reading and TF ParameterServerV2 strategy ```python import tensorflow_decision_forests as tfdf @@ -78,7 +81,7 @@ See Yggdrasil Decision Forests [supported formats](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format) for the possible values of `dataset_format`. -### Distribution with ParameterServerV2 distributed dataset +### Distribution with ParameterServerV2 distributed dataset and TF ParameterServerV2 strategy ```python import tensorflow_decision_forests as tfdf @@ -149,3 +152,38 @@ model.fit( print("Trained model") model.summary() ``` + +### Distribution with Yggdrasil distributed dataset reading and Yggdrasil DF GRPC distribute strategy + +```python +import tensorflow_decision_forests as tfdf +import tensorflow as tf + +deployment_config = tfdf.keras.core.YggdrasilDeploymentConfig() +deployment_config.try_resume_training = True +deployment_config.distribute.implementation_key = "GRPC" +socket_addresses = deployment_config.distribute.Extensions[ + tfdf.keras.core.grpc_pb2.grpc].socket_addresses + +# Socket addresses of ":grpc_worker_main" running instances. +socket_addresses.addresses.add(ip="127.0.0.1", port=2001) +socket_addresses.addresses.add(ip="127.0.0.2", port=2001) +socket_addresses.addresses.add(ip="127.0.0.3", port=2001) +socket_addresses.addresses.add(ip="127.0.0.4", port=2001) + +model = tfdf.keras.DistributedGradientBoostedTreesModel( + advanced_arguments=tfdf.keras.AdvancedArguments( + yggdrasil_deployment_config=deployment_config)) + +model.fit_on_dataset_path( + train_path="/path/to/dataset@100000", + label_key="label_key", + dataset_format="tfrecord+tfe") + +print("Trained model") +model.summary() +``` + +See Yggdrasil Decision Forests +[supported formats](https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format) +for the possible values of `dataset_format`. diff --git a/documentation/known_issues.md b/documentation/known_issues.md index ba8334e8..80852486 100644 --- a/documentation/known_issues.md +++ b/documentation/known_issues.md @@ -47,6 +47,7 @@ The following table shows the compatibility between tensorflow_decision_forests | tensorflow --------------------------- | ---------- +0.2.0 | 2.6 0.1.9 | 2.6 0.1.1 - 0.1.8 | 2.5 0.1.0 | 2.4 diff --git a/tensorflow_decision_forests/BUILD b/tensorflow_decision_forests/BUILD index 6f578c24..b8cb4030 100644 --- a/tensorflow_decision_forests/BUILD +++ b/tensorflow_decision_forests/BUILD @@ -31,3 +31,14 @@ config_setting( name = "stop_training_on_interrupt", values = {"define": "stop_training_on_interrupt=1"}, ) + +# If "disable_tf_ps_distribution_strategy" is true, the TF Parameter Server +# distribution strategy is not available for distributed training. +# +# Distribution with TF PS is currently NOT supported for OSS TF-DF with shared +# build (monolithic build works however) and TF<2.7. In this case, the GRPC +# Worker Server can be used instead. +config_setting( + name = "disable_tf_ps_distribution_strategy", + values = {"define": "tf_ps_distribution_strategy=0"}, +) diff --git a/tensorflow_decision_forests/__init__.py b/tensorflow_decision_forests/__init__.py index 25222dff..d3dd6b44 100644 --- a/tensorflow_decision_forests/__init__.py +++ b/tensorflow_decision_forests/__init__.py @@ -45,7 +45,7 @@ """ -__version__ = "0.1.9" +__version__ = "0.2.0" __author__ = "Mathieu Guillame-Bert" from tensorflow_decision_forests import keras diff --git a/tensorflow_decision_forests/keras/BUILD b/tensorflow_decision_forests/keras/BUILD index 5553cbe5..34fd8b4c 100644 --- a/tensorflow_decision_forests/keras/BUILD +++ b/tensorflow_decision_forests/keras/BUILD @@ -75,6 +75,7 @@ py_library( "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto", "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", + "@ydf//yggdrasil_decision_forests/utils/distribute/implementations/grpc:grpc_py_proto", ], ) @@ -112,13 +113,15 @@ py_test( # This test relies on the support of TF PS distribution strategy and TF-DF. # Note: TF PS distribution strategy and TF-DF are currently not compatible in non-monolithic build of TensorFlow+TFDF (e.g. OSS TFDF). +# +# This test is expected to fail TF PS distributed training is disabled (i.e. +# enabling the ":disable_tf_ps_distribution_strategy" rule). py_test( name = "keras_distributed_test", size = "large", srcs = ["keras_distributed_test.py"], data = [ - ":synthetic_dataset", - ":test_runner", + ":grpc_worker_main", "@ydf//yggdrasil_decision_forests/test_data", ], python_version = "PY3", @@ -132,10 +135,10 @@ py_test( # absl/testing:parameterized dep, # numpy dep, # pandas dep, - "//third_party/py/portpicker", + # portpicker dep, "@org_tensorflow//tensorflow/python", "@org_tensorflow//tensorflow/python/distribute:distribute_lib", - "//third_party/tensorflow_decision_forests", + "//tensorflow_decision_forests", ], ) @@ -164,3 +167,12 @@ tf_cc_binary( "@ydf//yggdrasil_decision_forests/cli/utils:synthetic_dataset_lib_with_main", ], ) + +tf_cc_binary( + name = "grpc_worker_main", + deps = [ + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:lib", + "@ydf//yggdrasil_decision_forests/utils/distribute/implementations/grpc:grpc_worker_lib_with_main", + ], +) diff --git a/tensorflow_decision_forests/keras/core.py b/tensorflow_decision_forests/keras/core.py index 6d2eefc2..a2c21b14 100644 --- a/tensorflow_decision_forests/keras/core.py +++ b/tensorflow_decision_forests/keras/core.py @@ -63,6 +63,7 @@ from yggdrasil_decision_forests.dataset import data_spec_pb2 from yggdrasil_decision_forests.learner import abstract_learner_pb2 from yggdrasil_decision_forests.model import abstract_model_pb2 # pylint: disable=unused-import +from yggdrasil_decision_forests.utils.distribute.implementations.grpc import grpc_pb2 # pylint: disable=unused-import layers = tf.keras.layers models = tf.keras.models diff --git a/tensorflow_decision_forests/keras/keras_distributed_test.py b/tensorflow_decision_forests/keras/keras_distributed_test.py index d4fdab0a..20d8718b 100644 --- a/tensorflow_decision_forests/keras/keras_distributed_test.py +++ b/tensorflow_decision_forests/keras/keras_distributed_test.py @@ -17,7 +17,8 @@ from __future__ import print_function import os -from typing import List +import subprocess +from typing import List, Tuple from absl import flags from absl import logging @@ -44,7 +45,7 @@ def tmp_path() -> str: return flags.FLAGS.test_tmpdir -def _create_in_process_cluster(num_workers, num_ps): +def _create_in_process_tf_ps_cluster(num_workers, num_ps): """Create a cluster of TF workers and returns their addresses. Such cluster simulate the behavior of multiple TF parameter servers. @@ -85,6 +86,32 @@ def _create_in_process_cluster(num_workers, num_ps): cluster_spec, rpc_layer="grpc") +def _create_in_process_grpc_worker_cluster( + num_workers) -> List[Tuple[str, int]]: + """Create a cluster of GRPC workers and returns their addresses. + + Args: + num_workers: Number of workers.. + + Returns: + List of socket addresses. + """ + + worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)] + worker_ip = "localhost" + worker_addresses = [] + + for i in range(num_workers): + worker_addresses.append((worker_ip, worker_ports[i])) + args = [ + "tensorflow_decision_forests/keras/grpc_worker_main", + "--alsologtostderr", "--port", + str(worker_ports[i]) + ] + subprocess.Popen(args, stdout=subprocess.PIPE) + return worker_addresses + + class TFDFDistributedTest(parameterized.TestCase, tf.test.TestCase): def test_distributed_training_synthetic(self): @@ -124,7 +151,7 @@ def dataset_fn(context: distribute_lib.InputContext, seed: int): return dataset # Create the workers - cluster_resolver = _create_in_process_cluster(num_workers=4, num_ps=1) + cluster_resolver = _create_in_process_tf_ps_cluster(num_workers=4, num_ps=1) # Configure the model and datasets strategy = tf.distribute.experimental.ParameterServerStrategy( @@ -191,7 +218,7 @@ def dataset_fn(input_context): dataset_creator = tf.keras.utils.experimental.DatasetCreator(dataset_fn) - cluster_resolver = _create_in_process_cluster(num_workers=2, num_ps=1) + cluster_resolver = _create_in_process_tf_ps_cluster(num_workers=2, num_ps=1) strategy = tf.distribute.experimental.ParameterServerStrategy( cluster_resolver) @@ -288,7 +315,7 @@ def extract_label(*columns): return ds_dataset # Create the workers - cluster_resolver = _create_in_process_cluster(num_workers=5, num_ps=1) + cluster_resolver = _create_in_process_tf_ps_cluster(num_workers=5, num_ps=1) # Configure the model and datasets strategy = tf.distribute.experimental.ParameterServerStrategy( @@ -335,7 +362,7 @@ def extract_label(*columns): # at different speed, some examples can be repeated. self.assertAlmostEqual(evaluation["accuracy"], 0.8603476, delta=0.02) - def test_distributed_training_adult_from_disk(self): + def test_distributed_training_adult_from_file(self): # Path to dataset. dataset_directory = os.path.join(test_data_path(), "dataset") train_path = os.path.join(dataset_directory, "adult_train.csv") @@ -344,7 +371,7 @@ def test_distributed_training_adult_from_disk(self): label = "income" # Create the workers - cluster_resolver = _create_in_process_cluster(num_workers=5, num_ps=1) + cluster_resolver = _create_in_process_tf_ps_cluster(num_workers=5, num_ps=1) # Configure the model and datasets strategy = tf.distribute.experimental.ParameterServerStrategy( @@ -378,6 +405,47 @@ def test_distributed_training_adult_from_disk(self): "capital_gain", "capital_loss", "hours_per_week", "native_country" ]) + def test_distributed_training_adult_from_file_with_grpc_worker(self): + # Path to dataset. + dataset_directory = os.path.join(test_data_path(), "dataset") + train_path = os.path.join(dataset_directory, "adult_train.csv") + test_path = os.path.join(dataset_directory, "adult_test.csv") + + label = "income" + + # Create GRPC Yggdrasil DF workers + worker_addresses = _create_in_process_grpc_worker_cluster(5) + + # Specify the socket addresses of the worker to the manager. + deployment_config = tfdf.keras.core.YggdrasilDeploymentConfig() + deployment_config.try_resume_training = True + deployment_config.distribute.implementation_key = "GRPC" + socket_addresses = deployment_config.distribute.Extensions[ + tfdf.keras.core.grpc_pb2.grpc].socket_addresses + for worker_ip, worker_port in worker_addresses: + socket_addresses.addresses.add(ip=worker_ip, port=worker_port) + + model = tfdf.keras.DistributedGradientBoostedTreesModel( + advanced_arguments=tfdf.keras.AdvancedArguments( + yggdrasil_deployment_config=deployment_config)) + model.compile(metrics=["accuracy"]) + + training_history = model.fit_on_dataset_path( + train_path=train_path, + label_key=label, + dataset_format="csv", + valid_path=test_path) + logging.info("Training history: %s", training_history.history) + + logging.info("Trained model:") + model.summary() + + test_df = pd.read_csv(test_path) + tf_test = tfdf.keras.pd_dataframe_to_tf_dataset(test_df, label) + evaluation = model.evaluate(tf_test, return_dict=True) + logging.info("Evaluation: %s", evaluation) + self.assertAlmostEqual(evaluation["accuracy"], 0.8703476, delta=0.01) + def test_in_memory_not_supported(self): dataframe = pd.DataFrame({ diff --git a/tensorflow_decision_forests/keras/keras_test.py b/tensorflow_decision_forests/keras/keras_test.py index 3e0f5b2d..453d0052 100644 --- a/tensorflow_decision_forests/keras/keras_test.py +++ b/tensorflow_decision_forests/keras/keras_test.py @@ -1225,51 +1225,6 @@ def f(a=1, b=2, c=3, explicit_args=None): f(b=6, c=7) self.assertEqual(f.last_explicit_args, set(["b", "c"])) - def test_rank1_preprocessing(self): - """Test the limitation on rank1 preprocessing.""" - - def experiment(infer_prediction_signature, save_model): - x_train = [1.0, 2.0, 3.0, 4.0] # Dataset with a single feature. - y_train = [0, 1, 0, 1] - - @tf.function( - input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),)) - def processor(x): - return x + 1 - - model = keras.RandomForestModel( - preprocessing=processor, - advanced_arguments=keras.AdvancedArguments( - infer_prediction_signature=infer_prediction_signature)) - model.fit(x=x_train, y=y_train) - - if save_model: - # Fails if the model is not build before. - model.save(os.path.join(self.get_temp_dir(), "saved_model")) - - experiment(infer_prediction_signature=False, save_model=False) - - with self.assertRaises(ValueError): - experiment(infer_prediction_signature=False, save_model=True) - - # Starting with tf2.7, Keras does not expends automatically rank 1 tensors. - keras_expends = False - try: - - def versiontuple(v): - return tuple(map(int, (v.split(".")))) - - keras_expends = versiontuple(tf.__version__) < (2, 7) - except Exception: # pylint: disable=broad-except - pass - - if not keras_expends: - # Does not expect an exception. - experiment(infer_prediction_signature=True, save_model=False) - else: - with self.assertRaises(ValueError): - experiment(infer_prediction_signature=True, save_model=False) - def test_get_all_models(self): print(keras.get_all_models()) @@ -1326,7 +1281,7 @@ def create_ds(feature_name): with self.assertRaises(ValueError): model.save(os.path.join(self.get_temp_dir(), "model")) - def test_training_adult_from_disk(self): + def test_training_adult_from_file(self): # Path to dataset. dataset_directory = os.path.join(test_data_path(), "dataset") train_path = os.path.join(dataset_directory, "adult_train.csv") @@ -1361,7 +1316,7 @@ def test_training_adult_from_disk(self): "capital_gain", "capital_loss", "hours_per_week", "native_country" ]) - def test_training_adult_from_disk_with_features(self): + def test_training_adult_from_file_with_features(self): # Path to dataset. dataset_directory = os.path.join(test_data_path(), "dataset") train_path = os.path.join(dataset_directory, "adult_train.csv") diff --git a/tensorflow_decision_forests/tensorflow/BUILD b/tensorflow_decision_forests/tensorflow/BUILD index 9f62119f..ba6b4662 100644 --- a/tensorflow_decision_forests/tensorflow/BUILD +++ b/tensorflow_decision_forests/tensorflow/BUILD @@ -34,18 +34,22 @@ cc_library( # Available engines for distributed training. cc_library( name = "distribution_engines", - deps = [ - # Distributed training with TF Parameter Server. - # - # If not registered, the following error will be raised - # when using distributed training with the TF PS and TF-DF: - # "Unknown item TF_DIST in class pool". - # - # Currently, this distribution engine is not available in custom - # non-core c++ ops in TF OSS (i.e. when TF-DF is compiled as a shared library). - # Either use monolithic build for TF+TF-DF, or use the GRPC distribution strategy. - "//tensorflow_decision_forests/tensorflow/distribute:tf_distribution", - ], + deps = ["@ydf//yggdrasil_decision_forests/utils/distribute"] + + select({ + "//tensorflow_decision_forests:disable_tf_ps_distribution_strategy": [], + "//conditions:default": [ + # Distributed training with TF Parameter Server. + # + # If not registered, the following error will be raised + # when using distributed training with the TF PS and TF-DF: + # "Unknown item TF_DIST in class pool". + # + # Currently, this distribution engine is not available in custom + # non-core c++ ops in TF OSS (i.e. when TF-DF is compiled as a shared library). + # Either use monolithic build for TF+TF-DF, or use the GRPC distribution strategy. + "//tensorflow_decision_forests/tensorflow/distribute:tf_distribution", + ], + }), alwayslink = 1, ) @@ -71,14 +75,18 @@ py_library( "@org_tensorflow//tensorflow/python/distribute:distribute_lib", "@org_tensorflow//tensorflow/python/distribute:parameter_server_strategy_v2", "@org_tensorflow//tensorflow/python/distribute/coordinator:cluster_coordinator", - "@org_tensorflow//tensorflow/python/distribute/coordinator:coordinator_context", - "//tensorflow_decision_forests/tensorflow/distribute:api_py", # Compatibility with TF Parameter Server for distribution. "//tensorflow_decision_forests/tensorflow/distribute:tf_distribution_py_proto", "//tensorflow_decision_forests/tensorflow/ops/training:api_py", "@ydf//yggdrasil_decision_forests/dataset:data_spec_py_proto", "@ydf//yggdrasil_decision_forests/learner:abstract_learner_py_proto", "@ydf//yggdrasil_decision_forests/model:abstract_model_py_proto", - ], + ] + select({ + "//tensorflow_decision_forests:disable_tf_ps_distribution_strategy": [], + "//conditions:default": [ + "@org_tensorflow//tensorflow/python/distribute/coordinator:coordinator_context", + "//tensorflow_decision_forests/tensorflow/distribute:api_py", # Compatibility with TF Parameter Server for distribution. + ], + }), ) # Tests diff --git a/tensorflow_decision_forests/tensorflow/core.py b/tensorflow_decision_forests/tensorflow/core.py index 515c8aa2..8983c31b 100644 --- a/tensorflow_decision_forests/tensorflow/core.py +++ b/tensorflow_decision_forests/tensorflow/core.py @@ -31,14 +31,20 @@ from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import parameter_server_strategy_v2 -from tensorflow.python.distribute.coordinator import coordinator_context from tensorflow_decision_forests.tensorflow.distribute import tf_distribution_pb2 from tensorflow_decision_forests.tensorflow.ops.training import api as training_op -from tensorflow_decision_forests.tensorflow.distribute import api # pylint: disable=unused-import from yggdrasil_decision_forests.dataset import data_spec_pb2 from yggdrasil_decision_forests.learner import abstract_learner_pb2 from yggdrasil_decision_forests.model import abstract_model_pb2 +try: + from tensorflow_decision_forests.tensorflow.distribute import api as distributed_api # pytype: disable=import-error + from tensorflow.python.distribute.coordinator import coordinator_context # pytype: disable=import-error +except Exception as e: + distributed_api = None + coordinator_context = None + logging.warning("TF Parameter Server distributed training not available.") + class Semantic(enum.Enum): """Semantic (e.g. @@ -288,6 +294,14 @@ def extract_label(*columns): (integer). """ + if coordinator_context is None: + raise ValueError( + "The library was compile without Parameter Server distributed training " + " support i.e. tf_ps_distribution_strategy=0. Either recompile it with " + "tf_ps_distribution_strategy=1,disable distributed training, or use " + "the Grpc Server distribution strategy. Note: TF-DF OSS release it " + "currently compiled without PS support.") + # Not used for now. del context diff --git a/tensorflow_decision_forests/tensorflow/distribute/BUILD b/tensorflow_decision_forests/tensorflow/distribute/BUILD index b7448f51..4714565c 100644 --- a/tensorflow_decision_forests/tensorflow/distribute/BUILD +++ b/tensorflow_decision_forests/tensorflow/distribute/BUILD @@ -1,9 +1,10 @@ # Implementation of the Yggdrasil Distribute API using TensorFlow Distribution Strategies. # + load("@ydf//yggdrasil_decision_forests/utils:compile.bzl", "all_proto_library") load("//tensorflow_decision_forests/tensorflow:utils.bzl", "tf_custom_op_library_external") -load("@org_tensorflow//tensorflow:tensorflow.bzl","tf_cc_test") +load("@org_tensorflow//tensorflow:tensorflow.bzl","tf_cc_test", "cc_header_only_library") package( default_visibility = ["//visibility:public"], @@ -54,14 +55,21 @@ all_proto_library( # ======= # TODO(gbm): Split the functions exported and non-exported in tensorflow.so. -TF_CC_API_DEP = [ - "@org_tensorflow//tensorflow/cc:cc_ops", - "@org_tensorflow//tensorflow/cc:client_session", - "@org_tensorflow//tensorflow/cc:ops", - "@org_tensorflow//tensorflow/cc:scope", - "@org_tensorflow//tensorflow/core:protos_all_cc", - "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_session", -] + +cc_header_only_library( + name = "tf_ops_header_lib", + deps = [ + "@org_tensorflow//tensorflow/cc:client_session", + "@org_tensorflow//tensorflow/cc:const_op", + "@org_tensorflow//tensorflow/cc:scope", + "@org_tensorflow//tensorflow/cc:cc_ops", + "@org_tensorflow//tensorflow/cc:ops", + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_session", + ], +) + +TF_CC_API_DEP = [":tf_ops_header_lib"] cc_library( name = "tf_distribution", diff --git a/tensorflow_decision_forests/tensorflow/distribute/api.py b/tensorflow_decision_forests/tensorflow/distribute/api.py index 6f340ca8..fb674bae 100644 --- a/tensorflow_decision_forests/tensorflow/distribute/api.py +++ b/tensorflow_decision_forests/tensorflow/distribute/api.py @@ -16,6 +16,7 @@ from __future__ import division from __future__ import print_function + import tensorflow as tf from tensorflow.python.framework import load_library from tensorflow.python.platform import resource_loader diff --git a/tensorflow_decision_forests/tensorflow/ops/training/BUILD b/tensorflow_decision_forests/tensorflow/ops/training/BUILD index 44331cf0..8dda2b0b 100644 --- a/tensorflow_decision_forests/tensorflow/ops/training/BUILD +++ b/tensorflow_decision_forests/tensorflow/ops/training/BUILD @@ -33,6 +33,8 @@ tf_custom_op_library_external( cc_library( name = "dataset_formats", deps = [ + "@ydf//yggdrasil_decision_forests/dataset:csv_example_reader", + "@ydf//yggdrasil_decision_forests/dataset:tf_example_io_tfrecord", "@ydf//yggdrasil_decision_forests/learner/distributed_decision_tree/dataset_cache:dataset_cache_reader", ], alwayslink = 1, diff --git a/tensorflow_decision_forests/tensorflow/ops/training/feature_on_file.h b/tensorflow_decision_forests/tensorflow/ops/training/feature_on_file.h index 6898e6fe..f59e1b99 100644 --- a/tensorflow_decision_forests/tensorflow/ops/training/feature_on_file.h +++ b/tensorflow_decision_forests/tensorflow/ops/training/feature_on_file.h @@ -31,6 +31,7 @@ #ifndef TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_FEATURE_ON_FILE_H_ #define TENSORFLOW_DECISION_FORESTS_TENSORFLOW_OPS_TRAINING_FEATURE_ON_FILE_H_ +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/path.h" @@ -205,7 +206,18 @@ class FeatureOnFileOp : public tensorflow::OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_id", &resource_id_)); dataset_already_on_disk_ = HasDoneFile(dataset_path_); - worker_idx_ = ctx->device()->parsed_name().task; + + // TODO(gbm): Use the following code when tf2.7 is released. + // worker_idx_ = ctx->device()->parsed_name().task; + + auto* device = dynamic_cast(ctx->device()); + if (device == nullptr) { + OP_REQUIRES_OK(ctx, + tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Cannot find the worker idx")); + } + worker_idx_ = device->parsed_name().task; + if (dataset_already_on_disk_) { LOG(INFO) << "Already existing dataset cache for worker #" << worker_idx_ << " on device " << ctx->device()->name(); diff --git a/tensorflow_decision_forests/tensorflow/ops/training/kernel_on_file.cc b/tensorflow_decision_forests/tensorflow/ops/training/kernel_on_file.cc index ffb50021..b954c6cb 100644 --- a/tensorflow_decision_forests/tensorflow/ops/training/kernel_on_file.cc +++ b/tensorflow_decision_forests/tensorflow/ops/training/kernel_on_file.cc @@ -148,7 +148,7 @@ class SimpleMLModelTrainerOnFile : public tensorflow::OpKernel { #endif LOG(INFO) << "Train model"; - std::optional valid_dataset_path; + absl::optional valid_dataset_path; if (!valid_dataset_path_.empty()) { valid_dataset_path = valid_dataset_path_; } diff --git a/third_party/yggdrasil_decision_forests/workspace.bzl b/third_party/yggdrasil_decision_forests/workspace.bzl index 644e82be..e4984b15 100644 --- a/third_party/yggdrasil_decision_forests/workspace.bzl +++ b/third_party/yggdrasil_decision_forests/workspace.bzl @@ -1,14 +1,16 @@ """Yggdrasil Decision Forests project.""" -def deps(): - #http_archive( - # name = "ydf", - # urls = ["https://github.com/google/yggdrasil-decision-forests/archive/refs/heads/main.zip"], - # strip_prefix = "yggdrasil-decision-forests-main", - #) +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - # You can also clone the YDF repository manually. - native.local_repository( +def deps(): + http_archive( name = "ydf", - path = "../yggdrasil_decision_forests_bazel", + urls = ["https://github.com/google/yggdrasil-decision-forests/archive/refs/heads/main.zip"], + strip_prefix = "yggdrasil-decision-forests-main", ) + + # # You can also clone the YDF repository manually. + # native.local_repository( + # name = "ydf", + # path = "../yggdrasil_decision_forests_bazel", + # ) diff --git a/tools/build_pip_package.sh b/tools/build_pip_package.sh index 11ec5e2b..b8988a32 100755 --- a/tools/build_pip_package.sh +++ b/tools/build_pip_package.sh @@ -64,8 +64,15 @@ function assemble_files() { cp ${SRCBIN}/tensorflow/ops/inference/op.py ${SRCPK}/tensorflow_decision_forests/tensorflow/ops/inference/ cp ${SRCBIN}/tensorflow/ops/training/training.so ${SRCPK}/tensorflow_decision_forests/tensorflow/ops/training/ cp ${SRCBIN}/tensorflow/ops/training/op.py ${SRCPK}/tensorflow_decision_forests/tensorflow/ops/training/ + cp ${SRCBIN}/tensorflow/distribute/distribute.so ${SRCPK}/tensorflow_decision_forests/tensorflow/distribute/ cp ${SRCBIN}/keras/wrappers.py ${SRCPK}/tensorflow_decision_forests/keras/ + # TFDF's proto wrappers. + cp ${SRCBIN}/tensorflow/distribute/tf_distribution_pb2.py ${SRCPK}/tensorflow_decision_forests/tensorflow/distribute/ + + # Distribution server binaries + cp ${SRCBIN}/keras/grpc_worker_main ${SRCPK}/tensorflow_decision_forests/keras/ + # YDF's proto wrappers. YDFSRCBIN="bazel-bin/external/ydf/yggdrasil_decision_forests" mkdir -p ${SRCPK}/yggdrasil_decision_forests @@ -73,7 +80,7 @@ function assemble_files() { find -name \*.py -exec cp --parents -prv {} ${SRCPK}/yggdrasil_decision_forests \; popd - # Add __init__.py to all exported Yggdrqasil sub-directories. + # Add __init__.py to all exported Yggdrasil sub-directories. find ${SRCPK}/yggdrasil_decision_forests -type d -exec touch {}/__init__.py \; } diff --git a/tools/test_bazel.sh b/tools/test_bazel.sh index 84227af1..8800e393 100755 --- a/tools/test_bazel.sh +++ b/tools/test_bazel.sh @@ -28,32 +28,20 @@ BAZEL=bazel-3.7.2 # TENSORFLOW_BAZELRC="${HOME}/git/tf_bazelrc" # Alternatively, download bazelrc: -# .bazelrc of TF 2.6.0. This value should match the TF version in the "WORKSPACE" file. +# .bazelrc of TF v2.7.0-rc1 This value should match the TF version in the "WORKSPACE" file. TENSORFLOW_BAZELRC="tensorflow_bazelrc" -wget https://raw.githubusercontent.com/tensorflow/tensorflow/r2.6/.bazelrc -O ${TENSORFLOW_BAZELRC} +wget https://raw.githubusercontent.com/tensorflow/tensorflow/v2.7.0-rc1/.bazelrc -O ${TENSORFLOW_BAZELRC} # copybara:strip_begin # First follow the instruction: go/tf-rbe-guide # copybara:strip_end -FLAGS="--config=linux --config=rbe_cpu_linux --config=tensorflow_testing_rbe_linux --config=rbe_linux_py3" - -# Uncomment the following line to generate a sharable pip package. -# You will also need to install the dockers described in: -# https://github.com/tensorflow/custom-op -# -# For =TF2.7.0 -# FLAGS="${FLAGS} --crosstool_top=@ubuntu18.04-gcc7_manylinux2010-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" - +FLAGS="--config=linux --config=rbe_cpu_linux --config=tensorflow_testing_rbe_linux --config=rbe_linux_py3 --define tf_ps_distribution_strategy=0" ${BAZEL} --bazelrc=${TENSORFLOW_BAZELRC} build \ //tensorflow_decision_forests/...:all \ ${FLAGS} -# TEMPORARY: Tests do not pass with Cloud RBE because the wrong version of -# pandas is installed. -# ${BAZEL} --bazelrc=${TENSORFLOW_BAZELRC} test \ -# //tensorflow_decision_forests/...:all \ -# ${FLAGS} +${BAZEL} --bazelrc=${TENSORFLOW_BAZELRC} test \ + //tensorflow_decision_forests/...:all \ + ${FLAGS}