forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_loop.py
157 lines (131 loc) · 5.51 KB
/
main_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training and evaluation loops for an experiment."""
import time
from typing import Any, Mapping, Text, Type, Union
from absl import app
from absl import flags
from absl import logging
import jax
import numpy as np
from byol import byol_experiment
from byol import eval_experiment
from byol.configs import byol as byol_config
from byol.configs import eval as eval_config
flags.DEFINE_string('experiment_mode',
'pretrain', 'The experiment, pretrain or linear-eval')
flags.DEFINE_string('worker_mode', 'train', 'The mode, train or eval')
flags.DEFINE_string('worker_tpu_driver', '', 'The tpu driver to use')
flags.DEFINE_integer('pretrain_epochs', 1000, 'Number of pre-training epochs')
flags.DEFINE_integer('batch_size', 4096, 'Total batch size')
flags.DEFINE_string('checkpoint_root', '/tmp/byol',
'The directory to save checkpoints to.')
flags.DEFINE_integer('log_tensors_interval', 60, 'Log tensors every n seconds.')
FLAGS = flags.FLAGS
Experiment = Union[
Type[byol_experiment.ByolExperiment],
Type[eval_experiment.EvalExperiment]]
def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
"""The main training loop.
This loop periodically saves a checkpoint to be evaluated in the eval_loop.
Args:
experiment_class: the constructor for the experiment (either byol_experiment
or eval_experiment).
config: the experiment config.
"""
experiment = experiment_class(**config)
rng = jax.random.PRNGKey(0)
step = 0
host_id = jax.host_id()
last_logging = time.time()
if config['checkpointing_config']['use_checkpointing']:
checkpoint_data = experiment.load_checkpoint()
if checkpoint_data is None:
step = 0
else:
step, rng = checkpoint_data
local_device_count = jax.local_device_count()
while step < config['max_steps']:
step_rng, rng = tuple(jax.random.split(rng))
# Broadcast the random seeds across the devices
step_rng_device = jax.random.split(step_rng, num=jax.device_count())
step_rng_device = step_rng_device[
host_id * local_device_count:(host_id + 1) * local_device_count]
step_device = np.broadcast_to(step, [local_device_count])
# Perform a training step and get scalars to log.
scalars = experiment.step(global_step=step_device, rng=step_rng_device)
# Checkpointing and logging.
if config['checkpointing_config']['use_checkpointing']:
experiment.save_checkpoint(step, rng)
current_time = time.time()
if current_time - last_logging > FLAGS.log_tensors_interval:
logging.info('Step %d: %s', step, scalars)
last_logging = current_time
step += 1
logging.info('Saving final checkpoint')
logging.info('Step %d: %s', step, scalars)
experiment.save_checkpoint(step, rng)
def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]):
"""The main evaluation loop.
This loop periodically loads a checkpoint and evaluates its performance on the
test set, by calling experiment.evaluate.
Args:
experiment_class: the constructor for the experiment (either byol_experiment
or eval_experiment).
config: the experiment config.
"""
experiment = experiment_class(**config)
last_evaluated_step = -1
while True:
checkpoint_data = experiment.load_checkpoint()
if checkpoint_data is None:
logging.info('No checkpoint found. Waiting for 10s.')
time.sleep(10)
continue
step, _ = checkpoint_data
if step <= last_evaluated_step:
logging.info('Checkpoint at step %d already evaluated, waiting.', step)
time.sleep(10)
continue
host_id = jax.host_id()
local_device_count = jax.local_device_count()
step_device = np.broadcast_to(step, [local_device_count])
scalars = experiment.evaluate(global_step=step_device)
if host_id == 0: # Only perform logging in one host.
logging.info('Evaluation at step %d: %s', step, scalars)
last_evaluated_step = step
if last_evaluated_step >= config['max_steps']:
return
def main(_):
if FLAGS.worker_tpu_driver:
jax.config.update('jax_xla_backend', 'tpu_driver')
jax.config.update('jax_backend_target', FLAGS.worker_tpu_driver)
logging.info('Backend: %s %r', FLAGS.worker_tpu_driver, jax.devices())
if FLAGS.experiment_mode == 'pretrain':
experiment_class = byol_experiment.ByolExperiment
config = byol_config.get_config(FLAGS.pretrain_epochs, FLAGS.batch_size)
elif FLAGS.experiment_mode == 'linear-eval':
experiment_class = eval_experiment.EvalExperiment
config = eval_config.get_config(f'{FLAGS.checkpoint_root}/pretrain.pkl',
FLAGS.batch_size)
else:
raise ValueError(f'Unknown experiment mode: {FLAGS.experiment_mode}')
config['checkpointing_config']['checkpoint_dir'] = FLAGS.checkpoint_root
if FLAGS.worker_mode == 'train':
train_loop(experiment_class, config)
elif FLAGS.worker_mode == 'eval':
eval_loop(experiment_class, config)
if __name__ == '__main__':
app.run(main)