Skip to content

Commit

Permalink
t5 changes based on mcore changes
Browse files Browse the repository at this point in the history
Signed-off-by: Pablo Garay <[email protected]>
  • Loading branch information
pablo-garay committed Jul 22, 2024
1 parent a52ee73 commit b3d6286
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3430,7 +3430,8 @@ jobs:
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.precision=bf16 \
model.megatron_amp_O2=True \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
model.tensor_model_parallel_size=2 \
Expand Down Expand Up @@ -3472,7 +3473,8 @@ jobs:
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.precision=bf16 \
model.megatron_amp_O2=True \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \
exp_manager.resume_if_exists=True \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,12 @@ def _wrap_model_for_O2(self):
if type(self).__name__ == 'MegatronGPTModel':
nemo_args['share_token_embeddings'] = self.cfg.get('share_embeddings_and_output_weights', True)

mcore_args = {
'config': self.transformer_config,
}
if is_mcore_model:
mcore_args = {
'config': self.transformer_config,
}
else:
mcore_args = None

args = mcore_args if is_mcore_model else nemo_args
# Model wrapper to convert both model and inputs to half precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,6 @@ def training_step(self, dataloader_iter):
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
if not self.mcore_t5:
module = module.language_model
if hasattr(module, 'embedding'):
for param in module.embedding.parameters():
param.data_ptr()
Expand Down

0 comments on commit b3d6286

Please sign in to comment.