From e6940ed8cbd7ce3699b3d2e530dcf7bf900ed1d3 Mon Sep 17 00:00:00 2001 From: ashvinnihalani Date: Wed, 10 Jul 2024 21:19:33 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: ashvinnihalani --- .../nlp/modules/common/megatron/fused_bias_dropout_add.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py index 80a3533d6194b..a2efa8192fc40 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py @@ -47,13 +47,13 @@ def bias_dropout_add_fused_train_( # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) + @torch.jit.script -def dropout_add_fused_train_( - x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float -) -> torch.Tensor: +def dropout_add_fused_train_(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: # type: (Tensor, None, Tensor, float) -> Tensor return dropout_add(x, bias, residual, prob, True) + def bias_dropout_add_fused_train(x, bias, residual, prob): # re-enable torch grad to enable fused optimization. with torch.enable_grad():