Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into moe
Browse files Browse the repository at this point in the history
  • Loading branch information
haeggee committed Oct 23, 2024
2 parents 3bce1f4 + 4d6c1d3 commit 0507c7f
Show file tree
Hide file tree
Showing 16 changed files with 380 additions and 96 deletions.
2 changes: 1 addition & 1 deletion docs/nanoset.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
```

## Under the hood
Expand Down
114 changes: 97 additions & 17 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,16 @@ class NanosetDatasetsArgs:
dataset_folder: Union[str, dict, List[str]]

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = (
None # Set to None so we consume all the samples randomly
)
elif isinstance(
self.dataset_folder, dict
): # Case 3: dict with > 1 dataset_folder and weights
tmp_dataset_folder = self.dataset_folder.copy()
self.dataset_folder = list(tmp_dataset_folder.keys())
self.dataset_weights = list(tmp_dataset_folder.values())
Expand All @@ -116,16 +120,55 @@ def __post_init__(self):
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
languages: List[
str
] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB

def __post_init__(self):
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
self.training_folder = [self.training_folder]
self.validation_folder = [self.validation_folder]
self.dataset_weights = [1]
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights
self.dataset_weights = (
None # Set to None so we consume all the samples randomly
)
elif isinstance(
self.training_folder, dict
): # Case 3: dict with > 1 training_folder and weights
tmp_training_folder = self.training_folder.copy()
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

assert len(self.training_folder) == len(
self.languages
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"

assert len(self.training_folder) == len(
self.validation_folder
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"


@dataclass
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
languages: List[
str
] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB

def __post_init__(self):
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
self.training_folder = [self.training_folder]
self.validation_folder = [self.validation_folder]
self.dataset_weights = [1]
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = (
None # Set to None so we consume all the samples randomly
)
elif isinstance(
self.training_folder, dict
): # Case 3: dict with > 1 training_folder and weights
tmp_training_folder = self.training_folder.copy()
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())
Expand Down Expand Up @@ -167,7 +210,9 @@ class DatasetStageArgs:

def __post_init__(self):
if self.start_training_step < 0:
raise ValueError(f"training_steps should be a positive integer and not {self.start_training_step}")
raise ValueError(
f"training_steps should be a positive integer and not {self.start_training_step}"
)


@dataclass
Expand All @@ -182,6 +227,7 @@ class CheckpointsArgs:
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

Expand Down Expand Up @@ -387,13 +433,19 @@ def __post_init__(self):
if self.profiler is not None and self.profiler.profiler_export_path is not None:
assert self.tokens.train_steps < 10

if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
if (
self.optimizer is not None
and self.optimizer.learning_rate_scheduler.lr_decay_steps is None
):
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps
self.tokens.train_steps
- self.optimizer.learning_rate_scheduler.lr_warmup_steps
)

if self.data_stages is not None:
self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step)
self.data_stages = sorted(
self.data_stages, key=lambda stage: stage.start_training_step
)
names = [stage.name for stage in self.data_stages]
training_steps = [stage.start_training_step for stage in self.data_stages]
assert any(
Expand All @@ -402,7 +454,9 @@ def __post_init__(self):

for stage in self.data_stages:
if names.count(stage.name) > 1:
raise ValueError(f"Each stage should have unique names and not {names}")
raise ValueError(
f"Each stage should have unique names and not {names}"
)

if training_steps.count(stage.start_training_step) > 1:
raise ValueError(
Expand All @@ -411,13 +465,29 @@ def __post_init__(self):

# NOTE: must order the stages by start_training_step from lowest to highest
assert all(
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
self.data_stages[i].start_training_step
< self.data_stages[i + 1].start_training_step
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"

# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
# must comply with val_check_interval % iteration_step_info_interval = 0
if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0:
if (
not self.tokens.val_check_interval
% self.logging.iteration_step_info_interval
== 0
):
raise ValueError(
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
)

# NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we
# must comply with val_check_interval % iteration_step_info_interval = 0
if (
not self.tokens.val_check_interval
% self.logging.iteration_step_info_interval
== 0
):
raise ValueError(
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
)
Expand All @@ -428,7 +498,11 @@ def __post_init__(self):

@property
def global_batch_size(self):
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp
return (
self.tokens.micro_batch_size
* self.tokens.batch_accumulation_per_replica
* self.parallelism.dp
)

def save_as_yaml(self, file_path: str):
config_dict = serialize(self)
Expand Down Expand Up @@ -460,12 +534,18 @@ def get_config_from_dict(
if skip_unused_config_keys:
logger.warning("skip_unused_config_keys set")
config_dict = {
field.name: config_dict[field.name] for field in fields(config_class) if field.name in config_dict
field.name: config_dict[field.name]
for field in fields(config_class)
if field.name in config_dict
}
if skip_null_keys:
logger.warning("Skip_null_keys set")
config_dict = {
k: ({kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v)
k: (
{kk: vv for kk, vv in v.items() if vv is not None}
if isinstance(v, dict)
else v
)
for k, v in config_dict.items()
if v is not None
}
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/config/lighteval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __post_init__(self):
class LightEvalTasksArgs:
"""Arguments related to tasks for LightEval"""

langs: Optional[str] = None
tasks: Optional[str] = None
custom_tasks: Optional[str] = None
max_samples: Optional[int] = None
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
36 changes: 27 additions & 9 deletions src/nanotron/data/multilingual_nanoset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(

# Checks
if isinstance(dataset_folders, str):
warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]")
warnings.warn(
"dataset_folders should be of type List[str] but str was provided. Converting to List[str]"
)
dataset_folders = [dataset_folders]

# Init
Expand All @@ -63,7 +65,9 @@ def __init__(

# Build Nanoset Index
## To build the index we need the length of each dataset
self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets]
self.dataset_lengths = [
len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets
]
## Set dataset weights
if (
dataset_weights is None
Expand All @@ -76,10 +80,14 @@ def __init__(
), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided."
## Build dataset index and dataset sample index
if is_valid: # Valid MultilingualNanoset
self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths)
self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(
self.dataset_lengths
)

else: # Train MultilingualNanoset
self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index()
self.dataset_index, self.dataset_sample_index = (
self.build_train_nanoset_index()
)

self.print_nanoset_info()

Expand Down Expand Up @@ -129,7 +137,9 @@ def build_train_nanoset_index(self) -> np.ndarray:
numpy_random_state.shuffle(dataset_sample_index)
# Concatenate num_epochs the shuffled indexes
dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)])
dataset_sample_index = np.concatenate(
[dataset_sample_index for _ in range(num_epochs)]
)
# Just keep the necessary samples
dataset_index = dataset_index[: self.train_split_num_samples]
dataset_sample_index = dataset_sample_index[: self.train_split_num_samples]
Expand All @@ -152,7 +162,9 @@ def print_nanoset_info(self):
)

# Print samples from each dataset + weight
dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders))
dataset_sample_count = count_dataset_indexes(
self.dataset_index, len(self.dataset_folders)
)
for index, sample_count in enumerate(dataset_sample_count):
log_rank(
f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})",
Expand All @@ -174,7 +186,9 @@ def build_train_nanoset_index_helper(
"""
# Create empty arrays for dataset indices and dataset sample indices
dataset_index = np.empty((n_samples,), dtype="uint")
dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples
dataset_sample_index = np.empty(
(n_samples,), dtype="long"
) # Supports dataset with up to 2**64 samples

# Initialize buffer for number of samples used for each dataset
current_samples = np.zeros((len(weights),), dtype="long")
Expand All @@ -191,7 +205,9 @@ def build_train_nanoset_index_helper(

# Assign the dataset index and update the sample index
dataset_index[sample_idx] = max_error_index
dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index]
dataset_sample_index[sample_idx] = (
current_samples[max_error_index] % dataset_sizes[max_error_index]
)

# Update the total samples for the selected dataset
current_samples[max_error_index] += 1
Expand All @@ -211,4 +227,6 @@ def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray:
dataset_index.extend([i] * length)
dataset_sample_index.extend(range(length))

return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long")
return np.array(dataset_index, dtype="uint"), np.array(
dataset_sample_index, dtype="long"
)
Loading

0 comments on commit 0507c7f

Please sign in to comment.