Skip to content

Commit

Permalink
Dev icenet-ai#252: adding further amendments to stabilise MPI based t…
Browse files Browse the repository at this point in the history
…raining with horovod
  • Loading branch information
JimCircadian committed May 24, 2024
1 parent 10f53cd commit babef09
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 9 deletions.
4 changes: 4 additions & 0 deletions icenet/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def add_tensorflow(self):
return self

def add_horovod(self):
self.add_argument("--no-horovod",
dest="horovod",
default=True,
action="store_false")
self.add_argument("--device-type",
default=None,
help="Choose a device type to detect, if using")
Expand Down
10 changes: 10 additions & 0 deletions icenet/model/handlers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

def init_wandb(cli_args):
if wandb_available:
if cli_args.horovod:
try:
import horovod.tensorflow.keras as hvd
except ModuleNotFoundError:
raise RuntimeError("We're running horovod jobs without the module, eh?")

if hvd.rank() > 0:
logging.info("Not initialising wandb for rank {}".format(hvd.rank()))
return

logging.warning("Initialising WANDB for this run at user request")

run = wandb.init(
Expand Down
12 changes: 7 additions & 5 deletions icenet/model/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

class BaseNetwork:
def __init__(self,
dataset: object,
run_name: object,
dataset: object,
callbacks_additional: list = None,
callbacks_default: list = None,
network_folder: object = None,
Expand All @@ -25,10 +25,10 @@ def __init__(self,

self._model_path = os.path.join(
self._network_folder, "{}.model_{}.{}".format(run_name,
dataset.identifier,
seed))
dataset.identifier,
seed))

self._callbacks = list() if callbacks_default is None else callbacks_default
self._callbacks = self.get_default_callbacks() if callbacks_default is None else callbacks_default
self._callbacks += callbacks_additional if callbacks_additional is not None else []
self._dataset = dataset
self._run_name = run_name
Expand All @@ -50,9 +50,11 @@ def add_callback(self, callback):
logging.debug("Adding callback {}".format(callback))
self._callbacks.append(callback)

def get_default_callbacks(self):
return list()

@abstractmethod
def train(self,
dataset: object,
epochs: int,
model_creator: callable,
train_dataset: object,
Expand Down
9 changes: 6 additions & 3 deletions icenet/model/networks/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ def train(self,
with open(history_path, 'w') as fh:
pd.DataFrame(model_history.history).to_json(fh)

def get_callbacks(self):
def get_default_callbacks(self):
callbacks_list = list()

if self._checkpoint_monitor is not None:
logging.info("Adding ModelCheckpoint callback")
callbacks_list.append(
ModelCheckpoint(filepath=self._weights_path,
monitor=self._checkpoint_monitor,
Expand All @@ -115,6 +116,7 @@ def get_callbacks(self):
save_best_only=True))

if self._early_stopping_patience > 0:
logging.info("Adding EarlyStopping callback")
callbacks_list.append(
EarlyStopping(monitor=self._checkpoint_monitor,
mode=self._checkpoint_mode,
Expand All @@ -123,6 +125,7 @@ def get_callbacks(self):
baseline=None))

if self._lr_decay[0] > 0:
logging.info("ADding LearningRateScheduler callback")
lr_decay = -0.1 * np.log(self._lr_decay[0])

callbacks_list.append(
Expand Down Expand Up @@ -180,11 +183,11 @@ def train(self,

logging.debug("Calling training loop")
model_history = network.fit(
train_dataset,
train_dataset.repeat(),
epochs=epochs,
verbose=1 if hvd.rank() == 0 and self._verbose else 0,
callbacks=self.callbacks,
validation_data=validation_dataset,
validation_data=validation_dataset.repeat(),
max_queue_size=self._data_queue_size,
steps_per_epoch=self.dataset.counts["train"] // (self.dataset.batch_size * hvd.size()),
)
Expand Down
6 changes: 5 additions & 1 deletion icenet/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import time

import tensorflow as tf
import horovod.tensorflow.keras as hvd

try:
import horovod.tensorflow.keras as hvd
except ModuleNotFoundError:
pass

from icenet.data.dataset import IceNetDataSet, MergedIceNetDataSet
from icenet.model.cli import TrainingArgParser
Expand Down

0 comments on commit babef09

Please sign in to comment.