diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py index 80a3533d6194b..a2efa8192fc40 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py @@ -47,13 +47,13 @@ def bias_dropout_add_fused_train_( # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) + @torch.jit.script -def dropout_add_fused_train_( - x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float -) -> torch.Tensor: +def dropout_add_fused_train_(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: # type: (Tensor, None, Tensor, float) -> Tensor return dropout_add(x, bias, residual, prob, True) + def bias_dropout_add_fused_train(x, bias, residual, prob): # re-enable torch grad to enable fused optimization. with torch.enable_grad():