Skip to content

Commit

Permalink
Cache FP8 weight and transpose only at the first micro-batch in each …
Browse files Browse the repository at this point in the history
…validation and test routine (#7470) (#7483)

* Cache weight and transpose only in the first batch in all training, val, and test runs



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Sangkug Lym <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 11, 2023
1 parent 7e5bce4 commit d8238cf
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,10 @@ def __init__(
reduce_amax=reduce_amax,
)

self.is_first_microbatch = True
self.is_first_train_microbatch = (
True # Is the current micro-batch the first micro-batch in a global-batch in training
)
self.is_prev_microbatch_training = True # Is the previous micro-batch in training mode
self.microbatch_count = 0 # transformer engine forward needs to know if it is working on the first microbatch
self.checkpoint_core_attention = (
activations_checkpoint_granularity == 'selective'
Expand Down Expand Up @@ -1244,6 +1247,12 @@ def custom_forward(*inputs):
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
# Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch
# in training, (2) the first micro-batch in each validation and test routine.
# The caching happens in TransformerEngine when passing `is_first_microbatch=True`.
is_first_microbatch = (self.is_first_train_microbatch and self.training) or (
self.is_prev_microbatch_training and not self.training
)
for index in range(start, end):
layer = self._get_layer(index)
hidden_states = layer(
Expand All @@ -1252,7 +1261,7 @@ def custom_forward(*inputs):
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=None,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=False,
)

Expand Down Expand Up @@ -1528,14 +1537,20 @@ def forward(
else:
checkpoint_core_attention = False

# Cache FP8 weight and transpose at (1) the first micro-batch in each global-batch
# in training, (2) the first micro-batch in each validation and test routine.
# The caching happens in TransformerEngine when passing `is_first_microbatch=True`.
is_first_microbatch = (self.is_first_train_microbatch and self.training) or (
self.is_prev_microbatch_training and not self.training
)
if self.transformer_engine:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=self.inference_params,
is_first_microbatch=self.is_first_microbatch,
is_first_microbatch=is_first_microbatch,
checkpoint_core_attention=checkpoint_core_attention,
)
else:
Expand All @@ -1562,9 +1577,10 @@ def forward(
self.microbatch_count += 1
if self.microbatch_count % num_micro_batches == 0:
self.microbatch_count = 0
self.is_first_microbatch = True
self.is_first_train_microbatch = True
else:
self.is_first_microbatch = False
self.is_first_train_microbatch = False
self.is_prev_microbatch_training = self.training

output = hidden_states

Expand Down

0 comments on commit d8238cf

Please sign in to comment.