Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
Signed-off-by: Hemil Desai <[email protected]>
  • Loading branch information
hemildesai committed Jul 19, 2024
1 parent 6306228 commit 61f36a3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

TRAIN_LOSS_REDUCTION = MaskedTokenLossReduction()
VALIDATION_LOSS_REDUCTION = MaskedTokenLossReduction(validation_step=True)


def gpt_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state
Expand Down Expand Up @@ -163,6 +160,8 @@ def __init__(
self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True))
self.optim.connect(self) # This will bind the `configure_optimizers` method
self.model_transform = model_transform
self._training_loss_reduction = None
self._validation_loss_reduction = None

def configure_model(self) -> None:
if not hasattr(self, "module"):
Expand Down Expand Up @@ -203,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
return TRAIN_LOSS_REDUCTION
if not self._training_loss_reduction:
self._training_loss_reduction = MaskedTokenLossReduction()

return self._training_loss_reduction

@property
def validation_loss_reduction(self) -> MaskedTokenLossReduction:
return VALIDATION_LOSS_REDUCTION
if not self._validation_loss_reduction:
self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True)

return self._validation_loss_reduction


def get_batch_on_this_context_parallel_rank(batch):
Expand Down

0 comments on commit 61f36a3

Please sign in to comment.