Skip to content

Commit

Permalink
Merge pull request #496 from yahoo/leewyang_gpu_fix
Browse files Browse the repository at this point in the history
detect TF version w/o importing
  • Loading branch information
leewyang authored Jan 21, 2020
2 parents e989e53 + 6ca75e0 commit 28c03b0
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
8 changes: 6 additions & 2 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

import getpass
import logging
import pkg_resources

from packaging import version
from six.moves.queue import Empty
from . import compat, marker

logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version


def hdfs_path(ctx, path):
Expand Down Expand Up @@ -79,11 +81,10 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
A tuple of (cluster_spec, server)
"""
import os
import tensorflow as tf
import time
from . import gpu_info

if version.parse(tf.__version__) >= version.parse("2.0.0"):
if version.parse(TF_VERSION) >= version.parse("2.0.0"):
raise Exception("DEPRECATED: Use higher-level APIs like `tf.keras` or `tf.estimator`")

logging.info("{0}: ======== {1}:{2} ========".format(ctx.worker_num, ctx.job_name, ctx.task_index))
Expand Down Expand Up @@ -115,6 +116,9 @@ def start_cluster_server(ctx, num_gpus=1, rdma=False):
# Set GPU device to use for TensorFlow
os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use

# Import tensorflow after gpu allocation
import tensorflow as tf

# Create a cluster from the parameter server and worker hosts.
cluster = tf.train.ClusterSpec(cluster_spec)

Expand Down
13 changes: 6 additions & 7 deletions tensorflowonspark/TFSparkNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@
import logging
import multiprocessing
import os
import pkg_resources
import platform
import socket
import subprocess
import sys
import uuid
import time
import traceback
from packaging import version
from threading import Thread

from . import TFManager
from . import TFNode
from . import compat
from . import gpu_info
from . import marker
from . import reservation
from . import util

logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version


class TFNodeContext:
Expand Down Expand Up @@ -137,15 +139,12 @@ def run(fn, tf_args, cluster_meta, tensorboard, log_dir, queues, background):
A nodeRDD.mapPartitions() function.
"""
def _mapfn(iter):
import tensorflow as tf
from packaging import version

# Note: consuming the input iterator helps Pyspark re-use this worker,
for i in iter:
executor_id = i

# check that there are enough available GPUs (if using tensorflow-gpu) before committing reservation on this node
if compat.is_gpu_available():
if gpu_info.is_gpu_available():
num_gpus = tf_args.num_gpus if 'num_gpus' in tf_args else 1
gpus_to_use = gpu_info.get_gpus(num_gpus)

Expand Down Expand Up @@ -227,7 +226,7 @@ def _mapfn(iter):
raise Exception("Unable to find 'tensorboard' in: {}".format(search_path))

# launch tensorboard
if version.parse(tf.__version__) >= version.parse('2.0.0'):
if version.parse(TF_VERSION) >= version.parse('2.0.0'):
tb_proc = subprocess.Popen([pypath, tb_path, "--reload_multifile=True", "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
else:
tb_proc = subprocess.Popen([pypath, tb_path, "--logdir=%s" % logdir, "--port=%d" % tb_port], env=os.environ)
Expand Down Expand Up @@ -296,7 +295,7 @@ def _mapfn(iter):
os.environ['TF_CONFIG'] = tf_config

# reserve GPU(s) again, just before launching TF process (in case situation has changed)
if compat.is_gpu_available():
if gpu_info.is_gpu_available():
# compute my index relative to other nodes on the same host (for GPU allocation)
my_addr = cluster_spec[job_name][task_index]
my_host = my_addr.split(':')[0]
Expand Down
9 changes: 9 additions & 0 deletions tensorflowonspark/gpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ def _get_gpu():
return gpu


def is_gpu_available():
"""Determine if GPUs are available on the host"""
try:
subprocess.check_output(["nvidia-smi", "--list-gpus"])
return True
except Exception:
return False


def get_gpus(num_gpu=1, worker_index=-1):
"""Get list of free GPUs according to nvidia-smi.
Expand Down
14 changes: 9 additions & 5 deletions tensorflowonspark/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@
import argparse
import copy
import logging
import pkg_resources
import sys
import tensorflow as tf

from . import TFCluster, util
from packaging import version


logger = logging.getLogger(__name__)
TF_VERSION = pkg_resources.get_distribution('tensorflow').version


# TensorFlowOnSpark Params
Expand Down Expand Up @@ -370,7 +371,7 @@ def __init__(self, train_fn, tf_args, export_fn=None):
self.train_fn = train_fn
self.args = Namespace(tf_args)

master_node = 'chief' if version.parse(tf.__version__) >= version.parse("2.0.0") else None
master_node = 'chief' if version.parse(TF_VERSION) >= version.parse("2.0.0") else None
self._setDefault(input_mapping={},
cluster_size=1,
num_ps=0,
Expand Down Expand Up @@ -413,7 +414,7 @@ def _fit(self, dataset):
cluster.shutdown(grace_secs=self.getGraceSecs())

if self.export_fn:
if version.parse(tf.__version__) < version.parse("2.0.0"):
if version.parse(TF_VERSION) < version.parse("2.0.0"):
# For TF1.x, run export function, if provided
assert local_args.export_dir, "Export function requires --export_dir to be set"
logging.info("Exporting saved_model (via export_fn) to: {}".format(local_args.export_dir))
Expand Down Expand Up @@ -480,7 +481,7 @@ def _transform(self, dataset):

tf_args = self.args.argv if self.args.argv else local_args

_run_model = _run_model_tf1 if version.parse(tf.__version__) < version.parse("2.0.0") else _run_model_tf2
_run_model = _run_model_tf1 if version.parse(TF_VERSION) < version.parse("2.0.0") else _run_model_tf2
rdd_out = dataset.select(input_cols).rdd.mapPartitions(lambda it: _run_model(it, local_args, tf_args))

# convert to a DataFrame-friendly format
Expand Down Expand Up @@ -516,7 +517,7 @@ def _run_model_tf1(iterator, args, tf_args):
output_tensor_names = [tensor for tensor, col in sorted(args.output_mapping.items())]

# if using a signature_def_key, get input/output tensor info from the requested signature
if version.parse(tf.__version__) < version.parse("2.0.0") and args.signature_def_key:
if version.parse(TF_VERSION) < version.parse("2.0.0") and args.signature_def_key:
assert args.export_dir, "Inferencing with signature_def_key requires --export_dir argument"
logging.info("===== loading meta_graph_def for tag_set ({0}) from saved_model: {1}".format(args.tag_set, args.export_dir))
meta_graph_def = get_meta_graph_def(args.export_dir, args.tag_set)
Expand All @@ -534,6 +535,7 @@ def _run_model_tf1(iterator, args, tf_args):
sess = global_sess
else:
# otherwise, create new session and load graph from disk
import tensorflow as tf
tf.reset_default_graph()
sess = tf.Session(graph=tf.get_default_graph())
if args.export_dir:
Expand Down Expand Up @@ -584,6 +586,8 @@ def _run_model_tf2(iterator, args, tf_args):
"""mapPartitions function (for TF2.x) to run single-node inferencing from a saved_model, using input/output mappings."""
single_node_env(tf_args)

import tensorflow as tf

logger.info("===== input_mapping: {}".format(args.input_mapping))
logger.info("===== output_mapping: {}".format(args.output_mapping))
input_tensor_names = [tensor for col, tensor in sorted(args.input_mapping.items())]
Expand Down
4 changes: 2 additions & 2 deletions tensorflowonspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import subprocess
import errno
from socket import error as socket_error
from . import compat, gpu_info
from . import gpu_info

logger = logging.getLogger(__name__)

Expand All @@ -28,7 +28,7 @@ def single_node_env(num_gpus=1, worker_index=-1, nodes=[]):
os.environ['CLASSPATH'] = classpath + os.pathsep + hadoop_classpath
os.environ['TFOS_CLASSPATH_UPDATED'] = '1'

if compat.is_gpu_available() and num_gpus > 0:
if gpu_info.is_gpu_available() and num_gpus > 0:
# reserve GPU(s), if requested
if worker_index >= 0 and len(nodes) > 0:
# compute my index relative to other nodes on the same host, if known
Expand Down
2 changes: 1 addition & 1 deletion test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Note: the tests that use Spark will require a local Spark Standalone cluster (vs
export SPARK_HOME=<path_to_Spark>
export TFoS_HOME=<path_to_TFoS>
export PYTHONPATH=${SPARK_HOME}/python
export SPARK_CLASSPATH=<path_to_tensorflow-hadoop-*.jar>
export SPARK_CLASSPATH=${TFoS_HOME}/lib/tensorflow-hadoop-1.0-SNAPSHOT.jar
```
2a. Run script to automatically start Spark Standalone cluster, run all tests, and shutdown the cluster, OR
```
Expand Down

0 comments on commit 28c03b0

Please sign in to comment.