diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py new file mode 100644 index 000000000000..b7e12d8fb2de --- /dev/null +++ b/examples/llm/sft/hf.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from torch.utils.data import DataLoader + +from nemo import lightning as nl +from nemo.collections import llm + + +class SquadDataModuleWithPthDataloader(llm.SquadDataModule): + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + batch_size=self.micro_batch_size, + **kwargs, + ) + + +def squad(tokenizer) -> pl.LightningDataModule: + return SquadDataModuleWithPthDataloader( + tokenizer=tokenizer, + seq_length=2048, + micro_batch_size=2, + global_batch_size=128, # assert gbs == mbs * accumulate_grad_batches + num_workers=0, + sanity_check_dist_workers=False, + ) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='meta-llama/Llama-3.2-1B') + parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) + parser.add_argument('--devices', default=1) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument('--wandb-project', type=str, default=None) + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = 0.5 + if args.strategy == 'fsdp': + # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + grad_clip = None + use_dist_samp = False + + llm.api.finetune( + model=llm.HfAutoModelForCausalLM(args.model), + data=squad(llm.HfAutoModelForCausalLM.configure_tokenizer(args.model)), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=args.strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + logger=wandb, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(max_lr=1e-5, clip_grad=0.5)), + log=None, + ) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 4205c401eea8..6dde88079567 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -21,6 +21,7 @@ from nemo.collections.llm.gpt.data import ( DollyDataModule, FineTuningDataModule, + HfDatasetDataModule, MockDataModule, PreTrainingDataModule, SquadDataModule, @@ -57,6 +58,7 @@ GPTConfig126M, GPTConfig175B, GPTModel, + HfAutoModelForCausalLM, Llama2Config7B, Llama2Config13B, Llama2Config70B, @@ -182,6 +184,7 @@ "squad", "dolly", "peft", + "HfAutoModelForCausalLM", ] diff --git a/nemo/collections/llm/gpt/data/__init__.py b/nemo/collections/llm/gpt/data/__init__.py index 45ca0788874f..f4e97d91e5cd 100644 --- a/nemo/collections/llm/gpt/data/__init__.py +++ b/nemo/collections/llm/gpt/data/__init__.py @@ -14,8 +14,16 @@ from nemo.collections.llm.gpt.data.dolly import DollyDataModule from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule +from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule from nemo.collections.llm.gpt.data.squad import SquadDataModule -__all__ = ["FineTuningDataModule", "SquadDataModule", "DollyDataModule", "MockDataModule", "PreTrainingDataModule"] +__all__ = [ + "FineTuningDataModule", + "SquadDataModule", + "DollyDataModule", + "MockDataModule", + "PreTrainingDataModule", + "HfDatasetDataModule", +] diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 01cf617a094d..2545bbc93f1d 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -70,6 +70,7 @@ def __init__( persistent_workers: bool = False, pad_to_max_length: bool = False, packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, + sanity_check_dist_workers: bool = True, ): super().__init__() self.seq_length = seq_length @@ -89,6 +90,7 @@ def __init__( self.packed_sequence_specs = packed_sequence_specs self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size self.validate_batch_size_for_packed_sequence() + self._sanity_check_dist_workers = sanity_check_dist_workers def validate_batch_size_for_packed_sequence(self): if self.packed_sequence_size > 0 and self.micro_batch_size > 1: @@ -134,6 +136,7 @@ def train_dataloader(self) -> DataLoader: self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, max_num_samples=self.max_train_samples, pad_to_max_length=self.pad_to_max_length, + sanity_check_dist_workers=self._sanity_check_dist_workers, ) ) @@ -143,6 +146,7 @@ def val_dataloader(self) -> DataLoader: self.validation_path, is_test=True, pad_to_max_length=self.pad_to_max_length, + sanity_check_dist_workers=self._sanity_check_dist_workers, ), ) @@ -153,6 +157,7 @@ def test_dataloader(self) -> DataLoader: tokens_to_generate=32, is_test=True, pad_to_max_length=self.pad_to_max_length, + sanity_check_dist_workers=self._sanity_check_dist_workers, ) ) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py new file mode 100644 index 000000000000..7e70a970913e --- /dev/null +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + + +class HfDatasetDataModule(pl.LightningDataModule): + def __init__( + self, + dataset, + num_workers=2, + pin_memory=True, + persistent_workers=True, + micro_batch_size=2, + global_batch_size=2, + pad_token_id=0, + use_mcore_sampler=False, + mcore_dataloader_type='cyclic', + ) -> None: + super().__init__() + assert pad_token_id is not None + + self.dataset = dataset + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.pad_token_id = pad_token_id + + self.use_mcore_sampler = use_mcore_sampler + self.mcore_dataloader_type = mcore_dataloader_type + + @staticmethod + def collate_fn(batch, pad_token_id=0): + def batchify(tensor): + if tensor.ndim == 1: + return tensor.unsqueeze_(0) + return tensor + + def extract_key_from_dicts(batch, key): + return list(map(lambda x: x[key], batch)) + + def pad_within_micro(batch, pad_token_id): + max_len = max(map(len, batch)) + return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + + return { + key: batchify( + torch.LongTensor( + pad_within_micro( + extract_key_from_dicts(batch, key), + pad_token_id, + ) + ) + ) + for key in ['tokens', 'labels'] + } + + def train_dataloader(self, collate_fn=None): + from nemo.lightning.data import add_megatron_sampler + + if collate_fn is None: + collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + + dataloader = DataLoader( + self.dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=collate_fn, + batch_size=self.micro_batch_size, + ) + if not self.use_mcore_sampler: + return dataloader + + rank = 0 + world_size = 1 + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + return add_megatron_sampler( + dataloader, + self.micro_batch_size, + self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + rank=rank, + world_size=world_size, + ) diff --git a/nemo/collections/llm/gpt/data/squad.py b/nemo/collections/llm/gpt/data/squad.py index f872db94077d..cabbd444c0cf 100644 --- a/nemo/collections/llm/gpt/data/squad.py +++ b/nemo/collections/llm/gpt/data/squad.py @@ -56,6 +56,7 @@ def __init__( persistent_workers: bool = False, pad_to_max_length: bool = False, packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, + sanity_check_dist_workers: bool = True, ): self.force_redownload = force_redownload self.delete_raw = delete_raw @@ -74,6 +75,7 @@ def __init__( persistent_workers=persistent_workers, pad_to_max_length=pad_to_max_length, packed_sequence_specs=packed_sequence_specs, + sanity_check_dist_workers=sanity_check_dist_workers, ) def prepare_data(self) -> None: diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index ebecc06140fe..26b8d67cb53d 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -37,6 +37,7 @@ GemmaConfig7B, GemmaModel, ) +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, CodeLlamaConfig13B, @@ -166,4 +167,5 @@ "gpt_forward_step", "transformer_engine_layer_spec", "local_layer_spec", + "HfAutoModelForCausalLM", ] diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py new file mode 100644 index 000000000000..794c39738dbe --- /dev/null +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.lightning import io + + +def _extract_non_bias_params(model): + return list(map(lambda x: x[1], filter(lambda x: not 'bias' in x[0], model.named_parameters()))) + + +def masked_cross_entropy(logits, targets, mask=None): + if mask is not None: + loss = F.cross_entropy(logits, targets, reduction='none') + return torch.mean(loss[mask == 1]) + else: + return F.cross_entropy(logits, targets) + + +class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin): + def __init__(self, model_name='gpt2', load_pretrained_weights=True, tokenizer=None, loss_fn=masked_cross_entropy): + super().__init__() + self.save_hyperparameters() + self.model_name = model_name + self._tokenizer = None + self.model = None + self.loss_fn = loss_fn + self.load_pretrained_weights = load_pretrained_weights + self.is_hf_model = True + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name) + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, value): + assert self._tokenizer is None + self._tokenizer = value + + @staticmethod + def configure_tokenizer(model_name): + return AutoTokenizer(model_name) + + def configure_model(self): + # create all your layers here + if self.load_pretrained_weights: + self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype='auto') + else: + from transformers import AutoConfig + + config = AutoConfig.from_pretained(self.model_name) + self.model = AutoModelForCausalLM.from_config(config) + self.model.train() + + def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None): + outputs = self.model( + input_ids=input_ids.to(self.model.device), + attention_mask=attention_mask, + ) + labels = labels.to(self.model.device) + if loss_mask is not None: + loss_mask = loss_mask.to(self.model.device).view(-1) + n_cls = outputs.logits.shape[-1] + outputs.loss = self.loss_fn(outputs.logits.view(-1, n_cls), labels.view(-1), loss_mask) + return outputs + + def training_step(self, batch): + tokens = batch['tokens'] + labels = batch['labels'] + loss_mask = batch.get('loss_mask', None) + output = self.forward( + input_ids=tokens, + labels=labels, + loss_mask=loss_mask, + ) + + loss = output.loss + self.log('train_log', loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + tokens = batch['tokens'] + labels = batch['labels'] + output = self.forward( + input_ids=tokens, + labels=labels, + ) + + loss = output.loss + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True) diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index ff81c3b383fc..b1fc15aee07c 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -18,6 +18,7 @@ chatglm3_6b, gemma_2b, gemma_7b, + hf_auto_model_for_causal_lm, llama3_8b, llama3_8b_16k, llama3_8b_64k, @@ -73,6 +74,7 @@ "mamba2_hybrid_8b", "mistral_7b", "mistral_nemo_12b", + "hf_auto_model_for_causal_lm", "mixtral_8x7b", "mixtral_8x7b_16k", "mixtral_8x7b_64k", diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py new file mode 100644 index 000000000000..6c81bf922152 --- /dev/null +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing +from nemo.utils.exp_manager import TimingCallback + +NAME = "hf_auto_model_for_causal_lm" + + +@run.cli.factory(name=NAME) +def model(model_name) -> run.Config[pl.LightningModule]: + """ + Factory function to create HfAutoModelForCausalLM model configurations. + + Args: + model_name (str): Model id on HF. + + Returns: + run.Config[pl.LightningModule]: Configuration for the HfAutoModelForCausalLM. + + Examples: + CLI usage: + $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + + Python API usage: + >>> model_config = model(model_name="mistralai/Mistral-Nemo-Instruct-2407") + >>> print(model_config) + """ + return run.Config(HfAutoModelForCausalLM, model_name=model_name) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 100, + callbacks: Optional[list[run.Config[Callback]]] = None, + strategy: Optional[str] = 'ddp', + gradient_clip_val: float = 1.0, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for HfAutoModelForCausalLM. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + strategy: Optional[str] = 'ddp': Parallelism strategy. + gradient_clip_val: float = 1.0: gradient-clip value. + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=HfAutoModelForCausalLM ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + """ + strategy = str(strategy).lower() + assert strategy in ['', 'ddp', 'fsdp'], strategy + if strategy == 'fsdp': + # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 + gradient_clip_val = None + + trainer = run.Config( + nl.Trainer, + devices=num_gpus_per_node, + max_steps=max_steps, + accelerator='gpu', + strategy=strategy, + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + callbacks=callbacks, + gradient_clip_val=gradient_clip_val, + use_distributed_sampler=False, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + fn=pretrain, + model_name: str = '', +) -> run.Partial: + """ + Create a pre-training recipe for Mistral 7B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + + Python API usage: + >>> recipe = pretrain_recipe(name="auto_pretrain", num_nodes=2, model_name="mistralai/Mistral-Nemo-Instruct-2407") + >>> print(recipe) + """ + return run.Partial( + fn, + model=model(model_name), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=pytorch_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) diff --git a/nemo/collections/llm/recipes/optim/adam.py b/nemo/collections/llm/recipes/optim/adam.py index c6510577711d..4148d19c6635 100644 --- a/nemo/collections/llm/recipes/optim/adam.py +++ b/nemo/collections/llm/recipes/optim/adam.py @@ -17,7 +17,12 @@ import nemo_run as run from megatron.core.optimizer import OptimizerConfig -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, MegatronOptimizerModule, OptimizerModule +from nemo.lightning.pytorch.optim import ( + CosineAnnealingScheduler, + MegatronOptimizerModule, + OptimizerModule, + PytorchOptimizerModule, +) @run.cli.factory @@ -59,3 +64,55 @@ def distributed_fused_adam_with_cosine_annealing( config=opt_cfg, lr_scheduler=sched, ) + + +@run.cli.factory +def pytorch_adam_with_cosine_annealing( + precision: str = "bf16-mixed", # or "16-mixed" + warmup_steps: int = 2000, + constant_steps: int = 0, + max_lr: float = 1e-5, + min_lr: Optional[float] = None, + clip_grad: float = 1.0, +) -> run.Config[OptimizerModule]: + from torch.optim import Adam + + return run.Config( + PytorchOptimizerModule, + optim_cls=Adam, + config=dict( + lr=max_lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-8, + ), + lr_scheduler=run.Config( + CosineAnnealingScheduler, + warmup_steps=warmup_steps, + constant_steps=constant_steps, + min_lr=min_lr or (0.1 * max_lr), + ), + ) + + +@run.cli.factory +def pytorch_adam_with_flat_lr( + precision: str = "bf16-mixed", # or "16-mixed" + warmup_steps: int = 2000, + constant_steps: int = 0, + max_lr: float = 1e-5, + min_lr: Optional[float] = None, + clip_grad: float = 1.0, +) -> run.Config[OptimizerModule]: + from torch.optim import Adam + + return run.Config( + PytorchOptimizerModule, + optim_cls=Adam, + config=dict( + lr=max_lr, + weight_decay=0.1, + betas=(0.9, 0.95), + eps=1e-8, + ), + ) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py index 17ffc01fb7f4..4ce9701e76b4 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/dataset_utils.py @@ -209,7 +209,7 @@ def create_masked_lm_predictions( # on-the-fly whole word masking is possible. token_boundary = [0] * len(tokens) skip_mask_idx = None # Store the index of token that cannot be masked. - for (i, token) in enumerate(tokens): + for i, token in enumerate(tokens): if token == skip_masking_id: skip_mask_idx = i if token == cls_id or token == sep_id: @@ -285,7 +285,10 @@ def create_masked_lm_predictions( available_ngrams = list(cand_index_set.keys()) # n - 1 because pvals is 0-indexed and available ngrams are 1-indexed. pvals_current = np.array([pvals[n - 1] for n in available_ngrams]) - n = np_rng.choice(available_ngrams, p=pvals_current / pvals_current.sum(keepdims=True),) + n = np_rng.choice( + available_ngrams, + p=pvals_current / pvals_current.sum(keepdims=True), + ) else: # Sampling "n" from the geometric distribution and clipping it to # the max_ngrams. Using p=0.2 default from the SpanBERT paper @@ -488,7 +491,10 @@ def create_extreme_masked_lm_predictions( if span_length_distribution == LengthDistribution.uniform: available_ngrams = list(cand_index_set.keys()) pvals_current = np.array([pvals[n] for n in available_ngrams]) - n = np_rng.choice(available_ngrams, p=pvals_current / pvals_current.sum(keepdims=True),) + n = np_rng.choice( + available_ngrams, + p=pvals_current / pvals_current.sum(keepdims=True), + ) elif span_length_distribution == LengthDistribution.geometric: # Sampling "n" from the geometric distribution and clipping it to # the max_ngrams. Using p=0.2 default from the SpanBERT paper @@ -914,7 +920,13 @@ def build_train_valid_test_datasets( seed, ) test_ds = MockT5Dataset( - cfg, tokenizer, "test", int(train_valid_test_num_samples[2]), max_seq_length, max_seq_length_dec, seed, + cfg, + tokenizer, + "test", + int(train_valid_test_num_samples[2]), + max_seq_length, + max_seq_length_dec, + seed, ) return train_ds, valid_ds, test_ds else: @@ -1257,6 +1269,7 @@ def get_samples_mapping( binary_head, index_mapping_dir: str = None, samples_mapping: Any = None, + sanity_check_dist_workers: bool = True, ): """Get a list that maps a sample index to a starting sentence index, end sentence index, and length""" @@ -1328,14 +1341,16 @@ def get_samples_mapping( logging.info( ' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format(time.time() - start_time) ) - torch.distributed.barrier() - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) - torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() - // torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()) - ) + + if sanity_check_dist_workers: + torch.distributed.barrier() + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True)) + torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() + // torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group()) + ) # Load indexed dataset if not given externally. if samples_mapping is None: logging.info(' > loading indexed mapping from {}'.format(indexmap_filename)) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index c42249cec2f2..898ddb7d716b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -64,6 +64,7 @@ def __init__( output_original_text: bool = False, ceil_to_power_2: bool = False, get_attention_mask_from_fusion: bool = False, + sanity_check_dist_workers: bool = True, ): """ file_path: Path to a JSONL GPT supervised fine-tuning dataset. Data is formatted as multiple JSON lines with each line formatted as follows. {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} @@ -89,6 +90,7 @@ def __init__( special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} is_test: Whether this dataset is the test split. output_original_text (bool): if true, will keep the original text in the output alongside the tokenized ids. + sanity_check_dist_workers (bool): if true, will run sanity check across workers when making mapping. """ self.tokenizer = tokenizer self.file_path = file_path @@ -117,6 +119,7 @@ def __init__( self.output_original_text = output_original_text self.ceil_to_power_2 = ceil_to_power_2 self.get_attention_mask_from_fusion = get_attention_mask_from_fusion + self.sanity_check_dist_workers = sanity_check_dist_workers if special_tokens is None: self.special_tokens = { @@ -196,6 +199,7 @@ def _build_samples_mapping(self): binary_head=False, index_mapping_dir=self.index_mapping_dir, samples_mapping=osm, + sanity_check_dist_workers=self.sanity_check_dist_workers, ) else: self.samples_mapping = None diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 0f30dfe22851..ea7d91b37214 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -139,6 +139,8 @@ def add_megatron_sampler( dataloader_type: Literal["single", "cyclic", "batch"] = "single", drop_last: bool = True, pad_samples_to_global_batch_size: bool = False, + rank: int = 0, + world_size: int = 1, # data_sharding: bool = False ) -> DataLoader: """ @@ -172,9 +174,6 @@ def add_megatron_sampler( Returns: DataLoader: A new DataLoader instance with the configured Megatron sampler. """ - - from megatron.core import parallel_state - if dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataloader.dataset), @@ -182,8 +181,8 @@ def add_megatron_sampler( micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, rampup_batch_size=rampup_batch_size, - data_parallel_rank=parallel_state.get_data_parallel_rank(), - data_parallel_size=parallel_state.get_data_parallel_world_size(), + data_parallel_rank=rank, + data_parallel_size=world_size, drop_last=drop_last, pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, ) @@ -192,8 +191,8 @@ def add_megatron_sampler( total_samples=len(dataloader.dataset), consumed_samples=consumed_samples, micro_batch_size=micro_batch_size, - data_parallel_rank=parallel_state.get_data_parallel_rank(), - data_parallel_size=parallel_state.get_data_parallel_world_size(), + data_parallel_rank=rank, + data_parallel_size=world_size, drop_last=drop_last, # data_sharding=data_sharding ) @@ -207,8 +206,8 @@ def add_megatron_sampler( consumed_samples=consumed_samples, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, - data_parallel_rank=parallel_state.get_data_parallel_rank(), - data_parallel_size=parallel_state.get_data_parallel_world_size(), + data_parallel_rank=rank, + data_parallel_size=world_size, drop_last=drop_last, pad_samples_to_global_batch_size=not drop_last, ) diff --git a/nemo/lightning/pytorch/optim/__init__.py b/nemo/lightning/pytorch/optim/__init__.py index 1572e95e136a..db40e5c48c1b 100644 --- a/nemo/lightning/pytorch/optim/__init__.py +++ b/nemo/lightning/pytorch/optim/__init__.py @@ -28,6 +28,7 @@ WarmupPolicyScheduler, ) from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.optim.pytorch import PytorchOptimizerModule __all__ = [ "OptimizerModule", @@ -45,4 +46,5 @@ "PolynomialDecayAnnealingScheduler", "PolynomialHoldDecayAnnealingScheduler", "CosineAnnealingScheduler", + "PytorchOptimizerModule", ] diff --git a/nemo/lightning/pytorch/optim/pytorch.py b/nemo/lightning/pytorch/optim/pytorch.py new file mode 100644 index 000000000000..6600fc0cf0a4 --- /dev/null +++ b/nemo/lightning/pytorch/optim/pytorch.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List, Optional + +import pytorch_lightning as pl +from torch.optim import Optimizer + +from nemo.lightning.megatron_parallel import MegatronParallel +from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule + + +def _param_does_not_have_wd(param_name, param): + return 'bias' in param_name + + +class PytorchOptimizerModule(OptimizerModule): + """A OptimizerModule for pytorch optimizers. + + Attributes: + config (OptimizerConfig): Configuration for the optimizer. + no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. + scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. + lr_mult (float): Learning rate multiplier. + + Example:: + + config = OptimizerConfig(...) + lr_scheduler = MyLRSchedulerModule(...) + optimizer_module = PytorchOptimizerModule(config, lr_scheduler) + + Methods: + setup(model): Sets up the optimizer. + optimizers(model): Defines the optimizers. + """ + + def __init__( + self, + optim_cls, + config: dict = {'lr': 3e-4}, + lr_scheduler: Optional[LRSchedulerModule] = None, + no_weight_decay_cond: Optional[Callable] = _param_does_not_have_wd, + scale_lr_cond: Optional[Callable] = None, + lr_mult: float = 1.0, + ): + """Initializes the PytorchOptimizerModule. + + Args: + config (OptimizerConfig): Configuration for the optimizer. + lr_scheduler (Optional[LRSchedulerModule]): The learning rate scheduler module. + no_weight_decay_cond (Optional[Callable]): Condition for no weight decay. + scale_lr_cond (Optional[Callable]): Condition for scaling learning rate. + lr_mult (float): Learning rate multiplier. + """ + + super().__init__(lr_scheduler=lr_scheduler) + self.optim_cls = optim_cls + self.config = config + self.no_weight_decay_cond = no_weight_decay_cond + self.scale_lr_cond = scale_lr_cond + self.lr_mult = lr_mult + self.optim_cls = optim_cls + + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + # Noop + pass + + def optimizers(self, model) -> List[Optimizer]: + """Defines the optimizers. + + Args: + model (nn.Module): The model for which the optimizers are being defined. + + Returns: + List[Optimizer]: The list of optimizers. + + Raises: + ValueError: If the model is an instance of MegatronParallel. + """ + + if isinstance(model, MegatronParallel): + raise ValueError("Model cannot be an instance of MegatronParallel") + + params_with_wd, params_without_wd = [], [] + if self.no_weight_decay_cond is not None: + for name, param in model.named_parameters(): + if self.no_weight_decay_cond(name, param): + params_without_wd.append(param) + else: + params_with_wd.append(param) + else: + params_with_wd = model.parameters() + + optimizers = [] + if len(params_with_wd) > 0: + optimizers.append( + self.optim_cls( + params_with_wd, + **self.config, + ) + ) + + if len(params_without_wd) > 0: + wd = self.config.get('weight_decay', None) + kwargs['weight_decay'] = 0 + optimizers.append( + self.optim_cls( + params_without_wd, + **kwargs, + ) + ) + # restore value + if wd is not None: + kwargs['weight_decay'] = wd + + assert len(optimizers) > 0, "Expected at least one optimizer with params" + return optimizers + + def finalize_model_grads(self, *args, **kwargs): + # Noop + pass diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 55bafce5f71e..52ba9e3220ac 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -65,9 +65,14 @@ def setup(self, global_rank: int) -> None: setup_microbatch_calculator(global_rank, self.micro_batch_size, self.global_batch_size, self.rampup_batch_size) def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader: + from megatron.core import parallel_state + from nemo.lightning.data import add_megatron_sampler mode = getattr(dataloader, 'mode', 'train') + + data_parallel_rank = parallel_state.get_data_parallel_rank() + data_parallel_size = parallel_state.get_data_parallel_world_size() return add_megatron_sampler( dataloader, micro_batch_size=self.micro_batch_size, @@ -76,6 +81,8 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0 consumed_samples=self.init_consumed_samples if mode == 'train' else 0, dataloader_type=self.dataloader_type, drop_last=self.drop_last, + rank=data_parallel_rank, + world_size=data_parallel_size, ) def compute_consumed_samples(self, steps_since_resume=0) -> int: diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c5195511c522..b045804044ec 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -267,6 +267,8 @@ def __init__( def connect(self, model: pl.LightningModule) -> None: super().connect(model) + assert not hasattr(model, 'is_hf_model'), "Cannot use HfAutoModelForCausalLM with MegatronParallel" + _maybe_mcore_config = _strategy_lib.set_model_parallel_attributes(model, self.parallelism) if _maybe_mcore_config: self._mcore_config = _maybe_mcore_config