Skip to content

Commit

Permalink
Eagerly accumulate embedding grads into fp32 buffer (#6958)
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Aug 2, 2023
1 parent d5d600d commit 2baef81
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,37 @@ def __init__(self, params, disable_distributed_parameters=False, **kwargs):
distopt_param_groups = param_groups
dtype = kwargs['dtype'] if 'dtype' in kwargs else torch.float32
grad_sync_dtype = kwargs['grad_sync_dtype'] if 'grad_sync_dtype' in kwargs else dtype
needs_fp32_optimizer = any(
getattr(param, '_with_fp32_optimizer', False)
for param in itertools.chain.from_iterable(param_group['params'] for param_group in param_groups)
)
if (dtype != torch.float32 or grad_sync_dtype != torch.float32) and needs_fp32_optimizer:
needs_fp32_optimizer = dtype != torch.float32 or grad_sync_dtype != torch.float32
if needs_fp32_optimizer:
needs_fp32_optimizer = any(
any(getattr(param, '_with_fp32_optimizer', False) for param in param_group['params'])
for param_group in param_groups
)
if needs_fp32_optimizer:

# Find params that require explicit FP32 optimizer
distopt_param_groups = []
fp32_param_groups = []
self._fp32_optim_main_params = collections.OrderedDict()
for param_group in param_groups:
distopt_param_group = {key: val for key, val in param_group.items() if key != 'params'}
distopt_param_group = param_group.copy()
distopt_param_group['params'] = []
fp32_param_group = {key: val for key, val in param_group.items() if key != 'params'}
fp32_param_group = param_group.copy()
fp32_param_group['params'] = []
for model_param in param_group['params']:
if getattr(model_param, '_with_fp32_optimizer', False):
main_param = model_param.detach().clone().float()
model_param.main_grad = main_param.grad
fp32_param_group['params'].append(main_param)
self._fp32_optim_main_params[model_param] = main_param
else:
distopt_param_group['params'].append(model_param)
distopt_param_groups.append(distopt_param_group)
fp32_param_groups.append(fp32_param_group)

# Add callback hook so grads accumulate into FP32 buffer
self._fp32_register_post_backward_hooks()

# Construct explicit FP32 optimizer
adamw_kwargs = {}
for name in ('lr', 'betas', 'eps', 'weight_decay', 'amsgrad'):
Expand All @@ -113,6 +119,30 @@ def __init__(self, params, disable_distributed_parameters=False, **kwargs):
# Construct distributed optimizer
super().__init__(distopt_param_groups, **kwargs)

def _fp32_register_post_backward_hooks(self):
"""Attach hooks for FP32 gradients"""

# Helper function to avoid issues with late binding closures
def make_post_backward_hook(param):
def post_backward_hook(*unused):
self._fp32_optim_grad_sync_needed = True
if hasattr(param, 'main_grad'):
with torch.no_grad():
if param.grad is not None:
param.main_grad += param.grad
param.grad = None

return post_backward_hook

# Construct hooks and register with params
self._fp32_grad_accs = []
for param in self._fp32_optim_main_params.keys():
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
hook = make_post_backward_hook(param)
grad_acc.register_hook(hook)
self._fp32_grad_accs.append(grad_acc)

def _make_post_backward_hook(self, param, param_group_id, param_id):
def hook(*unused):
if getattr(param, '_pre_forward_hook_is_enabled', False):
Expand Down

0 comments on commit 2baef81

Please sign in to comment.