diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index d93ea5600e49..a8339e124564 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -25,9 +25,6 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -TRAIN_LOSS_REDUCTION = MaskedTokenLossReduction() -VALIDATION_LOSS_REDUCTION = MaskedTokenLossReduction(validation_step=True) - def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: from megatron.core import parallel_state @@ -163,6 +160,8 @@ def __init__( self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method self.model_transform = model_transform + self._training_loss_reduction = None + self._validation_loss_reduction = None def configure_model(self) -> None: if not hasattr(self, "module"): @@ -203,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: - return TRAIN_LOSS_REDUCTION + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + + return self._training_loss_reduction + @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return VALIDATION_LOSS_REDUCTION + if not self._validation_loss_reduction: + self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) + + return self._validation_loss_reduction def get_batch_on_this_context_parallel_rank(batch):