diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index 6cce2b42be9c..f3299d488fd0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -35,7 +35,9 @@ try: from megatron.core import parallel_state, tensor_parallel + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build from megatron.core.transformer.transformer_layer import BaseTransformerLayer from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -322,8 +324,10 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta # Use this spec to use the full Transformer layer from Transformer Engine -def get_gpt_full_te_layer_autocast_spec() -> ModuleSpec: +def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: if not HAVE_MEGATRON_CORE or not HAVE_TE: raise ImportError(IMPORT_ERROR) - - return ModuleSpec(module=TETransformerLayerAutocast) + num_layers = get_num_layers_to_build(transformer_config) + return TransformerBlockSubmodules( + layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, layer_norm=FusedLayerNorm + ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 997235e639d2..de9620e2d79f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -138,7 +138,11 @@ def mcore_supports_moe() -> bool: ## TODO: This function will not work if TE is not installed -def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, hyena_cfg: Dict = None): +def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None): + # else cases for backwards compatibility with neva + num_experts = transformer_config.num_moe_experts if transformer_config else None + moe_grouped_gemm = transformer_config.moe_grouped_gemm if transformer_config else False + if num_experts is not None: assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE" @@ -148,7 +152,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, "": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm), "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), - "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), + "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(transformer_config), "modelopt": get_gpt_layer_modelopt_spec(num_experts), "te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg), } @@ -415,8 +419,7 @@ def model_provider_func(self, pre_process, post_process): config=self.transformer_config, transformer_layer_spec=get_specs( self.spec_name, - self.transformer_config.num_moe_experts, - self.transformer_config.moe_grouped_gemm, + self.transformer_config, self.transformer_engine, self.cfg.get('hyena', None), ),