Skip to content

Commit

Permalink
Dev icenet-ai#252: testing with horovod
Browse files Browse the repository at this point in the history
  • Loading branch information
JimCircadian committed May 23, 2024
1 parent 5f56a1f commit 39ccc6c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 19 deletions.
6 changes: 1 addition & 5 deletions icenet/model/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,9 @@ def add_tensorflow(self):
return self

def add_horovod(self):
self.add_argument("-hv",
"--horovod",
default=False,
action="store_true")
self.add_argument("--device-type",
default=None,
help="Choose a device type for distribution, if using")
help="Choose a device type to detect, if using")
return self

def add_wandb(self):
Expand Down
25 changes: 13 additions & 12 deletions icenet/model/networks/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def train(self,
if save:
logging.info("Saving network to: {}".format(self._weights_path))
network.save_weights(self._weights_path)
logging.info("Saving model to: {}".format(self.model_path))
save_model(network, self.model_path)

with open(history_path, 'w') as fh:
Expand Down Expand Up @@ -145,19 +146,22 @@ def get_callbacks(self):
class HorovodNetwork(TensorflowNetwork):
def __init__(self,
*args,
device_type="XPU",
device_type: str = None,
**kwargs):
super().__init__(*args, **kwargs)

import horovod.tensorflow.keras as hvd
hvd.init()

gpus = tf.config.list_physical_devices(device_type)
logging.info("{} count is {}".format(device_type, len(gpus)))
if device_type in ("XPU", "GPU"):
devices = tf.config.list_physical_devices(device_type)
logging.info("{} count is {}".format(device_type, len(devices)))

for dev in devices:
tf.config.experimental.set_memory_growth(dev, True)

for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'XPU')
if devices:
tf.config.experimental.set_visible_devices(devices[hvd.local_rank()], device_type)

self.add_callback(
hvd.callbacks.BroadcastGlobalVariablesCallback(0)
Expand All @@ -166,12 +170,9 @@ def __init__(self,

def train(self,
epochs: int,
learning_rate: float,
loss: object,
metrics: object,
model_creator: callable,
train_dataset: object,
model_creator_args: dict = None,
model_creator_kwargs: dict = None,
save: bool = True,
validation_dataset: object = None):

Expand All @@ -180,7 +181,7 @@ def train(self,
self.run_name, self.seed))

# TODO: this is totally assuming the structure of model_creator :(
network = model_creator(**model_creator_args,
network = model_creator(**model_creator_kwargs,
custom_optimizer=self._horovod.DistributedOptimizer(Adam(learning_rate)),
experimental_run_tf_function=False)

Expand Down
18 changes: 16 additions & 2 deletions icenet/model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,22 @@ def get_datasets(args):

def horovod_main():
args = TrainingArgParser().add_unet().add_horovod().add_wandb().parse_args()
dataset = get_datasets()
network = HorovodNetwork()
dataset = get_datasets(args)
network = HorovodNetwork(dataset,
args.run_name,
checkpoint_mode=args.checkpoint_mode,
checkpoint_monitor=args.checkpoint_monitor,
device_type=args.device_type,
early_stopping_patience=args.early_stopping,
data_queue_size=args.max_queue_size,
lr_decay=(
args.lr_10e_decay_fac,
args.lr_decay_start,
args.lr_decay_end,
),
pre_load_path=args.preload,
seed=args.seed,
verbose=args.verbose)
execute_tf_training(args, dataset, network)


Expand Down

0 comments on commit 39ccc6c

Please sign in to comment.