Skip to content

Commit

Permalink
Enabling bias_dropout_add_fused with no bias term
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvinnihalani committed Jul 10, 2024
1 parent 14d42dc commit cd68086
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,24 @@ def bias_dropout_add_fused_train_(
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)

@torch.jit.script
def dropout_add_fused_train_(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
# type: (Tensor, None, Tensor, float) -> Tensor
return dropout_add(x, bias, residual, prob, True)

def bias_dropout_add_fused_train(x, bias, residual, prob):
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
args = _cast_if_autocast_enabled(x, bias, residual, prob)
with torch.cuda.amp.autocast(enabled=False):
return bias_dropout_add_fused_train_(*args)
if bias:
args = _cast_if_autocast_enabled(x, bias, residual, prob)
with torch.cuda.amp.autocast(enabled=False):
return bias_dropout_add_fused_train_(*args)
else:
args = _cast_if_autocast_enabled(x, residual, prob)
with torch.cuda.amp.autocast(enabled=False):
return dropout_add_fused_train_(*args)


@torch.jit.script
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def _get_bias_droput_add_func(self, transformer_block_type='pre_ln', position_af
if transformer_block_type == 'normformer' and position_after == 'attention':
bias_dropout_add_func = get_dropout_add(self.training)
# Bias dropout add fused kernel
elif self.bias and self.bias_dropout_add_fusion:
elif self.bias_dropout_add_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
Expand Down

0 comments on commit cd68086

Please sign in to comment.