Skip to content

Commit

Permalink
Merge branch 'r1.21.0' into sft_mp_fix_r21
Browse files Browse the repository at this point in the history
  • Loading branch information
aklife97 authored Sep 29, 2023
2 parents ac856f3 + 923fc0f commit 2bbbc74
Show file tree
Hide file tree
Showing 33 changed files with 242 additions and 105 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ RUN apt-get update && \
libsndfile1 sox \
libfreetype6 \
swig \
ffmpeg \
ffmpeg=ffmpeg_5.1.2-3ubuntu1 \
libavdevice-dev && \
rm -rf /var/lib/apt/lists/*

Expand Down
10 changes: 10 additions & 0 deletions nemo/collections/asr/models/enhancement_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,16 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32), sync_dist=True)

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
return output_dict

@classmethod
Expand Down
14 changes: 13 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,25 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self._macro_accuracy.update(preds=logits, target=labels)
stats = self._macro_accuracy._final_state()

return {
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro_top_k': acc_top_k,
f'{tag}_acc_macro_stats': stats,
}
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)

return output

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
Expand Down
24 changes: 15 additions & 9 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))

def validation_step(self, batch, batch_idx, dataloader_idx=0):
def validation_pass(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len = batch

# forward() only performs encoder forward
Expand Down Expand Up @@ -835,15 +835,21 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):

return tensorboard_logs

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {
'test_wer_num': logs['val_wer_num'],
'test_wer_denom': logs['val_wer_denom'],
# 'test_wer': logs['val_wer'],
}
if 'val_loss' in logs:
test_logs['test_loss'] = logs['val_loss']
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
return test_logs

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,10 @@ def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys:
for i, (ids, key) in enumerate(zip(template_ids, template_ids_keys)):
if key in self.truncation_fields:
truncation_length = truncation_length_list.pop()
assert len(ids) >= truncation_length, f'{key} is not long enough to truncate.'
if len(ids) < truncation_length:
logging.warning(f'{key} is not long enough to truncate.')
truncation_length = len(ids)

if self.truncation_method == 'left':
window_offset = truncation_length
elif self.truncation_method == 'right':
Expand Down Expand Up @@ -328,6 +331,7 @@ def _process_example(self, example):
if len(input_ids) > self.max_seq_length:
logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}')
input_ids = input_ids[: self.max_seq_length]
answer_ids = input_ids[answer_start_idx:]

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in self.prompt_template_keys}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
model_output = torch.argmax(model_output, 1)

eval_tensors = {'preds': model_output, 'labels': labels}
return {'val_loss': val_loss, 'eval_tensors': eval_tensors}
output = {'val_loss': val_loss, 'eval_tensors': eval_tensors}
self.validation_step_outputs.append(output)
return output

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
"""
Called at the end of validation to aggregate outputs.
outputs: list of individual outputs of each validation step.
"""
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
preds = torch.cat([x['eval_tensors']['preds'] for x in outputs])
labels = torch.cat([x['eval_tensors']['labels'] for x in outputs])
avg_loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]).mean()
preds = torch.cat([x['eval_tensors']['preds'] for x in self.validation_step_outputs])
labels = torch.cat([x['eval_tensors']['labels'] for x in self.validation_step_outputs])

all_preds = []
all_labels = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1483,10 +1483,14 @@ def build_transformer_config(self) -> TransformerConfig:
activation_func = activation_to_func(activation)

normalization = self.cfg.get('normalization', 'layernorm')
layernorm_zero_centered_gamma = self.cfg.get('normalization', 'layernorm') == 'layernorm1p'
if normalization == 'layernorm':
normalization = 'LayerNorm'
elif normalization == 'rmsnorm':
normalization = 'RMSNorm'
elif normalization == 'layernorm1p':
normalization = 'LayerNorm'
layernorm_zero_centered_gamma = True
else:
logging.warning(
f"The normalization type: {normalization} might not be supported in megatron core."
Expand Down Expand Up @@ -1530,7 +1534,7 @@ def build_transformer_config(self) -> TransformerConfig:
# any configs that are not in the nemo model config will be added here
config_mapping = {
'apply_residual_connection_post_layernorm': False, # we don't use this in NeMo
'layernorm_zero_centered_gamma': False, # not currently used in NeMo
'layernorm_zero_centered_gamma': layernorm_zero_centered_gamma,
'add_bias_linear': add_bias_linear,
'gated_linear_unit': gated_linear_unit,
'activation_func': activation_func,
Expand Down
52 changes: 39 additions & 13 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ def save_checkpoint(
checkpoint_dir = ckpt_to_dir(filepath)

fs = get_filesystem(checkpoint_dir)
if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir):
logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving')
return

if is_global_rank_zero():
fs.makedirs(checkpoint_dir, exist_ok=True)

Expand Down Expand Up @@ -465,19 +469,24 @@ def save_to(self, model, save_path: str):
# model weights is a directory
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))
fs = get_filesystem(dist_ckpt_dir)
if is_global_rank_zero():
fs.makedirs(dist_ckpt_dir, exist_ok=True)
sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if parallel_state.is_unitialized():

def dummy():
return
if fs.isdir(dist_ckpt_dir) and dist_checkpointing.check_is_distributed_checkpoint(dist_ckpt_dir):
logging.info(f'Distributed checkpoint at path {dist_ckpt_dir} already exists, skipping saving')
else:
if is_global_rank_zero():
fs.makedirs(dist_ckpt_dir, exist_ok=True)

sharded_state_dict = model.sharded_state_dict()
# dist checkpoint needs torch.distributed to save the checkpoint
if parallel_state.is_unitialized():

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir)
def dummy():
return

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=dist_ckpt_dir)

else:

Expand Down Expand Up @@ -1132,6 +1141,12 @@ class CustomProgressBar(TQDMProgressBar):
for megatron models
"""

def get_current_epoch_step(self, trainer):
"""
Get the value of step within an epoch
"""
return trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed

def init_train_tqdm(self):
"""
Override bar_format to not have 's/it'
Expand All @@ -1140,11 +1155,22 @@ def init_train_tqdm(self):
self.bar.bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]"
return self.bar

def on_train_epoch_start(self, trainer, *_):
if trainer.max_steps > 0 and (trainer.ckpt_path is not None):
# while resuming from a ckpt use trainer.max_steps as the total for progress bar as trainer.num_training_batches
# is truncated to max_steps - step being resumed at
num_training_batches = trainer.max_steps
else:
num_training_batches = trainer.num_training_batches
self.train_progress_bar.reset(num_training_batches)
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer, pl_module, *_, **__):
"""
Override parent class on_train_batch_end to update progress bar per global_step instead of per microbatch
Override parent class on_train_batch_end to update progress bar per global batch instead of per microbatch
"""
n = trainer.global_step
n = self.get_current_epoch_step(trainer)
if self._should_update(n, self.train_progress_bar.total):
_update_n(self.train_progress_bar, n)
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
113 changes: 96 additions & 17 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Create list of lists for val and test outputs to support multiple dataloaders
# Initialize an empty list as sometimes self._validation_dl can be None at this stage
self.validation_step_outputs = []
# Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if self._validation_dl and type(self._validation_dl) == list and len(self._validation_dl) > 1:
for _ in range(len(self._validation_dl)):
self.validation_step_outputs.append([])
self._validation_step_outputs = None

# Initialize an empty list as sometimes self._test_dl can be None at this stage
self.test_step_outputs = []
if self._test_dl and type(self._test_dl) == list and len(self._test_dl) > 1:
for _ in range(len(self._test_dl)):
self.test_step_outputs.append([])
self._test_step_outputs = None

# ModelPT wrappers over subclass implementations
self.training_step = model_utils.wrap_training_step(self.training_step)

Expand Down Expand Up @@ -856,12 +849,18 @@ def train_dataloader(self):
return self._train_dl

def val_dataloader(self):
if self._validation_dl is not None:
return self._validation_dl
if self._validation_dl is None:
# None dataloader no longer supported in PTL2.0
self._validation_dl = []

return self._validation_dl

def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
if self._test_dl is None:
# None dataloader no longer supported in PTL2.0
self._test_dl = []

return self._test_dl

def on_validation_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""
Expand Down Expand Up @@ -1567,6 +1566,61 @@ def cfg(self, cfg):
if hasattr(self, '_hparams_initial') and 'cfg' in self._hparams_initial:
self._hparams_initial['cfg'] = OmegaConf.to_object(self._cfg)

@property
def validation_step_outputs(self):
"""
Cached outputs of validation_step. It can be a list of items (for single data loader) or a list of lists
(for multiple data loaders).
Returns:
List of outputs of validation_step.
"""
if self._validation_step_outputs is not None:
return self._validation_step_outputs

# Initialize new output list
self._validation_step_outputs = []
# Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if (
self._validation_dl is not None
and isinstance(self._validation_dl, (list, tuple))
and len(self._validation_dl) > 1
):
for _ in range(len(self._validation_dl)):
self._validation_step_outputs.append([])

return self._validation_step_outputs

@validation_step_outputs.setter
def validation_step_outputs(self, value):
self._validation_step_outputs = value

@property
def test_step_outputs(self):
"""
Cached outputs of test_step. It can be a list of items (for single data loader) or a list of lists (for multiple data loaders).
Returns:
List of outputs of test_step.
"""
if self._test_step_outputs is not None:
return self._test_step_outputs

# Initialize new output list
self._test_step_outputs = []
# Check len(self._test_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if self._test_dl is not None and isinstance(self._test_dl, (list, tuple)) and len(self._test_dl) > 1:
for _ in range(len(self._test_dl)):
self._test_step_outputs.append([])

return self._test_step_outputs

@test_step_outputs.setter
def test_step_outputs(self, value):
self._test_step_outputs = value

@staticmethod
def _is_model_being_restored() -> bool:
app_state = AppState()
Expand Down Expand Up @@ -1708,15 +1762,40 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()

def _cleanup_on_execution_end(self):
"""
Utility function to clean up the module state at the end of execution.
"""

# dynamic freezing cleanup
if hasattr(self, '_freeze_cfg'):
delattr(self, '_freeze_cfg')

# Clear up the val and test output caches
self._validation_step_outputs = None
self._test_step_outputs = None

def on_train_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to cleanup the dynamic freezing config.
"""

# dynamic freezing cleanup
if hasattr(self, '_freeze_cfg'):
delattr(self, '_freeze_cfg')
self._cleanup_on_execution_end()

def on_test_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
"""

self._cleanup_on_execution_end()

def on_predict_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
"""

self._cleanup_on_execution_end()

# TODO: Remove in PTL 1.7.2
def cuda(self, device=None):
Expand Down
Loading

0 comments on commit 2bbbc74

Please sign in to comment.