Skip to content

Commit

Permalink
fix layer replacement (#11356)
Browse files Browse the repository at this point in the history
Signed-off-by: Onur Yilmaz <[email protected]>
  • Loading branch information
oyilmaz-nvidia authored Nov 21, 2024
1 parent 5e5fc4a commit dabd47f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions nemo/lightning/pytorch/accelerate/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def te_accelerate(model, fp8_autocast=False):

@torch.no_grad
def _apply_basic_module_replacement(model):
for name, module in model.named_modules():
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
has_bias = module.bias is not None
if any(p % 16 != 0 for p in module.weight.shape):
Expand All @@ -51,17 +51,19 @@ def _apply_basic_module_replacement(model):
if has_bias:
te_module.bias.copy_(module.bias)

setattr(module, name.split(".")[-1], te_module)
setattr(model, name, te_module)
elif isinstance(module, torch.nn.LayerNorm):
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
setattr(module, name.split(".")[-1], te_module)
setattr(model, name, te_module)
elif isinstance(module, torch.nn.RMSNorm):
te_module = te.RMSNorm(module.normalized_shape[0], eps=module.eps, dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
setattr(module, name.split(".")[-1], te_module)
setattr(model, name, te_module)
else:
_apply_basic_module_replacement(module)


def is_te_accelerated(model):
Expand Down

0 comments on commit dabd47f

Please sign in to comment.