diff --git a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp index 7e566038a..ed49de72c 100644 --- a/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp +++ b/include/experimental/group/gemm/impl/int4_dequantize_xe.hpp @@ -157,6 +157,9 @@ class gemm_t< : reg_layout::tiled; static constexpr reg_layout reg_layout_b = + is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled; + + static constexpr reg_layout reg_layout_b_acc = // fpu compute_policy::mma_engine == mma_engine::fpu ? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled) @@ -214,7 +217,7 @@ class gemm_t< tile_size_y_b, block_size_x_b, block_size_y_b, - reg_layout_b>; + reg_layout_b_acc>; using matB_acc_t = subgroup::tile_t; public: @@ -696,9 +699,9 @@ class gemm_t< tile_mma::mma( matAcc, matAcc, matC, matB_acc, matA_acc, i == compute_stages - 1); } else { - if constexpr (is_col_major_b) { - tile_transpose(matB_acc); - } + if constexpr ( + matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled) + subgroup::vnni_convert(matB_acc); tile_mma::mma(matC, matC, matB_acc, matA_acc); } if constexpr (enable_periodic_sync) { diff --git a/include/subgroup/tile/impl/op_function.hpp b/include/subgroup/tile/impl/op_function.hpp index 44d2f6569..9ac7e8e15 100644 --- a/include/subgroup/tile/impl/op_function.hpp +++ b/include/subgroup/tile/impl/op_function.hpp @@ -704,36 +704,34 @@ layout_convert(T_dst& dst, T_src& src) { } } -template -void dump_mat( - T mat, - size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x, - size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) { -#pragma unroll - for (size_t row = 0; row < tile_y; row++) { -#pragma unroll - for (size_t col = 0; col < tile_x; col++) { - sycl::ext::oneapi::experimental::printf( - "%x(%d) ", - int(native_type_t(mat.reg[row * tile_x + col])), - int(native_type_t(mat.reg[row * tile_x + col]))); - } - sycl::ext::oneapi::experimental::printf("\n"); - } - sycl::ext::oneapi::experimental::printf("\n "); -} template void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) { #pragma unroll for (size_t row = 0; row < tile_y; row++) { #pragma unroll for (size_t col = 0; col < tile_x; col++) { - sycl::ext::oneapi::experimental::printf( - "%d ", (int)(sycl::half)mat[row * tile_x + col]); + const auto&& v = int64_t( + native_type_t(mat[row * tile_x + col])); + (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) + ? sycl::ext::oneapi::experimental::printf( + "%08x(%10u) ", int(v), int(v)) + : (std::is_same::value || + std::is_same::value) + ? sycl::ext::oneapi::experimental::printf("%016llx(%20llu) ", v, v) + : sycl::ext::oneapi::experimental::printf("%3lld ", v); } sycl::ext::oneapi::experimental::printf("\n"); } sycl::ext::oneapi::experimental::printf("\n"); } - +template +void dump_mat( + T mat, + size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x, + size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) { + dump_mat_reg(mat.reg, tile_x, tile_y); +} } // namespace gpu::xetla::subgroup diff --git a/include/subgroup/tile/impl/tile_op_functor.hpp b/include/subgroup/tile/impl/tile_op_functor.hpp index ab1f0038e..7cd83c151 100644 --- a/include/subgroup/tile/impl/tile_op_functor.hpp +++ b/include/subgroup/tile/impl/tile_op_functor.hpp @@ -83,6 +83,11 @@ struct dequant_int4_weight_t { constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y; static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2; + constexpr bool trans_acc = + matB_t::register_layout == reg_layout::transpose_tiled && + (matB_acc_t::register_layout == reg_layout::tiled || + matB_acc_t::register_layout == reg_layout::vnni_tiled); + constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b; constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b; #pragma unroll @@ -149,9 +154,18 @@ struct dequant_int4_weight_t { cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) - int8_t(8); } - dst_blk.xetla_select(jj * block_size_y_b + ii) = - cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * - scale.reg[scale_idx]; + // Scale and write back to matB_acc + if constexpr (trans_acc) { + dst_blk.xetla_select( + ii * block_size_x_b + jj) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + + } else { + dst_blk.xetla_select(jj * block_size_y_b + ii) = + cvt_blk_i8.xetla_select(jj * block_size_y_b + ii) * + scale.reg[scale_idx]; + } // sycl::ext::oneapi::experimental::printf( // "scale[%d] %f \n", diff --git a/tests/integration/gemv/int4/main.cpp b/tests/integration/gemv/int4/main.cpp index 93803441d..b9162e3f2 100644 --- a/tests/integration/gemv/int4/main.cpp +++ b/tests/integration/gemv/int4/main.cpp @@ -31,14 +31,14 @@ template class test_col_major_1 { public: // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; + static constexpr size_t mat_m = 4096; static constexpr size_t mat_n = 4096; static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 1; - static constexpr size_t wg_n = 1; - static constexpr size_t sg_m = 1; - static constexpr size_t sg_n = 1; - static constexpr size_t sg_k = 512 / sg_m; + static constexpr size_t wg_m = 64; + static constexpr size_t wg_n = 32; + static constexpr size_t sg_m = 16; + static constexpr size_t sg_n = 8; + static constexpr size_t sg_k = 32; static constexpr size_t dequant_s = 128; static constexpr quant_mode quant_mode = quant_mode::I4_SYM; @@ -46,8 +46,8 @@ class test_col_major_1 { static constexpr size_t global_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::col_major; - static constexpr mma_engine mma_eng = mma_engine::fpu; - static constexpr gpu_arch arch = gpu_arch::XeLpg; + static constexpr mma_engine mma_eng = mma_engine::xmx; + static constexpr gpu_arch arch = gpu_arch::XeHpg; using data_type_a = scalar_t; using data_type_b = int4x8; using data_type_c = scalar_t; @@ -110,14 +110,17 @@ int gemm_result_validate( bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation"); #ifdef UT_DEBUG - // for (uint32_t i = 0; i < m; i++) { - // for (uint32_t j = 0; j < n; j++) { - // std::cout << float(sycl::half(C[i * n + j])) << " "; - // } - // std::cout << std::endl; - // } + if (m * n <= 4096) { + std::cout << "result:\n"; + for (uint32_t i = 0; i < m; i++) { + for (uint32_t j = 0; j < n; j++) { + std::cout << float(sycl::half(C[i * n + j])) << " "; + } + std::cout << "\n"; + } + } #endif - std::cout << (!result ? "FAILED\n" : "PASSED\n"); + std::cout << (!result ? "FAILED" : "PASSED") << std::endl; return result ? 0 : 1; } @@ -187,12 +190,15 @@ std::vector dequantize_weight( } } #ifdef UT_DEBUG - // for (uint32_t i = 0; i < matrix_n; i++) { - // for (uint32_t j = 0; j < matrix_k; j++) { - // std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; - // } - // std::cout << std::endl; - // } + if (matrix_n * matrix_k <= 4096) { + std::cout << "dequantize_weight:\n"; + for (uint32_t i = 0; i < matrix_n; i++) { + for (uint32_t j = 0; j < matrix_k; j++) { + std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " "; + } + std::cout << std::endl; + } + } #endif return b_out; } @@ -387,12 +393,14 @@ void dequantize_gemv_run(int iter) { if constexpr (std::is_same_v) { B_h[i] = random_uint8(); #ifdef UT_DEBUG - B_h[i] = 0x77; + B_h[i] = ((7 + i) % 15 + 1) * 0x11; + if (i >= size_b) + B_h[i] = -1; #endif } else if constexpr (std::is_same_v) { B_h[i] = random_uint32(); #ifdef UT_DEBUG - B_h[i] = 0x77777777; + B_h[i] = ((7 + i) % 15 + 1) * 0x11111111; #endif } }