diff --git a/tensorflow/python/distribute/hvd_strategy.py b/tensorflow/python/distribute/hvd_strategy.py index 8a3ae9c3f43..977ac4a4bea 100644 --- a/tensorflow/python/distribute/hvd_strategy.py +++ b/tensorflow/python/distribute/hvd_strategy.py @@ -388,10 +388,9 @@ def wraps_optimizer(cls): HvdOptimizer ''' class HvdOptimizer(cls, optimizer.Optimizer): - def __init__(self, *args, **kwargs): - kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\ - HvdContext.get().world_size - super(HvdOptimizer, self).__init__(*args, **kwargs) + def __init__(self, learning_rate=0.001, *args, **kwargs): + learning_rate = learning_rate * HvdContext.get().world_size + super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs) def compute_gradients(self, loss, **kwargs): loss = hvd.allreduce(loss, op=hvd.Sum) @@ -1449,4 +1448,4 @@ def export(export_dir_base, as_text=as_text, clear_devices=clear_devices, strip_default_attrs=strip_default_attrs, - modes=[mode]) \ No newline at end of file + modes=[mode])