Skip to content

Commit

Permalink
Add dist ckpt support for regular optimizers (#7749) (#8293)
Browse files Browse the repository at this point in the history
* Add dist ckpt support for regular optimizers

* [tutorial] fixed missing RIR scripts file. (#8257)

* fix imports

* imports fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ci imports fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert asr notebook

* revert asr notebook

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Xuesong Yang <[email protected]>
Co-authored-by: Dmytro Pykhtar <[email protected]>
Co-authored-by: dimapihtar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Pablo Garay <[email protected]>
  • Loading branch information
7 people authored and pablo-garay committed Mar 19, 2024
1 parent e596a5d commit 597de1b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
19 changes: 17 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@
from nemo.collections.nlp.parts import utils_funcs
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import MainParamsOptimizerWrapper
from nemo.core.optim.optimizers import init_optimizer_states
from nemo.utils import AppState, logging
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank

try:
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam

HAVE_APEX = True

Expand Down Expand Up @@ -259,7 +261,7 @@ def optimizer_sharded_state_dict(self):
ValueError: If a parameter ID does not match any model sharded parameter.
"""

optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) # MainParamsOptimizerWrapper
optimizer = self.lightning_module.optimizers(use_pl_optimizer=False)

model_sharded_state_dict = self.lightning_module.sharded_state_dict()

Expand All @@ -268,8 +270,21 @@ def optimizer_sharded_state_dict(self):
key: value for key, value in model_sharded_state_dict.items() if not key.endswith('_extra_state')
}

if not isinstance(optimizer, MainParamsOptimizerWrapper):
if isinstance(optimizer, MegatronDistributedFusedAdam):
return optimizer.sharded_state_dict(model_sharded_state_dict)
elif not isinstance(optimizer, MainParamsOptimizerWrapper):
# Regular optimizer, e.g. Adam or FusedAdam
init_optimizer_states(optimizer)
optimizer_state_dict = optimizer.state_dict()
id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict,
optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups),
)
optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map)
return optimizer_state_dict

# MainParamsOptimizerWrapper
init_optimizer_states(optimizer.optimizer)

optimizer_state_dict = optimizer.state_dict()

Expand Down
13 changes: 0 additions & 13 deletions nemo/core/optim/optimizer_with_main_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def __init__(
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group)

# init exp_avg and exp_avg_sq before loading optimizer state, needed for dist checkpointing
self._init_opt_state()

# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
Expand Down Expand Up @@ -543,13 +540,3 @@ def _set_defaults(self, value):
self.optimizer.defaults = value

defaults = property(_get_defaults, _set_defaults)

def _init_opt_state(self):
"""
Initialize the optimizer state with zero tensors for 'exp_avg' and 'exp_avg_sq' of each parameter.
"""
for group in self.optimizer.param_groups:
for p in group['params']:
if len(self.optimizer.state[p]) == 0:
self.optimizer.state[p]['exp_avg'] = torch.zeros_like(p.data)
self.optimizer.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
15 changes: 15 additions & 0 deletions nemo/core/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,18 @@ def get_optimizer(name: str, **kwargs: Optional[Dict[str, Any]]) -> Optimizer:
optimizer = AVAILABLE_OPTIMIZERS[name]
optimizer = partial(optimizer, **kwargs)
return optimizer


def init_optimizer_states(optimizer: Optimizer):
adam_nondist_optims = (optim.Adam, optim.AdamW)
if HAVE_APEX:
adam_nondist_optims += (FusedAdam,)
if isinstance(optimizer, adam_nondist_optims):
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
state['exp_avg_sq'] = torch.zeros_like(p.data, memory_format=torch.preserve_format)
if group.get('amsgrad'):
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)

0 comments on commit 597de1b

Please sign in to comment.