Skip to content

Commit

Permalink
Merge pull request swiss-ai#3 from huggingface/main
Browse files Browse the repository at this point in the history
Update
  • Loading branch information
AleHD authored Apr 23, 2024
2 parents c0bbcdb + b792a22 commit decfc3f
Show file tree
Hide file tree
Showing 22 changed files with 1,094 additions and 164 deletions.
11 changes: 7 additions & 4 deletions examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -62,11 +63,13 @@
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
49 changes: 41 additions & 8 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,34 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down Expand Up @@ -34,9 +65,6 @@ model:
vocab_size: 256
optimizer:
accumulate_grad_in_fp32: true
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.001
Expand All @@ -46,13 +74,18 @@ optimizer:
lr_warmup_steps: 2000 # 20% of the total steps
lr_warmup_style: linear
min_decay_lr: 1.0e-05
torch_adam_is_fused: true
weight_decay: 0.1
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
Expand All @@ -77,7 +110,7 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_or_datasets: HuggingFaceH4/testing_codealpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
Expand All @@ -99,7 +132,7 @@ checkpoints:
checkpoint_interval: 10
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
resume_checkpoint_path: checkpoints
save_initial_state: false
profiler: null
logging:
Expand Down
8 changes: 7 additions & 1 deletion examples/contributor-guide/debug_config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CheckpointsArgs,
Config,
DataArgs,
DatasetStageArgs,
GeneralArgs,
LlamaConfig,
LoggingArgs,
Expand Down Expand Up @@ -95,7 +96,12 @@
optimizer=optimizer,
logging=LoggingArgs(),
tokens=tokens,
data=DataArgs(dataset=dataset, seed=seed),
data_stages=[
DatasetStageArgs(
name="Stable Training Stage", start_training_step=1, data=DataArgs(dataset=dataset, seed=seed)
),
DatasetStageArgs(name="Annealing Phase", start_training_step=10, data=DataArgs(dataset=dataset, seed=seed)),
],
profiler=None,
)

Expand Down
39 changes: 25 additions & 14 deletions examples/contributor-guide/debug_config_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
checkpoints:
checkpoint_interval: 10
checkpoints_path: /fsx/ferdinandmom/ferdinand-hf/nanotron/examples/checkpoints
checkpoints_path: /fsx/haojun/nanotron_latest/examples/checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false

data_stages:
- name: General purpose training
start_training_step: 1
data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
hf_dataset_splits: train
text_column_name: completion
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand Down
30 changes: 21 additions & 9 deletions examples/mamba/create_config_mamba.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information."""
import math
import os
import uuid

from config import MambaConfig, MambaInit, MambaModelConfig
from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
DataArgs,
DatasetStageArgs,
Expand All @@ -19,6 +21,10 @@
)
from nanotron.logging import human_format

new_job_id = uuid.uuid4()
job_id = str(new_job_id)[:8]
seed = 42

ssm_cfg_dtype = "bfloat16"
ssm_cfg = {
"d_state": 16,
Expand All @@ -37,7 +43,7 @@
# https://huggingface.co/state-spaces/mamba-790m/blob/main/config.json
model_config = MambaModelConfig(
d_model=1024,
num_hidden_layers=48,
num_hidden_layers=2,
vocab_size=50278,
ssm_cfg=ssm_cfg,
rms_norm=True,
Expand Down Expand Up @@ -88,24 +94,28 @@

seed = 42


optimizer = OptimizerArgs(
zero_stage=0,
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=True, # NOTE(fmom): because we are using PP=TP=DP=1
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=LRSchedulerArgs(
learning_rate=0.0015,
lr_warmup_steps=30,
lr_warmup_style="linear",
lr_decay_style="cosine",
min_decay_lr=0.00015,
),
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)


parallelism = ParallelismArgs(
dp=2,
pp=2,
Expand All @@ -128,17 +138,19 @@
)
]

model = ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
)

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

config = MambaConfig(
general=GeneralArgs(project="test", run="mamba", seed=seed, ignore_sanity_checks=True),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=100),
parallelism=parallelism,
model=ModelArgs(
init_method=MambaInit(initializer_range=0.02, rescale_prenorm_residual=True, n_residuals_per_layer=1),
model_config=model_config,
),
model=model,
tokenizer=TokenizerArgs("gpt2"),
optimizer=optimizer,
logging=LoggingArgs(),
Expand Down
18 changes: 10 additions & 8 deletions examples/mamba/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,13 @@ def __init__(
self.A_log = create_sharded_parameter_from_config(
parameter=A_log, pg=self.tp_pg, split_config=SplitConfig(split_dim=0)
)
self.A_log._no_weight_decay = True

# D "skip" parameter
self.D = create_sharded_parameter_from_config(
parameter=torch.ones(self.d_inner // self.tp_pg.size(), device=device),
pg=self.tp_pg,
split_config=SplitConfig(split_dim=0),
)
self.D._no_weight_decay = True

# self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.out_proj = TensorParallelRowLinear(
Expand Down Expand Up @@ -664,7 +662,7 @@ def get_block_compute_costs(self):

def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
"""
Get flops per second for a Mamba model.
Get flops per second for a Mamba model.
Terms such as nonlinearities, biases, and layer normalization are omitted (https://arxiv.org/pdf/2001.08361.pdf)
"""
# world_size = self.parallel_context.world_pg.size()
Expand Down Expand Up @@ -807,6 +805,14 @@ def forward(
label_mask=label_mask,
)["loss"]
return {"loss": loss}

def get_named_params_without_weight_decay(self):
# get full name with "A_log", "D"
named_param_without_weight_decay = []
for name, _ in self.model.named_parameters():
if "A_log" in name or "D" in name:
named_param_without_weight_decay.append(name)
return named_param_without_weight_decay

@torch.no_grad()
def init_model_randomly(self, config):
Expand Down Expand Up @@ -917,11 +923,7 @@ def init_model_randomly(self, config):
raise ValueError(f"Who the fuck is {param_name}?")

elif isinstance(module, Mamba):
# NOTE(fmom): nn.Parameter are initialized in Mamba __init__
# In Mamba, only those 3 parameters don't have weight decay.
if param_name in ["dt_bias", "A_log", "D"]:
param._no_weight_decay = True

pass
else:
raise Exception(f"Parameter {full_param_name} was not initialized")

Expand Down
11 changes: 7 additions & 4 deletions examples/moe/config_llamoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional

from nanotron.config import (
AdamWOptimizerArgs,
CheckpointsArgs,
Config,
DataArgs,
Expand Down Expand Up @@ -99,11 +100,13 @@ def __post_init__(self):
weight_decay=0.01,
clip_grad=1.0,
accumulate_grad_in_fp32=False,
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
learning_rate_scheduler=learning_rate,
optimizer_factory=AdamWOptimizerArgs(
adam_eps=1e-08,
adam_beta1=0.9,
adam_beta2=0.95,
torch_adam_is_fused=True,
),
)

parallelism = ParallelismArgs(
Expand Down
Loading

0 comments on commit decfc3f

Please sign in to comment.