Skip to content

Commit

Permalink
New mcore transformer block spec (#8925)
Browse files Browse the repository at this point in the history
* update package info (#8793)

Signed-off-by: eharper <[email protected]>

* update mcore (#8917)

Signed-off-by: Jan Baczek <[email protected]>

* Use new mcore transformer block config handling

Signed-off-by: Jan Baczek <[email protected]>

* API fixes

Signed-off-by: Jan Baczek <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert chages to CI and Dockerfile

Signed-off-by: Jan Baczek <[email protected]>

---------

Signed-off-by: eharper <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 24, 2024
1 parent e9bcaf3 commit 8f95646
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build
from megatron.core.transformer.transformer_layer import BaseTransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint

Expand Down Expand Up @@ -322,8 +324,10 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta


# Use this spec to use the full Transformer layer from Transformer Engine
def get_gpt_full_te_layer_autocast_spec() -> ModuleSpec:
def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec:
if not HAVE_MEGATRON_CORE or not HAVE_TE:
raise ImportError(IMPORT_ERROR)

return ModuleSpec(module=TETransformerLayerAutocast)
num_layers = get_num_layers_to_build(transformer_config)
return TransformerBlockSubmodules(
layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, layer_norm=FusedLayerNorm
)
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def mcore_supports_moe() -> bool:
return False


def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True):
def get_specs(spec_name, transformer_config=None, use_te=True):
# else cases for backwards compatibility with neva
num_experts = transformer_config.num_moe_experts if transformer_config else None
moe_grouped_gemm = transformer_config.moe_grouped_gemm if transformer_config else False

if num_experts is not None:
assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE"

Expand All @@ -145,7 +149,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True):
"": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm),
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(transformer_config),
"ammo": get_gpt_layer_ammo_spec(),
}
if spec_name not in name_spec_dict:
Expand Down Expand Up @@ -391,12 +395,7 @@ def model_provider_func(self, pre_process, post_process):
if self.mcore_gpt:
model = MCoreGPTModel(
config=self.transformer_config,
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
),
transformer_layer_spec=get_specs(self.spec_name, self.transformer_config, self.transformer_engine,),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down

0 comments on commit 8f95646

Please sign in to comment.