Skip to content

Commit

Permalink
Changes to enable CUDA graph for LLM (#8751)
Browse files Browse the repository at this point in the history
* Use next instead of get_batch

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* CUDA graph changes

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change to enable CG with weight caching

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Use next instead of get_batch"

This reverts commit 0021bb4.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py

Signed-off-by: Jan Baczek <[email protected]>
Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py"

This reverts commit b4f736e.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Remove skip_weight_update argument

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix + cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Use new TE API for FP8 Param transpose

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change config param cuda_graph to enable_cuda_graph

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Enable TE RNGStatesTracker through config

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change te_rng_tracker to use_te_rng_tracker

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* FP8 weight transpose handled inside TE

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Revert "Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py""

This reverts commit e318624.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Jan Baczek <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored and web-flow committed Apr 17, 2024
1 parent 5b296e8 commit 126e27a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: torch.Tensor = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None,
Expand All @@ -169,7 +169,7 @@ def forward(
with torch.autocast(device_type="cuda", dtype=self.dtype):
return super().forward(
hidden_states,
attention_mask,
attention_mask=attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
Expand Down Expand Up @@ -242,25 +242,30 @@ def __init__(self, config, layer_number=1, hidden_dropout=None):
def forward(
self,
hidden_states,
attention_mask,
is_first_microbatch=None,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
# Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise.
hidden_states = super().forward(
hidden_states,
attention_mask=attention_mask,
encoder_output=context,
enc_dec_attn_mask=context_mask,
inference_params=inference_params,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch,
# checkpoint_core_attention,
)
self.is_first_microbatch = False
context = None

# CUDA graph requires returned values to be Tensors
if self.config.enable_cuda_graph and self.training:
return hidden_states
return hidden_states, context

def _get_layer_offset(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False),
)

# This must be called after initialize model parallel since it needs to know the data parallel size
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def initialize_model_parallel_for_nemo(
seed=1234,
apex_transformer_log_level=30,
use_tp_pp_dp_mapping=False,
use_te_rng_tracker=False,
):

if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED:
Expand Down Expand Up @@ -128,6 +129,7 @@ def initialize_model_parallel_for_nemo(
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
# @chcui not setting seed is for model conversion. always set seed for training/inference.
_set_random_seed(seed)
Expand Down
2 changes: 0 additions & 2 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,6 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet
for param in params:
if is_float8tensor(param):
param._reset_caches()
param.transpose(update_cache=True)
param._lazy_transpose_cache = True

@torch.no_grad()
def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None:
Expand Down

0 comments on commit 126e27a

Please sign in to comment.