Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix circular import for MM dataprep notebook #9287

Merged
merged 2 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
# since PyTorch 2.3 the path has changed
from torch.amp.grad_scaler import _refresh_per_optimizer_state

from nemo.collections.multimodal.modules.stable_diffusion.attention import BasicTransformerBlock
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer
from nemo.collections.nlp.parts import utils_funcs
Expand Down Expand Up @@ -120,7 +119,7 @@
def init_model_parallel(
sharp: bool, nccl_communicator_config_path: str = None, distributed_timeout_minutes: int = 30
) -> None:
""" Initializes Megatron-LM model parallel if using model parallelism.
"""Initializes Megatron-LM model parallel if using model parallelism.

Args:
sharp: Apply SHARP to NCCL data-parallel communication.
Expand Down Expand Up @@ -164,7 +163,7 @@ def init_model_parallel(


class NLPDDPStrategy(DDPStrategy):
""" DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models.
"""DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models.

Args:
no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2
Expand Down Expand Up @@ -231,8 +230,8 @@ def setup_distributed(self, global_rank: int = None, world_size: int = None) ->
)

def configure_ddp(self):
""" Override LightningModule ddp if using model parallel.
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
"""Override LightningModule ddp if using model parallel.
Sets find_unused_parameters to False to use activation-checkpoint-recomputation.
"""

if (hasattr(self.model, 'megatron_amp_O2') and self.model.megatron_amp_O2) or (
Expand Down Expand Up @@ -406,7 +405,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
self.lightning_module.load_state_dict(checkpoint["state_dict"], strict=strict)

def _fix_tensors_device(self, ckpt: Dict) -> Dict:
""" Ensure checkpoint tensors are on the correct device."""
"""Ensure checkpoint tensors are on the correct device."""
assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized())
cur_dev = torch.device("cuda", index=torch.cuda.current_device())

Expand All @@ -418,10 +417,10 @@ def _fix_device(t):
return dict_list_map_outplace(_fix_device, ckpt)

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
""" PTL method which we override to integrate distributed checkpoints for model parallel models.
In order to load distributed checkpoints we need to provide the sharded_state_dict to
the distributed load function. We get the sharded_state_dict from self.lightning_module
which makes it convenient to have the loading logic happen at the strategy level.
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
In order to load distributed checkpoints we need to provide the sharded_state_dict to
the distributed load function. We get the sharded_state_dict from self.lightning_module
which makes it convenient to have the loading logic happen at the strategy level.
"""

fs = get_filesystem(checkpoint_path)
Expand Down Expand Up @@ -500,15 +499,15 @@ def distributed_sampler_kwargs(self):

@property
def restore_checkpoint_after_setup(self) -> bool:
""" This needs to be True for distributed checkpointing because
we require the model to have configured the optimizer before
deserializing the checkpoint.
"""This needs to be True for distributed checkpointing because
we require the model to have configured the optimizer before
deserializing the checkpoint.
"""
return True


class NLPDDPStrategyNotebook(NLPDDPStrategy):
""" Version of NLPDDPStrategy to be used in a Jupyter Notebook
"""Version of NLPDDPStrategy to be used in a Jupyter Notebook
A large portion of Megatron code has DDP dependency, so it has been necessary to use NLPDDPStrategy even for
single-GPU training (e.g. in a Jupyter notebook)
A PTL 2.0 changes has prevented DDPStrategy to be used in a notebook.
Expand Down Expand Up @@ -546,7 +545,7 @@ def _get_full_state_dict_context(module: torch.nn.Module, rank0_only: bool = Fal


class NLPFSDPStrategy(FSDPStrategy):
""" FSDP plugin for Pytorch Lightning with the support for tensor-parallelism.
"""FSDP plugin for Pytorch Lightning with the support for tensor-parallelism.

Args:
sharding_strategy: FSDP parameter sharding strategy.
Expand Down Expand Up @@ -583,6 +582,9 @@ def __init__(
# Use the default FSDP backward-prefetch policy for proper communication overlap.
kwargs['backward_prefetch'] = BackwardPrefetch.BACKWARD_PRE

# import here to prevent circular imports
from nemo.collections.multimodal.modules.stable_diffusion.attention import BasicTransformerBlock

# Set FSDP wrapping policy: use Transformer layer module as the FSDP sharding granularity.
self.fsdp_wrap_module = {
MCoreTransformerLayer,
Expand Down Expand Up @@ -639,7 +641,11 @@ def _set_mixed_precision_recipe(
reduce_dtype = utils_funcs.torch_dtype_from_precision(grad_reduce_dtype, None)
if set_buffer_dtype is not None:
buffer_dtype = utils_funcs.torch_dtype_from_precision(buffer_dtype, None)
return MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype,)
return MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)

def setup_environment(self) -> None:
"""
Expand Down Expand Up @@ -750,15 +756,19 @@ def _get_osd(opt_state):
with FSDP.summon_full_params(self.model, writeback=True, rank0_only=False):
# rekey the osd stored from non-FSDP model
rekeyed_osd = FSDP.rekey_optim_state_dict(
temp_osd, OptimStateKeyType.PARAM_NAME, self.model,
temp_osd,
OptimStateKeyType.PARAM_NAME,
self.model,
)
temp_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, self.model)
except Exception as e:
print(f"Failed to load optimzier state dicts. Errored with {e}")
exit(1)
# Shard optimizer state dict
sharded_osd = FSDP.optim_state_dict_to_load(
optim_state_dict=temp_osd, model=self.model, optim=optimizer,
optim_state_dict=temp_osd,
model=self.model,
optim=optimizer,
)

optimizer.load_state_dict(sharded_osd)
Expand All @@ -767,9 +777,9 @@ def _get_osd(opt_state):
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
""" Store checkpoints
1. In case of sharded checkpoint, all ranks store unique checkpoints.
2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints.
"""Store checkpoints
1. In case of sharded checkpoint, all ranks store unique checkpoints.
2. In case of non-sharded checkpoint, all data-parallel rank 0 store checkpoints.
"""
app_state = AppState()
filepath = inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=self.sharded_checkpoint)
Expand All @@ -780,8 +790,7 @@ def save_checkpoint(
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
""" Load checkpoints
"""
"""Load checkpoints"""
# 1. Load normal or FSDP-sharded checkpoints.
fs = get_filesystem(checkpoint_path)

Expand All @@ -798,8 +807,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
return checkpoint

def remove_checkpoint(self, filepath: Union[str, Path]) -> None:
""" Remove checkpoints
"""
"""Remove checkpoints"""
# legacy checkpoint logic, does not use megatron core
app_state = AppState()
# PTL override to accomodate model parallel checkpoints
Expand All @@ -814,9 +822,9 @@ def remove_checkpoint(self, filepath: Union[str, Path]) -> None:

@property
def restore_checkpoint_after_setup(self) -> bool:
""" When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring
FSDP sharding to match FSDP-sharded format between the checkpoint and the current
model and optimizer.
"""When loading FSDP-sharded checkpoint, need to restore checkpoint after configuring
FSDP sharding to match FSDP-sharded format between the checkpoint and the current
model and optimizer.
"""
return True

Expand Down Expand Up @@ -915,7 +923,8 @@ def dummy():
else:
# move weights to the tmpdir
for tp_rank, pp_rank in itertools.product(
range(app_state.tensor_model_parallel_size), range(app_state.pipeline_model_parallel_size),
range(app_state.tensor_model_parallel_size),
range(app_state.pipeline_model_parallel_size),
):
os.makedirs(os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}'))
mp_model_weights = os.path.join(
Expand Down Expand Up @@ -1000,6 +1009,7 @@ def modify_state_dict(self, conf, state_dict):
loaded_keys = state_dict.keys()
if 'model.model.diffusion_model.input_blocks.1.0.in_layers.2.weight' in loaded_keys:
new_state_dict = {}

# GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following
def should_process(key):
base_str = "model.model.diffusion_model."
Expand Down Expand Up @@ -1110,7 +1120,13 @@ def restore_from(
# Get path where the command is executed - the artifacts will be "retrieved" there
# (original .nemo behavior)
loaded_params = super().load_config_and_state_dict(
calling_cls, restore_path, override_config_path, map_location, strict, return_config, trainer,
calling_cls,
restore_path,
override_config_path,
map_location,
strict,
return_config,
trainer,
)
if not isinstance(loaded_params, tuple) or return_config is True:
return loaded_params
Expand Down Expand Up @@ -1165,12 +1181,12 @@ def dummy():


class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.
"""Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.

We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
"""

def __init__(
Expand Down Expand Up @@ -1206,12 +1222,12 @@ def forward_context(self) -> Generator[None, None, None]:


class FSDPMixedPrecisionPlugin(FSDPPrecision):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.
"""Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.

We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
"""

def __init__(
Expand Down Expand Up @@ -1246,7 +1262,7 @@ class GradScaler(torch.cuda.amp.GradScaler):

def __init__(
self,
init_scale=2.0 ** 16,
init_scale=2.0**16,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000,
Expand Down Expand Up @@ -1500,15 +1516,15 @@ def optimizer_step(

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
""" No explicit precision casting. Inputs are supposed to be manually casted """
"""No explicit precision casting. Inputs are supposed to be manually casted"""
try:
yield
finally:
pass


class GlobalBatchDataFetcher(_DataFetcher):
""" Overrides PTL DataFetcher. Used to fetch global batches."""
"""Overrides PTL DataFetcher. Used to fetch global batches."""

def __init__(self, prefetch_batches: int = 0, store_on_device: bool = False) -> None:

Expand Down
21 changes: 6 additions & 15 deletions tutorials/multimodal/Multimodal Data Preparation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"\n",
"This notebook will show you how to prepare an image-text dataset into the [WebDataset](https://github.com/webdataset/webdataset) format. The Webdataset format is required to train all multimodal models in NeMo, such as Stable Diffusion and Imagen. \n",
"\n",
"This notebook is designed to demonstrate the different stages of multimodal dataset preparation. It is not meant to be used to process large-scale datasets since many stages are too time-consuming to run without parallelism. For large workloads, we recommend running the multimodal dataset preparation pipeline with the NeMo-Megatron-Launcher on multiple processors/GPUs. NeMo-Megatron-Launcher packs the same 5 scripts in this notebook into one runnable command and one config file to enable a smooth and a streamlined workflow.\n",
"This notebook is designed to demonstrate the different stages of multimodal dataset preparation. It is not meant to be used to process large-scale datasets since many stages are too time-consuming to run without parallelism. For large workloads, we recommend running the multimodal dataset preparation pipeline with the NeMo-Framework-Launcher on multiple processors/GPUs. NeMo-Framework-Launcher packs the same 5 scripts in this notebook into one runnable command and one config file to enable a smooth and a streamlined workflow.\n",
"\n",
"Depending on your use case, not all 5 stages need to be run. Please go to [NeMo Multimodal Documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/multimodal/text2img/datasets.html) for an overview of the 5 stages.\n",
" \n",
Expand All @@ -85,7 +85,7 @@
"source": [
"import os\n",
"\n",
"LAUNCHER_DIR = \"/opt/NeMo-Megatron-Launcher\"\n",
"LAUNCHER_DIR = \"/opt/NeMo-Framework-Launcher\" # formerly NeMo-Megatron-Launcher\n",
"SCRIPT_DIR = os.path.join(LAUNCHER_DIR, \"launcher_scripts/nemo_launcher/collections/dataprep_scripts/multimodal_dataprep\")\n",
"CONF_DIR = \"conf\"\n",
"DATA_DIR = \"dummy_data\"\n",
Expand Down Expand Up @@ -168,7 +168,7 @@
"\n",
"Script: download_images.py\n",
"\n",
"Environment variables (automatically set by SLURM if running with NeMo-Megatron-Launcher):\n",
"Environment variables (automatically set by SLURM if running with NeMo-Framework-Launcher):\n",
"- `SLURM_ARRAY_TASK_COUNT`: total number of tasks, should be set to the number of parquet files in `$DATA_DIR/parquet/dummy_dataset50000.parquet_parts`. (i.e. `parquet_subpartitions` x `num_parquets_downloaded`)\n",
"- `SLURM_ARRAY_TASK_ID`: id of the current task (0 <= SLURM_ARRAY_TASK_ID < SLURM_ARRAY_TASK_COUNT)\n",
"\n",
Expand Down Expand Up @@ -266,7 +266,7 @@
"\n",
"Script: reorganize_tar.py\n",
"\n",
"Environment variables (automatically set by SLURM if running with NeMo-Megatron-Launcher):\n",
"Environment variables (automatically set by SLURM if running with NeMo-Framework-Launcher):\n",
"- `SLURM_ARRAY_TASK_COUNT`: total number of tasks, should be set to parquet_subpartitions x num_parquets_downloaded\n",
"- `SLURM_ARRAY_TASK_ID`: id of the current task (0 <= `SLURM_ARRAY_TASK_ID` < `SLURM_ARRAY_TASK_COUNT`)\n",
"\n",
Expand Down Expand Up @@ -430,7 +430,7 @@
},
"outputs": [],
"source": [
"! wget https://raw.githubusercontent.com/NVIDIA/NeMo-Megatron-Launcher/master/launcher_scripts/conf/data_preparation/multimodal/precache_sd.yaml -P $CONF_DIR/"
"! wget https://raw.githubusercontent.com/NVIDIA/NeMo-Framework-Launcher/master/launcher_scripts/conf/data_preparation/multimodal/precache_sd.yaml -P $CONF_DIR/"
]
},
{
Expand Down Expand Up @@ -506,7 +506,7 @@
"\n",
"Script: precache_encodings.py\n",
"\n",
"Environment variables (automatically set by SLURM if running with NeMo-Megatron-Launcher):\n",
"Environment variables (automatically set by SLURM if running with NeMo-Framework-Launcher):\n",
"- `SLURM_ARRAY_TASK_COUNT`: total number of tasks, should be set to parquet_subpartitions x num_parquets_downloaded\n",
"- `SLURM_ARRAY_TASK_ID`: id of the current task (0 <= `SLURM_ARRAY_TASK_ID` < `SLURM_ARRAY_TASK_COUNT`)\n",
"\n",
Expand All @@ -533,15 +533,6 @@
" precache_config_path=$CONF_DIR/precache_sd_example.yaml"
]
},
{
"cell_type": "markdown",
"source": [
"If you encounter a nemo import problem with the cell above, please also running it in the terminal directly."
],
"metadata": {
"collapsed": false
}
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
Loading