Skip to content

Commit

Permalink
Merge branch 'canary2-optimizations' into canary2
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Nov 26, 2024
2 parents 2fa9bed + 6c39efc commit f8f4964
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 23 deletions.
8 changes: 7 additions & 1 deletion nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@

def lens_to_mask(lens, max_length):
batch_size = lens.shape[0]
mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None]
arange = torch.arange(max_length, device=lens.device)
mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1)
return mask


Expand Down Expand Up @@ -680,6 +681,11 @@ def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb):
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float)
tot_tokens = torch.as_tensor(batch.prompted_transcript.numel(), device=num_frames.device, dtype=torch.float)

num_frames = batch.audio_lens.sum().float()
num_tokens = batch.prompted_transcript_lens.sum().float()
tot_frames = torch.as_tensor(batch.audio.numel(), device=num_frames.device, dtype=torch.float)
tot_tokens = torch.as_tensor(batch.prompted_transcript.numel(), device=num_frames.device, dtype=torch.float)

transf_log_probs, encoded_len, enc_states, enc_mask = self.forward(
input_signal=batch.audio,
input_signal_length=batch.audio_lens,
Expand Down
20 changes: 0 additions & 20 deletions nemo/collections/asr/modules/transformer/transformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,7 @@ def _build_pos_enc(self, hidden_size, max_sequence_length, device=None):
self.register_buffer('pos_enc', pos_enc)

def forward(self, position_ids):
max_pos_id = position_ids.max()
# update positional encoding if needed
if max_pos_id >= self._max_sequence_length:
logging.warning(
f'Max position id {max_pos_id} is greater than max sequence length {self._max_sequence_length}. Expanding position embeddings just for this batch. This is not expected to work very well. Consider chunking your input into smaller sequences.'
)
self._build_pos_enc(
hidden_size=self._hidden_size,
max_sequence_length=max_pos_id + 1,
device=position_ids.device,
)

embeddings = torch.embedding(self.pos_enc, position_ids)

# Revert expansion of position embeddings since this wall checkpoint size mismatches.
if max_pos_id >= self._max_sequence_length:
self._build_pos_enc(
hidden_size=self._hidden_size,
max_sequence_length=self._max_sequence_length,
device=position_ids.device,
)
return embeddings


Expand Down
8 changes: 6 additions & 2 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def _on_batch_end(self, name, pl_module):
# Set the `batch_size=1` as WAR for `dataloader_iter`, which is not used for any metric
pl_module.log(
name + ' in s',
self.timer[name],
torch.as_tensor(self.timer[name]),
on_step=True,
on_epoch=False,
batch_size=1,
Expand Down Expand Up @@ -1171,7 +1171,11 @@ def configure_checkpointing(
params.prefix = name
if params.always_save_nemo:
app_state = AppState()
if (app_state.tensor_model_parallel_size is not None and app_state.tensor_model_parallel_size > 1) or (app_state.pipeline_model_parallel_size is not None and app_state.pipeline_model_parallel_size > 1) or (app_state.context_parallel_size is not None and app_state.context_parallel_size > 1):
if (
(app_state.tensor_model_parallel_size is not None and app_state.tensor_model_parallel_size > 1)
or (app_state.pipeline_model_parallel_size is not None and app_state.pipeline_model_parallel_size > 1)
or (app_state.context_parallel_size is not None and app_state.context_parallel_size > 1)
):
raise LoggerMisconfigurationError(
"always_save_nemo is set to True, please ensure that model parallel is not used."
f"tensor_model_parallel_size: {app_state.tensor_model_parallel_size},"
Expand Down

0 comments on commit f8f4964

Please sign in to comment.