diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml new file mode 100644 index 00000000..599bff6c --- /dev/null +++ b/examples/config_multilingual_nanoset.yaml @@ -0,0 +1,134 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + training_folder: datasets/c4-es/train + validation_folder: datasets/c4-es/validation + lang_to_ids: + es: 128002 + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +- data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 + num_loading_workers: 1 + seed: 42 + name: Second purpose training (> 1 dataset) + start_training_step: 15 +- data: + dataset: + training_folder: + datasets/c4-es/train: 0.6 + datasets/c4-en/train: 0.3 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 + + num_loading_workers: 1 + seed: 42 + name: Third purpose training (Blended dataset) + start_training_step: 25 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Nanoset + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 512 + initializer_range: 0.02 + intermediate_size: 512 + is_llama_config: true + max_position_embeddings: 1024 + num_hidden_layers: 2 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-06 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 4 + sequence_length: 1024 + train_steps: 200 + val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 021d955d..39cda23b 100644 --- a/run_train.py +++ b/run_train.py @@ -12,7 +12,13 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + MultilingualNanosetDatasetsArgs, + NanosetDatasetsArgs, + PretrainDatasetsArgs, +) from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -171,6 +177,40 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: MultilingualNanosets + elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = MultilingualNanoset( + dataset_folders=data.dataset.training_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + dataset_tokens=data.dataset.dataset_tokens, + random_seed=data.seed, + ) + + # Prepare dataloader + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -178,6 +218,53 @@ def get_dataloader_from_data_stage( return dataloader +def get_valid_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples +): + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Only support Validation with MultilingualNanosets + if isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Multilingual Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + valid_dataset = MultilingualNanoset( + dataset_folders=data.dataset.validation_folder, + sequence_length=trainer.sequence_length, + token_size=token_size, + dataset_tokens=data.dataset.dataset_tokens, + is_valid=True, + random_seed=data.seed, + ) + + # Prepare dataloader + valid_dataloader = build_nanoset_dataloader( + valid_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + + return valid_dataloader + else: + raise ValueError( + f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset" + ) + + def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} @@ -219,6 +306,30 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: return dataloaders +def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + + log_rank( + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_valid_dataloader_from_data_stage(trainer, stage.data) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") @@ -231,7 +342,8 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + train_dataloader = get_dataloader(trainer) + valid_dataloader = get_valid_dataloader(trainer) # Train - trainer.train(dataloader) + trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..dd2c157d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,11 +107,38 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class MultilingualNanosetDatasetsArgs: + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, List[str]] + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + + def __post_init__(self): + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] + self.dataset_weights = [1] + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) + + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py new file mode 100644 index 00000000..7af57448 --- /dev/null +++ b/src/nanotron/data/multilingual_nanoset.py @@ -0,0 +1,214 @@ +import os +import warnings +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from datatrove.utils.dataset import DatatroveFolderDataset +from nanotron import logging +from nanotron.data.utils import count_dataset_indexes, normalize +from nanotron.logging import log_rank +from numba import jit + +logger = logging.get_logger(__name__) + + +class MultilingualNanoset(torch.utils.data.Dataset): + """ + The Nanoset dataset + + Args: + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + sequence_length (int): Sequence length of the built samples + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size + """ + + def __init__( + self, + dataset_folders: List[str], + sequence_length: int, + token_size: int, + dataset_tokens: List[int], + train_split_num_samples: int = None, + is_valid: bool = False, + dataset_weights: Union[List[float], None] = None, + random_seed: int = 1234, + ) -> None: + + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + + # Init + self.dataset_folders = dataset_folders + self.sequence_length = sequence_length + self.token_size = token_size + self.train_split_num_samples = train_split_num_samples + self.dataset_tokens = dataset_tokens + self.is_valid = is_valid + self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) + + # Build Nanoset Index + ## To build the index we need the length of each dataset + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + ## Set dataset weights + if ( + dataset_weights is None + ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch + self.dataset_weights = normalize(self.dataset_lengths) + else: + self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." + ## Build dataset index and dataset sample index + if is_valid: # Valid MultilingualNanoset + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) + + else: # Train MultilingualNanoset + self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() + + self.print_nanoset_info() + + def __len__(self) -> int: + """ + Returns: + int: The number of samples of the Nanoset + """ + + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + """ + Returns sequence_length + 1 tokens from the memmap dataset + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + """ + dataset = self.dataset_index[idx] + dataset_sample = self.dataset_sample_index[idx] + + tokens = self.datatrove_datasets[dataset][dataset_sample] + tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + + return tokens + + def build_train_nanoset_index(self) -> np.ndarray: + """ + Build train dataset index and dataset sample index + """ + # Compute samples per epoch and number of epochs + samples_per_epoch = sum(self.dataset_lengths) + num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 + # Build the dataset indexes for 1 epoch + dataset_index, dataset_sample_index = build_train_nanoset_index_helper( + n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths + ) + # Shuffle the indexes the same way + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_index) + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_sample_index) + # Concatenate num_epochs the shuffled indexes + dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) + # Just keep the necessary samples + dataset_index = dataset_index[: self.train_split_num_samples] + dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] + + return dataset_index, dataset_sample_index + + def print_nanoset_info(self): + + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # Print samples from each dataset + weight + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + for index, sample_count in enumerate(dataset_sample_count): + log_rank( + f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +@jit(nopython=True, cache=True) +def build_train_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given multiple datasets and a weighting array, build samples indexes + such that it follows those weights. + For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths + for generating the valid split + """ + # Create empty arrays for dataset indices and dataset sample indices + dataset_index = np.empty((n_samples,), dtype="uint") + dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + + # Initialize buffer for number of samples used for each dataset + current_samples = np.zeros((len(weights),), dtype="long") + + # Iterate over all samples + for sample_idx in range(n_samples): + + # Convert sample index to float for comparison against weights + sample_idx_float = max(sample_idx, 1.0) + + # Find the dataset with the highest error + errors = weights * sample_idx_float - current_samples + max_error_index = np.argmax(errors) + + # Assign the dataset index and update the sample index + dataset_index[sample_idx] = max_error_index + dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] + + # Update the total samples for the selected dataset + current_samples[max_error_index] += 1 + + return dataset_index, dataset_sample_index + + +@jit(nopython=True, cache=True) +def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bc81e326..3f4c5189 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -393,7 +393,10 @@ def find_stage_idx_to_resume(): def train( self, - dataloader_or_dls: Dict[ + train_dataloader_or_dls: Dict[ + str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + ], + valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] ], **kwargs, @@ -424,7 +427,7 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c668aa58..8383ba38 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -98,7 +98,9 @@ def main(args): dataset_options={"split": args.split}, ) elif args.readers == "parquet": - datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + datatrove_reader = ParquetReader( + data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern + ) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) @@ -107,6 +109,7 @@ def main(args): datatrove_reader, DocumentTokenizer( output_folder=args.output_folder, + shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, max_tokens_per_file=1e9,