Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
add xmx colmajor
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Aug 28, 2024
1 parent b0efdf4 commit d002978
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 51 deletions.
11 changes: 7 additions & 4 deletions include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class gemm_t<
: reg_layout::tiled;

static constexpr reg_layout reg_layout_b =
is_col_major_b ? reg_layout::transpose_tiled : reg_layout::tiled;

static constexpr reg_layout reg_layout_b_acc =
// fpu
compute_policy::mma_engine == mma_engine::fpu
? (is_gemv ? reg_layout::transpose_tiled : reg_layout::tiled)
Expand Down Expand Up @@ -214,7 +217,7 @@ class gemm_t<
tile_size_y_b,
block_size_x_b,
block_size_y_b,
reg_layout_b>;
reg_layout_b_acc>;
using matB_acc_t = subgroup::tile_t<dtype_mma_b, matB_acc_tile_desc_t>;

public:
Expand Down Expand Up @@ -696,9 +699,9 @@ class gemm_t<
tile_mma::mma(
matAcc, matAcc, matC, matB_acc, matA_acc, i == compute_stages - 1);
} else {
if constexpr (is_col_major_b) {
tile_transpose(matB_acc);
}
if constexpr (
matB_acc_tile_desc_t::register_layout == reg_layout::vnni_tiled)
subgroup::vnni_convert(matB_acc);
tile_mma::mma(matC, matC, matB_acc, matA_acc);
}
if constexpr (enable_periodic_sync) {
Expand Down
40 changes: 19 additions & 21 deletions include/subgroup/tile/impl/op_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,36 +704,34 @@ layout_convert(T_dst& dst, T_src& src) {
}
}

template <typename T>
void dump_mat(
T mat,
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
#pragma unroll
for (size_t row = 0; row < tile_y; row++) {
#pragma unroll
for (size_t col = 0; col < tile_x; col++) {
sycl::ext::oneapi::experimental::printf(
"%x(%d) ",
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])),
int(native_type_t<typename T::dtype>(mat.reg[row * tile_x + col])));
}
sycl::ext::oneapi::experimental::printf("\n");
}
sycl::ext::oneapi::experimental::printf("\n ");
}
template <typename T>
void dump_mat_reg(T mat, size_t tile_x, size_t tile_y) {
#pragma unroll
for (size_t row = 0; row < tile_y; row++) {
#pragma unroll
for (size_t col = 0; col < tile_x; col++) {
sycl::ext::oneapi::experimental::printf(
"%d ", (int)(sycl::half)mat[row * tile_x + col]);
const auto&& v = int64_t(
native_type_t<typename T::element_type>(mat[row * tile_x + col]));
(std::is_same<typename T::element_type, int4x2>::value ||
std::is_same<typename T::element_type, int4x8>::value ||
std::is_same<typename T::element_type, uint32_t>::value ||
std::is_same<typename T::element_type, int32_t>::value)
? sycl::ext::oneapi::experimental::printf(
"%08x(%10u) ", int(v), int(v))
: (std::is_same<typename T::element_type, uint64_t>::value ||
std::is_same<typename T::element_type, int64_t>::value)
? sycl::ext::oneapi::experimental::printf("%016llx(%20llu) ", v, v)
: sycl::ext::oneapi::experimental::printf("%3lld ", v);
}
sycl::ext::oneapi::experimental::printf("\n");
}
sycl::ext::oneapi::experimental::printf("\n");
}

template <typename T>
void dump_mat(
T mat,
size_t tile_x = T::reg_transpose ? T::tile_size_y : T::tile_size_x,
size_t tile_y = T::reg_transpose ? T::tile_size_x : T::tile_size_y) {
dump_mat_reg(mat.reg, tile_x, tile_y);
}
} // namespace gpu::xetla::subgroup
20 changes: 17 additions & 3 deletions include/subgroup/tile/impl/tile_op_functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ struct dequant_int4_weight_t {
constexpr uint32_t block_size_y_b = matB_acc_t::block_size_y;
static constexpr uint32_t pack_ratio = sizeof(typename matB_t::dtype) * 2;

constexpr bool trans_acc =
matB_t::register_layout == reg_layout::transpose_tiled &&
(matB_acc_t::register_layout == reg_layout::tiled ||
matB_acc_t::register_layout == reg_layout::vnni_tiled);

constexpr uint32_t num_block_x = tile_size_x_b / block_size_x_b;
constexpr uint32_t num_block_y = tile_size_y_b / block_size_y_b;
#pragma unroll
Expand Down Expand Up @@ -149,9 +154,18 @@ struct dequant_int4_weight_t {
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) -
int8_t(8);
}
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];
// Scale and write back to matB_acc
if constexpr (trans_acc) {
dst_blk.xetla_select<step, block_size_x_b>(
ii * block_size_x_b + jj) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];

} else {
dst_blk.xetla_select<step, 1>(jj * block_size_y_b + ii) =
cvt_blk_i8.xetla_select<step, 1>(jj * block_size_y_b + ii) *
scale.reg[scale_idx];
}

// sycl::ext::oneapi::experimental::printf(
// "scale[%d] %f \n",
Expand Down
54 changes: 31 additions & 23 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ template <typename scalar_t>
class test_col_major_1 {
public:
// Extract the parameters required by different test cases
static constexpr size_t mat_m = 1;
static constexpr size_t mat_m = 4096;
static constexpr size_t mat_n = 4096;
static constexpr size_t mat_k = 4096;
static constexpr size_t wg_m = 1;
static constexpr size_t wg_n = 1;
static constexpr size_t sg_m = 1;
static constexpr size_t sg_n = 1;
static constexpr size_t sg_k = 512 / sg_m;
static constexpr size_t wg_m = 64;
static constexpr size_t wg_n = 32;
static constexpr size_t sg_m = 16;
static constexpr size_t sg_n = 8;
static constexpr size_t sg_k = 32;
static constexpr size_t dequant_s = 128;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;

static constexpr size_t local_kslicing = 1;
static constexpr size_t global_kslicing = 1;
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::fpu;
static constexpr gpu_arch arch = gpu_arch::XeLpg;
static constexpr mma_engine mma_eng = mma_engine::xmx;
static constexpr gpu_arch arch = gpu_arch::XeHpg;
using data_type_a = scalar_t;
using data_type_b = int4x8;
using data_type_c = scalar_t;
Expand Down Expand Up @@ -110,14 +110,17 @@ int gemm_result_validate(
bool result = buff_cmp::xetla_buff_cmp(data, other, "gemv validation");

#ifdef UT_DEBUG
// for (uint32_t i = 0; i < m; i++) {
// for (uint32_t j = 0; j < n; j++) {
// std::cout << float(sycl::half(C[i * n + j])) << " ";
// }
// std::cout << std::endl;
// }
if (m * n <= 4096) {
std::cout << "result:\n";
for (uint32_t i = 0; i < m; i++) {
for (uint32_t j = 0; j < n; j++) {
std::cout << float(sycl::half(C[i * n + j])) << " ";
}
std::cout << "\n";
}
}
#endif
std::cout << (!result ? "FAILED\n" : "PASSED\n");
std::cout << (!result ? "FAILED" : "PASSED") << std::endl;
return result ? 0 : 1;
}

Expand Down Expand Up @@ -187,12 +190,15 @@ std::vector<data_type_acc_in> dequantize_weight(
}
}
#ifdef UT_DEBUG
// for (uint32_t i = 0; i < matrix_n; i++) {
// for (uint32_t j = 0; j < matrix_k; j++) {
// std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
// }
// std::cout << std::endl;
// }
if (matrix_n * matrix_k <= 4096) {
std::cout << "dequantize_weight:\n";
for (uint32_t i = 0; i < matrix_n; i++) {
for (uint32_t j = 0; j < matrix_k; j++) {
std::cout << float(sycl::half(b_out[i * matrix_k + j])) << " ";
}
std::cout << std::endl;
}
}
#endif
return b_out;
}
Expand Down Expand Up @@ -387,12 +393,14 @@ void dequantize_gemv_run(int iter) {
if constexpr (std::is_same_v<int4x2, data_type_b>) {
B_h[i] = random_uint8();
#ifdef UT_DEBUG
B_h[i] = 0x77;
B_h[i] = ((7 + i) % 15 + 1) * 0x11;
if (i >= size_b)
B_h[i] = -1;
#endif
} else if constexpr (std::is_same_v<int4x8, data_type_b>) {
B_h[i] = random_uint32();
#ifdef UT_DEBUG
B_h[i] = 0x77777777;
B_h[i] = ((7 + i) % 15 + 1) * 0x11111111;
#endif
}
}
Expand Down

0 comments on commit d002978

Please sign in to comment.