Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Aug 20, 2024
1 parent 4c84227 commit cf623f9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]:
Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
"""
return []

def get_named_params_without_weight_decay(self) -> List[str]:
"""Return a list of named parameters that should not have weight decay applied to them."""
return []
Expand Down
31 changes: 15 additions & 16 deletions src/nanotron/models/llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from nanotron import distributed as dist
from nanotron.config import LlamaConfig, LlaMoEConfig, ParallelismArgs
from nanotron.models import llama
from nanotron.models.llama import CausalSelfAttention, LlamaForTraining, LlamaModel
from nanotron.models.llama import LlamaDecoderLayer
from nanotron.models.llama import CausalSelfAttention, LlamaDecoderLayer, LlamaForTraining, LlamaModel
from nanotron.models.moe import (
dMoE,
)
Expand All @@ -20,7 +19,6 @@
from nanotron.parallel.pipeline_parallel.block import PipelineBlock
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear

from src.nanotron.random import RandomStates


Expand Down Expand Up @@ -112,34 +110,31 @@ def _core_forward(

hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"], aux_losses
return hidden_states, output["sequence_mask"], aux_losses

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]],
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask, aux_losses)


def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask, aux_losses = self._checkpointed_forward(hidden_states, sequence_mask, aux_losses)
hidden_states, sequence_mask, aux_losses = self._checkpointed_forward(
hidden_states, sequence_mask, aux_losses
)
else:
hidden_states, sequence_mask, aux_losses = self._core_forward(hidden_states, sequence_mask, aux_losses)

return {
"hidden_states": hidden_states,
"sequence_mask": sequence_mask,
"aux_losses": aux_losses
}
return {"hidden_states": hidden_states, "sequence_mask": sequence_mask, "aux_losses": aux_losses}


class LlaMoEModel(LlamaModel):
Expand Down Expand Up @@ -184,8 +179,8 @@ def forward(

hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
"aux_losses": aux_losses
"sequence_mask": input_mask,
"aux_losses": aux_losses,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
Expand Down Expand Up @@ -251,7 +246,11 @@ def forward(
def get_block_compute_costs(self):
"""Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
model_config = self.config
d_ff = model_config.intermediate_size if model_config.intermediate_size is not None else 4 * model_config.hidden_size
d_ff = (
model_config.intermediate_size
if model_config.intermediate_size is not None
else 4 * model_config.hidden_size
)
d_qkv = model_config.hidden_size // model_config.num_attention_heads
# active experts + routing
mlp_cost = (
Expand Down
6 changes: 4 additions & 2 deletions src/nanotron/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim]
hidden_states = self.w2(self.act(merged_states))
return hidden_states


class GLU(MLP):
def __init__(
self,
Expand All @@ -676,11 +677,12 @@ def __init__(
expert_parallel_size=self.expert_pg_size,
)

def forward(self, x, topo):
def forward(self, hidden_states, topo):
merged_states = self.w1(hidden_states)
hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states))
return hidden_states


def inclusive_cumsum(x, dim):
scalar = ops.inclusive_cumsum(x, dim)
return scalar.view(1) if not len(scalar.size()) else scalar
Expand Down Expand Up @@ -718,4 +720,4 @@ def forward(self, x, topo):
x1 = self.sdd(x, self.w1.module.weight, topo)
x2 = self.sdd(x, self.w3.module.weight, topo)
x = stk.ops.mul(act_fn(x1, self.act), x2)
return self.dsd(x, self.w2.module.weight)
return self.dsd(x, self.w2.module.weight)
1 change: 0 additions & 1 deletion src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
attach_pipeline_state_to_model,
)
from nanotron.parallel.pipeline_parallel.state import (
PipelineEvalBatchState,
PipelineTrainBatchState,
)
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
Expand Down

0 comments on commit cf623f9

Please sign in to comment.