Skip to content

Commit

Permalink
Merge branch 'r2.0.0rc1' into dpykhtar/torch_dist_as_default
Browse files Browse the repository at this point in the history
  • Loading branch information
dimapihtar authored Jul 24, 2024
2 parents 37d5d0b + 0ba1991 commit 4e4901a
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 101 deletions.
1 change: 0 additions & 1 deletion examples/llm/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_args():
checkpoint_callback = ModelCheckpoint(
every_n_train_steps=5000,
enable_nemo_ckpt_io=False,
async_save=False,
)
callbacks = [checkpoint_callback]

Expand Down
14 changes: 12 additions & 2 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(
self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True))
self.optim.connect(self) # This will bind the `configure_optimizers` method
self.model_transform = model_transform
self._training_loss_reduction = None
self._validation_loss_reduction = None

def configure_model(self) -> None:
if not hasattr(self, "module"):
Expand Down Expand Up @@ -200,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
return MaskedTokenLossReduction()
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()

return self._training_loss_reduction

@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
return MaskedTokenLossReduction(validation_step=True)
if not self._validation_loss_reduction:
self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True)

return self._validation_loss_reduction


def get_batch_on_this_context_parallel_rank(batch):
Expand Down
88 changes: 32 additions & 56 deletions nemo/core/classes/mixins/hf_io_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from abc import ABC
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter
from huggingface_hub import HfApi, ModelCard, ModelCardData
from huggingface_hub import get_token as get_hf_token
from huggingface_hub.hf_api import ModelInfo
from huggingface_hub.utils import SoftTemporaryDirectory
Expand All @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC):
"""

@classmethod
def get_hf_model_filter(cls) -> ModelFilter:
def get_hf_model_filter(cls) -> Dict[str, Any]:
"""
Generates a filter for HuggingFace models.
Additionally includes default values of some metadata about results returned by the Hub.
Additionaly includes default values of some metadata about results returned by the Hub.
Metadata:
resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False.
limit_results: Optional int, limits the number of results returned.
Returns:
A Hugging Face Hub ModelFilter object.
A dict representing the arguments passable to huggingface list_models().
"""
model_filter = ModelFilter(library='nemo')

# Attach some additional info
model_filter.resolve_card_info = False
model_filter.limit_results = None
model_filter = dict(
author=None,
library='nemo',
language=None,
model_name=None,
task=None,
tags=None,
limit=None,
full=None,
cardData=False,
)

return model_filter

@classmethod
def search_huggingface_models(
cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None
) -> List['ModelInfo']:
def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']:
"""
Should list all pre-trained models available via Hugging Face Hub.
Expand All @@ -75,16 +79,16 @@ def search_huggingface_models(
# You can replace <DomainSubclass> with any subclass of ModelPT.
from nemo.core import ModelPT
# Get default ModelFilter
# Get default filter dict
filt = <DomainSubclass>.get_hf_model_filter()
# Make any modifications to the filter as necessary
filt.language = [...]
filt.task = ...
filt.tags = [...]
filt['language'] = [...]
filt['task'] = ...
filt['tags'] = [...]
# Add any metadata to the filter as needed
filt.limit_results = 5
# Add any metadata to the filter as needed (kwargs to list_models)
filt['limit'] = 5
# Obtain model info
model_infos = <DomainSubclass>.search_huggingface_models(model_filter=filt)
Expand All @@ -96,10 +100,9 @@ def search_huggingface_models(
model = ModelPT.from_pretrained(card.modelId)
Args:
model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub)
model_filter: Optional Dictionary (for Hugging Face Hub kwargs)
that filters the returned list of compatible model cards, and selects all results from each filter.
Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model.
If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`.
Returns:
A list of ModelInfo entries.
Expand All @@ -108,23 +111,6 @@ def search_huggingface_models(
if model_filter is None:
model_filter = cls.get_hf_model_filter()

# If single model filter, wrap into list
if not isinstance(model_filter, Iterable):
model_filter = [model_filter]

# Inject `nemo` library filter
for mfilter in model_filter:
if isinstance(mfilter.library, str) and mfilter.library != 'nemo':
logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}")
mfilter.library = "nemo"

elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library:
logging.warning(
f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}"
)
mfilter.library = list(mfilter)
mfilter.library.append('nemo')

# Check if api token exists, use if it does
hf_token = get_hf_token()

Expand All @@ -134,24 +120,11 @@ def search_huggingface_models(
# Setup extra arguments for model filtering
all_results = [] # type: List[ModelInfo]

for mfilter in model_filter:
cardData = None
limit = None

if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True:
cardData = True

if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None:
limit = mfilter.limit_results

results = api.list_models(
filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit,
) # type: Iterable[ModelInfo]

for result in results:
all_results.append(result)
results = api.list_models(
token=hf_token, sort="lastModified", direction=-1, **model_filter
) # type: Iterable[ModelInfo]

return all_results
return results

def push_to_hf_hub(
self,
Expand Down Expand Up @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str,
A HuggingFace ModelCard object that can be converted to a model card string.
"""
card_data = ModelCardData(
library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True,
library_name='nemo',
tags=['pytorch', 'NeMo'],
license='cc-by-4.0',
ignore_metadata_errors=True,
)

if 'card_data' not in template_kwargs:
Expand Down
8 changes: 4 additions & 4 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ def __init__(
save_best_model: bool = False,
save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation
enable_nemo_ckpt_io: bool = True,
async_save: bool = False,
try_restore_best_ckpt: bool = True,
**kwargs,
):
self.save_best_model = save_best_model
self.previous_best_path = ""
self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
self.async_save = async_save
# Checkpoints which removal is deferred until async save is done.
# Each element of `deferred_ckpts_to_remove` is a growing list
# that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
Expand Down Expand Up @@ -221,7 +219,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()

def setup(self, *args, **kwargs) -> None:
def setup(self, trainer, *args, **kwargs) -> None:
from nemo.utils.get_rank import is_global_rank_zero

if is_global_rank_zero():
Expand All @@ -230,7 +228,9 @@ def setup(self, *args, **kwargs) -> None:
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().setup(*args, **kwargs)

self.async_save = getattr(trainer.strategy, "async_save", False)
super().setup(trainer, *args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
Expand Down
25 changes: 14 additions & 11 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
lazy_init: bool = False,
pipeline_dtype: Optional[torch.dtype] = None,
save_ckpt_format='torch_dist',
ckpt_async_save=False,
ckpt_torch_dist_multiproc=None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure=False,
ckpt_parallel_save=True,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))

self.save_ckpt_format = save_ckpt_format
self.async_save = ckpt_async_save
self.torch_dist_multiproc = ckpt_torch_dist_multiproc
self.assume_constant_structure = ckpt_assume_constant_structure
self.parallel_save = ckpt_parallel_save
Expand Down Expand Up @@ -253,6 +255,16 @@ def setup(self, trainer: pl.Trainer) -> None:
assert self.model is not None
_sync_module_states(self.model)

## add AsyncFinalizerCallback if using async
if self.async_save:
have_async_callback = False
for callback in self.trainer.callbacks:
if isinstance(callback, AsyncFinalizerCallback):
have_async_callback = True
break
if not have_async_callback:
self.trainer.callbacks.append(AsyncFinalizerCallback())

@override
def setup_distributed(self) -> None:
self._setup_parallel_ranks()
Expand Down Expand Up @@ -577,27 +589,18 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
@override
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
checkpoint_callback = self.trainer.checkpoint_callback
async_save = getattr(checkpoint_callback, "async_save", False)
self._checkpoint_io = MegatronCheckpointIO(
save_ckpt_format=self.save_ckpt_format,
async_save=async_save,
async_save=self.async_save,
torch_dist_multiproc=self.torch_dist_multiproc,
assume_constant_structure=self.assume_constant_structure,
parallel_save=self.parallel_save,
parallel_save_within_dp=self.parallel_save_within_dp,
parallel_load=self.parallel_load,
load_directly_on_device=self.load_directly_on_device,
)
if async_save:
if self.async_save:
self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)
have_async_callback = False
for callback in self.trainer.callbacks:
if isinstance(callback, AsyncFinalizerCallback):
have_async_callback = True
break
if not have_async_callback:
self.trainer.callbacks.append(AsyncFinalizerCallback())
elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
self._checkpoint_io.checkpoint_io = MegatronCheckpointIO()

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fiddle
huggingface_hub>=0.20.3,<0.24.0
huggingface_hub>=0.24
numba
numpy>=1.22
onnx>=1.7.0
Expand Down
Loading

0 comments on commit 4e4901a

Please sign in to comment.