From 50d650818cd9fb8d3548ce3dd6e2cfa2f42f3dd6 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 24 Jul 2024 00:35:43 -0700 Subject: [PATCH] [NeMo-UX] Use single instance of loss reductions in GPTModel (#9801) * Use single instance of loss reductions Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai * Refactor Signed-off-by: Hemil Desai --------- Signed-off-by: Hemil Desai Signed-off-by: hemildesai Co-authored-by: hemildesai --- nemo/collections/llm/gpt/model/base.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 0e4fabe020af..a8339e124564 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -160,6 +160,8 @@ def __init__( self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method self.model_transform = model_transform + self._training_loss_reduction = None + self._validation_loss_reduction = None def configure_model(self) -> None: if not hasattr(self, "module"): @@ -200,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction() + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + return self._training_loss_reduction + + @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction(validation_step=True) + if not self._validation_loss_reduction: + self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) + + return self._validation_loss_reduction def get_batch_on_this_context_parallel_rank(batch):