Skip to content

Commit

Permalink
Dev icenet-ai#252: trying to solve some kind of deadlocking?
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed May 23, 2024
1 parent 39ccc6c commit 72f4b53
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
6 changes: 2 additions & 4 deletions icenet/model/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self,
dataset.identifier,
seed))

self._callbacks = list() if callbacks_default is None else self.get_callbacks()
self._callbacks = list() if callbacks_default is None else callbacks_default
self._callbacks += callbacks_additional if callbacks_additional is not None else []
self._dataset = dataset
self._run_name = run_name
Expand All @@ -47,11 +47,9 @@ def _attempt_seed_setup(self):
random.seed(self._seed)

def add_callback(self, callback):
logging.debug("Adding callback {}".format(callback))
self._callbacks.append(callback)

def get_callbacks(self):
return list()

@abstractmethod
def train(self,
dataset: object,
Expand Down
22 changes: 12 additions & 10 deletions icenet/model/networks/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
import tensorflow as tf
import horovod.tensorflow.keras as hvd

from tensorflow.keras.callbacks import \
EarlyStopping, ModelCheckpoint, LearningRateScheduler
Expand Down Expand Up @@ -150,10 +151,8 @@ def __init__(self,
**kwargs):
super().__init__(*args, **kwargs)

import horovod.tensorflow.keras as hvd
hvd.init()

if device_type in ("XPU", "GPU"):
logging.debug("Setting up {} devices".format(device_type))
devices = tf.config.list_physical_devices(device_type)
logging.info("{} count is {}".format(device_type, len(devices)))

Expand All @@ -166,7 +165,6 @@ def __init__(self,
self.add_callback(
hvd.callbacks.BroadcastGlobalVariablesCallback(0)
)
self._horovod = hvd

def train(self,
epochs: int,
Expand All @@ -181,25 +179,29 @@ def train(self,
self.run_name, self.seed))

# TODO: this is totally assuming the structure of model_creator :(
logging.debug("Calling {} to create our model".format(model_creator))
network = model_creator(**model_creator_kwargs,
custom_optimizer=self._horovod.DistributedOptimizer(Adam(learning_rate)),
custom_optimizer=hvd.DistributedOptimizer(Adam(model_creator_kwargs["learning_rate"])),
experimental_run_tf_function=False)

logging.debug("Created model for rank {}".format(hvd.rank()))

if self._pre_load_path and os.path.exists(self._pre_load_path):
logging.warning("Automagically loading network weights from {}".format(
self._pre_load_path))
network.load_weights(self._pre_load_path)

network.summary()
if model_creator_kwargs["horovod"].rank() == 0:
network.summary()

logging.debug("Calling training loop")
model_history = network.fit(
train_dataset,
epochs=epochs,
verbose=1 if self._horovod.rank() == 0 and self._verbose else 0,
verbose=1 if hvd.rank() == 0 and self._verbose else 0,
callbacks=self.callbacks,
validation_data=validation_dataset,
max_queue_size=self._data_queue_size,
steps_per_epoch=self.dataset.counts["train"] // (self.dataset.batch_size * self._horovod.size()),
steps_per_epoch=self.dataset.counts["train"] // (self.dataset.batch_size * hvd.size()),
)

if save:
Expand All @@ -222,7 +224,7 @@ def unet_batchnorm(input_shape: object,
filter_size: float = 3,
n_filters_factor: float = 1,
n_forecast_days: int = 1,
legacy_rounding: bool = False) -> object:
legacy_rounding: bool = True) -> object:
"""
:param input_shape:
Expand Down
2 changes: 2 additions & 0 deletions icenet/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time

import tensorflow as tf
import horovod.tensorflow.keras as hvd
hvd.init()

from icenet.data.dataset import IceNetDataSet, MergedIceNetDataSet
from icenet.model.cli import TrainingArgParser
Expand Down

0 comments on commit 72f4b53

Please sign in to comment.