From f776d7f56216b179d5d17153c81785a5ed291e5b Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 17 Jul 2024 15:46:07 -0700 Subject: [PATCH 1/3] Use single instance of loss reductions Signed-off-by: Hemil Desai --- nemo/collections/llm/gpt/model/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 0e4fabe020af..8c75615d8606 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -25,6 +25,8 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +TRAIN_LOSS_REDUCTION = MaskedTokenLossReduction() +VALIDATION_LOSS_REDUCTION = MaskedTokenLossReduction(validation_step=True) def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: from megatron.core import parallel_state @@ -201,10 +203,10 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) def training_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction() + return TRAIN_LOSS_REDUCTION def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction(validation_step=True) + return VALIDATION_LOSS_REDUCTION def get_batch_on_this_context_parallel_rank(batch): From 09d493e6c994c38f444ba9a92f312eb96496d16c Mon Sep 17 00:00:00 2001 From: hemildesai Date: Thu, 18 Jul 2024 22:19:39 +0000 Subject: [PATCH 2/3] Apply isort and black reformatting Signed-off-by: hemildesai --- nemo/collections/llm/gpt/model/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 8c75615d8606..d93ea5600e49 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -28,6 +28,7 @@ TRAIN_LOSS_REDUCTION = MaskedTokenLossReduction() VALIDATION_LOSS_REDUCTION = MaskedTokenLossReduction(validation_step=True) + def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: from megatron.core import parallel_state From eded32215342a19151e27c5181143238eb480a76 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Fri, 19 Jul 2024 14:41:31 -0700 Subject: [PATCH 3/3] Refactor Signed-off-by: Hemil Desai --- nemo/collections/llm/gpt/model/base.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index d93ea5600e49..a8339e124564 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -25,9 +25,6 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -TRAIN_LOSS_REDUCTION = MaskedTokenLossReduction() -VALIDATION_LOSS_REDUCTION = MaskedTokenLossReduction(validation_step=True) - def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: from megatron.core import parallel_state @@ -163,6 +160,8 @@ def __init__( self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method self.model_transform = model_transform + self._training_loss_reduction = None + self._validation_loss_reduction = None def configure_model(self) -> None: if not hasattr(self, "module"): @@ -203,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: - return TRAIN_LOSS_REDUCTION + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + + return self._training_loss_reduction + @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return VALIDATION_LOSS_REDUCTION + if not self._validation_loss_reduction: + self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) + + return self._validation_loss_reduction def get_batch_on_this_context_parallel_rank(batch):