Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
pvc wa for 2d-load
Browse files Browse the repository at this point in the history
  • Loading branch information
DDEle committed Aug 28, 2024
1 parent d002978 commit bfaa4f3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 3 additions & 1 deletion include/experimental/group/gemm/impl/int4_dequantize_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ class gemm_t<
mem_desc_scale_t,
scale_tile_desc_t,
subgroup::msg_type_v<scale_tile_desc_t, mem_desc_scale_t>,
arch_tag>;
(tile_size_x_b > 1 && arch_tag == gpu_arch::XeHpc) // TODO(Yi): PVC 2d WA
? gpu_arch::XeHpg
: arch_tag>;

// compress int4 along N dimensions
using zero_pt_tile_desc_t = subgroup::tile_desc_t<
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/gemv/int4/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include <utils/utils.hpp>
#include "xetla.hpp"
// #define UT_DEBUG
#define UT_DEBUG
using namespace gpu::xetla;
using namespace gpu::xetla::group;
// The number of times the kernel is executed
Expand All @@ -35,9 +35,9 @@ class test_col_major_1 {
static constexpr size_t mat_n = 4096;
static constexpr size_t mat_k = 4096;
static constexpr size_t wg_m = 64;
static constexpr size_t wg_n = 32;
static constexpr size_t wg_n = 64;
static constexpr size_t sg_m = 16;
static constexpr size_t sg_n = 8;
static constexpr size_t sg_n = 16;
static constexpr size_t sg_k = 32;
static constexpr size_t dequant_s = 128;
static constexpr quant_mode quant_mode = quant_mode::I4_SYM;
Expand All @@ -47,7 +47,7 @@ class test_col_major_1 {
static constexpr mem_layout layout_a = mem_layout::row_major;
static constexpr mem_layout layout_b = mem_layout::col_major;
static constexpr mma_engine mma_eng = mma_engine::xmx;
static constexpr gpu_arch arch = gpu_arch::XeHpg;
static constexpr gpu_arch arch = gpu_arch::XeHpc;
using data_type_a = scalar_t;
using data_type_b = int4x8;
using data_type_c = scalar_t;
Expand Down

0 comments on commit bfaa4f3

Please sign in to comment.