Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

XeTLA XMX colmajor #304

Open
wants to merge 3 commits into
base: xetla
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ class gemm_t<
: is_vnni_tiled_a ? reg_layout::vnni_tiled
: reg_layout::tiled;

// reg_layout of the load result
static constexpr reg_layout reg_layout_b =
is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled;

// reg_layout required by mma
static constexpr reg_layout reg_layout_b_acc =
// fpu
compute_policy::mma_engine == mma_engine::fpu
? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled)
Expand Down Expand Up @@ -214,7 +219,7 @@ class gemm_t<
tile_size_y_b,
block_size_x_b,
block_size_y_b,
reg_layout_b>;
reg_layout_b_acc>;
using matB_acc_t = subgroup::tile_t<dtype_mma_b, matB_acc_tile_desc_t>;

public:
Expand Down Expand Up @@ -281,7 +286,9 @@ class gemm_t<
mem_desc_scale_t,
scale_tile_desc_t,
subgroup::msg_type_v<scale_tile_desc_t, mem_desc_scale_t>,
arch_tag>;
(tile_size_x_b > 1 && arch_tag == gpu_arch::XeHpc) // TODO(Yi): PVC 2d WA
? gpu_arch::XeHpg
: arch_tag>;

// compress int4 along N dimensions
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
Expand Down Expand Up @@ -629,9 +636,10 @@ class gemm_t<
matA_acc,
i == args.inner_loop_count - 1);
} else {
if constexpr (is_col_major_b) {
tile_transpose(matB_acc);
}
// The result of dequantize should always be (plain) tiled
if constexpr (
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
subgroup::vnni_convert(matB_acc);
tile_mma::mma(matC, matC, matB_acc, matA_acc);
}
if constexpr (enable_periodic_sync) {
Expand Down Expand Up @@ -696,9 +704,10 @@ class gemm_t<
tile_mma::mma(
matAcc, matAcc, matC, matB_acc, matA_acc, i == compute_stages - 1);
} else {
if constexpr (is_col_major_b) {
tile_transpose(matB_acc);
}
// The result of dequantize should always be (plain) tiled
if constexpr (
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
subgroup::vnni_convert(matB_acc);
tile_mma::mma(matC, matC, matB_acc, matA_acc);
}
if constexpr (enable_periodic_sync) {
Expand Down
43 changes: 22 additions & 21 deletions include/subgroup/tile/impl/op_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,36 +704,37 @@ layout_convert(T_dst& dst, T_src& src) {
}
}

template <typename T>
void dump_mat(
T mat,
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
#pragma unroll
for (size_t row = 0; row < tile_y; row++) {
#pragma unroll
for (size_t col = 0; col < tile_x; col++) {
sycl::ext::oneapi::experimental::printf(
"%x(%d) ",
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
}
sycl::ext::oneapi::experimental::printf("\n");
}
sycl::ext::oneapi::experimental::printf("\n ");
}
template <typename T>
void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) {
#pragma unroll
for (size_t row = 0; row < tile_y; row++) {
#pragma unroll
for (size_t col = 0; col < tile_x; col++) {
sycl::ext::oneapi::experimental::printf(
"%d ", (int)(sycl::half)mat[row * tile_x + col]);
const auto&& v = int64_t(
native_type_t<typename T::element_type>(mat[row * tile_x + col]));
constexpr bool is_int32 =
(std::is_same<typename T::element_type, int4x2>::value ||
std::is_same<typename T::element_type, int4x8>::value ||
std::is_same<typename T::element_type, uint32_t>::value ||
std::is_same<typename T::element_type, int32_t>::value);
constexpr bool is_int64 =
(std::is_same<typename T::element_type, uint64_t>::value ||
std::is_same<typename T::element_type, int64_t>::value);
is_int32 ? sycl::ext::oneapi::experimental::printf(
"%08x(%10u) ", int(v), int(v))
: is_int64
? sycl::ext::oneapi::experimental::printf("%016llx(%20llu) ", v, v)
: sycl::ext::oneapi::experimental::printf("%3lld ", v);
}
sycl::ext::oneapi::experimental::printf("\n");
}
sycl::ext::oneapi::experimental::printf("\n");
}

template <typename T>
void dump_mat(
T mat,
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
dump_mat_reg(mat.reg, tile_x, tile_y);
}
} // namespace gpu::xetla::subgroup
21 changes: 18 additions & 3 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ struct dequant_int4_weight_t {
constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y;
static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2;

// If the result of dequant should be tranposed before storing to matB_acc
constexpr bool trans_acc =
matB_t::register_layout == reg_layout::transpose_tiled &&
(matB_acc_t::register_layout == reg_layout::tiled ||
matB_acc_t::register_layout == reg_layout::vnni_tiled);

constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
#pragma unroll
Expand Down Expand Up @@ -149,9 +155,18 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
}
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];
// Scale and write back to matB_acc
if constexpr (trans_acc) {
dst_blk.xetla_select<step, block_size_x_b>(
ii * block_size_x_b + jj) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];

} else {
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];
}

// sycl::ext::oneapi::experimental::printf(
// "scale[%d] %f \n",
Expand Down
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
set(TEST_GPU_ARCH "xe_lpg" CACHE STRING "Set gpu_arch to test. Options: xe_lpg,xe_hpg,xe_hpc")
string(TOUPPER "${TEST_GPU_ARCH}" TEST_GPU_ARCH)
add_compile_definitions("TEST_GPU_ARCH_${TEST_GPU_ARCH}")

add_subdirectory(./integration)
add_subdirectory(./unit)
2 changes: 1 addition & 1 deletion tests/integration/fmha/fmha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void fmha_run_(const test_params_t& p, uint32_t iter, uint32_t warmup) {
using fmha_forward_op_t = gpu::xetla::fmha::fmha_forward_t<
policy_t,
FMHA_T,
gpu_arch::XeLpg,
TEST_GPU_ARCH,
false,
kUseBias,
false,
Expand Down
65 changes: 36 additions & 29 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <utils/utils.hpp>
#include "xetla.hpp"
// #define UT_DEBUG
#define UT_DEBUG
using namespace gpu::xetla;
using namespace gpu::xetla::group;
// The number of times the kernel is executed
Expand All @@ -31,23 +31,23 @@ template <typename scalar_t>
class test_col_major_1 {
public:
// Extract the parameters required by different test cases
static constexpr size_t mat_m = 1;
static constexpr size_t mat_m = 4096;
static constexpr size_t mat_n = 4096;
static constexpr size_t mat_k = 4096;
static constexpr size_t wg_m = 1;
static constexpr size_t wg_n = 1;
static constexpr size_t sg_m = 1;
static constexpr size_t sg_n = 1;
static constexpr size_t sg_k = 512 / sg_m;
static constexpr size_t wg_m = 64;
static constexpr size_t wg_n = 64;
static constexpr size_t sg_m = 16;
static constexpr size_t sg_n = 16;
static constexpr size_t sg_k = 32;
static constexpr size_t dequant_s = 128;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;

static constexpr size_t local_kslicing = 1;
static constexpr size_t global_kslicing = 1;
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
static constexpr mma_engine mma_eng =
arch_has_xmx<TEST_GPU_ARCH> ? mma_engine::xmx : mma_engine::fpu;
using data_type_a = scalar_t;
using data_type_b = int4x8;
using data_type_c = scalar_t;
Expand All @@ -72,7 +72,6 @@ class test_col_major_2 {
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
using data_type_a = fp16;
using data_type_b = int4x8;
using data_type_c = fp16;
Expand Down Expand Up @@ -110,14 +109,17 @@ int gemm_result_validate(
bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation");

#ifdef UT_DEBUG
// for (uint32_t i = 0; i < m; i++) {
// for (uint32_t j = 0; j < n; j++) {
// std::cout << float(sycl::half(C[i * n + j])) << " ";
// }
// std::cout << std::endl;
// }
if (m * n <= 4096) {
std::cout << "result:\n";
for (uint32_t i = 0; i < m; i++) {
for (uint32_t j = 0; j < n; j++) {
std::cout << float(sycl::half(C[i * n + j])) << " ";
}
std::cout << "\n";
}
}
#endif
std::cout << (!result ? "FAILED\n" : "PASSED\n");
std::cout << (!result ? "FAILED" : "PASSED") << std::endl;
return result ? 0 : 1;
}

Expand Down Expand Up @@ -187,12 +189,15 @@ std::vector<data_type_acc_in> dequantize_weight(
}
}
#ifdef UT_DEBUG
// for (uint32_t i = 0; i < matrix_n; i++) {
// for (uint32_t j = 0; j < matrix_k; j++) {
// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
// }
// std::cout << std::endl;
// }
if (matrix_n * matrix_k <= 4096) {
std::cout << "dequantize_weight:\n";
for (uint32_t i = 0; i < matrix_n; i++) {
for (uint32_t j = 0; j < matrix_k; j++) {
std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
}
std::cout << std::endl;
}
}
#endif
return b_out;
}
Expand Down Expand Up @@ -297,22 +302,22 @@ void dequantize_gemv_run(int iter) {
data_type_zero_pt,
quant_info,
Test::mma_eng,
Test::arch>;
TEST_GPU_ARCH>;

using gemm_t = xetla::group::
gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t>;

using bias_op_t =
gpu::xetla::subgroup::bias_add_op_t<mem_desc_bias_t, Test::arch>;
gpu::xetla::subgroup::bias_add_op_t<mem_desc_bias_t, TEST_GPU_ARCH>;

using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t<bias_op_t>;

using epilogue_t = xetla::group::epilogue_t<
xetla::group::epilogue_policy_tile_op<tile_op_t, Test::arch>,
xetla::group::epilogue_policy_tile_op<tile_op_t, TEST_GPU_ARCH>,
tile_shape,
mem_desc_c_t>;

using group_swizzle = xetla::kernel::group_swizzle_default<Test::arch>;
using group_swizzle = xetla::kernel::group_swizzle_default<TEST_GPU_ARCH>;

using gemm_op_t = xetla::kernel::gemm_universal_t<
gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing<
Expand Down Expand Up @@ -387,12 +392,14 @@ void dequantize_gemv_run(int iter) {
if constexpr (std::is_same_v<int4x2, data_type_b>) {
B_h[i] = random_uint8();
#ifdef UT_DEBUG
B_h[i] = 0x77;
B_h[i] = ((7 + i) % 15 + 1) * 0x11;
if (i >= size_b)
B_h[i] = -1;
#endif
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
B_h[i] = random_uint32();
#ifdef UT_DEBUG
B_h[i] = 0x77777777;
B_h[i] = ((7 + i) % 15 + 1) * 0x11111111;
#endif
}
}
Expand Down
10 changes: 10 additions & 0 deletions tests/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@
#include "execution.hpp"
#include "gemm_gen.hpp"
#include "profiling.hpp"

#if defined(TEST_GPU_ARCH_XE_LPG)
inline constexpr gpu_arch TEST_GPU_ARCH = gpu_arch::XeLpg;
#elif defined(TEST_GPU_ARCH_XE_HPG)
inline constexpr gpu_arch TEST_GPU_ARCH = gpu_arch::XeHpg;
#elif defined(TEST_GPU_ARCH_XE_HPC)
inline constexpr gpu_arch TEST_GPU_ARCH = gpu_arch::XeHpc;
#else
static_assert(false, "TEST_GPU_ARCH not defined");
#endif
Loading