diff --git a/tests/test_attention.py b/tests/test_attention.py index e83060a..4b37905 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize( - "num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)] + "num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 256, 32, False)] ) def test_forward(num_batches, num_heads, y_size, x_size, is_causal, device): query = torch.randn(num_batches, num_heads, y_size, x_size, device=device) @@ -32,9 +32,16 @@ def test_forward(num_batches, num_heads, y_size, x_size, is_causal, device): trident.function.scaled_dot_product_attention(query, key, value, is_causal=is_causal), ) + mask = torch.randn(num_batches, num_heads, y_size, y_size, device=device) + + assert util.equal( + torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=mask), + trident.function.scaled_dot_product_attention(query, key, value, attn_mask=mask), + ) + @pytest.mark.parametrize( - "num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 128, 64, False)] + "num_batches, num_heads, y_size, x_size, is_causal", [(4, 8, 128, 64, True), (4, 8, 256, 32, False)] ) def test_backward(num_batches, num_heads, y_size, x_size, is_causal, device): query = torch.rand(num_batches, num_heads, y_size, x_size, device=device) @@ -57,6 +64,25 @@ def train(func): assert util.equal(y, b) assert util.equal(z, c) + mask = torch.randn(num_batches, num_heads, y_size, y_size, device=device) + + def train(func): + i = query.clone() + j = key.clone() + k = value.clone() + l = mask.clone() + i.requires_grad = j.requires_grad = k.requires_grad = l.requires_grad = True + func(i, j, k, attn_mask=l).backward(grad_output, retain_graph=True) + return i.grad, j.grad, k.grad, l.grad + + (w, x, y, z) = train(torch.nn.functional.scaled_dot_product_attention) + (a, b, c, d) = train(trident.function.scaled_dot_product_attention) + + assert util.equal(w, a) + assert util.equal(x, b) + assert util.equal(y, c) + assert util.equal(z, d) + @pytest.mark.parametrize( "num_batches, num_heads, y_size, x_size, is_causal", [(1, 1, 1, 16, True), (1, 1, 1, 16, False)] @@ -81,3 +107,10 @@ def test_attention(num_batches, num_heads, y_size, x_size, is_causal, device, dt assert key.grad.dtype == dtype assert value.grad is not None assert value.grad.dtype == dtype + + mask = torch.randn(num_batches, num_heads, y_size, y_size, **factory_kwargs, requires_grad=True) + output = trident.function.scaled_dot_product_attention(query, key, value, attn_mask=mask) + output.backward(grad_output) + + assert mask.grad is not None + assert mask.grad.dtype == dtype diff --git a/trident/function/function.py b/trident/function/function.py index 3089e32..08bb9af 100644 --- a/trident/function/function.py +++ b/trident/function/function.py @@ -239,6 +239,7 @@ def scaled_dot_product_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: torch.Tensor = None, dropout_p: float = 0.0, is_causal: bool = False, use_accelerator: bool = False, @@ -247,11 +248,8 @@ def scaled_dot_product_attention( Computes scaled dot product attention on query, key and value tensors, and applying dropout if a probability greater than 0.0 is specified. """ - if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: - raise ValueError("The dimension of query, key and value should be 4.") - return operation.Attention.apply( - query, key, value, dropout_p, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator + query, key, value, attn_mask, dropout_p, is_causal, 1.0 / math.sqrt(key.shape[-1]), use_accelerator ) diff --git a/trident/kernel/attention.py b/trident/kernel/attention.py index 3dc108c..b509fb6 100644 --- a/trident/kernel/attention.py +++ b/trident/kernel/attention.py @@ -41,6 +41,10 @@ def forward( head_stride: tl.int32, y_stride: tl.int32, x_stride: tl.int32, + mask_ptr: tl.tensor, + mask_head_stride: tl.int32, + mask_y_stride: tl.int32, + mask_x_stride: tl.int32, dropout_p: tl.float32, seed: tl.int32, is_causal: tl.constexpr, @@ -98,6 +102,16 @@ def forward( order=(1, 0), ) + if mask_ptr is not None: + mask_block_ptr = tl.make_block_ptr( + mask_ptr + head * mask_head_stride, + shape=(y_size, y_size), + strides=(mask_y_stride, mask_x_stride), + offsets=(y_offset, 0), + block_shape=(y_block_size, y_block_size), + order=(1, 0), + ) + query = tl.load(query_block_ptr) score_scale = (softmax_scale * language.log2e).to(dtype) query *= score_scale @@ -118,6 +132,10 @@ def forward( n_offsets = tl.arange(0, y_block_size) + n_offset condition = m_offsets[:, None] >= n_offsets[None, :] score = tl.where(condition, score, float("-inf")) + elif mask_ptr is not None: + mask = tl.load(mask_block_ptr) + mask *= language.log2e + score += mask key = tl.load(key_block_ptr) score += language.dot(query, key, use_accelerator, dtype) @@ -132,6 +150,9 @@ def forward( key_block_ptr = tl.advance(key_block_ptr, (0, y_block_size)) value_block_ptr = tl.advance(value_block_ptr, (y_block_size, 0)) + if mask_ptr is not None: + mask_block_ptr = tl.advance(mask_block_ptr, (0, y_block_size)) + output /= sum[:, None].to(dtype) if dropout_p > language.eps: @@ -149,6 +170,7 @@ def backward( grad_query_ptr: tl.tensor, grad_key_ptr: tl.tensor, grad_value_ptr: tl.tensor, + grad_mask_ptr: tl.tensor, grad_output_ptr: tl.tensor, query_ptr: tl.tensor, key_ptr: tl.tensor, @@ -158,6 +180,10 @@ def backward( head_stride: tl.int32, y_stride: tl.int32, x_stride: tl.int32, + mask_ptr: tl.tensor, + mask_head_stride: tl.int32, + mask_y_stride: tl.int32, + mask_x_stride: tl.int32, output_ptr: tl.tensor, log2sum_ptr: tl.tensor, delta_ptr: tl.tensor, @@ -236,6 +262,24 @@ def backward( order=(0,), ) + if mask_ptr is not None: + grad_mask_block_ptr = tl.make_block_ptr( + grad_mask_ptr + pid * mask_head_stride, + shape=(y_size, y_size), + strides=(mask_y_stride, mask_x_stride), + offsets=(0, n_block * y_block_size), + block_shape=(y_block_size, y_block_size), + order=(1, 0), + ) + mask_block_ptr = tl.make_block_ptr( + mask_ptr + pid * mask_head_stride, + shape=(y_size, y_size), + strides=(mask_y_stride, mask_x_stride), + offsets=(0, n_block * y_block_size), + block_shape=(y_block_size, y_block_size), + order=(1, 0), + ) + grad_value = tl.zeros((y_block_size, x_block_size), dtype) grad_key = tl.zeros((y_block_size, x_block_size), dtype) ptr_offsets = n_strides[:, None] + x_strides[None, :] @@ -250,11 +294,14 @@ def backward( if is_causal: condition = m_offsets[:, None] >= n_offsets[None, :] score = tl.where(condition, 0.0, float("-inf")) + elif mask_ptr is not None: + mask = tl.load(mask_block_ptr) + mask *= language.log2e + score = mask else: score = tl.zeros((y_block_size, y_block_size), dtype) - score += language.dot(query, tl.trans(key), use_accelerator, dtype) - score *= score_scale + score += language.dot(query, tl.trans(key), use_accelerator, dtype) * score_scale log2sum = tl.load(log2sum_block_ptr) alpha = tl.math.exp2(score - log2sum[:, None]).to(dtype) grad_output = tl.load(grad_output_block_ptr) @@ -282,6 +329,11 @@ def backward( delta_block_ptr = tl.advance(delta_block_ptr, (y_block_size,)) log2sum_block_ptr = tl.advance(log2sum_block_ptr, (y_block_size,)) + if mask_ptr is not None: + tl.store(grad_mask_block_ptr, (grad_softmax / softmax_scale).to(dtype)) + mask_block_ptr = tl.advance(mask_block_ptr, (y_block_size, 0)) + grad_mask_block_ptr = tl.advance(grad_mask_block_ptr, (y_block_size, 0)) + tl.store(grad_key_ptr + ptr_offsets, grad_key) tl.store(grad_value_ptr + ptr_offsets, grad_value) n_strides += y_block_size * y_stride diff --git a/trident/operation/attention.py b/trident/operation/attention.py index 615f442..fa02549 100644 --- a/trident/operation/attention.py +++ b/trident/operation/attention.py @@ -23,15 +23,25 @@ class Attention(torch.autograd.Function): @staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any): - query, key, value, dropout_p, is_causal, softmax_scale, use_accelerator = args + query, key, value, mask, dropout_p, is_causal, softmax_scale, use_accelerator = args + + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise ValueError("The dimension of query, key and value should be 4.") + + if mask is not None: + if is_causal: + raise ValueError("Error because both attn_mask and is_causal are set.") + if mask.dtype == torch.bool: + raise ValueError("Boolean mask is not supported yet.") util.push_trace("Attention.__forward") output, log_sum_exp = Attention.__forward( - query, key, value, dropout_p, is_causal, softmax_scale, use_accelerator + query, key, value, mask, dropout_p, is_causal, softmax_scale, use_accelerator ) util.pop_trace() ctx.save_for_backward(query, key, value, output, log_sum_exp) + ctx.mask = mask ctx.dropout_p = dropout_p ctx.is_causal = is_causal ctx.softmax_scale = softmax_scale @@ -45,13 +55,14 @@ def backward(ctx: Any, *grad_outputs: Any): query, key, value, output, log_sum_exp = ctx.saved_tensors util.push_trace("Attention.__backward") - grad_query, grad_key, grad_value = Attention.__backward( + grad_query, grad_key, grad_value, grad_mask = Attention.__backward( grad_output, query, key, value, output, log_sum_exp, + ctx.mask, ctx.softmax_scale, ctx.dropout_p, ctx.is_causal, @@ -59,13 +70,14 @@ def backward(ctx: Any, *grad_outputs: Any): ) util.pop_trace() - return grad_query, grad_key, grad_value, None, None, None, None + return grad_query, grad_key, grad_value, grad_mask, None, None, None, None @staticmethod def __forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor, dropout_p: torch.float32, is_causal: torch.bool, softmax_scale: torch.float32, @@ -95,6 +107,10 @@ def grid(meta): query.stride(1), query.stride(2), query.stride(3), + mask, + mask.stride(1) if mask is not None else 0, + mask.stride(2) if mask is not None else 0, + mask.stride(3) if mask is not None else 0, dropout_p, torch.random.seed(), is_causal, @@ -115,6 +131,7 @@ def __backward( value: torch.Tensor, output: torch.Tensor, log2sum: torch.Tensor, + mask: torch.Tensor, softmax_scale: torch.float32, dropout_p: torch.float32, is_causal: torch.bool, @@ -124,6 +141,7 @@ def __backward( grad_query = torch.zeros_like(query) grad_key = torch.empty_like(key) grad_value = torch.empty_like(value) + grad_mask = torch.empty_like(mask) if mask is not None else None delta = torch.empty_like(log2sum) def grid(meta): @@ -143,6 +161,7 @@ def grid(meta): grad_query, grad_key, grad_value, + grad_mask, grad_output, query, key, @@ -152,6 +171,10 @@ def grid(meta): query.stride(1), query.stride(2), query.stride(3), + mask, + mask.stride(1) if mask is not None else 0, + mask.stride(2) if mask is not None else 0, + mask.stride(3) if mask is not None else 0, output, log2sum, delta, @@ -166,4 +189,4 @@ def grid(meta): ) util.pop_trace() - return grad_query, grad_key, grad_value + return grad_query, grad_key, grad_value, grad_mask