From 6bf0f92eae1d43d85baaa751b5d52a39a93cef37 Mon Sep 17 00:00:00 2001 From: Ashvin Nihalani Date: Wed, 10 Jul 2024 21:08:03 +0000 Subject: [PATCH 1/2] Enabling bias_dropout_add_fused with no bias term Signed-off-by: Ashvin Nihalani --- .../common/megatron/fused_bias_dropout_add.py | 17 ++++++++++++++--- .../nlp/modules/common/megatron/transformer.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py index d42f14d03e05..80a3533d6194 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py @@ -47,13 +47,24 @@ def bias_dropout_add_fused_train_( # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) +@torch.jit.script +def dropout_add_fused_train_( + x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float +) -> torch.Tensor: + # type: (Tensor, None, Tensor, float) -> Tensor + return dropout_add(x, bias, residual, prob, True) def bias_dropout_add_fused_train(x, bias, residual, prob): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - args = _cast_if_autocast_enabled(x, bias, residual, prob) - with torch.cuda.amp.autocast(enabled=False): - return bias_dropout_add_fused_train_(*args) + if bias: + args = _cast_if_autocast_enabled(x, bias, residual, prob) + with torch.cuda.amp.autocast(enabled=False): + return bias_dropout_add_fused_train_(*args) + else: + args = _cast_if_autocast_enabled(x, residual, prob) + with torch.cuda.amp.autocast(enabled=False): + return dropout_add_fused_train_(*args) @torch.jit.script diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index cb23c4a6b1fd..3c884d3057b5 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -453,7 +453,7 @@ def _get_bias_droput_add_func(self, transformer_block_type='pre_ln', position_af if transformer_block_type == 'normformer' and position_after == 'attention': bias_dropout_add_func = get_dropout_add(self.training) # Bias dropout add fused kernel - elif self.bias and self.bias_dropout_add_fusion: + elif self.bias_dropout_add_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: From 021cebea8fe6d12e4af828e84d6eb1ac83b7d4b9 Mon Sep 17 00:00:00 2001 From: ashvinnihalani Date: Wed, 10 Jul 2024 21:32:00 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: ashvinnihalani --- .../nlp/modules/common/megatron/fused_bias_dropout_add.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py index 80a3533d6194..a2efa8192fc4 100644 --- a/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py +++ b/nemo/collections/nlp/modules/common/megatron/fused_bias_dropout_add.py @@ -47,13 +47,13 @@ def bias_dropout_add_fused_train_( # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) + @torch.jit.script -def dropout_add_fused_train_( - x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float -) -> torch.Tensor: +def dropout_add_fused_train_(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: # type: (Tensor, None, Tensor, float) -> Tensor return dropout_add(x, bias, residual, prob, True) + def bias_dropout_add_fused_train(x, bias, residual, prob): # re-enable torch grad to enable fused optimization. with torch.enable_grad():