Skip to content

Commit

Permalink
New mcore transformer block spec (#8925)
Browse files Browse the repository at this point in the history
* update package info (#8793)

Signed-off-by: eharper <[email protected]>

* update mcore (#8917)

Signed-off-by: Jan Baczek <[email protected]>

* Use new mcore transformer block config handling

Signed-off-by: Jan Baczek <[email protected]>

* API fixes

Signed-off-by: Jan Baczek <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert chages to CI and Dockerfile

Signed-off-by: Jan Baczek <[email protected]>

---------

Signed-off-by: eharper <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Jan Baczek <[email protected]>
  • Loading branch information
4 people committed Jul 19, 2024
1 parent ab8988e commit a7ae71d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

try:
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build
from megatron.core.transformer.transformer_layer import BaseTransformerLayer
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint

Expand Down Expand Up @@ -322,8 +324,10 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta


# Use this spec to use the full Transformer layer from Transformer Engine
def get_gpt_full_te_layer_autocast_spec() -> ModuleSpec:
def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec:
if not HAVE_MEGATRON_CORE or not HAVE_TE:
raise ImportError(IMPORT_ERROR)

return ModuleSpec(module=TETransformerLayerAutocast)
num_layers = get_num_layers_to_build(transformer_config)
return TransformerBlockSubmodules(
layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, layer_norm=FusedLayerNorm

Check failure

Code scanning / CodeQL

Wrong name for an argument in a class instantiation Error

Keyword argument 'module' is not a supported parameter name of
ApexGuardDefaults.__init__
.
)
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ def mcore_supports_moe() -> bool:


## TODO: This function will not work if TE is not installed
def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, hyena_cfg: Dict = None):
def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None):
# else cases for backwards compatibility with neva
num_experts = transformer_config.num_moe_experts if transformer_config else None
moe_grouped_gemm = transformer_config.moe_grouped_gemm if transformer_config else False

if num_experts is not None:
assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE"

Expand All @@ -148,7 +152,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True,
"": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm),
"te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm),
"megatron_falcon_gpt": get_falcon_layer_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(),
"megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(transformer_config),
"modelopt": get_gpt_layer_modelopt_spec(num_experts),
"te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg),
}
Expand Down Expand Up @@ -415,8 +419,7 @@ def model_provider_func(self, pre_process, post_process):
config=self.transformer_config,
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_config,
self.transformer_engine,
self.cfg.get('hyena', None),
),
Expand Down

0 comments on commit a7ae71d

Please sign in to comment.