Skip to content

Commit

Permalink
[WIP] Change splitM heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 18, 2024
1 parent 7e04427 commit 1b22709
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
30 changes: 15 additions & 15 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>&
std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
splited.first = lower_bound;
splited.second = m_dim / lower_bound;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}

const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
splited.first = divisor_0;
splited.second = divisor_1;
break;
// TODO: should we limit minimal kernel_m?
const size_t min_kernel_m = 4;
// Strategy 1: Find a combination such that (batch_dim * splited.first) % optimal_parallelism_work_amount == 0
for (size_t divisor = 1; divisor <= m_dim; ++divisor) {
if (m_dim % divisor == 0) {
const auto m_batch = divisor;
const auto m_kernel = m_dim / divisor;
if (m_kernel < min_kernel_m)
break;
splited = { m_batch, m_kernel };
if ((batch_dim * splited.first) % optimal_parallelism_work_amount == 0) {
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
}
}

OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
Expand Down
15 changes: 7 additions & 8 deletions src/common/snippets/tests/src/utils/split_dim_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ TEST_P(SplitDimensionMTest, SplitDimensionM) {
namespace SplitDimensionMInstantiation {
const std::vector<SplitDimensionMParams> split_dimension_cases = {
// Negative test cases: split is not needed
{InputData{40 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{false /*is_split*/}},
{InputData{65, 32, 40}, ReferenceData{false}},
{InputData{32 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{false /*is_split*/}},
{InputData{50, 32, 32}, ReferenceData{false}},

// Positive test cases
{InputData{20 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{true /*is_split*/, 2 /*batch_m*/, 16 /*kernel_m*/}},
{InputData{30, 60, 40}, ReferenceData{true, 2, 30}},
{InputData{10, 100, 40}, ReferenceData{true, 4, 25}},
{InputData{15, 45, 40}, ReferenceData{true, 5, 9}},
{InputData{25, 50, 40}, ReferenceData{true, 2, 25}},
{InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}},
{InputData{20 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{true /*is_split*/, 8 /*batch_m*/, 4 /*kernel_m*/}},
{InputData{16, 60, 32}, ReferenceData{true, 2, 30}},
{InputData{10, 100, 32}, ReferenceData{true, 25, 4}},
{InputData{25, 50, 32}, ReferenceData{true, 10, 5}},
{InputData{5, 16384, 32}, ReferenceData{true, 32, 512}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,
Expand Down

0 comments on commit 1b22709

Please sign in to comment.