diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7fc6cc708c31..4cd7edde2e3d 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5225,8 +5225,6 @@ jobs: --pp_size 1 \ --mbs 1 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_SFT_TP1PP1_MBS2: needs: [cicd-test-container-setup] @@ -5256,8 +5254,6 @@ jobs: --pp_size 1 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_SFT_TP1PP2_MBS2: needs: [cicd-test-container-setup] @@ -5287,8 +5283,6 @@ jobs: --pp_size 2 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_SFT_TP2PP1_MBS2: needs: [cicd-test-container-setup] @@ -5318,8 +5312,35 @@ jobs: --pp_size 1 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} + + L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 3 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft none \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 6 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft none \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed + L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1: needs: [cicd-test-container-setup] @@ -5349,8 +5370,6 @@ jobs: --pp_size 1 \ --mbs 1 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2: needs: [cicd-test-container-setup] @@ -5380,8 +5399,6 @@ jobs: --pp_size 1 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2: needs: [cicd-test-container-setup] @@ -5411,8 +5428,6 @@ jobs: --pp_size 2 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2: needs: [cicd-test-container-setup] @@ -5442,8 +5457,33 @@ jobs: --pp_size 1 \ --mbs 2 - AFTER_SCRIPT: | - rm -rf /tmp/nemo2_gpt_finetune/${{ github.run_id }} + L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 3 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft lora \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 6 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft lora \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] @@ -5597,10 +5637,12 @@ jobs: - L2_NeMo_2_GPT_SFT_TP1PP1_MBS2 - L2_NeMo_2_GPT_SFT_TP1PP2_MBS2 - L2_NeMo_2_GPT_SFT_TP2PP1_MBS2 + - L2_NeMo_2_GPT_SFT_TP1PP1_MBS1_PACKED - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1 - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2 - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 + - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_INT8_SQ - L2_PTQ_Llama2_FP8 diff --git a/nemo/collections/llm/gpt/data/dolly.py b/nemo/collections/llm/gpt/data/dolly.py index 78751d60cdb0..fb8cf9fd5da0 100644 --- a/nemo/collections/llm/gpt/data/dolly.py +++ b/nemo/collections/llm/gpt/data/dolly.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from nemo.collections.common.tokenizers import TokenizerSpec + from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs class DollyDataModule(FineTuningDataModule, IOMixin): @@ -56,7 +57,7 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, pad_to_max_length: bool = False, - packed_sequence_size: int = -1, + packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, ): self.force_redownload = force_redownload self.delete_raw = delete_raw @@ -74,7 +75,7 @@ def __init__( pin_memory=pin_memory, persistent_workers=persistent_workers, pad_to_max_length=pad_to_max_length, - packed_sequence_size=packed_sequence_size, + packed_sequence_specs=packed_sequence_specs, ) def prepare_data(self) -> None: diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 3e4dba7ec89c..01cf617a094d 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -20,12 +20,14 @@ import pytorch_lightning as pl from torch.utils.data import DataLoader +from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.llm.gpt.data.core import create_sft_dataset from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging if TYPE_CHECKING: from nemo.collections.common.tokenizers import TokenizerSpec + from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs class FineTuningDataModule(pl.LightningDataModule): @@ -50,10 +52,7 @@ class FineTuningDataModule(pl.LightningDataModule): persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False. max_train_steps (int, optional): Maximum number of steps to train. Used to calculate samples mapping for the mmap dataset pad_to_max_length (bool, optional): Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. - packed_sequence_size (int, optional): If a positive integer, this arg enables training with sequence packing and specifies the pack size - If less than or equal to 0, sequence packing is disabled. Defaults to -1. - Note: This arg is distinct from `seq_length` because `seq_length` specifies the maximum length of the original sequence - (i.e. the length to truncate long sequences in the input data). + packed_sequence_specs (PackedSequenceSpecs, optional): See PackedSequenceSpecs for details """ def __init__( @@ -70,7 +69,7 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, pad_to_max_length: bool = False, - packed_sequence_size: int = -1, + packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, ): super().__init__() self.seq_length = seq_length @@ -87,22 +86,21 @@ def __init__( self.data_sampler = None self.max_train_samples = None self.pad_to_max_length = pad_to_max_length - self.packed_sequence_size = packed_sequence_size - self._adjust_batch_sizes_for_packed_sequence() + self.packed_sequence_specs = packed_sequence_specs + self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size + self.validate_batch_size_for_packed_sequence() - def _adjust_batch_sizes_for_packed_sequence(self): + def validate_batch_size_for_packed_sequence(self): if self.packed_sequence_size > 0 and self.micro_batch_size > 1: - logging.warning( + raise ValueError( "Micro batch size should be 1 when training with packed sequence, but your micro batch size " - f"is {self.micro_batch_size}. Your config will be automatically updated to the following: " - f"MBS will be set to 1 (from {self.micro_batch_size}), " - f"GBS will be set to {self.global_batch_size // self.micro_batch_size} (from {self.global_batch_size}), " - f"packed sequence length will be set to {self.packed_sequence_size*self.micro_batch_size} (from {self.packed_sequence_size}). " + f"is {self.micro_batch_size}. \nThe following config is equivalent to your current setting for " + f"a packed dataset. Please update your config to the following: \n" + f"Set micro batch size to 1 (currently {self.micro_batch_size})\n" + f"Set global batch size to {self.global_batch_size // self.micro_batch_size} (currently {self.global_batch_size}) \n" + f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} (currently {self.packed_sequence_size}) \n" f"For details please visit https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html" ) - self.global_batch_size //= self.micro_batch_size - self.packed_sequence_size *= self.micro_batch_size - self.micro_batch_size = 1 def prepare_data(self) -> None: if self.packed_sequence_size > 0 and not self.train_path_packed.is_file(): @@ -187,7 +185,12 @@ def train_path(self) -> Path: @property def train_path_packed(self) -> Path: if self.packed_sequence_size > 0: - return self.dataset_root / f"training_packed{self.packed_sequence_size}.npy" + if self.packed_sequence_specs.packed_data_path is not None: + return self.packed_sequence_specs.packed_data_path + tokenizer_model_name = self._extract_tokenizer_model_name() + folder_name = self.dataset_root / "packed" / tokenizer_model_name + folder_name.mkdir(parents=True, exist_ok=True) + return folder_name / f"training_{self.packed_sequence_size}.npy" else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") @@ -198,3 +201,18 @@ def validation_path(self) -> Path: @property def test_path(self) -> Path: return self.dataset_root / "test.jsonl" + + def _extract_tokenizer_model_name(self) -> str: + if self.packed_sequence_specs.tokenizer_model_name is not None: + tokenizer_model_name = self.packed_sequence_specs.tokenizer_model_name + elif isinstance(self.tokenizer, AutoTokenizer): + name = self.tokenizer.tokenizer.name_or_path + if name.endswith("nemo_tokenizer"): + # NEMO_HOME/hf_org/hf_model/nemo_tokenizer => hf_org--hf_model + tokenizer_model_name = '--'.join(name.split("/")[-3:-1]) + else: + # hf_org/hf_model => hf_org--hf_model + tokenizer_model_name = name.replace("/", "--") + else: + tokenizer_model_name = f"unknown_tokenizer_{hash(self.tokenizer)}" + return tokenizer_model_name diff --git a/nemo/collections/llm/gpt/data/packed_sequence.py b/nemo/collections/llm/gpt/data/packed_sequence.py index 4675b3fbb398..372e851da7cd 100644 --- a/nemo/collections/llm/gpt/data/packed_sequence.py +++ b/nemo/collections/llm/gpt/data/packed_sequence.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from dataclasses import dataclass from pathlib import Path from typing import Optional @@ -83,3 +83,32 @@ def prepare_packed_sequence_data( # save output data np.save(output_path, output_data) logging.info(f"Packed sequence is prepared and saved to {output_path}") + + +@dataclass +class PackedSequenceSpecs: + packed_sequence_size: int = -1 + """ + If a positive integer, this arg enables training with sequence packing and specifies the pack size + If less than or equal to 0, sequence packing is disabled. Defaults to -1. + Note: This arg is distinct from `seq_length` because `seq_length` specifies the maximum length of the original sequence + (i.e. the length to truncate long sequences in the input data). + """ + + tokenizer_model_name: str = None + """ + Keep track of tokenizer model name, since each tokenizer produces a different packed sequence dataset file. + This field is set by llm.finetune api. + """ + + packed_data_path: Path = None + """ + If specified, use the packed dataset from this file instead of the default path. + """ + + def __post_init__(self): + if self.packed_data_path is not None: + assert ( + self.packed_data_path.suffix == ".npy" + ), f"packed data file must be a .npy file: {self.packed_data_path}" + assert self.packed_data_path.exists(), f"packed data file does not exist: {self.packed_data_path}" diff --git a/nemo/collections/llm/gpt/data/squad.py b/nemo/collections/llm/gpt/data/squad.py index ec0fc1aad02c..f872db94077d 100644 --- a/nemo/collections/llm/gpt/data/squad.py +++ b/nemo/collections/llm/gpt/data/squad.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from nemo.collections.common.tokenizers import TokenizerSpec + from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs class SquadDataModule(FineTuningDataModule, IOMixin): @@ -54,7 +55,7 @@ def __init__( pin_memory: bool = True, persistent_workers: bool = False, pad_to_max_length: bool = False, - packed_sequence_size: int = -1, + packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, ): self.force_redownload = force_redownload self.delete_raw = delete_raw @@ -72,7 +73,7 @@ def __init__( pin_memory=pin_memory, persistent_workers=persistent_workers, pad_to_max_length=pad_to_max_length, - packed_sequence_size=packed_sequence_size, + packed_sequence_specs=packed_sequence_specs, ) def prepare_data(self) -> None: diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index dc4fb8ececc5..f62613db891b 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -127,7 +127,7 @@ def __init__( index_mapping_dir=index_mapping_dir, ) - if is_distributed: + if is_distributed and not _lightning_prepare_data(): torch.distributed.barrier() if is_distributed and AppState().local_rank == 0: @@ -152,7 +152,7 @@ def __init__( index_mapping_dir=index_mapping_dir, ) - if is_distributed: + if is_distributed and not _lightning_prepare_data(): torch.distributed.barrier() logging.info(f"Loading data files") @@ -749,3 +749,19 @@ def get_sample_block(self, block_idx: int) -> np.ndarray: sample_block = sample_block % self.dataset_size return sample_block + + +def _lightning_prepare_data(): + """ + This function checks whether it is invoked in lightning's hook "prepare_data", which is run only on rank 0. + TextMemMapDataset contains a torch.distributed.barrier operation, so when run inside the single-process hook + prepare_data, the barrier operation would hang forever. + """ + import inspect + + return any( + [ + frame.function == 'prepare_data' and 'prepare_packed_sequence_data' in frame.code_context[0] + for frame in inspect.stack() + ] + ) diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py index 9eca287669cd..7eaa7744729c 100644 --- a/tests/collections/llm/gpt_finetuning.py +++ b/tests/collections/llm/gpt_finetuning.py @@ -19,6 +19,7 @@ from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer ## NOTE: This script is present for github-actions testing only. @@ -43,6 +44,7 @@ def get_args(): parser.add_argument('--mbs', type=int, default=1, help="micro batch size") parser.add_argument('--tp_size', type=int, default=1, help="tensor parallel size") parser.add_argument('--pp_size', type=int, default=1, help="pipeline parallel size") + parser.add_argument('--packed', action='store_true', help="use packed sequence dataset") return parser.parse_args() @@ -97,7 +99,16 @@ def get_args(): else: peft = None - squad = llm.SquadDataModule(seq_length=2048, micro_batch_size=args.mbs, global_batch_size=8, num_workers=0) + packed_sequence_specs = ( + PackedSequenceSpecs(packed_sequence_size=2048, tokenizer_model_name="dummy_tokenizer") if args.packed else None + ) + dolly = llm.DollyDataModule( + seq_length=2048, + micro_batch_size=args.mbs, + global_batch_size=8, + num_workers=0, + packed_sequence_specs=packed_sequence_specs, + ) tokenizer = get_nmt_tokenizer(tokenizer_model=os.path.join(args.restore_path, "dummy_tokenizer.model")) llama3_8b = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) @@ -109,7 +120,7 @@ def get_args(): llm.finetune( model=llama3_8b, - data=squad, + data=dolly, trainer=trainer, peft=peft, log=logger,