Skip to content

Commit

Permalink
Changes to enable CUDA graph for LLM (#8955)
Browse files Browse the repository at this point in the history
* Changes to enable CUDA graph for LLM (#8751)

* Use next instead of get_batch

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* CUDA graph changes

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change to enable CG with weight caching

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Use next instead of get_batch"

This reverts commit 0021bb4.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py

Signed-off-by: Jan Baczek <[email protected]>
Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py"

This reverts commit b4f736e.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Remove skip_weight_update argument

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix + cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Use new TE API for FP8 Param transpose

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change config param cuda_graph to enable_cuda_graph

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Enable TE RNGStatesTracker through config

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Change te_rng_tracker to use_te_rng_tracker

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* FP8 weight transpose handled inside TE

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Cleanup

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Revert "Revert "Copy jbaczek/mcore_parallel_state_api_change branch leaving out changes to nemo/export/quantize/quantizer.py""

This reverts commit e318624.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Fix merge conflicts

Signed-off-by: Vasudevan Rengasamy <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Jan Baczek <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: ericharper <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Signed-off-by: ericharper <[email protected]>
Co-authored-by: vasunvidia <[email protected]>
Co-authored-by: Jaemin Choi <[email protected]>
Co-authored-by: Jan Baczek <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: ericharper <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
7 people authored and janekl committed Jun 12, 2024
1 parent 4694b07 commit 72dcde7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: torch.Tensor = None,
encoder_output: Optional[torch.Tensor] = None,
enc_dec_attn_mask: Optional[torch.Tensor] = None,
inference_params: Optional[Any] = None,
Expand All @@ -169,7 +169,7 @@ def forward(
with torch.autocast(device_type="cuda", dtype=self.dtype):
return super().forward(
hidden_states,
attention_mask,
attention_mask=attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params,
Expand Down Expand Up @@ -242,25 +242,30 @@ def __init__(self, config, layer_number=1, hidden_dropout=None):
def forward(
self,
hidden_states,
attention_mask,
is_first_microbatch=None,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
# Use is_first_microbatch argument during CUDA graph capture. Use self.is_first_microbatch otherwise.
hidden_states = super().forward(
hidden_states,
attention_mask=attention_mask,
encoder_output=context,
enc_dec_attn_mask=context_mask,
inference_params=inference_params,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch if is_first_microbatch is not None else self.is_first_microbatch,
# checkpoint_core_attention,
)
self.is_first_microbatch = False
context = None

# CUDA graph requires returned values to be Tensors
if self.config.enable_cuda_graph and self.training:
return hidden_states
return hidden_states, context

def _get_layer_offset(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False),
)

# This must be called after initialize model parallel since it needs to know the data parallel size
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def initialize_model_parallel_for_nemo(
seed=1234,
apex_transformer_log_level=30,
use_tp_pp_dp_mapping=False,
use_te_rng_tracker=False,
):

if virtual_pipeline_model_parallel_size is not None and not HAVE_INTERLEAVED:
Expand Down Expand Up @@ -128,6 +129,7 @@ def initialize_model_parallel_for_nemo(
set_pipeline_model_parallel_world_size(app_state.pipeline_model_parallel_size)
set_pipeline_model_parallel_split_rank(app_state.pipeline_model_parallel_split_rank)

tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
if seed is not None:
# @chcui not setting seed is for model conversion. always set seed for training/inference.
_set_random_seed(seed)
Expand Down
84 changes: 63 additions & 21 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def hook(*unused):
self._grad_copy(param)
if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False):
self._try_start_bucket_grad_sync(
params=[param], ignore_last_bucket=need_to_initialize,
params=[param],
ignore_last_bucket=need_to_initialize,
)

return hook
Expand Down Expand Up @@ -167,10 +168,14 @@ def init_params(
# Initialize FP8 and non-FP8 tensors separately
if any(is_float8tensor(param) for param in params):
super().init_params(
filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs,
filter(is_float8tensor, params),
param_sync_dtype=torch.uint8,
**kwargs,
)
super().init_params(
params, param_sync_dtype=param_sync_dtype, **kwargs,
params,
param_sync_dtype=param_sync_dtype,
**kwargs,
)

def init_params_bucket(
Expand Down Expand Up @@ -200,7 +205,10 @@ def init_params_bucket(
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs,
fp32_params,
grad_sync_dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
Expand All @@ -216,7 +224,10 @@ def init_params_bucket(
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs,
fp8_params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=torch.uint8,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
Expand All @@ -225,12 +236,18 @@ def init_params_bucket(
normal_buckets = []
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs,
params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]

def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None:
def add_param_to_bucket(
param: torch.nn.Parameter,
bucket: self.StateBucket,
) -> None:
"""Add trivial param fragment to bucket"""
param_fragments = self.state[param]["fragments"]
param_group_id = param_fragments[0].param_group_id
Expand Down Expand Up @@ -283,7 +300,11 @@ def _init_param_state(
# Initialize non-FP8 params as usual
if not is_float8tensor(param):
super()._init_param_state(
param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs,
param,
param_group_id,
param_id,
param_sync_dtype=param_sync_dtype,
**kwargs,
)

# Return immediately if already initialized
Expand All @@ -293,7 +314,11 @@ def _init_param_state(
# Initialize with FP32 copy of param
fp32_param = param.float()
super()._init_param_state(
fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs,
fp32_param,
param_group_id,
param_id,
param_sync_dtype=torch.uint8,
**kwargs,
)
self.state[param].update(self.state[fp32_param])
del self.state[fp32_param]
Expand Down Expand Up @@ -360,7 +385,9 @@ def init_param_buffer(self) -> None:

# Copy values into param buffer
_multi_tensor_copy(
param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf,
param_flat_views,
param_buffer_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Make all params a view into the param buffer
Expand Down Expand Up @@ -393,7 +420,10 @@ def zero_grad(self, *args, **kwargs) -> None:
param.main_grad = self.grad_buffer_view(param)

def grad_norm(
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False,
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None,
norm_type: float = 2.0,
force: bool = False,
) -> torch.Tensor:
assert norm_type == 2

Expand All @@ -411,7 +441,8 @@ def grad_norm(

# Sum over all procs to get grad norm
torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
)
self._grad_norm = grad_norm_sq.sqrt()

Expand Down Expand Up @@ -479,16 +510,16 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet

# Copy data from parameter buckets to parameters
_multi_tensor_copy(
buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf,
buffers_in,
buffers_out,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Update transpose caches
params = set(self.parameter(fragment) for fragment in fragments)
for param in params:
if is_float8tensor(param):
param._reset_caches()
param.transpose(update_cache=True)
param._lazy_transpose_cache = True

@torch.no_grad()
def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedAdam.ParameterBucket]) -> None:
Expand Down Expand Up @@ -570,25 +601,35 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA
packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf,
scales,
packed_scale_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy(
packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf,
packed_scale_views,
scale_invs,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Reduce amaxes
# Note: Assume each param has a separate amax
packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf,
amaxes,
packed_amax_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group,
packed_amaxes,
op=torch.distributed.ReduceOp.MAX,
group=self.distributed_process_group,
)
_multi_tensor_copy(
packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf,
packed_amax_views,
amaxes,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Reset
Expand All @@ -602,7 +643,8 @@ def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None
optimizer_state_dict = self.state_dict()

id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),
model_sharded_state_dict=model_sharded_state_dict,
optim_params_iter=self.parameters(),
)
# Convert state
step = optimizer_state_dict['state'].pop('step')
Expand Down

0 comments on commit 72dcde7

Please sign in to comment.