Skip to content

Commit

Permalink
padding mask on torch side (#12577)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Dec 19, 2024
1 parent 47e90a3 commit e0921f8
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
hidden_states, self.normalized_shape,
self.weight, self.bias, self.eps
)


def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
max_kvs = 128
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
if mask is None:
if is_causal:
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask.triu_(1)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32:
mask = mask[..., :seq_length, :kv_length]
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
elif mask.size(3) != padding_kv_length:
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
dtype=dtype, device=device)
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
mask.set_(new_mask) # modify `mask` inplaced
else:
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
return mask


def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = None,
is_causal: bool = False, scale: float = None) -> torch.Tensor:
bsz, n_heads, seq_length, head_dim = query.shape
_, n_kv_heads, kv_length, _ = key.shape

dtype, device = query.dtype, query.device

if (
device.type == "xpu"
and dtype in [torch.float, torch.half]
and head_dim in [64, 80, 96, 128]
):
# prepare scale
scale = 1 / math.sqrt(head_dim) if scale is None else scale

# prepare mask
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)

# compute
import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)

return attn_output
else:
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)

0 comments on commit e0921f8

Please sign in to comment.