Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Implement attention with bias
Browse files Browse the repository at this point in the history
  • Loading branch information
steve.an committed Oct 5, 2023
1 parent cf768d2 commit 1b547d0
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 16 deletions.
48 changes: 43 additions & 5 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@


@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)]
"num_batches, num_heads, y_size, x_size, is_causal",
[(4, 8, 128, 64, True), (4, 8, 256, 32, False)],
)
def test_forward(num_batches, num_heads, y_size, x_size, is_causal, device):
query = torch.randn(num_batches, num_heads, y_size, x_size, device=device)
Expand All @@ -32,9 +33,17 @@ def test_forward(num_batches, num_heads, y_size, x_size, is_causal, device):
trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal),
)

mask = torch.randn(num_batches, num_heads, y_size, y_size, device=device)

assert util.equal(
torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=mask),
trident.function.scaled_dot_product_attention(query, key, value, attn_mask=mask),
)


@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)]
"num_batches, num_heads, y_size, x_size, is_causal",
[(4, 8, 128, 64, True), (4, 8, 256, 32, False)],
)
def test_backward(num_batches, num_heads, y_size, x_size, is_causal, device):
query = torch.rand(num_batches, num_heads, y_size, x_size, device=device)
Expand All @@ -47,6 +56,7 @@ def train(func):
j = key.clone()
k = value.clone()
i.requires_grad = j.requires_grad = k.requires_grad = True

func(i, j, k, is_causal=is_causal).backward(grad_output, retain_graph=True)
return i.grad, j.grad, k.grad

Expand All @@ -57,18 +67,42 @@ def train(func):
assert util.equal(y, b)
assert util.equal(z, c)

mask = torch.randn(num_batches, num_heads, y_size, y_size, device=device)

def train(func):
i = query.clone()
j = key.clone()
k = value.clone()
l = mask.clone()
i.requires_grad = j.requires_grad = k.requires_grad = l.requires_grad = True

func(i, j, k, attn_mask=l).backward(grad_output, retain_graph=True)
return i.grad, j.grad, k.grad, l.grad

(w, x, y, z) = train(torch.nn.functional.scaled_dot_product_attention)
(a, b, c, d) = train(trident.function.scaled_dot_product_attention)

assert util.equal(w, a)
assert util.equal(x, b)
assert util.equal(y, c)
assert util.equal(z, d)


@pytest.mark.parametrize(
"num_batches, num_heads, y_size, x_size, is_causal", [(1, 1, 1, 16, True), (1, 1, 1, 16, False)]
"num_batches, num_heads, y_size, x_size, is_causal, with_bias",
[(1, 1, 1, 16, False, False), (1, 1, 1, 16, True, False), (1, 1, 1, 16, False, True)],
)
def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dtype):
def test_attention(num_batches, num_heads, y_size, x_size, is_causal, with_bias, device, dtype):
factory_kwargs = {"device": device, "dtype": dtype}
query = torch.rand(num_batches, num_heads, y_size, x_size, **factory_kwargs, requires_grad=True)
key = torch.randn_like(query, requires_grad=True)
value = torch.randn_like(query, requires_grad=True)
attn_mask = (
torch.randn(num_batches, num_heads, y_size, y_size, **factory_kwargs, requires_grad=True) if with_bias else None
)
grad_output = torch.randn_like(query)

output = trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
output = trident.function.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, is_causal=is_causal)

assert output is not None
assert output.dtype == dtype
Expand All @@ -81,3 +115,7 @@ def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dt
assert key.grad.dtype == dtype
assert value.grad is not None
assert value.grad.dtype == dtype

if attn_mask is not None:
assert attn_mask.grad is not None
assert attn_mask.grad.dtype == dtype
6 changes: 2 additions & 4 deletions trident/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
is_causal: bool = False,
use_accelerator: bool = False,
Expand All @@ -238,11 +239,8 @@ def scaled_dot_product_attention(
Computes scaled dot product attention on query, key and value tensors,
and applying dropout if a probability greater than 0.0 is specified.
"""
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError("The dimension of query, key and value should be 4.")

return operation.Attention.apply(
query, key, value, dropout_p, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator
query, key, value, attn_mask, dropout_p, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator
)


Expand Down
56 changes: 54 additions & 2 deletions trident/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def forward(
head_stride: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
mask_ptr: tl.tensor,
mask_head_stride: tl.int32,
mask_y_stride: tl.int32,
mask_x_stride: tl.int32,
dropout_p: tl.float32,
seed: tl.int32,
is_causal: tl.constexpr,
Expand Down Expand Up @@ -98,6 +102,16 @@ def forward(
order=(1, 0),
)

if mask_ptr is not None:
mask_block_ptr = tl.make_block_ptr(
mask_ptr + head * mask_head_stride,
shape=(y_size, y_size),
strides=(mask_y_stride, mask_x_stride),
offsets=(y_offset, 0),
block_shape=(y_block_size, y_block_size),
order=(1, 0),
)

query = tl.load(query_block_ptr)
score_scale = (softmax_scale * language.log2e).to(dtype)
query *= score_scale
Expand All @@ -118,6 +132,10 @@ def forward(
n_offsets = tl.arange(0, y_block_size) + n_offset
condition = m_offsets[:, None] >= n_offsets[None, :]
score = tl.where(condition, score, float("-inf"))
elif mask_ptr is not None:
mask = tl.load(mask_block_ptr)
mask *= language.log2e
score += mask

key = tl.load(key_block_ptr)
score += language.dot(query, key, use_accelerator, dtype)
Expand All @@ -132,6 +150,9 @@ def forward(
key_block_ptr = tl.advance(key_block_ptr, (0, y_block_size))
value_block_ptr = tl.advance(value_block_ptr, (y_block_size, 0))

if mask_ptr is not None:
mask_block_ptr = tl.advance(mask_block_ptr, (0, y_block_size))

output /= sum[:, None].to(dtype)

if dropout_p > language.eps:
Expand All @@ -149,6 +170,7 @@ def backward(
grad_query_ptr: tl.tensor,
grad_key_ptr: tl.tensor,
grad_value_ptr: tl.tensor,
grad_mask_ptr: tl.tensor,
grad_output_ptr: tl.tensor,
query_ptr: tl.tensor,
key_ptr: tl.tensor,
Expand All @@ -158,6 +180,10 @@ def backward(
head_stride: tl.int32,
y_stride: tl.int32,
x_stride: tl.int32,
mask_ptr: tl.tensor,
mask_head_stride: tl.int32,
mask_y_stride: tl.int32,
mask_x_stride: tl.int32,
output_ptr: tl.tensor,
log2sum_ptr: tl.tensor,
delta_ptr: tl.tensor,
Expand Down Expand Up @@ -236,6 +262,24 @@ def backward(
order=(0,),
)

if mask_ptr is not None:
grad_mask_block_ptr = tl.make_block_ptr(
grad_mask_ptr + pid * mask_head_stride,
shape=(y_size, y_size),
strides=(mask_y_stride, mask_x_stride),
offsets=(0, n_block * y_block_size),
block_shape=(y_block_size, y_block_size),
order=(1, 0),
)
mask_block_ptr = tl.make_block_ptr(
mask_ptr + pid * mask_head_stride,
shape=(y_size, y_size),
strides=(mask_y_stride, mask_x_stride),
offsets=(0, n_block * y_block_size),
block_shape=(y_block_size, y_block_size),
order=(1, 0),
)

grad_value = tl.zeros((y_block_size, x_block_size), dtype)
grad_key = tl.zeros((y_block_size, x_block_size), dtype)
ptr_offsets = n_strides[:, None] + x_strides[None, :]
Expand All @@ -250,11 +294,14 @@ def backward(
if is_causal:
condition = m_offsets[:, None] >= n_offsets[None, :]
score = tl.where(condition, 0.0, float("-inf"))
elif mask_ptr is not None:
mask = tl.load(mask_block_ptr)
mask *= language.log2e
score = mask
else:
score = tl.zeros((y_block_size, y_block_size), dtype)

score += language.dot(query, tl.trans(key), use_accelerator, dtype)
score *= score_scale
score += language.dot(query, tl.trans(key), use_accelerator, dtype) * score_scale
log2sum = tl.load(log2sum_block_ptr)
alpha = tl.math.exp2(score - log2sum[:, None]).to(dtype)
grad_output = tl.load(grad_output_block_ptr)
Expand Down Expand Up @@ -282,6 +329,11 @@ def backward(
delta_block_ptr = tl.advance(delta_block_ptr, (y_block_size,))
log2sum_block_ptr = tl.advance(log2sum_block_ptr, (y_block_size,))

if mask_ptr is not None:
tl.store(grad_mask_block_ptr, (grad_softmax / softmax_scale).to(dtype))
mask_block_ptr = tl.advance(mask_block_ptr, (y_block_size, 0))
grad_mask_block_ptr = tl.advance(grad_mask_block_ptr, (y_block_size, 0))

tl.store(grad_key_ptr + ptr_offsets, grad_key)
tl.store(grad_value_ptr + ptr_offsets, grad_value)
n_strides += y_block_size * y_stride
33 changes: 28 additions & 5 deletions trident/operation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,25 @@
class Attention(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any):
query, key, value, dropout_p, is_causal, softmax_scale, use_accelerator = args
query, key, value, mask, dropout_p, is_causal, softmax_scale, use_accelerator = args

if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError("The dimension of query, key and value should be 4.")

if mask is not None:
if is_causal:
raise ValueError("Error because both attn_mask and is_causal are set.")
if mask.dtype == torch.bool:
raise ValueError("Boolean mask is not supported yet.")

util.push_trace("Attention.__forward")
output, log_sum_exp = Attention.__forward(
query, key, value, dropout_p, is_causal, softmax_scale, use_accelerator
query, key, value, mask, dropout_p, is_causal, softmax_scale, use_accelerator
)
util.pop_trace()

ctx.save_for_backward(query, key, value, output, log_sum_exp)
ctx.mask = mask
ctx.dropout_p = dropout_p
ctx.is_causal = is_causal
ctx.softmax_scale = softmax_scale
Expand All @@ -45,27 +55,29 @@ def backward(ctx: Any, *grad_outputs: Any):
query, key, value, output, log_sum_exp = ctx.saved_tensors

util.push_trace("Attention.__backward")
grad_query, grad_key, grad_value = Attention.__backward(
grad_query, grad_key, grad_value, grad_mask = Attention.__backward(
grad_output,
query,
key,
value,
output,
log_sum_exp,
ctx.mask,
ctx.softmax_scale,
ctx.dropout_p,
ctx.is_causal,
ctx.use_accelerator,
)
util.pop_trace()

return grad_query, grad_key, grad_value, None, None, None, None
return grad_query, grad_key, grad_value, grad_mask, None, None, None, None

@staticmethod
def __forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor,
dropout_p: torch.float32,
is_causal: torch.bool,
softmax_scale: torch.float32,
Expand Down Expand Up @@ -95,6 +107,10 @@ def grid(meta):
query.stride(1),
query.stride(2),
query.stride(3),
mask,
mask.stride(1) if mask is not None else 0,
mask.stride(2) if mask is not None else 0,
mask.stride(3) if mask is not None else 0,
dropout_p,
torch.random.seed(),
is_causal,
Expand All @@ -115,6 +131,7 @@ def __backward(
value: torch.Tensor,
output: torch.Tensor,
log2sum: torch.Tensor,
mask: torch.Tensor,
softmax_scale: torch.float32,
dropout_p: torch.float32,
is_causal: torch.bool,
Expand All @@ -124,6 +141,7 @@ def __backward(
grad_query = torch.zeros_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
grad_mask = torch.empty_like(mask) if mask is not None else None
delta = torch.empty_like(log2sum)

def grid(meta):
Expand All @@ -143,6 +161,7 @@ def grid(meta):
grad_query,
grad_key,
grad_value,
grad_mask,
grad_output,
query,
key,
Expand All @@ -152,6 +171,10 @@ def grid(meta):
query.stride(1),
query.stride(2),
query.stride(3),
mask,
mask.stride(1) if mask is not None else 0,
mask.stride(2) if mask is not None else 0,
mask.stride(3) if mask is not None else 0,
output,
log2sum,
delta,
Expand All @@ -166,4 +189,4 @@ def grid(meta):
)
util.pop_trace()

return grad_query, grad_key, grad_value
return grad_query, grad_key, grad_value, grad_mask

0 comments on commit 1b547d0

Please sign in to comment.