Skip to content

Commit

Permalink
[Distributed] fix TypeError of multi learning_rate input.
Browse files Browse the repository at this point in the history
Signed-off-by: 泊霆 <[email protected]>
  • Loading branch information
Mesilenceki committed Apr 1, 2024
1 parent 6dae552 commit 33ade7b
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tensorflow/python/distribute/hvd_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,9 @@ def wraps_optimizer(cls):
HvdOptimizer
'''
class HvdOptimizer(cls, optimizer.Optimizer):
def __init__(self, *args, **kwargs):
kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\
HvdContext.get().world_size
super(HvdOptimizer, self).__init__(*args, **kwargs)
def __init__(self, learning_rate=0.001, *args, **kwargs):
learning_rate = learning_rate * HvdContext.get().world_size
super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs)

def compute_gradients(self, loss, **kwargs):
loss = hvd.allreduce(loss, op=hvd.Sum)
Expand Down Expand Up @@ -1449,4 +1448,4 @@ def export(export_dir_base,
as_text=as_text,
clear_devices=clear_devices,
strip_default_attrs=strip_default_attrs,
modes=[mode])
modes=[mode])

0 comments on commit 33ade7b

Please sign in to comment.