diff --git a/Dockerfile.ci b/Dockerfile.ci index ac36e6429475..dd8af593768f 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -32,9 +32,9 @@ EOF WORKDIR /workspace # Install NeMo requirements -ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e +ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=0ab8dd4c7520408683fdb9f8ac119eff7d38fc0e +ARG MCORE_TAG=0bc3547702464501feefeb5523b7a17e591b21fa ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ @@ -61,6 +61,22 @@ git checkout ${MCORE_TAG} && \ popd && \ popd export PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" + +# Mamba dependancy installation +git clone https://github.com/state-spaces/mamba.git && \ + cd mamba && \ + git checkout v2.0.3 && \ + python setup.py install && \ + cd .. && \ + rm -rf mamba + +git clone https://github.com/Dao-AILab/causal-conv1d && \ + cd causal-conv1d && \ + git checkout v1.2.2.post1 && \ + python setup.py install && \ + cd .. && \ + rm -rf causal-conv1d + EOF # Copy over NeMo code diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml index 3684b61bb186..33498540a3d5 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml @@ -48,119 +48,38 @@ exp_manager: model: - restore_from_path: null - # model parallelism - mcore_gpt: True - micro_batch_size: 1 - global_batch_size: 8 - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null - expert_model_parallel_size: 1 # expert model parallelism - - vocab_size: 65536 - # model architecture - encoder_seq_length: 4096 - hybrid_override_pattern: null - max_position_embeddings: ${.encoder_seq_length} - position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. - num_layers: 64 - gated_linear_unit: False - add_bias_linear: False - num_query_groups: 8 - ngroups_mamba: 8 - attention_dropout: 0.0 - hidden_dropout: 0.0 - hidden_size: 4096 - ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. - num_attention_heads: 32 - transformer_block_type: pre_ln - init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') - kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null - apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. - normalization: RMSNorm - layernorm_epsilon: 1e-5 - num_moe_experts: 16 - moe_router_topk: 2 - moe_aux_loss_coeff: 0.001 - make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. - pre_process: True # add embedding - post_process: True # add pooler - megatron_legacy: False - persist_layer_norm: True - - - # mixed-precision - attention_softmax_in_fp32: False - - # Distributed checkpoint setup - dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. - dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU - dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint - - - tokenizer: - library: 'huggingface' - type: 'EleutherAI/gpt-neox-20b' - model: null - vocab_file: null - merge_file: null - sentencepiece_legacy: False - use_fast: True - - # precision - native_amp_init_scale: 4294967296 # 2 ** 32 - native_amp_growth_interval: 1000 - fp32_residual_connection: False # Move residual connections to fp32 - fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 - - # Megatron O2-style half-precision - megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters - grad_allreduce_chunk_size_mb: 125 - - # Fusion - grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. - gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. - bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. - bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. - masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. - get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope - - # miscellaneous seed: 1234 - use_cpu_initialization: False # Init weights on the CPU (slow for large models) - onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. - gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - - ## Activation Checkpointing - # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. - # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + encoder_seq_length: 1024 + global_batch_size: 8 + micro_batch_size: 1 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. - # 'full' will checkpoint the entire transformer layer. - activations_checkpoint_granularity: null # 'selective' or 'full' - activations_checkpoint_method: null # 'uniform', 'block' + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' # 'uniform' divides the total number of transformer layers and checkpoints the input activation - # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity - activations_checkpoint_num_layers: null - # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. - # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. - num_micro_batches_with_partial_activation_checkpoints: null - # This feature is valid only when used with pipeline-model-parallelism. - # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed - # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is - # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint - # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. - # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_num_layers: null # not used with 'selective' activations_checkpoint_layers_per_pipeline: null - # This feature is valid only when used with pipeline-model-parallelism. - # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later - # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than - # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage - # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', - # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. - sequence_parallel: False + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 peft: peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml index 2d34aefffc7e..fddfa16c8c09 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml @@ -39,113 +39,39 @@ exp_manager: model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} model: - restore_from_path: null - # model parallelism - mcore_gpt: True - micro_batch_size: 2 - global_batch_size: 2 - tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 1 - virtual_pipeline_model_parallel_size: null - expert_model_parallel_size: 1 # expert model parallelism - hybrid_override_pattern: null - vocab_size: 65536 - # model architecture - encoder_seq_length: 4096 - max_position_embeddings: ${.encoder_seq_length} - position_embedding_type: 'none' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental. - num_layers: 64 - gated_linear_unit: False - num_query_groups: 8 - ngroups_mamba: 8 - attention_dropout: 0.0 - hidden_dropout: 0.0 - hidden_size: 4096 - ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. - num_attention_heads: 32 - transformer_block_type: pre_ln - init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') - kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null - apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. - normalization: RMSNorm - layernorm_epsilon: 1e-5 - num_moe_experts: 16 - moe_router_topk: 2 - moe_aux_loss_coeff: 0.001 - make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. - pre_process: True # add embedding - post_process: True # add pooler - megatron_legacy: False - persist_layer_norm: True - add_bias_linear: False - - answer_only_loss: True - - tokenizer: - library: 'huggingface' - type: 'EleutherAI/gpt-neox-20b' - model: null - vocab_file: null - merge_file: null - sentencepiece_legacy: False - use_fast: True - - - # precision - native_amp_init_scale: 4294967296 # 2 ** 32 - native_amp_growth_interval: 1000 - fp32_residual_connection: False # Move residual connections to fp32 - fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 - - # Megatron O2-style half-precision - megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters - grad_allreduce_chunk_size_mb: 125 - - # Fusion - grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce. Only used with O2 and no pipeline parallelism.. - gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism and O2. - bias_activation_fusion: False # Use a kernel that fuses the bias addition from weight matrices with the subsequent activation function. - bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. - masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. - get_attention_mask_from_fusion: True # When using fused softmax it will create the attention mask so we won't copy it to the pipeline stages. - apply_rope_fusion: True # Use a kernel to add rotary positional embeddings. Only used if position_embedding_type=rope - - - # miscellaneous seed: 1234 - use_cpu_initialization: False # Init weights on the CPU (slow for large models) - onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. - gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - - ## Activation Checkpointing - # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. - # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + encoder_seq_length: 1024 + global_batch_size: 8 + micro_batch_size: 1 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. - # 'full' will checkpoint the entire transformer layer. - activations_checkpoint_granularity: null # 'selective' or 'full' - activations_checkpoint_recurrent: False # If set to True, the checkpointing is only done for rglru and conv1d and not for attention and mlp layers - activations_checkpoint_method: null # 'uniform', 'block' + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' # 'uniform' divides the total number of transformer layers and checkpoints the input activation - # of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model. + # of each chunk at the specified granularity # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity - activations_checkpoint_num_layers: null - # when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory. - # when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage. - num_micro_batches_with_partial_activation_checkpoints: null - # This feature is valid only when used with pipeline-model-parallelism. - # When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed - # and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is - # set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint - # per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'. - # This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage. + activations_checkpoint_num_layers: null # not used with 'selective' activations_checkpoint_layers_per_pipeline: null - # This feature is valid only when used with pipeline-model-parallelism. - # When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later - # pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than - # stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage - # uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints', - # this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path. - sequence_parallel: False + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + peft: peft_scheme: null # can be either adapter,ia3, lora, or ptuning diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py index fb8a04b947b0..5180bd12b35e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -13,9 +13,8 @@ # limitations under the License. import torch - -# from megatron.core.models.mamba import MambaModel -# from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba import MambaModel +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec from omegaconf.dictconfig import DictConfig from pytorch_lightning.trainer.trainer import Trainer @@ -46,16 +45,15 @@ def model_provider_func(self, pre_process, post_process): self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5) # TODO @ataghibakhsh: add mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8) once MLM MR merged - # TODO @ataghibakhsh: add the following - '''MambaModel( + + model = MambaModel( config=self.transformer_config, max_sequence_length=self.cfg.get('encoder_seq_length', 4096), vocab_size=self.cfg.get('vocab_size', 65536), + mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8), mamba_stack_spec=mamba_stack_spec, hybrid_override_pattern=self.hybrid_override_pattern, - )''' - # after package mismatch is resovled - model = None + ) return model diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 34ca175470ab..45f4af3cfbf3 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -127,14 +127,15 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_ f'model.{mcore_target}', f'model.module.{mcore_target}', ]: # simple string match for now - swap_mcore_mixin(module, mcore_mixin) - if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): - module.add_adapter( - name=peft_name, - cfg=peft_cfg, - base_model_cfg=self.cfg, - model_parallel_config=self.model_parallel_config, - ) + if not isinstance(module, IdentityOp): + swap_mcore_mixin(module, mcore_mixin) + if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): + module.add_adapter( + name=peft_name, + cfg=peft_cfg, + base_model_cfg=self.cfg, + model_parallel_config=self.model_parallel_config, + ) elif isinstance(module, AdapterModuleMixin): if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types(): module.add_adapter( diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index d006ccb7ad65..a1dad5b64a8a 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -1,6 +1,5 @@ accelerated-scan boto3 -causal-conv1d==1.2.0.post2 einops faiss-cpu fasttext @@ -10,7 +9,6 @@ gdown h5py ijson jieba -mamba-ssm==1.2.0.post1 markdown2 matplotlib>=3.3.2 #megatron_core>0.6.0 # add back once mcore on pypi is compatible again diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py index 9a44f9c2c5c4..9dfd9565179d 100644 --- a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py @@ -29,8 +29,9 @@ CUDA_VISIBLE_DEVICES="0" python /NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \ --input_name_or_path \ --output_path \ - --ngroups_mamba 8 \ - --precision bf16 + --mamba_ssm_ngroups 8 \ + --precision bf16 \ + --tokenizer_model_dir ''' @@ -49,17 +50,20 @@ def get_args(): type=str, required=True, ) - parser.add_argument("--ngroups_mamba", type=int, default=8, help="ngroups for Mamba model") + parser.add_argument("--mamba_ssm_ngroups", type=int, default=8, help="ngroups for Mamba model") parser.add_argument( "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weights saved" ) + parser.add_argument( + "--tokenizer_model_dir", type=str, default=None, help="Path to the tokenizer.model, required for 8b models" + ) args = parser.parse_args() return args def convert(args): - checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu')['model'] + checkpoint_weights = torch.load(args.input_name_or_path, map_location='cpu') new_state_dict = {} if 'backbone' in list(checkpoint_weights.keys())[0]: @@ -95,6 +99,11 @@ def convert(args): old_key = f'backbone.layers.{i}.{attr}' new_state_dict[new_key] = checkpoint_weights[old_key] + # Tokenizer settings + tokenizer_library = 'huggingface' + tokenizer_type = 'EleutherAI/gpt-neox-20b' + tokenizer_model = None + else: layer_keys = [key for key in checkpoint_weights.keys() if re.match(r'decoder\.layers\.\d+\.', key)] @@ -103,6 +112,11 @@ def convert(args): new_state_dict = {"model." + key: value for key, value in checkpoint_weights.items()} + # Tokenizer settings + tokenizer_library = 'megatron' + tokenizer_type = 'GPTSentencePieceTokenizer' + tokenizer_model = args.tokenizer_model_dir + layers = defaultdict(list) for key in new_state_dict.keys(): @@ -131,7 +145,10 @@ def convert(args): ].shape nemo_config.model.num_layers = num_layers nemo_config.model.hybrid_override_pattern = layer_pattern - nemo_config.model.ngroups_mamba = args.ngroups_mamba + nemo_config.model.mamba_ssm_ngroups = args.mamba_ssm_ngroups + nemo_config.model.tokenizer.library = tokenizer_library + nemo_config.model.tokenizer.type = tokenizer_type + nemo_config.model.tokenizer.model = tokenizer_model if "-" in layer_pattern: nemo_config.model.ffn_hidden_size = new_state_dict[