From 42fb8582df34c1f6497f0b4bcac707b3f43367c8 Mon Sep 17 00:00:00 2001 From: Chris Wendler Date: Thu, 28 Nov 2024 14:41:20 +0100 Subject: [PATCH 1/5] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 7a22f12a..271f244d 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,11 @@ pip install triton "flash-attn>=2.5.0" --no-build-isolation > [!NOTE] > If you get `undefined symbol: ncclCommRegister` error you should install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121` +> [!NOTE] +> If you get `TypeError: ~ (operator.invert) is only implemented on integer and Boolean-type tensors` error you should also install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121` +> In order to not break flash-attn, it needs to be reinstalled as well with `pip uninstall flash-attn; pip cache purge; pip install triton flash-attn --no-build-isolation` (at least for me) +> Also make sure to use GPUs ">=" A100 (for V100 it this also fails). + > [!TIP] > We log to wandb automatically if it's installed. For that you can use `pip install wandb`. If you don't want to use wandb, you can run `wandb disabled`. From 5b1a16faef08ea626502b9854298716cc17fcefb Mon Sep 17 00:00:00 2001 From: Chris Wendler Date: Thu, 28 Nov 2024 15:05:24 +0100 Subject: [PATCH 2/5] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 271f244d..f36f8403 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ pip install -e . # Install dependencies if you want to use the example scripts pip install datasets transformers +pip install ninja pip install triton "flash-attn>=2.5.0" --no-build-isolation ``` > [!NOTE] From 9060b9f39501a0b9cd0bd8f4a73c20dd411260ca Mon Sep 17 00:00:00 2001 From: Chris Wendler Date: Thu, 28 Nov 2024 15:58:28 +0100 Subject: [PATCH 3/5] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f36f8403..42fbceee 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ pip install triton "flash-attn>=2.5.0" --no-build-isolation > If you get `undefined symbol: ncclCommRegister` error you should install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121` > [!NOTE] -> If you get `TypeError: ~ (operator.invert) is only implemented on integer and Boolean-type tensors` error you should also install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121` +> If you get `TypeError: ~ (operator.invert) is only implemented on integer and Boolean-type tensors` error you should also install torch 2.1.2 instead: `pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121; pip install numpy==1.26.4` > In order to not break flash-attn, it needs to be reinstalled as well with `pip uninstall flash-attn; pip cache purge; pip install triton flash-attn --no-build-isolation` (at least for me) > Also make sure to use GPUs ">=" A100 (for V100 it this also fails). From c5817f6823c0d692030f7a9499898d2c479c232a Mon Sep 17 00:00:00 2001 From: Chris Wendler Date: Wed, 4 Dec 2024 11:02:23 +0000 Subject: [PATCH 4/5] continued pretraining example --- README.md | 1 + examples/continued-pretraining/README.md | 39 ++++++ .../config_1gpu_tiny_llama.yaml | 112 ++++++++++++++++++ .../hf_example_generations.py | 26 ++++ src/nanotron/models/llama.py | 6 +- 5 files changed, 181 insertions(+), 3 deletions(-) create mode 100644 examples/continued-pretraining/README.md create mode 100644 examples/continued-pretraining/config_1gpu_tiny_llama.yaml create mode 100644 examples/continued-pretraining/hf_example_generations.py diff --git a/README.md b/README.md index 42fbceee..5d5e3c57 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ You can find more examples in the [`/examples`](/examples) directory: | Example | Description | | --- | --- | +| `continued-pretraining` | A minimal continued pretraining example using an TinyLlama_v1.1 (1B). | | `custom-dataloader` | Plug a custom dataloader to nanotron | | `datatrove` | Use the datatrove library to load data | | `doremi` | Use DoReMi to speed up training | diff --git a/examples/continued-pretraining/README.md b/examples/continued-pretraining/README.md new file mode 100644 index 00000000..21f768b5 --- /dev/null +++ b/examples/continued-pretraining/README.md @@ -0,0 +1,39 @@ +# Continued pretraining of TinyLlama_v1.1 1B on a single GPU + +I am assuming that you run these commands from the root folder of the nanotron repo. + +1. Download the model from huggingface +`pip install huggingface_hub[cli]` +`huggingface-cli download TinyLlama/TinyLlama_v1.1` +This will download the model into your huggingface cache to a path like `~/.cache/huggingface/hub/models--TinyLlama--TinyLlama_v1.1/snapshots/ff3c701f2424c7625fdefb9dd470f45ef18b02d6` (the output of the command). Move the model somewhere else or note down this path, e.g., by doing `export HF_MODEL=~/.cache/huggingface/hub/models--TinyLlama--TinyLlama_v1.1/snapshots/ff3c701f2424c7625fdefb9dd470f45ef18b02d6` or download the model via git lfs instead. + +2. Convert the model into nanotron format +`torchrun --nproc_per_node=1 examples/llama/convert_hf_to_nanotron.py --checkpoint_path=$HF_MODEL --save_path=models/llama1b` + +3. Create the training config file. The main thing to consider here is to initialise the model from the checkpoint we just created. This can be done in the training config by changing +``` +model: + init_method: + std: 0.025 +``` +to +``` +model: + init_method: + path: models/llama1b +``` +Additionally, the remaining model hyperparameters need to be updated to match the ones in `models/llama1b/model_config.json`. This file should look like this +``` +{"bos_token_id": 1, "eos_token_id": 2, "hidden_act": "silu", "hidden_size": 2048, "initializer_range": 0.02, "intermediate_size": 5632, "is_llama_config": true, "max_position_embeddings": 2048, "num_attention_heads": 32, "num_hidden_layers": 22, "num_key_value_heads": 4, "pad_token_id": null, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 10000.0, "rope_interleaved": true, "tie_word_embeddings": false, "use_cache": true, "vocab_size": 32000} +``` +We provide an updated training config yaml in `examples/continued-pretraining/config_1gpu_tiny_llama.yaml`. We also updated the maximum learning rate to 5% of the learning rate and the minimum learning rate to 10% of the learning rate used during pretraining TinyLlama. + +4. Launch your training +``` +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config-file examples/continued-pretraining/config_1gpu_tiny_llama.yaml +``` + +5. We configured our training in a way that also serializes the snapshot before training. We can use this snapshot to compare nanotron checkpoint generations with the ones from the corresponding huggingface checkpoint. To do so, run +`pip install accelerate` + +`torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/0/ &> nano-gen.log` and `python examples/continued-pretraining/hf_example_generations.py &> hf-gen.log` and compare the resulting log files containing their outputs. \ No newline at end of file diff --git a/examples/continued-pretraining/config_1gpu_tiny_llama.yaml b/examples/continued-pretraining/config_1gpu_tiny_llama.yaml new file mode 100644 index 00000000..37223ee5 --- /dev/null +++ b/examples/continued-pretraining/config_1gpu_tiny_llama.yaml @@ -0,0 +1,112 @@ +checkpoints: + checkpoint_interval: 100 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: true + save_final_state: true +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: models/llama1b + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 2048 + initializer_range: 0.02 + intermediate_size: 5632 + is_llama_config: true + max_position_embeddings: 2048 + num_attention_heads: 32 + num_hidden_layers: 22 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + rope_theta: 10000.0 + rope_interleaved: true + tie_word_embeddings: false + use_cache: true + vocab_size: 32000 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 2.0e-5 + lr_decay_starting_step: null + lr_decay_steps: 950 + lr_decay_style: cosine + lr_warmup_steps: 50 + lr_warmup_style: linear + min_decay_lr: 4.0e-06 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: TinyLlama/TinyLlama_v1.1 + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/examples/continued-pretraining/hf_example_generations.py b/examples/continued-pretraining/hf_example_generations.py new file mode 100644 index 00000000..0dc4b8e4 --- /dev/null +++ b/examples/continued-pretraining/hf_example_generations.py @@ -0,0 +1,26 @@ +from transformers import pipeline +from transformers import AutoTokenizer, AutoModelForCausalLM + + +example_generations = ["The future of AI is", + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "def fib(n)"] + +generator = pipeline('text-generation', model='TinyLlama/TinyLlama_v1.1', device=0) + +""" +# It's a mystery to me why the text generation pipeline code does not produce the same outputs as the one below. +# The one below agreed for me with the nanotron generation script. +for idx, prompt in enumerate(example_generations): + generated_text = generator(prompt, max_new_tokens=50, num_return_sequences=1, do_sample=False, num_beams=1) + print(f"Example {idx}: {generated_text[0]['generated_text']}") +""" +tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1") +model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama_v1.1", device_map="auto") + +for idx, prompt in enumerate(example_generations): + inputs = tokenizer(prompt, return_tensors="pt") + outputs = model.generate(inputs['input_ids'].to(model.device), max_new_tokens=50, num_return_sequences=1) + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"Example {idx}: {generated_text}") + print() \ No newline at end of file diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 28a2e30f..0a2e5825 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -423,14 +423,14 @@ def forward( ) # Remove pad tokens from key_states and concatenate samples in key_unpad # cu_seqlens_k is the cumulative sequence lengths of key_states - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _) = bert_padding.unpad_input( query_states, sequence_mask, ) - (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _) = bert_padding.unpad_input( key_states, sequence_mask ) - (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) + (value_unpad, _, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) From 0328d7c38354ee17efed7e257c0078384b068c65 Mon Sep 17 00:00:00 2001 From: Chris Wendler Date: Wed, 4 Dec 2024 12:04:46 +0100 Subject: [PATCH 5/5] Update README.md --- examples/continued-pretraining/README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/continued-pretraining/README.md b/examples/continued-pretraining/README.md index 21f768b5..edf50aa5 100644 --- a/examples/continued-pretraining/README.md +++ b/examples/continued-pretraining/README.md @@ -34,6 +34,4 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 run_train.py --config- ``` 5. We configured our training in a way that also serializes the snapshot before training. We can use this snapshot to compare nanotron checkpoint generations with the ones from the corresponding huggingface checkpoint. To do so, run -`pip install accelerate` - -`torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/0/ &> nano-gen.log` and `python examples/continued-pretraining/hf_example_generations.py &> hf-gen.log` and compare the resulting log files containing their outputs. \ No newline at end of file +`pip install accelerate`. Then you can generate using `torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/0/ &> nano-gen.log` and `python examples/continued-pretraining/hf_example_generations.py &> hf-gen.log` and compare the resulting log files containing their outputs.