Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
kylematoba committed Oct 22, 2024
1 parent 6216fab commit 7e99150
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 136 deletions.
149 changes: 130 additions & 19 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@
torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml
```
"""
import time
import argparse
from typing import Dict, cast

import numpy as np

from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.config import (
DataArgs,
DatasetStageArgs,
MultilingualNanosetDatasetsArgs,
NanosetDatasetsArgs,
PretrainDatasetsArgs,
)
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
Expand All @@ -29,13 +33,12 @@
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
import nanotron.trainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader

try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
Expand Down Expand Up @@ -63,10 +66,6 @@ def get_dataloader_from_data_stage(
# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

print("--" * 40)
print(data.dataset)
print(type(data.dataset))
print("--" * 40)
# Case 1: Dummy data generator
if data.dataset is None:
log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
Expand Down Expand Up @@ -149,13 +148,6 @@ def get_dataloader_from_data_stage(
# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
# sleep_seconds = 600
# print(f"Sleeping for {sleep_seconds} seconds")
# time.sleep(sleep_seconds)

print(trainer.config.tokenizer.tokenizer_name_or_path)
# model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-Nemo-Base-2407")
# del model
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
Expand Down Expand Up @@ -185,13 +177,94 @@ def get_dataloader_from_data_stage(
dataloader_drop_last=True,
)

return train_dataloader
# Case 4: MultilingualNanosets
elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Nanoset
from nanotron.data.multilingual_nanoset import MultilingualNanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = MultilingualNanoset(
dataset_folders=data.dataset.training_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)

# Prepare dataloader
train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")

return dataloader


def get_valid_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
# consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples
):

# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

# Only support Validation with MultilingualNanosets
if isinstance(data.dataset, MultilingualNanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Multilingual Nanoset
from nanotron.data.multilingual_nanoset import MultilingualNanoset

with main_rank_first(trainer.parallel_context.world_pg):
valid_dataset = MultilingualNanoset(
dataset_folders=data.dataset.validation_folder,
sequence_length=trainer.sequence_length,
token_size=token_size,
is_valid=True,
random_seed=data.seed,
)

# Prepare dataloader
valid_dataloader = build_nanoset_dataloader(
valid_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
shuffle=True,
is_multilingual=True,
)

return valid_dataloader
else:
raise ValueError(
f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset"
)


def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}

Expand Down Expand Up @@ -233,6 +306,40 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
return dataloaders


def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}

for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)

log_rank(
f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_valid_dataloader_from_data_stage(trainer, stage.data)
if stage_idx == 0
else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data)
)
# TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times
# the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda
# funcs and directly create all dataloaders.
#
# This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead
# of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with
# the Dataset, collator, etc.
# BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling
# from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve
# the DataLoader object again to delete it (More comments in trainer.py)
dataloaders[stage.name] = dataloader
return dataloaders


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
Expand All @@ -245,8 +352,12 @@ def get_args():

# Load trainer and data
trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)
train_dataloader = get_dataloader(trainer)

config = nanotron.trainer.get_config_from_file(config_file)
trainer.train(dataloader, validation_args=config.validation)
# NOTE(tj.solergibert) Build validation dataloaders only if necessary
valid_dataloader = None
if trainer.config.tokens.val_check_interval != -1:
valid_dataloader = get_valid_dataloader(trainer)

# Train
trainer.train(train_dataloader, valid_dataloader)
5 changes: 4 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def __post_init__(self):
self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step
for i in range(len(self.data_stages) - 1)
), "The stages are not sorted by start_training_step in increasing order"

if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0:
raise ValueError(
f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}"
)
# # if lighteval, we need tokenizer to be defined
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None
Expand Down
8 changes: 4 additions & 4 deletions src/nanotron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def all_gather_into_tensor( # pylint: disable=function-redefined
) -> Optional[Work]:
if group is None:
group = dist.torch_dist.distributed_c10d._get_default_group()

assert (
group.size() > 1
), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over"
#
# assert (
# group.size() > 1
# ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over"

if torch_version_above_1_13:
return dist.all_gather_into_tensor(
Expand Down
20 changes: 13 additions & 7 deletions src/nanotron/parallel/pipeline_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
from typing import Dict, Iterable, Optional, Union

import torch
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

from nanotron import distributed as dist
from nanotron import logging
from nanotron.distributed import ProcessGroup
from nanotron.logging import log_rank
from nanotron.optim.gradient_accumulator import GradientAccumulator
from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd
from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model
from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer
from nanotron.utils import ContextManagers
from torch import nn as torch_nn
from torch.nn.parallel import DistributedDataParallel

logger = logging.get_logger(__name__)

Expand All @@ -29,6 +30,7 @@ def forward(
state: PipelineTrainBatchState,
micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]],
model: torch_nn.Module,
is_validation: bool = False,
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Increment the number of backwards
state.nb_forwards += 1
Expand All @@ -52,7 +54,7 @@ def forward(
output["loss"] = output["loss"] / self.nb_microbatches

# Add output as activations that require backward pass
if not isinstance(output["loss"], TensorPointer):
if not isinstance(output["loss"], TensorPointer) and not is_validation:
assert output["loss"].requires_grad
state.register_activation_requiring_backward(output["loss"])
return output
Expand Down Expand Up @@ -134,7 +136,7 @@ def validate_batch_iter(
nb_microbatches: int,
) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]:
# Assign a new state for the current batch
state = PipelineTrainBatchState() # TODO: do i need state?
state = PipelineEvalBatchState()
self.nb_microbatches = nb_microbatches

outputs = []
Expand All @@ -143,7 +145,9 @@ def validate_batch_iter(
# All forward
for micro_batch in batch:
context = self._get_fwd_context(model=model)
output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model)
output = self.forward(
context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True
)
# TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage"
for _ in range(len(state.microbatches_activations_to_send)):
send_activation = state.microbatches_activations_to_send.popleft()
Expand All @@ -157,8 +161,10 @@ def validate_batch_iter(
# Store the loss for each microbatch
if not isinstance(output["loss"], TensorPointer):
output = {k: v.detach() for k, v in output.items()}
outputs.append(output)

outputs.extend(
list(output["sample_loss"])
) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors
return outputs


Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/serialize/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class TrainingMetadata:
last_stage_idx: Optional[int] = None
data_stages: Optional[List[DataStageMetadata]] = None

last_validation_stage_idx: Optional[int] = None

def __post_init__(self):
# NOTE: this is a sanity check after loading a trained checkpoint
total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages)
Expand Down
Loading

0 comments on commit 7e99150

Please sign in to comment.