Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deleting cross attention between documents during pertaining #16

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
419b33a
First prototype, let's jump padding free
TJ-Solergibert Jul 27, 2024
03f4308
This mess produces sames generations as hf
TJ-Solergibert Jul 29, 2024
c57533d
Added SFT generations check script
TJ-Solergibert Jul 29, 2024
a66b0c6
Added masked LOSS check
TJ-Solergibert Jul 29, 2024
06af8cf
Getting ready
TJ-Solergibert Jul 30, 2024
a8f979d
RCP Working
TJ-Solergibert Jul 30, 2024
c026422
Added todi scripts
TJ-Solergibert Jul 30, 2024
38f3815
Added SFT docs
TJ-Solergibert Aug 2, 2024
2d882db
Merge branch 'sft' of https://github.com/tj-solergibert/nanotron into…
TJ-Solergibert Aug 2, 2024
d5228bb
Dont predict EOText token
TJ-Solergibert Aug 2, 2024
71122d3
first commit
TJ-Solergibert Aug 22, 2024
efd168f
Little hack to fix first length
TJ-Solergibert Aug 22, 2024
5157392
Lets move to todi
TJ-Solergibert Aug 26, 2024
a185c50
Adding liger kernels and modifyng conversion scripts
TJ-Solergibert Aug 26, 2024
a7051d1
Compatibility with llama.py checkpoints
TJ-Solergibert Sep 16, 2024
a750b45
Added script to count total number of tokens of SFT datasets
TJ-Solergibert Sep 16, 2024
ed51183
Merge branch 'main' into document_xattention
TJ-Solergibert Sep 16, 2024
cd81111
Only load model parameters on SFT
TJ-Solergibert Sep 16, 2024
81fdb3a
Fixed metadata issue
TJ-Solergibert Sep 16, 2024
f3bf21d
read datasets locally
TJ-Solergibert Sep 17, 2024
06553df
create cos and sin in each decoder layer and check_sft working on todi
TJ-Solergibert Sep 17, 2024
3969aa2
No more NaN losses
TJ-Solergibert Sep 26, 2024
df3ef9d
Bringing liger kernels back
TJ-Solergibert Sep 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/sft.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# LlamaSFT
## Introduction
We have incorporated the ability to perform SFT in nanotron with the following features:
1. Packing multiple samples to fill the sequence length of the model
2. Training on completions only: The model learns from the answers, not from the user prompt & chat templates
3. Removing cross-attention between the multiple samples packed

In the following sections, we will delve into more detail about these features and how we have implemented them.

### Feature 1: Packing
To train the models efficiently, we will pack multiple conversations into the same sample until filling the sequence length. As we are packing multiple sequences and to avoid introducing padding tokens, [we will flatten the batch size](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/trainer.py#L259), so `sequence_length = micro_batch_size * sequence_length` and `micro_batch_size = 1`.
![](sft_feature1.png)

### Feature 2: Training only on completions
Conversations consist of user messages, which are usually questions or inquiries, and the model's responses. The ultimate goal is for the model to improve the quality of its responses, and not so much to learn about user questions or other aspects like the chat template. Therefore, during training, we will compute the loss only with the tokens that belong to the answers produced by the model.

To achieve this, when tokenizing the conversations, we will [store the role of each token](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L59) and create an attention mask that the model will use in the loss computation [[1]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L617), [[2]](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L603).
![](sft_feature2.png)

### Feature 3: Removing cross-attention
Finally, as we are packing multiple conversations together, we do not want the tokens of one conversation to attend to those of other conversations.
To do this, we will store the `position_ids` of each token in the sequence length to:
1. Apply the RoPE embeddings correctly to each conversation
2. [Create the attention mask](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L346) needed by [`flash_attn_varlen_func`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/models/llama_sft.py#L352) to compute the attention without cross-contamination between different conversations
![](sft_feature3.png)

## Internals
### Config file
For SFT, we need to setup the config file as follows:
```yaml
- data:
dataset:
hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered
hf_dataset_split: train
conversation_column_name: conversations
train_on_completions_only: true
remove_cross_attention: true
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
```
The `hf_dataset` should be a dataset from the HuggingFace Hub with the same structure as `Magpie-Align/Magpie-Pro-300K-Filtered`; that is, each conversation will be a list of dictionaries, each with the keys `from` [`gpt`, `human`] and `value`. We can select a split with `hf_dataset_split` and the dataset column with `conversation_column_name`. `train_on_completions_only` & `remove_cross_attention` are to toggle on/off Features 2 and 3, but we will remove them for the final release.

### Iterable Dataset
For SFT training, we have developed a new dataset, [`ChatDataset`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L17), responsible for producing data batches during training. Unlike `Nanosets`, this new `ChatDataset` is an [`IterableDataset`](https://pytorch.org/docs/stable/data.html#iterable-style-datasets). The advantage of this type of dataset is that they do not require preprocessing the data before training as they do it on-the-fly, saving us the preprocessing step and the space occupied by the preprocessed data. The downside is that it is not trivial to recover the state of the DataLoader when restarting training. For this, we are developing a solution based on `torchdata`'s [`StatefulDataLoader`](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader) that we will incorporate soon.

For now, we allow splitting the dataset between the different data parallel ranks and plan to support interleaved datasets.

### ChatTokenizer
To apply the chat template, tokenize the conversations, and store the role of each token, we have developed the [`ChatTokenizer`](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_tokenizer.py#L6). Based on the one included in [`meta/llama3`](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), [this tokenizer will return](https://github.com/swiss-ai/nanotron/blob/c026422e5bf0bc1086c039e65d8f7bbe75dc9728/src/nanotron/data/chat_dataset.py#L92) the `tokens` of the conversation and the list of bools `is_completions` indicating whether the token belongs to the model's responses or not, necessary for Feature 2.

For now, we only support the Llama3 tokenizer along with the official chat template of this model.

### Recover DataLoader
Pending development
Binary file added docs/sft_feature1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sft_feature2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sft_feature3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions examples/config_llama8b_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
checkpoints:
checkpoint_interval: 1000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
hf_dataset: Magpie-Align/Magpie-Pro-300K-Filtered
hf_dataset_split: train
conversation_column_name: conversations
train_on_completions_only: true
remove_cross_attention: true
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: SFT-Todi
run: Llama3-8B
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 131072
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rope_interleaved: false
rope_theta: 500000.0
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 2.0e-5
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 4
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 4
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: /store/swissai/a06/models/nanotron_checkpoints/Meta-Llama-3.1-8B-Instruct
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 4
sequence_length: 4096
train_steps: 250
val_check_interval: -1
113 changes: 113 additions & 0 deletions examples/config_nanoset_dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
checkpoints:
checkpoint_interval: 100000000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
dataset_folder: /mloscratch/homes/solergib/SFT/nanotron/datasets/SlimPajama-6B
remove_document_xattention: True
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
- data:
dataset:
dataset_folder:
- datasets/SlimPajama-6B/tokenized
- datasets/c4-es/tokenized
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 150000
- data:
dataset:
dataset_folder:
datasets/SlimPajama-6B/tokenized: 0.8
datasets/c4-es/tokenized: 0.2
num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
start_training_step: 150001
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: Nanoset
run: llama
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 128000
eos_token_id: 128001
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 4096
num_hidden_layers: 2
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rope_interleaved: false
rope_theta: 500000.0
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3.1-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 1
sequence_length: 4096
train_steps: 200000
val_check_interval: -1
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"safetensors",
"dacite",
"tqdm",
"wandb",
]

[tool.setuptools.packages.find]
Expand Down
30 changes: 28 additions & 2 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@

import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.config import ChatDatasetsArgs, DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.chat_dataset import ChatDataset
from nanotron.data.dataloader_builder import build_chat_dataloader, build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
Expand Down Expand Up @@ -162,6 +163,7 @@ def get_dataloader_from_data_stage(
train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
remove_document_xattention=data.dataset.remove_document_xattention,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
Expand All @@ -171,6 +173,30 @@ def get_dataloader_from_data_stage(
dataloader_drop_last=True,
)

return train_dataloader
# Case 4: Chat Datasets
elif isinstance(data.dataset, ChatDatasetsArgs):
with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = ChatDataset(
dataset_path=data.dataset.hf_dataset,
tokenizer_name_or_path=trainer.config.tokenizer.tokenizer_name_or_path,
sequence_length=trainer.sequence_length,
train_on_completions_only=data.dataset.train_on_completions_only,
remove_cross_attention=data.dataset.remove_cross_attention,
split=data.dataset.hf_dataset_split,
conversation_column_name=data.dataset.conversation_column_name,
dp_rank=trainer.parallel_context.dp_pg.rank(),
dp_ranks_size=trainer.parallel_context.dp_pg.size(),
)

# Prepare dataloader
train_dataloader = build_chat_dataloader(
dataset=train_dataset,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")
Expand Down
20 changes: 19 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __post_init__(self):
@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, dict, List[str]]
remove_document_xattention: bool = False

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
Expand All @@ -107,11 +108,27 @@ def __post_init__(self):
self.dataset_weights = list(tmp_dataset_folder.values())


@dataclass
class ChatDatasetsArgs:
hf_dataset: str
hf_dataset_split: str
conversation_column_name: str
# Debug
train_on_completions_only: bool = True
remove_cross_attention: bool = True

def __post_init__(self):
if self.hf_dataset_split is None:
self.hf_dataset_split = "train"
if self.conversation_column_name is None:
self.conversation_column_name = "conversations"


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -305,6 +322,7 @@ class OptimizerArgs:
clip_grad: Optional[float]
accumulate_grad_in_fp32: bool
learning_rate_scheduler: LRSchedulerArgs
sft: bool = False


@dataclass
Expand Down
Loading
Loading