Skip to content

Commit

Permalink
fix FA, start to make validation step look like train step
Browse files Browse the repository at this point in the history
  • Loading branch information
kylematoba committed Nov 19, 2024
1 parent f587ee8 commit d0facd6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ def forward(
return_attn_probs=False,
)
else:
assert not self.is_using_mup, "have not tested this"
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.permute(0, 2, 1, 3),
key_states.permute(0, 2, 1, 3),
value_states.permute(0, 2, 1, 3),
dropout_p=0.0,
is_causal=True,
) # [batch, q_length, q_heads, head_dim]
).permute(0, 2, 1, 3)
return attn_output


Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,15 @@ def train(
log_entries = [LogItem("validation_loss_avg", loss_avg, "human_format")]
self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step)

# NOTE: only one rank writes to wandb
if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None:
wandb.log(
{
**{log_item.tag: log_item.scalar_value for log_item in log_entries},
"iteration_step": self.iteration_step,
}
)

# Checkpoint
if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0:
self.save_checkpoint()
Expand Down Expand Up @@ -636,6 +645,13 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten
)
return outputs

def valid_step_logs(self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
loss_avg: Optional[torch.Tensor],
) -> None:
pass


def train_step_logs(
self,
outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]],
Expand Down

0 comments on commit d0facd6

Please sign in to comment.