diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 4c2e830cbdd..7f6dbd75939 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor): hidden_states, self.normalized_shape, self.weight, self.bias, self.eps ) + + +def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device): + max_kvs = 128 + padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs + if mask is None: + if is_causal: + mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min, + dtype=dtype, device=device) + mask.triu_(1) + mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length]) + elif seq_length != kv_length and seq_length <= 32: + mask = None + else: + mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min, + dtype=dtype, device=device) + mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length]) + else: + if seq_length != kv_length and seq_length <= 32: + mask = mask[..., :seq_length, :kv_length] + mask = mask.expand([bsz, n_heads, seq_length, kv_length]) + elif mask.size(3) != padding_kv_length: + new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length], + dtype=dtype, device=device) + new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length] + new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min + new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length]) + mask.set_(new_mask) # modify `mask` inplaced + else: + mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length]) + return mask + + +def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = None, + is_causal: bool = False, scale: float = None) -> torch.Tensor: + bsz, n_heads, seq_length, head_dim = query.shape + _, n_kv_heads, kv_length, _ = key.shape + + dtype, device = query.dtype, query.device + + if ( + device.type == "xpu" + and dtype in [torch.float, torch.half] + and head_dim in [64, 80, 96, 128] + ): + # prepare scale + scale = 1 / math.sqrt(head_dim) if scale is None else scale + + # prepare mask + mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device) + + # compute + import xe_addons + if is_causal: + if key.dtype == torch.uint8: + attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) + else: + attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) + elif seq_length != kv_length and seq_length <= 32: + # todo: add scale support + if key.dtype == torch.uint8: + attn_output = xe_addons.sdp_fp8(query, key, value, mask) + else: + attn_output = xe_addons.sdp(query, key, value, mask) + else: + if key.dtype == torch.uint8: + attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) + else: + attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) + + return attn_output + else: + mask = mask[..., :seq_length, :kv_length] if mask is not None else None + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, mask, is_causal=is_causal, scale=scale + )