Skip to content

Commit

Permalink
Merge pull request google#1696 from lukaszkaiser:lk
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 396062436
  • Loading branch information
copybara-github committed Sep 11, 2021
2 parents a24ed37 + b6f5b06 commit 6151599
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 26 deletions.
6 changes: 3 additions & 3 deletions trax/data/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def UniformlySeek(name=None, host_id=None, n_hosts=None, dataset_size=None): #
logging.error(
'No dataset size given to Uniformly seek, assuming: %d', dataset_size)
assert name
host_id = jax.host_id() if host_id is None else host_id
host_id = jax.process_index() if host_id is None else host_id
n_hosts = n_hosts or jax.host_count()
each_host = int(dataset_size / n_hosts)
def _f(generator):
Expand Down Expand Up @@ -1095,7 +1095,7 @@ def count_and_skip(generator, name):
def save_data_counters(output_dir, host_id=None):
"""Checkpoint data counters."""
global data_counters
host_id = jax.host_id() if host_id is None else host_id
host_id = jax.process_index() if host_id is None else host_id
fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id)
with tf.io.gfile.GFile(fname, 'wb') as f:
pickle.dump(data_counters, f)
Expand All @@ -1104,7 +1104,7 @@ def save_data_counters(output_dir, host_id=None):
def load_data_counters(output_dir, host_id=None):
"""Checkpoint data counters."""
global data_counters
host_id = jax.host_id() if host_id is None else host_id
host_id = jax.process_index() if host_id is None else host_id
fname = os.path.join(output_dir, 'data_counters%d.pkl' % host_id)
if not tf.io.gfile.exists(fname):
logging.info('Did not load data counters as %s does not exist.', fname)
Expand Down
53 changes: 32 additions & 21 deletions trax/data/tf_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import jax
import numpy as np
import scipy
import t5.data
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
Expand All @@ -47,6 +46,18 @@
_MAX_SKIP_EXAMPLES = 1e5


def t5_data():
"""Get the T5 data module if available."""
module = None
try:
import t5.data # pylint: disable=g-import-not-at-top
module = t5.data
except AttributeError as e:
logging.error('pip install t5')
raise e
return module


def no_preprocess(dataset, training):
del training
return dataset
Expand Down Expand Up @@ -315,7 +326,7 @@ def TFDS( # pylint: disable=invalid-name
"""
data_dir = download_and_prepare(dataset_name, data_dir)

host_id = jax.host_id() if host_id is None else host_id
host_id = jax.process_index() if host_id is None else host_id
n_hosts = n_hosts or jax.host_count()
if n_hosts > 1:
subsplit = (host_id / n_hosts, (host_id + 1) / n_hosts)
Expand Down Expand Up @@ -589,8 +600,8 @@ def _get_vocab(vocab_type='subword', vocab_file=None, vocab_dir=None,
return text_encoder.BertEncoder(path, do_lower_case=True)

assert vocab_type == 'sentencepiece'
return t5.data.SentencePieceVocabulary(sentencepiece_model_file=path,
extra_ids=extra_ids)
return t5_data().SentencePieceVocabulary(sentencepiece_model_file=path,
extra_ids=extra_ids)


# Makes the function accessible in gin configs, even with all args denylisted.
Expand Down Expand Up @@ -849,7 +860,7 @@ def concat_and_add_mask(features, targets):

def sentencepiece_tokenize(stream, spm_path=None, extra_ids=0):
"""Sentencepiece tokenization."""
spm_path = spm_path or t5.data.DEFAULT_SPM_PATH
spm_path = spm_path or t5_data().DEFAULT_SPM_PATH
vocab_file = os.path.basename(spm_path)
vocab_dir = os.path.dirname(spm_path)
vocab = _get_vocab(vocab_type='sentencepiece',
Expand Down Expand Up @@ -898,7 +909,7 @@ def spc_tokenize(tokenizer, features, targets):
return features, features['targets']

if tokenization == 'spc':
spm_path = spm_path or t5.data.DEFAULT_SPM_PATH
spm_path = spm_path or t5_data().DEFAULT_SPM_PATH
with tf.compat.v1.gfile.GFile(spm_path, 'rb') as f:
spc_model = f.read()
tokenizer = tf_text.SentencepieceTokenizer(model=spc_model)
Expand All @@ -923,18 +934,18 @@ def c4_bare_preprocess_fn(dataset,
sequence_length=None):
"""Returns a dataset that contains 'inputs' and 'targets' from C4."""
# Set target key to be equal to the text content.
dataset = t5.data.preprocessors.rekey(
dataset = t5_data().preprocessors.rekey(
dataset, key_map={
'targets': 'text',
'inputs': None
})

# Vocabulary for tokenization.
extra_ids = 0
vocab = t5.data.SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH,
vocab = t5_data().SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH,
extra_ids=extra_ids)
feature = t5.data.Feature(vocab)
feature = t5_data().Feature(vocab)
output_features = {'targets': feature, 'inputs': feature}

# Tokenize the targets.
Expand Down Expand Up @@ -963,7 +974,7 @@ def encode_string_features_fn(features):
num_parallel_calls=tf.data.experimental.AUTOTUNE)

# Preprocess the tokens - the exact preprocessors are set via gin.
dataset = t5.data.preprocessors.unsupervised(
dataset = t5_data().preprocessors.unsupervised(
dataset, sequence_length=sequence_length, output_features=output_features)

# Add EOS.
Expand Down Expand Up @@ -1140,14 +1151,14 @@ def print_examples(x):

# Vocabulary for tokenization.
extra_ids = 0
vocab = t5.data.SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5.data.DEFAULT_SPM_PATH,
vocab = t5_data().SentencePieceVocabulary(
sentencepiece_model_file=spm_path or t5_data().DEFAULT_SPM_PATH,
extra_ids=extra_ids)
feature = t5.data.Feature(vocab)
feature = t5_data().Feature(vocab)
output_features = {'targets': feature, 'inputs': feature}

# Tokenize the inputs and targets.
dataset = t5.data.preprocessors.tokenize(
dataset = t5_data().preprocessors.tokenize(
dataset, output_features, copy_pretokenized=copy_pretokenized)

# Apply the token-preprocessors.
Expand Down Expand Up @@ -1196,7 +1207,7 @@ def get_t5_preprocessor_by_name(name=None, fn_kwargs=None):
"""

assert name is not None
f = getattr(t5.data.preprocessors, name)
f = getattr(t5_data().preprocessors, name)
if fn_kwargs is not None:
f = functools.partial(f, **fn_kwargs)
return lambda ds, unused_training: f(ds)
Expand Down Expand Up @@ -1383,7 +1394,7 @@ def BertNextSentencePredictionInputs(dataset_name, # pylint: disable=invalid-na
dataset_name,
data_dir=data_dir,
tfds_preprocess_fn=functools.partial(
t5.data.preprocessors.next_sentence_prediction,
t5_data().preprocessors.next_sentence_prediction,
text_key=text_key,
label_sentences=True,
buffer_size=shuffle_size),
Expand All @@ -1409,7 +1420,7 @@ def CorpusToRandomChunks(dataset_name, num_tokens=512, train=True): # pylint: d
return TFDS(
dataset_name,
tfds_preprocess_fn=functools.partial(
t5.data.preprocessors.random_split_text,
t5_data().preprocessors.random_split_text,
max_words_per_segment=num_tokens),
train=train,
keys=['text'])
Expand Down Expand Up @@ -1644,7 +1655,7 @@ def _t5_glue_data_split_no_token(benchmark_id):
"""Returns a GLUE data split prepared with the standard T5 preprocessor."""
benchmark, split = _t5_glue_benchmark_and_split(benchmark_id)
dataset = tfds.load(name=f'glue/{benchmark}', split=split)
processed_dataset = t5.data.preprocessors.glue( # pylint: disable=g-long-lambda
processed_dataset = t5_data().preprocessors.glue( # pylint: disable=g-long-lambda
dataset,
benchmark_name=benchmark,
label_names=_GLUE_LABELS[benchmark])
Expand All @@ -1668,9 +1679,9 @@ def _t5_glue_data_split(benchmark_id):
dataset = tfds.load(name=f'glue/{benchmark}', split=split)
processed_dataset = generic_text_dataset_preprocess_fn(
dataset,
spm_path=t5.data.DEFAULT_SPM_PATH,
spm_path=t5_data().DEFAULT_SPM_PATH,
text_preprocess_fns=[
lambda ds, training: t5.data.preprocessors.glue( # pylint: disable=g-long-lambda
lambda ds, training: t5_data().preprocessors.glue( # pylint: disable=g-long-lambda
ds,
benchmark_name=benchmark,
label_names=_GLUE_LABELS[benchmark])
Expand Down
2 changes: 1 addition & 1 deletion trax/supervised/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,7 +1292,7 @@ def init_host_and_devices(n_devices=None, random_seed=None):
random_seed: The passed in value of random_seed or a computed default.
"""
if fastmath.is_backend(fastmath.Backend.JAX):
host_id = jax.host_id()
host_id = jax.process_index()
host_count = jax.host_count()
else:
host_id = 0
Expand Down
33 changes: 32 additions & 1 deletion trax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.

"""Trax trainer."""
import atexit
import datetime
import functools
import os

from absl import app
Expand All @@ -23,6 +25,7 @@

import gin
import jax
from jax.lib import xla_extension as xc
import tensorflow.compat.v2 as tf
from trax import fastmath
from trax import trainer_flags # pylint: disable=unused-import
Expand Down Expand Up @@ -95,7 +98,6 @@ def _output_dir_or_default():

# TODO(afrozm): Share between trainer.py and rl_trainer.py
def _jax_and_tf_configure_for_devices(): # pylint: disable=missing-function-docstring
jax.config.enable_omnistaging()
if FLAGS.use_tpu:
jax.config.update('jax_platform_name', 'tpu')
jax.config.update('jax_xla_backend', FLAGS.jax_xla_backend)
Expand Down Expand Up @@ -143,13 +145,42 @@ def tf_init_tpu(worker='', protocol=None):
return '/job:worker'


def _make_jax_gpu_cluster(host_id, server_ip, n_hosts, server_port=5005):
"""Make JAX GPU Cluster."""

addr = f'{server_ip}:{server_port}'
if host_id == 0:
logging.info('starting service on %s', addr)
service = xc.get_distributed_runtime_service(addr, n_hosts)
# We add an explicit call to shutdown the service via atexit as Python
# interpreter may not call the service destructor on process termination.
atexit.register(service.shutdown)

logging.info('connecting to service on %s', addr)
dist_client = xc.get_distributed_runtime_client(addr, host_id)
dist_client.connect()
atexit.register(dist_client.shutdown)

# register dist gpu backend
factory = functools.partial(jax.lib.xla_client.make_gpu_client,
dist_client, host_id)
jax.lib.xla_bridge.register_backend_factory('gpu', factory, priority=300)


def main(_):
logging.set_verbosity(FLAGS.log_level)

_tf_setup_from_flags()
_gin_parse_configs()
_jax_and_tf_configure_for_devices()

# Create a JAX GPU cluster if using JAX and given a chief IP.
if fastmath.is_backend(Backend.JAX) and FLAGS.gpu_cluster_chief_ip:
_make_jax_gpu_cluster(FLAGS.gpu_cluster_host_id,
FLAGS.gpu_cluster_chief_ip,
FLAGS.gpu_cluster_n_hosts,
FLAGS.gpu_cluster_port)

if FLAGS.disable_jit:
fastmath.disable_jit()

Expand Down
7 changes: 7 additions & 0 deletions trax/trainer_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
flags.DEFINE_string('data_dir', None, 'Path to the directory with data.')
flags.DEFINE_integer('log_level', logging.INFO, 'Log level.')

# JAX/XLA GPU cluster flags.
flags.DEFINE_string('gpu_cluster_chief_ip', '', 'IP of GPU cluster chief.')
flags.DEFINE_integer('gpu_cluster_n_hosts', 1,
'Number of hosts in GPU cluster.')
flags.DEFINE_integer('gpu_cluster_host_id', 0, 'Host id inside GPU cluster.')
flags.DEFINE_integer('gpu_cluster_port', 5005, 'Port to use in GPU cluster.')

# TensorFlow Flags
flags.DEFINE_bool('enable_eager_execution',
True,
Expand Down

0 comments on commit 6151599

Please sign in to comment.