Skip to content

Commit

Permalink
[CPU]Fix mha_single_token for ACL
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangYiIntel committed Nov 25, 2024
1 parent 5f56f34 commit 202e406
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float*
}

#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
static ov::float16 dot_product_fp16(ov::float16* a, ov::float16* b, size_t n, float* scale, float* zp, float* head_sum, size_t group_size = 0) {
static ov::float16 dot_product_fp16(ov::float16* a, ov::float16* b, size_t n, float* scale, float* zp, float* head_sum) {
size_t i = 0;
ov::float16 sum = 0.0f;
auto vsum0 = vdupq_n_f16(0.0f);
Expand Down
8 changes: 4 additions & 4 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
MHAKernel<KType, T> kernel;
MHASingleToken kernel_single_token;

explicit AttentionExecutor(GraphContext::CPtr ctx, size_t k_group_size = 0, size_t v_group_size = 0)
explicit AttentionExecutor(GraphContext::CPtr ctx, size_t k_group_size, size_t v_group_size)
: context(ctx),
kernel(context),
kernel_single_token(k_group_size, v_group_size) {}
Expand Down Expand Up @@ -1148,12 +1148,12 @@ void ScaledDotProductAttention::createPrimitive() {
}
#elif defined(OV_CPU_WITH_ACL)
if (rtPrecision == ov::element::f16) {
executor = std::make_shared<AttentionExecutor<KT_ACL, ov::float16>>(context);
executor = std::make_shared<AttentionExecutor<KT_ACL, ov::float16>>(context, m_key_group_size, m_value_group_size);
} else {
executor = std::make_shared<AttentionExecutor<KT_ACL, float>>(context);
executor = std::make_shared<AttentionExecutor<KT_ACL, float>>(context, m_key_group_size, m_value_group_size);
}
#else
executor = std::make_shared<AttentionExecutor<KT_REF, float>>(context);
executor = std::make_shared<AttentionExecutor<KT_REF, float>>(context, m_key_group_size, m_value_group_size);
#endif
return executor;
};
Expand Down

0 comments on commit 202e406

Please sign in to comment.