Skip to content

Commit

Permalink
Add multilingual validation (#3)
Browse files Browse the repository at this point in the history
Add multilingual validation step.
  • Loading branch information
TJ-Solergibert authored and Negar Foroutan Eghlidi committed Sep 8, 2024
1 parent 93546b3 commit c41730d
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 87 deletions.
77 changes: 39 additions & 38 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
@@ -1,62 +1,63 @@
checkpoints:
checkpoint_interval: 1000
checkpoint_interval: 1000000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
training_folder: datasets/c4-es/train
validation_folder: datasets/c4-es/validation
lang_to_ids:
es: 128002
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
name: General purpose training (Blended dataset)
start_training_step: 1
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
lang_to_ids:
es: 128002
en: 128003
fr: 128004
languages:
- es
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
name: Second purpose training (Single dataset)
start_training_step: 1000
- data:
dataset:
training_folder:
datasets/c4-es/train: 0.6
datasets/c4-en/train: 0.3
datasets/c4-fr/train: 0.1
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
lang_to_ids:
es: 128002
en: 128003
fr: 128004

languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
start_training_step: 25
name: Third purpose training (>1 dataset)
start_training_step: 2000
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: Nanoset
project: MultilingualV2
run: llama
seed: 42
step: null
Expand All @@ -75,12 +76,12 @@ model:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 512
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 512
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 1024
num_hidden_layers: 2
max_position_embeddings: 4096
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
Expand All @@ -89,7 +90,7 @@ model:
rope_theta: 500000.0
rms_norm_eps: 1.0e-06
rope_scaling: null
tie_word_embeddings: true
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
Expand All @@ -112,11 +113,11 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
dp: 2
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp: 4
tp_linear_async_communication: false
tp_mode: REDUCE_SCATTER
profiler: null
Expand All @@ -128,7 +129,7 @@ tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 10
micro_batch_size: 4
sequence_length: 1024
train_steps: 200
val_check_interval: -1
micro_batch_size: 3
sequence_length: 4096
train_steps: 500
val_check_interval: 100
19 changes: 15 additions & 4 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage(
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
dataset_tokens=data.dataset.dataset_tokens,
random_seed=data.seed,
)

Expand All @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage(
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
is_multilingual=True,
)

return train_dataloader
Expand Down Expand Up @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage(
dataset_folders=data.dataset.validation_folder,
sequence_length=trainer.sequence_length,
token_size=token_size,
dataset_tokens=data.dataset.dataset_tokens,
is_valid=True,
random_seed=data.seed,
)
Expand All @@ -256,6 +255,8 @@ def get_valid_dataloader_from_data_stage(
micro_batch_size=trainer.micro_batch_size,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
shuffle=True,
is_multilingual=True,
)

return valid_dataloader
Expand Down Expand Up @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
stage = cast(DatasetStageArgs, stage)

log_rank(
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set",
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set",
logger=logger,
level=logging.INFO,
rank=0,
Expand All @@ -324,8 +325,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloader = (
get_valid_dataloader_from_data_stage(trainer, stage.data)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data)
else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data)
)
# TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times
# the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda
# funcs and directly create all dataloaders.
#
# This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead
# of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with
# the Dataset, collator, etc.
# BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling
# from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve
# the DataLoader object again to delete it (More comments in trainer.py)
dataloaders[stage.name] = dataloader
return dataloaders

Expand Down
17 changes: 12 additions & 5 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB

def __post_init__(self):
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
Expand All @@ -125,13 +125,13 @@ def __post_init__(self):
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

self.dataset_tokens = list(self.lang_to_ids.values())
assert len(self.training_folder) == len(
self.languages
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"

assert len(self.training_folder) == len(
self.validation_folder
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"
assert len(self.training_folder) == len(
self.dataset_tokens
), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})"


@dataclass
Expand Down Expand Up @@ -406,6 +406,13 @@ def __post_init__(self):
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"

# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
# must comply with val_check_interval % iteration_step_info_interval = 0
if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0:
raise ValueError(
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
)

# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
Expand Down
73 changes: 73 additions & 0 deletions src/nanotron/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni
)

return result


@dataclasses.dataclass
class MultilingualNanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.
- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""

sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext

def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"lang_code": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}

# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1)
batch_size, expanded_input_length = input_ids.shape

result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}

result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)

assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"

# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
result["lang_code"] = lang_code

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)

return result
14 changes: 12 additions & 2 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM
from nanotron.dataloader import (
EmptyInfiniteDataset,
get_dataloader_worker_init,
Expand All @@ -20,9 +20,11 @@ def build_nanoset_dataloader(
output_pp_rank: int,
micro_batch_size: int,
dataloader_num_workers: int,
is_multilingual: bool = False,
consumed_train_samples: int = 0,
dataloader_drop_last: bool = True,
dataloader_pin_memory: bool = True,
shuffle: bool = False,
) -> DataLoader:

# Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job
Expand All @@ -39,6 +41,14 @@ def build_nanoset_dataloader(
parallel_context=parallel_context,
)

if is_multilingual:
data_collator = MultilingualNanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
)

# Compute size and rank of dataloader workers
dp_ranks_size = parallel_context.dp_pg.size()
dp_rank = parallel_context.dp_pg.rank()
Expand All @@ -49,7 +59,7 @@ def build_nanoset_dataloader(
dl_rank=dp_rank,
drop_last=dataloader_drop_last,
consumed_train_samples=consumed_train_samples,
shuffle=False,
shuffle=shuffle,
)

return DataLoader(
Expand Down
4 changes: 1 addition & 3 deletions src/nanotron/data/multilingual_nanoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(
dataset_folders: List[str],
sequence_length: int,
token_size: int,
dataset_tokens: List[int],
train_split_num_samples: int = None,
is_valid: bool = False,
dataset_weights: Union[List[float], None] = None,
Expand All @@ -47,7 +46,6 @@ def __init__(
self.sequence_length = sequence_length
self.token_size = token_size
self.train_split_num_samples = train_split_num_samples
self.dataset_tokens = dataset_tokens
self.is_valid = is_valid
self.random_seed = random_seed
self.datatrove_datasets = []
Expand Down Expand Up @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]:
dataset_sample = self.dataset_sample_index[idx]

tokens = self.datatrove_datasets[dataset][dataset_sample]
tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token
tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long)

return tokens

Expand Down
Loading

0 comments on commit c41730d

Please sign in to comment.