Skip to content

Commit

Permalink
remove unused comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xczhai committed Sep 24, 2024
1 parent ceac9d5 commit 63700c7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
// For compatibility, all input_kvs are permuted to BHLS
size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3];
// Internal LBHS layout has strides[L] > strides[B]
// TODO: T2 == u8, assertion failed.
// assert(k_src.m_strides[2] > k_src.m_strides[0]);
parallel_for3d(L1, B, H, [&](size_t m, size_t b, size_t h) {
auto p_k = k_scale_zp.ptr<float>(m, b, h);
auto p_v = v_scale_zp.ptr<float>(m, b, h);
Expand Down
48 changes: 21 additions & 27 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@ BrgemmKernel::BrgemmKernel(size_t M,
if (is_f16 && !mayiuse(avx512_core_fp16))
THROW_ERROR("brgemm f16 kernel could only be used above avx512_f16");

is_avx_f16_only = inType == ov::element::f16 && mayiuse(avx512_core_fp16) && !mayiuse(avx512_core_amx_fp16);
// TODO: AMX_FP16
srcType = weiType = inType;
// f16 is supported by upconverted to f32
// If isa is avx512_core_fp16, f16 is supported by upconverted to f32
is_avx_f16_only = inType == ov::element::f16 && mayiuse(avx512_core_fp16) && !mayiuse(avx512_core_amx_fp16);
if (is_avx_f16_only) {
srcType = ov::element::f32;
weiType = ov::element::f32;
}
brgVnniFactor = 4 / weiType.size();

/*
AVX? AMX?
fp32 Y N
bf16 Y Y
fp16 Y N(TODO: AMX_FP16)
AVX AMX
fp32 Y N
bf16 Y Y
fp16 Y Y
*/
bool isAMXSupported = (is_bf16 && mayiuse(avx512_core_amx)) || (is_f16 && mayiuse(avx512_core_amx_fp16));
bool isBrgWithAMX = isAMXSupported && !is_avx_f16_only;
Expand Down Expand Up @@ -101,8 +100,7 @@ BrgemmKernel::BrgemmKernel(size_t M,
brgemmCtx.M = M_;
brgemmCtx.N = N_;
brgemmCtx.K = K_;
// brgemmCtx.LDA = k ? K_blk : lda;
brgemmCtx.LDA = k ? K_blk : (is_avx_f16_only ? K : lda); // TODO: f16 use f32 internally
brgemmCtx.LDA = k ? K_blk : (is_avx_f16_only ? K : lda); // f16 use f32 internally
brgemmCtx.LDB = (!is_f32 || b_transposed) ? rnd_up(N, N_blk) : ldb; // bf16/fp16/b_transposed needs copy
brgemmCtx.LDC = ldc;
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(srcType));
Expand All @@ -122,13 +120,11 @@ BrgemmKernel::BrgemmKernel(size_t M,
auto& brgemmCtx0 = brgCtxs[brg0BaseIdx];

if ((brgemmCtx0.is_with_amx && K_tail) || is_avx_f16_only) {
// AMX_BF16 needs to copy tail
// TODO: fp16 needs to copy all
init_brgemm_copy_a(brgCopyAKernel,
K,
K_blk,
K_tail,
is_avx_f16_only ? K : K_blk, // TODO: AMX_FP16
is_avx_f16_only ? K : K_blk,
brgemmCtx0.dt_in0,
false,
lda * inType.size());
Expand Down Expand Up @@ -171,12 +167,18 @@ void BrgemmKernel::init_brgemm(brgemmCtx& ctx,
cpu_isa_t isa;
if (use_amx) {
isa = isa_undef;
} else if (inType == ov::element::f16) {
// TODO: AMX_FP16
isa = avx512_core_fp16;
} else if (mayiuse(avx512_core)){
isa = ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16
: (is_int8 ? avx512_core_vnni : avx512_core);
if (ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 && mayiuse(avx512_core_bf16)) {
isa = avx512_core_bf16;
} else if (ctx.dt_in0 == dnnl_data_type_t::dnnl_f16 && mayiuse(avx512_core_fp16)) {
isa = avx512_core_fp16;
} else {
if (is_int8) {
isa = avx512_core_vnni;
} else {
isa = avx512_core;
}
}
} else {
isa = cpu_isa_t::avx2;
}
Expand Down Expand Up @@ -252,18 +254,14 @@ void BrgemmKernel::init_brgemm_copy_a(
brgCopyKernelConf.s8s8_compensation_required = false;
brgCopyKernelConf.wei_zp_type = dnnl::impl::cpu::x64::none;
brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none;
// TODO: AMX_FP16
brgCopyKernelConf.src_dt = is_avx_f16_only ? dnnl_data_type_t::dnnl_f32 : dt_in0;
brgCopyKernelConf.copy_A_src_stride = copy_A_src_stride;
// TODO: AMX_FP16
// copy_a_kernel assumes that in/out tensor has same data type except f16
// copy_a_kernel has special path for f16: assuming input(f16) -> output(f32)
brgCopyKernelConf.a_dt_sz = is_avx_f16_only ? sizeof(ov::float16) : DnnlExtensionUtils::sizeOfDataType(static_cast<dnnl::memory::data_type>(dt_in0));
// copied A has the same precision of original
// TODO: AMX_FP16
brgCopyKernelConf.tr_a_dt_sz = is_avx_f16_only ? sizeof(float) : DnnlExtensionUtils::sizeOfDataType(static_cast<dnnl::memory::data_type>(dt_in0));
brgCopyKernelConf.transposed_A = transpose;
// TODO: AMX_FP16
brgCopyKernelConf.isa = is_avx_f16_only ? avx512_core_fp16 : avx512_core_amx;

create_brgemm_matmul_copy_a(brgCopyKernel, &brgCopyKernelConf);
Expand Down Expand Up @@ -308,11 +306,10 @@ void BrgemmKernel::init_brgemm_copy_b(
brgCopyKernelConf.req_wei_vnni_downconvert = false;

if (is_with_amx) {
brgCopyKernelConf.isa = inType == ov::element::f16 ? avx512_core_amx_fp16 : avx512_core_amx;
brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_f16 ? avx512_core_amx_fp16 : avx512_core_amx;
brgCopyKernelConf.s8s8_compensation_required = false;
} else {
if (inType == ov::element::f16) {
// TODO: AMX_FP16
if (dt_in0 == dnnl_data_type_t::dnnl_f16) {
brgCopyKernelConf.isa = mayiuse(avx512_core_fp16) ? avx512_core_fp16 : avx2_vnni_2;
} else {
brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni;
Expand Down Expand Up @@ -365,9 +362,6 @@ void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void*
size_t K0_step0 = brgCtxs[brgIdx0].K;
auto cur_M_blk = is_M_tail ? M_tail : M_blk;
if (brgCopyAKernel) {
// TODO: AMX_FP16
// bf16 only copy tailed data;
// f16 copy all data
size_t K_offset = is_avx_f16_only ? 0 : (K < K_blk ? 0 : K0_step0 * srcType.size());
auto pCopyKernelIn = ptr_A + K_offset;
auto pCopyKernelOut = ptr_scartch_a;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/paged_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ ov::element::Type PagedAttention::getRuntimePrecision() const {
// bf16 should be enabled only when platform supports
if (rtPrecision == ov::element::bf16 && ov::with_cpu_x86_bfloat16()) {
rtPrecision = ov::element::bf16;
} else if (rtPrecision == ov::element::f16 && ov::with_cpu_x86_avx512_core_fp16()) {
rtPrecision = ov::element::f16;
} else {
rtPrecision = ov::element::f32;
}
Expand Down

0 comments on commit 63700c7

Please sign in to comment.