Skip to content

Commit

Permalink
get rid of multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 8, 2024
1 parent 7c69b82 commit a9da429
Show file tree
Hide file tree
Showing 2 changed files with 255 additions and 234 deletions.
241 changes: 7 additions & 234 deletions sailor/run_ft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,13 @@

from torch import multiprocessing

import numpy as np
from nanotron import logging
from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs
from nanotron.data.dataloader_builder import build_nanoset_dataloader
from nanotron.dataloader import (
clm_process,
dummy_infinite_data_generator,
get_datasets,
get_train_dataloader,
)
from nanotron.helpers import (
compute_remain_train_steps_of_a_data_stage_from_ckp,
get_consumed_train_samples_of_a_data_stage_from_ckp,
)
from nanotron.logging import log_rank
from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
from nanotron.trainer import DistributedTrainer
from nanotron.utils import main_rank_first
from torch.utils.data import DataLoader

try:
from huggingface_hub import __version__ as hf_hub_version
from transformers import AutoTokenizer
from transformers import __version__ as tf_version
except ImportError:
hf_hub_version = None
tf_version = None

logger = logging.get_logger(__name__)

from orchestration_pb2_grpc import WorkerAgentServicer, add_WorkerAgentServicer_to_server
from orchestration_pb2 import CheckReadyResponse, KillResponse, WorkerConfigurationResponse


class ElasticWorkerAgent(WorkerAgentServicer):
def __init__(self, script_args):
self.training_process = None
self.training_process_alive = False
self.hostname = socket.gethostname()
self.world_size = 0
self.node_rank = -1
Expand All @@ -59,9 +29,9 @@ def CheckReady(self, request, context):

def Kill(self, request, context):
print(f"Killing local process ...")
if self.training_process:
self.training_process.terminate()
self.training_process = None
if self.training_process_alive:
os.system("pkill -f run_train_custom.py") # TODO: check cleanup
self.training_process_alive = False
# TODO: check abort
return KillResponse()

Expand All @@ -73,9 +43,9 @@ def ConfigurationChange(self, request, context):
topology_list = list(request.topology)
if self.is_in_topo(topology_list):
print(f"Starting new process, node rank is {self.node_rank}")
self.training_process = multiprocessing.Process(target=run, args=(args.config_file, self.world_size, self.node_rank, self.master_addr))
self.training_process.start()

start_cmd = f"python run_train_custom.py --config_file {self.script_args.config_file} --world_size {self.world_size} --rank {self.node_rank} --master_ip {self.master_addr}"
os.system(start_cmd)
self.training_process_alive = True
return WorkerConfigurationResponse()

def is_in_topo(self, topology):
Expand All @@ -86,203 +56,6 @@ def is_in_topo(self, topology):
self.master_addr = topology[0]
return True


def get_dataloader_from_data_stage(
trainer: DistributedTrainer,
data: DataArgs,
consumed_train_samples: int,
num_remaining_train_steps: int,
):
"""
Returns a dataloader for a given data stage.
data: The data configuration for the current stage.
consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero).
num_remaining_train_steps: The number of remaining training steps for this stage.
"""
assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0"
assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0"

# First, we need to know which ranks to feed the dataloader to
input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)

# Case 1: Dummy data generator
if data.dataset is None:
log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
dataloader = dummy_infinite_data_generator(
micro_batch_size=trainer.micro_batch_size,
sequence_length=trainer.sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
vocab_size=trainer.model_config.vocab_size,
seed=data.seed,
parallel_context=trainer.parallel_context,
)()

# Case 2: HuggingFace datasets
elif isinstance(data.dataset, PretrainDatasetsArgs):
log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
log_rank(
f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
logger=logger,
level=logging.INFO,
rank=0,
)

# We need to the 1st device to process dataset and cache it, then other devices load from cache
with main_rank_first(trainer.parallel_context.world_pg):
# TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
# TODO: generalise to include for validation/test splits

# We load the raw dataset
raw_dataset = get_datasets(
hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets,
hf_dataset_config_name=data.dataset.hf_dataset_config_name,
splits=data.dataset.hf_dataset_splits,
)["train"]

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# Check that tokenizer's vocab size is smaller than the model's vocab size
assert (
tokenizer.vocab_size <= trainer.model_config.vocab_size
), f"Tokenizer's vocab size ({tokenizer.vocab_size}) is larger than the model's vocab size ({trainer.model_config.vocab_size})"

# We apply the Causal Language Modeling preprocessing
train_dataset = clm_process(
raw_dataset=raw_dataset,
tokenizer=tokenizer,
text_column_name=data.dataset.text_column_name,
dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process,
dataset_overwrite_cache=data.dataset.dataset_overwrite_cache,
sequence_length=trainer.sequence_length,
)

# We load the processed dataset on the ranks requiring it
dataloader = get_train_dataloader(
train_dataset=train_dataset,
sequence_length=trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
seed_worker=data.seed,
dataloader_drop_last=True,
)

# Check if we have enough samples for train_steps
total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length
num_tokens_needed_for_training = (
num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length
)
assert num_tokens_needed_for_training <= total_tokens_dataset, (
f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), "
f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}"
)

# Case 3: Nanosets
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Nanoset
from nanotron.data.nanoset import Nanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = Nanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)

# Prepare dataloader
train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
micro_batch_size=trainer.micro_batch_size,
consumed_train_samples=consumed_train_samples,
dataloader_num_workers=data.num_loading_workers,
dataloader_drop_last=True,
)

return train_dataloader
else:
raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}")

return dataloader


def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]:
dataloaders = {}

for stage_idx, stage in enumerate(trainer.config.data_stages):
# NOTE: we only create the dataloader for the first stage,
# then we lazy initialize the dataloader for the other stages
stage = cast(DatasetStageArgs, stage)
consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata)
assert (
consumed_train_samples is not None
), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint"

num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, trainer.config, trainer.metadata
)
log_rank(
f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples",
logger=logger,
level=logging.INFO,
rank=0,
)

dataloader = (
get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
if stage_idx == 0
else lambda stage=stage: get_dataloader_from_data_stage(
trainer,
stage.data,
consumed_train_samples=consumed_train_samples,
num_remaining_train_steps=num_remaining_train_steps,
)
)
dataloaders[stage.name] = dataloader
return dataloaders


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
return parser.parse_args()

def run(config_file, world_size, rank, master_addr):

os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = "1234" # TODO

trainer = DistributedTrainer(config_file)
dataloader = get_dataloader(trainer)

# Train
trainer.train(dataloader)

if __name__ == "__main__":
multiprocessing.set_start_method('spawn')

Expand Down
Loading

0 comments on commit a9da429

Please sign in to comment.