diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index 0f50ad27931e04..a263fb8de0a87a 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -34,23 +34,23 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& std::pair SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { std::pair splited = { 1, m_dim }; - const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; - if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) { - splited.first = lower_bound; - splited.second = m_dim / lower_bound; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } - - const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); - for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { - size_t divisor_1 = m_dim / divisor_0; - if (divisor_1 * divisor_0 == m_dim) { - splited.first = divisor_0; - splited.second = divisor_1; - break; + // TODO: should we limit minimal kernel_m? + const size_t min_kernel_m = 4; + // Strategy 1: Find a combination such that (batch_dim * splited.first) % optimal_parallelism_work_amount == 0 + for (size_t divisor = 1; divisor <= m_dim; ++divisor) { + if (m_dim % divisor == 0) { + const auto m_batch = divisor; + const auto m_kernel = m_dim / divisor; + if (m_kernel < min_kernel_m) + break; + splited = { m_batch, m_kernel }; + if ((batch_dim * splited.first) % optimal_parallelism_work_amount == 0) { + OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); + return splited; + } } } + OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); return splited; } diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp index 69a04da6f1263f..db574a38f54685 100644 --- a/src/common/snippets/tests/src/utils/split_dim_m.cpp +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -48,16 +48,15 @@ TEST_P(SplitDimensionMTest, SplitDimensionM) { namespace SplitDimensionMInstantiation { const std::vector split_dimension_cases = { // Negative test cases: split is not needed - {InputData{40 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{false /*is_split*/}}, - {InputData{65, 32, 40}, ReferenceData{false}}, + {InputData{32 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{false /*is_split*/}}, + {InputData{50, 32, 32}, ReferenceData{false}}, // Positive test cases - {InputData{20 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{true /*is_split*/, 2 /*batch_m*/, 16 /*kernel_m*/}}, - {InputData{30, 60, 40}, ReferenceData{true, 2, 30}}, - {InputData{10, 100, 40}, ReferenceData{true, 4, 25}}, - {InputData{15, 45, 40}, ReferenceData{true, 5, 9}}, - {InputData{25, 50, 40}, ReferenceData{true, 2, 25}}, - {InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}}, + {InputData{20 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{true /*is_split*/, 8 /*batch_m*/, 4 /*kernel_m*/}}, + {InputData{16, 60, 32}, ReferenceData{true, 2, 30}}, + {InputData{10, 100, 32}, ReferenceData{true, 25, 4}}, + {InputData{25, 50, 32}, ReferenceData{true, 10, 5}}, + {InputData{5, 16384, 32}, ReferenceData{true, 32, 512}}, }; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,