Skip to content

Commit

Permalink
[Snippets][CPU] Applied Ivan comments
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Nov 27, 2024
1 parent 4a524b8 commit 96842c1
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 86 deletions.
3 changes: 2 additions & 1 deletion src/common/snippets/src/runtime_configurator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ void RuntimeConfigurator::update_loop_info(const lowered::LinearIRCPtr& linear_i

void RuntimeConfigurator::update_buffer_scratchpad_size(const lowered::LinearIRCPtr& linear_ir) const {
const auto& loop_manager = linear_ir->get_loop_manager();
m_config->buffer_scratchpad_size = linear_ir->get_static_buffer_scratchpad_size();
// Align initial buffer scratchpad size with cache line size
m_config->buffer_scratchpad_size = utils::rnd_up(linear_ir->get_static_buffer_scratchpad_size(), lowered::pass::SolveBufferMemory::byte_alignment);

auto is_not_executed = [&loop_manager](const lowered::ExpressionPtr& buffer_expr) {
const auto& loop_ids = buffer_expr->get_loop_ids();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,32 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
if (in.size() > 2)
mem_ptrs_idxs.emplace_back(in[2]);

if (std::dynamic_pointer_cast<BrgemmAMXKernelExecutor>(m_kernel_executor))
emit_call<BrgemmAMXKernelExecutor>(mem_ptrs_idxs);
else if (std::dynamic_pointer_cast<BrgemmKernelExecutor>(m_kernel_executor))
emit_call<BrgemmKernelExecutor>(mem_ptrs_idxs);
else
OV_CPU_JIT_EMITTER_THROW("uknown execuor type");
}

template<typename T,
typename std::enable_if<std::is_base_of<BrgemmBaseKernelExecutor, T>::value, bool>::type>
void jit_brgemm_emitter::emit_call(const std::vector<size_t>& mem_ptrs_idxs) const {
EmitABIRegSpills spill(h);
spill.preamble();

h->mov(h->rbp, m_is_with_amx ? reinterpret_cast<uint64_t>(BrgemmAMXKernelExecutor::execute)
: reinterpret_cast<uint64_t>(BrgemmKernelExecutor::execute));
auto reserved_stack_size = m_is_with_amx ? sizeof(BrgemmAMXKernelExecutor::call_args) : sizeof(BrgemmKernelExecutor::call_args);
h->mov(h->rbp, reinterpret_cast<uint64_t>(T::execute));
auto reserved_stack_size = sizeof(typename T::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64();

std::vector<size_t> brgemm_args_offsets;
if (m_is_with_amx) {
brgemm_args_offsets = {GET_OFF_BRGEMM_AMX_ARGS(A), GET_OFF_BRGEMM_AMX_ARGS(B), GET_OFF_BRGEMM_AMX_ARGS(C), GET_OFF_BRGEMM_AMX_ARGS(scratch)};
} else {
brgemm_args_offsets = {GET_OFF_BRGEMM_ARGS(A), GET_OFF_BRGEMM_ARGS(B), GET_OFF_BRGEMM_ARGS(C), GET_OFF_BRGEMM_ARGS(scratch)};
}
#define GET_OFF_CALL_ARGS(field) offsetof(typename T::call_args, field)
const std::vector<size_t> brgemm_args_offsets = { GET_OFF_CALL_ARGS(A), GET_OFF_CALL_ARGS(B), GET_OFF_CALL_ARGS(C), GET_OFF_CALL_ARGS(scratch) };
#undef GET_OFF_CALL_ARGS

const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
Expand All @@ -120,7 +128,7 @@ void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vec
h->mov(h->qword[h->rsp + brgemm_args_offsets.back()], reinterpret_cast<uintptr_t>(nullptr));

// abi_param1 always contains jit_snippets_call_args which has amx tile config for each thread
if (m_is_with_amx) {
if (std::is_same<T, BrgemmAMXKernelExecutor>()) {
h->lea(h->r10, h->ptr[abi_param1 + GET_OFF(amx_tile_config)]);
h->mov(h->qword[h->rsp + GET_OFF_BRGEMM_AMX_ARGS(amx_tile_config)], h->r10);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class jit_brgemm_emitter : public jit_emitter {
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

template <typename T,
typename std::enable_if<std::is_base_of<BrgemmBaseKernelExecutor, T>::value, bool>::type = true>
void emit_call(const std::vector<size_t>& mem_ptrs_idxs) const;

// Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in runtime
std::vector<size_t> m_memory_offsets{};
// Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace intel_cpu {

BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype,
bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(std::make_shared<StaticParams>(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) {
: BrgemmBaseKernelConfig(), m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) {
m_hash = compute_hash();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,17 @@
namespace ov {
namespace intel_cpu {

struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig, BrgemmBaseKernelConfig {
struct BrgemmKernelConfig : public BrgemmBaseKernelConfig {
public:
BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype,
bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa);
BrgemmKernelConfig() = delete;

std::unique_ptr<GenericConfig> get_clone_ptr() const override {
return std::unique_ptr<BrgemmKernelConfig>( new BrgemmKernelConfig(*this));
std::unique_ptr<snippets::KernelExecutorBase::GenericConfig> get_clone_ptr() const override {
return std::unique_ptr<BrgemmKernelConfig>(new BrgemmKernelConfig(*this));
}

bool is_completed() const override { return BrgemmBaseKernelConfig::is_completed(); }
size_t hash() const override { return BrgemmBaseKernelConfig::hash(); }

bool is_with_comp() const { return std::static_pointer_cast<StaticParams>(m_static_params)->is_with_comp; }

#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const override { return BrgemmBaseKernelConfig::to_string(); }
#endif
bool is_with_comp() const { return m_static_params->is_with_comp; }

private:
struct StaticParams : StaticBaseParams {
Expand All @@ -46,6 +39,10 @@ struct BrgemmKernelConfig : public snippets::KernelExecutorBase::GenericConfig,

const size_t m_hash {0};
};

std::shared_ptr<StaticBaseParams> get_static_params() const override { return m_static_params; }

std::shared_ptr<StaticParams> m_static_params {nullptr};
};

struct BrgemmCompiledKernel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace ov {
namespace intel_cpu {

BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
: BrgemmBaseKernelConfig(), m_static_params(std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
m_hash = compute_hash();
}

Expand All @@ -40,6 +40,10 @@ size_t BrgemmAMXKernelConfig::StaticParams::compute_hash() const {
return hash_combine(seed, vnni_factor);
}

bool BrgemmAMXKernelConfig::need_copy_a(dnnl_dim_t K) const {
return K % get_vnni_factor() > 0;
}

#ifdef SNIPPETS_DEBUG_CAPS
std::string BrgemmAMXKernelConfig::StaticParams::to_string() const {
std::stringstream ss;
Expand Down Expand Up @@ -73,13 +77,15 @@ std::shared_ptr<BrgemmAMXCompiledKernel> BrgemmAMXKernelExecutor::compile_kernel
}

if (K_tail != 0) {
K_tail = ov::snippets::utils::rnd_up(K_tail, config.get_vnni_factor());

const auto copy_A_src_stride = config.get_LDA() * dnnl_data_type_size(config.get_dt_in0());
const auto LDA = config.get_inner_K_blk();
auto LDA = config.get_LDA();
if (config.need_copy_a(K_tail)) {
const auto copy_A_src_stride = LDA * dnnl_data_type_size(config.get_dt_in0());
K_tail = ov::snippets::utils::rnd_up(K_tail, config.get_vnni_factor());
LDA = K_tail;

create_brgemm_copy_a_kernel(compiled_kernel->brgemm_copy_a_kernel, config.get_isa(), config.get_dt_in0(),
config.get_K(), config.get_inner_K_blk(), K_tail, copy_A_src_stride, LDA);
create_brgemm_copy_a_kernel(compiled_kernel->brgemm_copy_a_kernel, config.get_isa(), config.get_dt_in0(),
config.get_K(), config.get_inner_K_blk(), K_tail, copy_A_src_stride, LDA);
}

create_brgemm_kernel(compiled_kernel->brgemm_kernel_k_tail, config.get_dt_in0(), config.get_dt_in1(), config.get_isa(),
config.get_M(), config.get_N(), K_tail, LDA, config.get_LDB(), config.get_LDC(), beta,
Expand Down Expand Up @@ -155,9 +161,9 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c
const auto& config = static_cast<const BrgemmAMXKernelConfig&>(executor->get_config());
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr compiler kernel or invalid config");

const uint8_t* src_ptr = reinterpret_cast<const uint8_t*>(args->A);
const uint8_t* wei_ptr = reinterpret_cast<const uint8_t*>(args->B);
uint8_t* scratch = reinterpret_cast<uint8_t*>(args->scratch);
const auto* src_ptr = args->A;
const auto* wei_ptr = args->B;
auto* scratch = args->scratch;

const auto K_tail = config.get_K() % config.get_inner_K_blk();
const auto K_body = config.get_K() - K_tail;
Expand All @@ -171,12 +177,15 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c
}

if (K_tail != 0) {
uint8_t* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;
if (config.need_copy_a(K_tail)) {
auto* tr_src = scratch + BrgemmCPU::SCRATCH_BYTE_SIZE;

execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
execute_brgemm_copy_a_kernel(kernel->brgemm_copy_a_kernel, src_ptr, tr_src, config.get_M(), K_tail);
src_ptr = tr_src;
}

configure_tiles_if_needed(args->amx_tile_config, kernel->palette_tail, config.get_M(), config.get_N(), K_tail);
execute_brgemm_kernel(kernel->brgemm_kernel_k_tail, tr_src, wei_ptr, args->C, scratch, false);
execute_brgemm_kernel(kernel->brgemm_kernel_k_tail, src_ptr, wei_ptr, args->C, scratch, false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace ov {
namespace intel_cpu {

struct BrgemmAMXKernelConfig : public snippets::KernelExecutorBase::GenericConfig, public BrgemmBaseKernelConfig {
struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig {
public:
BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa);
BrgemmAMXKernelConfig() = delete;
Expand All @@ -26,17 +26,10 @@ struct BrgemmAMXKernelConfig : public snippets::KernelExecutorBase::GenericConfi
return std::unique_ptr<BrgemmAMXKernelConfig>(new BrgemmAMXKernelConfig(*this));
}

bool is_completed() const override { return BrgemmBaseKernelConfig::is_completed(); }
size_t hash() const override { return BrgemmBaseKernelConfig::hash(); }
dnnl_dim_t get_inner_K_blk() const { return m_static_params->inner_k_blk; }
dnnl_dim_t get_vnni_factor() const { return m_static_params->vnni_factor; }

void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta);

dnnl_dim_t get_inner_K_blk() const { return std::static_pointer_cast<StaticParams>(m_static_params)->inner_k_blk; }
dnnl_dim_t get_vnni_factor() const { return std::static_pointer_cast<StaticParams>(m_static_params)->vnni_factor; }

#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const override { return BrgemmBaseKernelConfig::to_string(); }
#endif
bool need_copy_a(dnnl_dim_t K) const;

private:
struct StaticParams : StaticBaseParams {
Expand All @@ -57,6 +50,10 @@ struct BrgemmAMXKernelConfig : public snippets::KernelExecutorBase::GenericConfi

const size_t m_hash {0};
};

std::shared_ptr<StaticBaseParams> get_static_params() const override { return m_static_params; }

std::shared_ptr<StaticParams> m_static_params {nullptr};
};

struct BrgemmAMXCompiledKernel {
Expand All @@ -73,10 +70,10 @@ class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor,
public CPUKernelExecutor<BrgemmAMXKernelConfig, BrgemmAMXCompiledKernel> {
public:
struct call_args {
const void* A = nullptr;
const void* B = nullptr;
const uint8_t* A = nullptr;
const uint8_t* B = nullptr;
void* C = nullptr;
void* scratch = nullptr;
uint8_t* scratch = nullptr;
amx_tile_config_t* amx_tile_config = nullptr;
};
BrgemmAMXKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmAMXKernelConfig config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ using namespace dnnl::impl::cpu::x64;
namespace ov {
namespace intel_cpu {

BrgemmBaseKernelConfig::BrgemmBaseKernelConfig(std::shared_ptr<StaticBaseParams> static_params)
: m_static_params(std::move(static_params)) {
m_hash = compute_hash();
}

bool BrgemmBaseKernelConfig::is_completed() const {
return !utils::one_of(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC) || is_empty();
}

bool BrgemmBaseKernelConfig::is_empty() const {
return everyone_is(0, m_M, m_N, m_K, m_LDA, m_LDB, m_LDC, m_beta);
}
Expand All @@ -36,7 +32,7 @@ bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const
return EQ(m_hash) && EQ(m_beta) &&
EQ(m_M) && EQ(m_N) && EQ(m_K) &&
EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) &&
(EQ(m_static_params.get()) || *m_static_params == *(rhs.m_static_params));
(EQ(get_static_params().get()) || *get_static_params() == *(rhs.get_static_params()));
#undef EQ
}

Expand All @@ -56,7 +52,7 @@ void BrgemmBaseKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dn
}

size_t BrgemmBaseKernelConfig::compute_hash() const {
size_t seed = m_static_params->hash();
size_t seed = get_static_params()->hash();
#define HASH(X) seed = hash_combine(seed, X)
HASH(m_M); HASH(m_N); HASH(m_K);
HASH(m_LDA); HASH(m_LDB); HASH(m_LDC);
Expand Down Expand Up @@ -94,7 +90,7 @@ std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const {

std::string BrgemmBaseKernelConfig::to_string() const {
std::stringstream ss;
ss << m_static_params->to_string() << "\n";
ss << get_static_params()->to_string() << "\n";
PRINT(m_M); PRINT(m_N); PRINT(m_K);
PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC);
PRINT(m_beta);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
namespace ov {
namespace intel_cpu {

struct BrgemmBaseKernelConfig {
struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConfig {
public:
virtual ~BrgemmBaseKernelConfig() = default;
BrgemmBaseKernelConfig() = default;

bool is_completed() const override;
size_t hash() const override { return m_hash; }

bool is_completed() const;
bool is_empty() const;
void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta);

bool operator==(const BrgemmBaseKernelConfig& rhs) const;
bool operator!=(const BrgemmBaseKernelConfig& rhs) const {return !(*this == rhs);}

size_t hash() const { return m_hash; }

dnnl_data_type_t get_dt_in0() const { return m_static_params->dt_in0; }
dnnl_data_type_t get_dt_in1() const { return m_static_params->dt_in1; }
dnnl_data_type_t get_dt_in0() const { return get_static_params()->dt_in0; }
dnnl_data_type_t get_dt_in1() const { return get_static_params()->dt_in1; }

dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_static_params->isa; }
dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return get_static_params()->isa; }
float get_beta() const { return m_beta; }

dnnl_dim_t get_M() const { return m_M; }
Expand All @@ -47,7 +47,7 @@ struct BrgemmBaseKernelConfig {
dnnl_dim_t get_LDC() const { return m_LDC; }

#ifdef SNIPPETS_DEBUG_CAPS
std::string to_string() const;
std::string to_string() const override;
#endif

protected:
Expand All @@ -69,12 +69,9 @@ struct BrgemmBaseKernelConfig {
virtual size_t compute_hash() const;
};

BrgemmBaseKernelConfig(std::shared_ptr<StaticBaseParams> static_params);
BrgemmBaseKernelConfig() = delete;

virtual std::shared_ptr<StaticBaseParams> get_static_params() const = 0;
size_t compute_hash() const;

std::shared_ptr<StaticBaseParams> m_static_params;
dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0};
float m_beta {0};
size_t m_hash {SIZE_MAX};
Expand Down
Loading

0 comments on commit 96842c1

Please sign in to comment.