Skip to content

Commit

Permalink
[NeMo-UX] Use single instance of loss reductions in GPTModel (#9801) (#…
Browse files Browse the repository at this point in the history
…9861)

* Use single instance of loss reductions



* Apply isort and black reformatting



* Refactor



---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: Hemil Desai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
3 people authored and monica-sekoyan committed Oct 11, 2024
1 parent 0686d10 commit b2f77e4
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(
self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True))
self.optim.connect(self) # This will bind the `configure_optimizers` method
self.model_transform = model_transform
self._training_loss_reduction = None
self._validation_loss_reduction = None

def configure_model(self) -> None:
if not hasattr(self, "module"):
Expand Down Expand Up @@ -200,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
return MaskedTokenLossReduction()
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()

return self._training_loss_reduction

@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
return MaskedTokenLossReduction(validation_step=True)
if not self._validation_loss_reduction:
self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True)

return self._validation_loss_reduction


def get_batch_on_this_context_parallel_rank(batch):
Expand Down

0 comments on commit b2f77e4

Please sign in to comment.