Skip to content

Commit

Permalink
[CPU] attn supports f16 (#26487)
Browse files Browse the repository at this point in the history
### Details:
 - *rebase from #22939
 - *enable avx512 fp16 for attention*
 - *enable amx fp16 for attention*
- *update PagedAttentionExtension lightly. can specify the correct type
to pa second output precision*

### Tickets:
 - *128183*
  • Loading branch information
xczhai authored Oct 15, 2024
1 parent fe2f67b commit 9486b7d
Show file tree
Hide file tree
Showing 25 changed files with 605 additions and 281 deletions.
5 changes: 5 additions & 0 deletions src/core/dev_api/openvino/op/paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class OPENVINO_API PagedAttentionExtension : public ov::op::Op {
PagedAttentionExtension(const ov::OutputVector& args);
void validate_and_infer_types() override;
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

void set_out_type(int index, const ov::element::Type& output_type);

protected:
std::vector<ov::element::Type> m_output_type = {ov::element::undefined, ov::element::undefined};
};

} // namespace op
Expand Down
18 changes: 16 additions & 2 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,27 @@ void PagedAttentionExtension::validate_and_infer_types() {
get_input_element_type(12),
".");

set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
set_output_type(1, get_input_element_type(0), {Dimension::dynamic()});
if (m_output_type[0] == ov::element::undefined) {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} else {
set_output_type(0, m_output_type[0], get_input_partial_shape(0));
}

if (m_output_type[1] == ov::element::undefined) {
set_output_type(1, get_input_element_type(0), {Dimension::dynamic()});
} else {
set_output_type(1, m_output_type[1], {Dimension::dynamic()});
}
}

std::shared_ptr<ov::Node> PagedAttentionExtension::clone_with_new_inputs(const ov::OutputVector& new_args) const {
return std::make_shared<PagedAttentionExtension>(new_args);
}

void PagedAttentionExtension::set_out_type(int index, const ov::element::Type& output_type) {
OPENVINO_ASSERT(index < 2, "Output index should be 0 or 1, but got " + std::to_string(index));
m_output_type[index] = output_type;
}

} // namespace op
} // namespace ov
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ void Graph::Replicate(const std::shared_ptr<const ov::Model> &model,
const auto port = unusedOutput.get_index();
const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName();
const NodePtr outNode = std::make_shared<node::Input>(parentNode->outputShapes[port],
parentNode->getOriginalOutputPrecisionAtPort(port),
nodeName, "Result", m_context);
parentNode->getOriginalOutputPrecisionAtPort(port),
nodeName, "Result", m_context);
CreateEdge(parentNode, outNode, port, 0);
AddNode(outNode);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
// For compatibility, all input_kvs are permuted to BHLS
size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3];
// Internal LBHS layout has strides[L] > strides[B]
assert(k_src.m_strides[2] > k_src.m_strides[0]);
parallel_for3d(L1, B, H, [&](size_t m, size_t b, size_t h) {
auto p_k = k_scale_zp.ptr<float>(m, b, h);
auto p_v = v_scale_zp.ptr<float>(m, b, h);
Expand Down Expand Up @@ -238,6 +237,8 @@ void attn_quantkv(const ov::intel_cpu::PlainTensor& k_src,
attn_quant_mt<float, uint8_t>(k_src, v_src, k_dst, v_dst, k_scale_zp, v_scale_zp);
} else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) {
attn_quant_mt<ov::bfloat16, uint8_t>(k_src, v_src, k_dst, v_dst, k_scale_zp, v_scale_zp);
} else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) {
attn_quant_mt<ov::float16, uint8_t>(k_src, v_src, k_dst, v_dst, k_scale_zp, v_scale_zp);
} else {
OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in attn_quantkv");
}
Expand All @@ -252,6 +253,8 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src,
paged_attn_quant_mt<float, uint8_t>(k_src, v_src, k_dst, v_dst, slot_mapping);
} else if (k_src.get_precision() == ov::element::bf16 && k_dst.get_precision() == ov::element::u8) {
paged_attn_quant_mt<ov::bfloat16, uint8_t>(k_src, v_src, k_dst, v_dst, slot_mapping);
} else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) {
paged_attn_quant_mt<ov::float16, uint8_t>(k_src, v_src, k_dst, v_dst, slot_mapping);
} else {
OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv");
}
Expand Down
95 changes: 70 additions & 25 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,22 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
return _mm512_castsi512_ps(_mm512_slli_epi32(y, 16));
}

// load addr to __m512 reg
inline __m512 mm512_uni_loadu_ps(const float* a) {
return _mm512_loadu_ps(a);
}

inline __m512 mm512_uni_loadu_ps(const ov::bfloat16* a) {
auto vec_bf16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a));
return cvt_bf16_to_fp32(vec_bf16);
}

inline __m512 mm512_uni_loadu_ps(const float* a) {
return _mm512_loadu_ps(a);
inline __m512 mm512_uni_loadu_ps(const ov::float16* a) {
auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a));
return _mm512_cvtph_ps(vec_f16);
}

// load addr to __m512 reg
inline __m512 mm512_uni_loadu_tail_ps(const float* a, size_t count) {
__mmask16 mask = (1 << count) - 1;
return _mm512_maskz_loadu_ps(mask, a);
Expand All @@ -57,6 +64,13 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
return cvt_bf16_to_fp32(bf16_vec);
}

inline __m512 mm512_uni_loadu_tail_ps(const ov::float16* a, size_t count) {
auto mask = (1 << count) - 1;
auto f16_vec = _mm256_maskz_loadu_epi16(mask, a);
return _mm512_cvtph_ps(f16_vec);
}

// store __m512 reg to addr
inline void mm512_uni_storeu_ps(float* a, __m512 v) {
_mm512_storeu_ps(a, v);
}
Expand All @@ -72,6 +86,13 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), _mm512_cvtepi32_epi16(x));
}

inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) {
__m256i vec_f16 = _mm512_cvtps_ph(v, 0);
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16);
}

// store __m512 reg to addr
inline void mm512_uni_mask_storeu_ps(ov::bfloat16 *addr, __mmask16 mask_addr, __m512 xps) {
__m512i xpi32 = _mm512_castps_si512(xps);
__m512i nan = _mm512_set1_epi32(0xffff);
Expand All @@ -85,18 +106,29 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
_mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x);
}

inline __m512 mm512_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a));
return _mm512_cvtph_ps(vec_f16);
inline void mm512_uni_storeu_tail_ps(float *addr, __m512 v, size_t count) {
__mmask16 mask_addr = (1 << count) - 1;
_mm512_mask_storeu_ps(addr, mask_addr, v);
}
inline __m512 mm512_uni_loadu_tail_ps(const ov::float16* a, size_t count) {
auto mask = (1 << count) - 1;
auto f16_vec = _mm256_maskz_loadu_epi16(mask, a);
return _mm512_cvtph_ps(f16_vec);

inline void mm512_uni_storeu_tail_ps(ov::bfloat16 *addr, __m512 v, size_t count) {
__mmask16 mask_addr = (1 << count) - 1;
__m512i xpi32 = _mm512_castps_si512(v);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask = _mm512_cmp_ps_mask(v, v, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16]
x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB
x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16;
x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16
_mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x);
}
inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) {

inline void mm512_uni_storeu_tail_ps(ov::float16 *addr, __m512 v, size_t count) {
__mmask16 mask_addr = (1 << count) - 1;
__m256i vec_f16 = _mm512_cvtps_ph(v, 0);
_mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16);
_mm256_mask_storeu_epi16(reinterpret_cast<__m256i *>(addr), mask_addr, vec_f16);
}
#endif

Expand All @@ -115,19 +147,25 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
};
return _mm256_loadu_si256(&mask[N7]);
}

// load addr to __m256 reg
inline __m256 mm256_uni_loadu_ps(const float* a) {
return _mm256_loadu_ps(a);
}
inline void mm256_uni_storeu_ps(float* a, __m256 v) {
_mm256_storeu_ps(a, v);
}

inline __m256 mm256_uni_loadu_ps(const ov::bfloat16* a) {
auto vec_bf16 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(a));
auto o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(vec_bf16), 16));
return o;
}

inline __m256 mm256_uni_loadu_ps(const ov::float16* a) {
auto vec_f16 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(a));
auto o = _mm256_cvtph_ps(vec_f16);
return o;
}

// load addr tail to __m256 reg
inline __m256 mm256_uni_loadu_tail_ps(const float* a, const size_t count) {
auto mask = get_mask(count);
return _mm256_maskload_ps(a, mask);
Expand All @@ -140,6 +178,17 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
return mm256_uni_loadu_ps(tmp_values);
}

inline __m256 mm256_uni_loadu_tail_ps(const ov::float16* a, const size_t count) {
ov::float16 tmp_values[8] = {0};
std::memcpy(tmp_values, a, count * sizeof(ov::float16));
return mm256_uni_loadu_ps(tmp_values);
}

// store __m256 reg to addr
inline void mm256_uni_storeu_ps(float* a, __m256 v) {
_mm256_storeu_ps(a, v);
}

inline void mm256_uni_storeu_ps(ov::bfloat16 *addr, __m256 xps) {
__m256i xpi32 = _mm256_castps_si256(xps);
__m256i nan = _mm256_set1_epi32(0xffff);
Expand All @@ -156,21 +205,17 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
_mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o);
}

inline __m256 mm256_uni_loadu_ps(ov::float16* a) {
auto vec_f16 = _mm_loadu_si128(reinterpret_cast<__m128i*>(a));
auto o = _mm256_cvtph_ps(vec_f16);
return o;
}
inline __m256 mm256_uni_loadu_tail_ps(const ov::float16* a, const size_t count) {
ov::float16 tmp_values[8] = {0};
std::memcpy(tmp_values, a, count * sizeof(ov::float16));
return mm256_uni_loadu_ps(tmp_values);
}
inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) {
__m128i vec_f16 = _mm256_cvtps_ph(v, 0);
_mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16);
}

// store __m256 to addr
inline void mm256_uni_storeu_tail_ps(float *addr, __m256 v, size_t count) {
const auto mask = get_mask(count);
return _mm256_maskstore_ps(addr, mask, v);
}

inline void hsum(__m256& x) {
__m256 y; // x: 0 1 2 3 4 5 6 7
y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4
Expand Down Expand Up @@ -292,4 +337,4 @@ static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16);
} // namespace XARCH
} // namespace Cpu
} // namespace Extensions
} // namespace ov
} // namespace ov
Loading

0 comments on commit 9486b7d

Please sign in to comment.