From b901138416c01ae5c5a5cafb2dace55bf0e07b2a Mon Sep 17 00:00:00 2001 From: Anna Shors <71393111+ashors1@users.noreply.github.com> Date: Tue, 23 Jul 2024 10:21:57 -0700 Subject: [PATCH 1/3] [NeMo-UX] Set async_save from strategy rather than ModelCheckpoint (#9800) * set async_save from strategy to make checkpoint_io more robust Signed-off-by: ashors1 * fix 2.0 test Signed-off-by: ashors1 --------- Signed-off-by: ashors1 --- examples/llm/megatron_gpt_pretraining.py | 1 - .../pytorch/callbacks/model_checkpoint.py | 8 +++--- nemo/lightning/pytorch/strategies.py | 25 +++++++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py index 180561d03bac..d3d049e4296e 100644 --- a/examples/llm/megatron_gpt_pretraining.py +++ b/examples/llm/megatron_gpt_pretraining.py @@ -65,7 +65,6 @@ def get_args(): checkpoint_callback = ModelCheckpoint( every_n_train_steps=5000, enable_nemo_ckpt_io=False, - async_save=False, ) callbacks = [checkpoint_callback] diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index ed8ac25185f3..eee3850dfb37 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -51,14 +51,12 @@ def __init__( save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation enable_nemo_ckpt_io: bool = True, - async_save: bool = False, try_restore_best_ckpt: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" self.enable_nemo_ckpt_io = enable_nemo_ckpt_io - self.async_save = async_save # Checkpoints which removal is deferred until async save is done. # Each element of `deferred_ckpts_to_remove` is a growing list # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint` @@ -221,7 +219,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) self._remove_invalid_entries_from_topk() - def setup(self, *args, **kwargs) -> None: + def setup(self, trainer, *args, **kwargs) -> None: from nemo.utils.get_rank import is_global_rank_zero if is_global_rank_zero(): @@ -230,7 +228,9 @@ def setup(self, *args, **kwargs) -> None: # Ensure that all ranks continue with unfinished checkpoints removed if torch.distributed.is_initialized(): torch.distributed.barrier() - super().setup(*args, **kwargs) + + self.async_save = getattr(trainer.strategy, "async_save", False) + super().setup(trainer, *args, **kwargs) def on_save_checkpoint(self, trainer, pl_module, checkpoint): output = super().on_save_checkpoint(trainer, pl_module, checkpoint) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 2219324f6b67..9adfb7801f2f 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -105,6 +105,7 @@ def __init__( lazy_init: bool = False, pipeline_dtype: Optional[torch.dtype] = None, save_ckpt_format='torch_dist', + ckpt_async_save=False, ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere? ckpt_assume_constant_structure=False, ckpt_parallel_save=True, @@ -142,6 +143,7 @@ def __init__( self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.save_ckpt_format = save_ckpt_format + self.async_save = ckpt_async_save self.torch_dist_multiproc = ckpt_torch_dist_multiproc self.assume_constant_structure = ckpt_assume_constant_structure self.parallel_save = ckpt_parallel_save @@ -253,6 +255,16 @@ def setup(self, trainer: pl.Trainer) -> None: assert self.model is not None _sync_module_states(self.model) + ## add AsyncFinalizerCallback if using async + if self.async_save: + have_async_callback = False + for callback in self.trainer.callbacks: + if isinstance(callback, AsyncFinalizerCallback): + have_async_callback = True + break + if not have_async_callback: + self.trainer.callbacks.append(AsyncFinalizerCallback()) + @override def setup_distributed(self) -> None: self._setup_parallel_ranks() @@ -577,11 +589,9 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr @override def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: - checkpoint_callback = self.trainer.checkpoint_callback - async_save = getattr(checkpoint_callback, "async_save", False) self._checkpoint_io = MegatronCheckpointIO( save_ckpt_format=self.save_ckpt_format, - async_save=async_save, + async_save=self.async_save, torch_dist_multiproc=self.torch_dist_multiproc, assume_constant_structure=self.assume_constant_structure, parallel_save=self.parallel_save, @@ -589,15 +599,8 @@ def checkpoint_io(self) -> CheckpointIO: parallel_load=self.parallel_load, load_directly_on_device=self.load_directly_on_device, ) - if async_save: + if self.async_save: self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io) - have_async_callback = False - for callback in self.trainer.callbacks: - if isinstance(callback, AsyncFinalizerCallback): - have_async_callback = True - break - if not have_async_callback: - self.trainer.callbacks.append(AsyncFinalizerCallback()) elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = MegatronCheckpointIO() From 6ff5bce31eecefa42d49b2c81fc57b1e1533fa7f Mon Sep 17 00:00:00 2001 From: Somshubra Majumdar Date: Tue, 23 Jul 2024 19:34:42 -0400 Subject: [PATCH 2/3] Fix hf hub for 0.24+ (#9806) * Update Huggingface Hub support Signed-off-by: smajumdar * Update hf hub Signed-off-by: smajumdar * Update hf hub Signed-off-by: smajumdar * Apply isort and black reformatting Signed-off-by: titu1994 --------- Signed-off-by: smajumdar Signed-off-by: Somshubra Majumdar Signed-off-by: titu1994 --- nemo/core/classes/mixins/hf_io_mixin.py | 88 +++++++++---------------- requirements/requirements.txt | 2 +- tests/core/test_save_restore.py | 73 ++++++++++++-------- 3 files changed, 80 insertions(+), 83 deletions(-) diff --git a/nemo/core/classes/mixins/hf_io_mixin.py b/nemo/core/classes/mixins/hf_io_mixin.py index b101cbabe749..543d6c6fccda 100644 --- a/nemo/core/classes/mixins/hf_io_mixin.py +++ b/nemo/core/classes/mixins/hf_io_mixin.py @@ -14,9 +14,9 @@ from abc import ABC from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union -from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter +from huggingface_hub import HfApi, ModelCard, ModelCardData from huggingface_hub import get_token as get_hf_token from huggingface_hub.hf_api import ModelInfo from huggingface_hub.utils import SoftTemporaryDirectory @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC): """ @classmethod - def get_hf_model_filter(cls) -> ModelFilter: + def get_hf_model_filter(cls) -> Dict[str, Any]: """ Generates a filter for HuggingFace models. - Additionally includes default values of some metadata about results returned by the Hub. + Additionaly includes default values of some metadata about results returned by the Hub. Metadata: resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False. limit_results: Optional int, limits the number of results returned. Returns: - A Hugging Face Hub ModelFilter object. + A dict representing the arguments passable to huggingface list_models(). """ - model_filter = ModelFilter(library='nemo') - - # Attach some additional info - model_filter.resolve_card_info = False - model_filter.limit_results = None + model_filter = dict( + author=None, + library='nemo', + language=None, + model_name=None, + task=None, + tags=None, + limit=None, + full=None, + cardData=False, + ) return model_filter @classmethod - def search_huggingface_models( - cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None - ) -> List['ModelInfo']: + def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']: """ Should list all pre-trained models available via Hugging Face Hub. @@ -75,16 +79,16 @@ def search_huggingface_models( # You can replace with any subclass of ModelPT. from nemo.core import ModelPT - # Get default ModelFilter + # Get default filter dict filt = .get_hf_model_filter() # Make any modifications to the filter as necessary - filt.language = [...] - filt.task = ... - filt.tags = [...] + filt['language'] = [...] + filt['task'] = ... + filt['tags'] = [...] - # Add any metadata to the filter as needed - filt.limit_results = 5 + # Add any metadata to the filter as needed (kwargs to list_models) + filt['limit'] = 5 # Obtain model info model_infos = .search_huggingface_models(model_filter=filt) @@ -96,10 +100,9 @@ def search_huggingface_models( model = ModelPT.from_pretrained(card.modelId) Args: - model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub) + model_filter: Optional Dictionary (for Hugging Face Hub kwargs) that filters the returned list of compatible model cards, and selects all results from each filter. Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model. - If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`. Returns: A list of ModelInfo entries. @@ -108,23 +111,6 @@ def search_huggingface_models( if model_filter is None: model_filter = cls.get_hf_model_filter() - # If single model filter, wrap into list - if not isinstance(model_filter, Iterable): - model_filter = [model_filter] - - # Inject `nemo` library filter - for mfilter in model_filter: - if isinstance(mfilter.library, str) and mfilter.library != 'nemo': - logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}") - mfilter.library = "nemo" - - elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library: - logging.warning( - f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}" - ) - mfilter.library = list(mfilter) - mfilter.library.append('nemo') - # Check if api token exists, use if it does hf_token = get_hf_token() @@ -134,24 +120,11 @@ def search_huggingface_models( # Setup extra arguments for model filtering all_results = [] # type: List[ModelInfo] - for mfilter in model_filter: - cardData = None - limit = None - - if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True: - cardData = True - - if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None: - limit = mfilter.limit_results - - results = api.list_models( - filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit, - ) # type: Iterable[ModelInfo] - - for result in results: - all_results.append(result) + results = api.list_models( + token=hf_token, sort="lastModified", direction=-1, **model_filter + ) # type: Iterable[ModelInfo] - return all_results + return results def push_to_hf_hub( self, @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str, A HuggingFace ModelCard object that can be converted to a model card string. """ card_data = ModelCardData( - library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True, + library_name='nemo', + tags=['pytorch', 'NeMo'], + license='cc-by-4.0', + ignore_metadata_errors=True, ) if 'card_data' not in template_kwargs: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7706aa58b267..3169d31dbeed 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ fiddle -huggingface_hub>=0.20.3,<0.24.0 +huggingface_hub>=0.24 numba numpy>=1.22 onnx>=1.7.0 diff --git a/tests/core/test_save_restore.py b/tests/core/test_save_restore.py index 740c75f20f47..e117d54982fb 100644 --- a/tests/core/test_save_restore.py +++ b/tests/core/test_save_restore.py @@ -19,7 +19,6 @@ import pytest import torch -from huggingface_hub.hf_api import ModelFilter from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.models import EncDecCTCModel, EncDecCTCModelBPE @@ -126,11 +125,15 @@ def __init__(self, cfg, trainer=None): self.child1_model: Optional[MockModel] # annotate type for IDE autocompletion and type checking if cfg.get("child1_model") is not None: self.register_nemo_submodule( - "child1_model", config_field="child1_model", model=MockModel(self.cfg.child1_model), + "child1_model", + config_field="child1_model", + model=MockModel(self.cfg.child1_model), ) elif cfg.get("child1_model_path") is not None: self.register_nemo_submodule( - "child1_model", config_field="child1_model", model=MockModel.restore_from(self.cfg.child1_model_path), + "child1_model", + config_field="child1_model", + model=MockModel.restore_from(self.cfg.child1_model_path), ) else: self.child1_model = None @@ -140,7 +143,9 @@ def __init__(self, cfg, trainer=None): self.child2_model: Optional[MockModelWithChildren] # annotate type for IDE autocompletion and type checking if cfg.get("child2_model") is not None: self.register_nemo_submodule( - "child2_model", config_field="child2_model", model=MockModelWithChildren(self.cfg.child2_model), + "child2_model", + config_field="child2_model", + model=MockModelWithChildren(self.cfg.child2_model), ) elif cfg.get("child2_model_path") is not None: self.register_nemo_submodule( @@ -169,7 +174,9 @@ def __init__(self, cfg, trainer=None): if cfg.get("ctc_model", None) is not None: self.register_nemo_submodule( - "ctc_model", config_field="ctc_model", model=EncDecCTCModelBPE(self.cfg.ctc_model), + "ctc_model", + config_field="ctc_model", + model=EncDecCTCModelBPE(self.cfg.ctc_model), ) else: # model is mandatory @@ -196,7 +203,9 @@ def __init__(self, cfg, trainer=None): self.child1_model: Optional[MockModel] # annotate type for IDE autocompletion and type checking if cfg.get("child1_model_config") is not None: self.register_nemo_submodule( - "child1_model", config_field="child1_model_config", model=MockModel(self.cfg.child1_model_config), + "child1_model", + config_field="child1_model_config", + model=MockModel(self.cfg.child1_model_config), ) else: self.child1_model = None @@ -900,11 +909,12 @@ def test_mock_model_nested_with_resources(self, change_child_resource: bool, chi child2_model_from_path: if child2_model_from_path is True, child2 model is restored from .nemo checkpoint, otherwise constructed directly from config. Child1 model always loaded from checkpoint. """ - with tempfile.NamedTemporaryFile('w') as file_child1, tempfile.NamedTemporaryFile( - 'w' - ) as file_child2, tempfile.NamedTemporaryFile('w') as file_child2_other, tempfile.NamedTemporaryFile( - 'w' - ) as file_parent: + with ( + tempfile.NamedTemporaryFile('w') as file_child1, + tempfile.NamedTemporaryFile('w') as file_child2, + tempfile.NamedTemporaryFile('w') as file_child2_other, + tempfile.NamedTemporaryFile('w') as file_parent, + ): # write text data, use these files as resources parent_data = ["*****\n"] child1_data = ["+++++\n"] @@ -988,11 +998,12 @@ def test_mock_model_nested_with_resources_multiple_passes(self): Test nested model with 2 children: multiple save-restore passes child models and parent model itself contain resources """ - with tempfile.NamedTemporaryFile('w') as file_child1, tempfile.NamedTemporaryFile( - 'w' - ) as file_child2, tempfile.NamedTemporaryFile('w') as file_child2_other, tempfile.NamedTemporaryFile( - 'w' - ) as file_parent: + with ( + tempfile.NamedTemporaryFile('w') as file_child1, + tempfile.NamedTemporaryFile('w') as file_child2, + tempfile.NamedTemporaryFile('w') as file_child2_other, + tempfile.NamedTemporaryFile('w') as file_parent, + ): # write text data, use these files as resources parent_data = ["*****\n"] child1_data = ["+++++\n"] @@ -1019,7 +1030,12 @@ def test_mock_model_nested_with_resources_multiple_passes(self): child2 = MockModelWithChildren(cfg=cfg_child2.model, trainer=None) child2 = child2.to('cpu') - with tempfile.TemporaryDirectory() as tmpdir_parent1, tempfile.TemporaryDirectory() as tmpdir_parent2, tempfile.TemporaryDirectory() as tmpdir_parent3, tempfile.TemporaryDirectory() as tmpdir_parent4: + with ( + tempfile.TemporaryDirectory() as tmpdir_parent1, + tempfile.TemporaryDirectory() as tmpdir_parent2, + tempfile.TemporaryDirectory() as tmpdir_parent3, + tempfile.TemporaryDirectory() as tmpdir_parent4, + ): parent_path1 = os.path.join(tmpdir_parent1, "parent.nemo") parent_path2 = os.path.join(tmpdir_parent2, "parent.nemo") with tempfile.TemporaryDirectory() as tmpdir_child: @@ -1074,9 +1090,11 @@ def test_mock_model_nested_double_with_resources(self): test nested model: parent -> child_with_child -> child; model and each child can be saved/restored separately all models can contain resources """ - with tempfile.NamedTemporaryFile('w') as file_child, tempfile.NamedTemporaryFile( - 'w' - ) as file_child_with_child, tempfile.NamedTemporaryFile('w') as file_parent: + with ( + tempfile.NamedTemporaryFile('w') as file_child, + tempfile.NamedTemporaryFile('w') as file_child_with_child, + tempfile.NamedTemporaryFile('w') as file_parent, + ): # write text data, use these files as resources parent_data = ["*****\n"] child_with_child_data = ["+++++\n"] @@ -1302,8 +1320,8 @@ class MockModelV2(MockModel): @pytest.mark.unit def test_hf_model_filter(self): filt = ModelPT.get_hf_model_filter() - assert isinstance(filt, ModelFilter) - assert filt.library == 'nemo' + assert isinstance(filt, dict) + assert filt['library'] == 'nemo' @pytest.mark.with_downloads() @pytest.mark.unit @@ -1312,10 +1330,12 @@ def test_hf_model_info(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=None) + model_infos = [next(model_infos) for _ in range(5)] assert len(model_infos) > 0 # check with default override results (should match above) default_model_infos = ModelPT.search_huggingface_models(model_filter=filt) + default_model_infos = [next(default_model_infos) for _ in range(5)] assert len(model_infos) == len(default_model_infos) @pytest.mark.pleasefixme() @@ -1326,13 +1346,12 @@ def test_hf_model_info_with_card_data(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=filt) + model_infos = [next(model_infos) for _ in range(5)] assert len(model_infos) > 0 - assert not hasattr(model_infos[0], 'cardData') # check overriden defaults - filt.resolve_card_info = True + filt['cardData'] = True model_infos = ModelPT.search_huggingface_models(model_filter=filt) - assert len(model_infos) > 0 for info in model_infos: if hasattr(info, 'cardData'): @@ -1346,10 +1365,12 @@ def test_hf_model_info_with_limited_results(self): # check no override results model_infos = ModelPT.search_huggingface_models(model_filter=filt) + model_infos = [next(model_infos) for _ in range(6)] assert len(model_infos) > 0 # check overriden defaults - filt.limit_results = 5 + filt['limit'] = 5 new_model_infos = ModelPT.search_huggingface_models(model_filter=filt) + new_model_infos = list(new_model_infos) assert len(new_model_infos) <= 5 assert len(new_model_infos) < len(model_infos) From 0ba1991e7a080c339774948f739146206e2f2f2b Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 24 Jul 2024 00:35:43 -0700 Subject: [PATCH 3/3] [NeMo-UX] Use single instance of loss reductions in GPTModel (#9801) * Use single instance of loss reductions Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * Refactor Signed-off-by: Hemil Desai --------- Signed-off-by: Hemil Desai Signed-off-by: hemildesai Co-authored-by: hemildesai --- nemo/collections/llm/gpt/model/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 0e4fabe020af..a8339e124564 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -160,6 +160,8 @@ def __init__( self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method self.model_transform = model_transform + self._training_loss_reduction = None + self._validation_loss_reduction = None def configure_model(self) -> None: if not hasattr(self, "module"): @@ -200,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction() + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + return self._training_loss_reduction + + @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction(validation_step=True) + if not self._validation_loss_reduction: + self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) + + return self._validation_loss_reduction def get_batch_on_this_context_parallel_rank(batch):