Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continued pretraining example #23

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ pip install -e .

# Install dependencies if you want to use the example scripts
pip install datasets transformers
pip install ninja
pip install triton "flash-attn>=2.5.0" --no-build-isolation
```
> [!NOTE]
> If you get `undefined symbol: ncclCommRegister` error you should install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121`

> [!NOTE]
> If you get `TypeError: ~ (operator.invert) is only implemented on integer and Boolean-type tensors` error you should also install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121; pip install numpy==1.26.4`
> In order to not break flash-attn, it needs to be reinstalled as well with `pip uninstall flash-attn; pip cache purge; pip install triton flash-attn --no-build-isolation` (at least for me)
> Also make sure to use GPUs ">=" A100 (for V100 it this also fails).

> [!TIP]
> We log to wandb automatically if it's installed. For that you can use `pip install wandb`. If you don't want to use wandb, you can run `wandb disabled`.

Expand All @@ -69,6 +75,7 @@ You can find more examples in the [`/examples`](/examples) directory:
<!-- Make a table of the examples we support -->
| Example | Description |
| --- | --- |
| `continued-pretraining` | A minimal continued pretraining example using an TinyLlama_v1.1 (1B). |
| `custom-dataloader` | Plug a custom dataloader to nanotron |
| `datatrove` | Use the datatrove library to load data |
| `doremi` | Use DoReMi to speed up training |
Expand Down
37 changes: 37 additions & 0 deletions examples/continued-pretraining/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Continued pretraining of TinyLlama_v1.1 1B on a single GPU

I am assuming that you run these commands from the root folder of the nanotron repo.

1. Download the model from huggingface
`pip install huggingface_hub[cli]`
`huggingface-cli download TinyLlama/TinyLlama_v1.1`
This will download the model into your huggingface cache to a path like `~/.cache/huggingface/hub/models--TinyLlama--TinyLlama_v1.1/snapshots/ff3c701f2424c7625fdefb9dd470f45ef18b02d6` (the output of the command). Move the model somewhere else or note down this path, e.g., by doing `export HF_MODEL=~/.cache/huggingface/hub/models--TinyLlama--TinyLlama_v1.1/snapshots/ff3c701f2424c7625fdefb9dd470f45ef18b02d6` or download the model via git lfs instead.

2. Convert the model into nanotron format
`torchrun --nproc_per_node=1 examples/llama/convert_hf_to_nanotron.py --checkpoint_path=$HF_MODEL --save_path=models/llama1b`

3. Create the training config file. The main thing to consider here is to initialise the model from the checkpoint we just created. This can be done in the training config by changing
```
model:
init_method:
std: 0.025
```
to
```
model:
init_method:
path: models/llama1b
```
Additionally, the remaining model hyperparameters need to be updated to match the ones in `models/llama1b/model_config.json`. This file should look like this
```
{"bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": 5632, "is_llama_config": true, "max_position_embeddings": 2048, "num_attention_heads": 32, "num_hidden_layers": 22, "num_key_value_heads": 4, "pad_token_id": null, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 10000.0, "rope_interleaved": true, "tie_word_embeddings": false, "use_cache": true, "vocab_size": 32000}
```
We provide an updated training config yaml in `examples/continued-pretraining/config_1gpu_tiny_llama.yaml`. We also updated the maximum learning rate to 5% of the learning rate and the minimum learning rate to 10% of the learning rate used during pretraining TinyLlama.

4. Launch your training
```
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file examples/continued-pretraining/config_1gpu_tiny_llama.yaml
```

5. We configured our training in a way that also serializes the snapshot before training. We can use this snapshot to compare nanotron checkpoint generations with the ones from the corresponding huggingface checkpoint. To do so, run
`pip install accelerate`. Then you can generate using `torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/0/ &> nano-gen.log` and `python examples/continued-pretraining/hf_example_generations.py &> hf-gen.log` and compare the resulting log files containing their outputs.
112 changes: 112 additions & 0 deletions examples/continued-pretraining/config_1gpu_tiny_llama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
checkpoints:
checkpoint_interval: 100
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: true
save_final_state: true
data_stages:
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: debug
run: tiny_llama_%date_%jobid
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
path: models/llama1b
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 2048
initializer_range: 0.02
intermediate_size: 5632
is_llama_config: true
max_position_embeddings: 2048
num_attention_heads: 32
num_hidden_layers: 22
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
rope_theta: 10000.0
rope_interleaved: true
tie_word_embeddings: false
use_cache: true
vocab_size: 32000
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 2.0e-5
lr_decay_starting_step: null
lr_decay_steps: 950
lr_decay_style: cosine
lr_warmup_steps: 50
lr_warmup_style: linear
min_decay_lr: 4.0e-06
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: TinyLlama/TinyLlama_v1.1
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 2048
train_steps: 1000
val_check_interval: -1
26 changes: 26 additions & 0 deletions examples/continued-pretraining/hf_example_generations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM


example_generations = ["The future of AI is",
# "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:",
"def fib(n)"]

generator = pipeline('text-generation', model='TinyLlama/TinyLlama_v1.1', device=0)

"""
# It's a mystery to me why the text generation pipeline code does not produce the same outputs as the one below.
# The one below agreed for me with the nanotron generation script.
for idx, prompt in enumerate(example_generations):
generated_text = generator(prompt, max_new_tokens=50, num_return_sequences=1, do_sample=False, num_beams=1)
print(f"Example {idx}: {generated_text[0]['generated_text']}")
"""
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama_v1.1", device_map="auto")

for idx, prompt in enumerate(example_generations):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs['input_ids'].to(model.device), max_new_tokens=50, num_return_sequences=1)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Example {idx}: {generated_text}")
print()
6 changes: 3 additions & 3 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,14 +423,14 @@ def forward(
)
# Remove pad tokens from key_states and concatenate samples in key_unpad
# cu_seqlens_k is the cumulative sequence lengths of key_states
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
(query_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _) = bert_padding.unpad_input(
query_states,
sequence_mask,
)
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
(key_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _) = bert_padding.unpad_input(
key_states, sequence_mask
)
(value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
(value_unpad, _, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)

# NOTE: this scale is for µTransfer,
# in SP, we use sqrt(1/d_h)
Expand Down
Loading