Skip to content

Commit

Permalink
[GPU] Enable output transposed gemm for onednn
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvinchoi-intel committed Nov 28, 2024
1 parent 79493c2 commit e9a9996
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/plugins/intel_gpu/src/graph/impls/onednn/gemm_onednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,15 @@ struct gemm_onednn : typed_primitive_onednn_impl<gemm> {
if (ret) {
tag = convert_data_format(transposed_format);
dnnl::memory::dims original_dims = dims;
for (size_t i = 0; i < original_dims.size(); ++i) {
dims[i] = original_dims[order[i]];
if (is_input) {
for (size_t i = 0; i < original_dims.size(); ++i) {
dims[i] = original_dims[order[i]];
}
} else {
// Get non-transposed dims for output dims
for (size_t i = 0; i < original_dims.size(); ++i) {
dims[order[i]] = original_dims[i];
}
}
} else {
std::ostringstream ostream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,95 @@ TEST_P(TransposeMatMulFusionOnGPU, CompareWithRefs){
};

} // namespace


//=================================================================================
// Transpose + MatMul + Transpose pattern fusion (TransposeMatMulTransposeMatcher)
//=================================================================================
namespace ov {
namespace test {

using MatMulTransposeFusionParams = std::tuple<ov::PartialShape, // input A shapes
ov::PartialShape, // input B shapes
ov::PartialShape>; // input C shapes
class MatMulTransposeFusionOnGPU: public testing::WithParamInterface<MatMulTransposeFusionParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<MatMulTransposeFusionParams> obj) {
ov::PartialShape input0;
ov::PartialShape input1;
ov::PartialShape input2;

std::tie(input0, input1, input2) = obj.param;

std::ostringstream result;
result << "device=(" << std::string(utils::DEVICE_GPU) << ")_";
result << ov::test::utils::partialShape2str({input0}) << "_";
result << ov::test::utils::partialShape2str({input1}) << "_";
result << ov::test::utils::partialShape2str({input2}) << "_";
return result.str();
}
protected:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_GPU;

ov::PartialShape shape1;
ov::PartialShape shape2;
ov::PartialShape shape3;

std::tie(shape1, shape2, shape3) = GetParam();

InputShape input_shape1 = {shape1, {shape1.get_shape()}};
InputShape input_shape2 = {shape2, {shape2.get_shape()}};
InputShape input_shape3 = {shape3, {shape3.get_shape()}};
init_input_shapes({input_shape1, input_shape2, input_shape3});

const auto param1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape1);
const auto param2 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape2);
const auto param3 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, shape3);

auto input2_shape = shape2.get_shape();

//input0
const auto input0_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {1, 0, 2, 3});
const auto input0_transpose = std::make_shared<ov::op::v1::Transpose>(param1, input0_order);
const auto input0_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape);
const auto input0_reshape = std::make_shared<ov::op::v1::Reshape>(input0_transpose, input0_shape_pattern, false);

//input1
const auto input1_order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {0, 1, 3, 2});
const auto input1_transpose = std::make_shared<ov::op::v1::Transpose>(param2, input1_order);

// matmul & softmax
const auto matmul1 = std::make_shared<ov::op::v0::MatMul>(input0_reshape, input1_transpose, false, false);
const auto softmax = std::make_shared<ov::op::v8::Softmax>(matmul1, -1);

// input3
const auto input3_transpose = std::make_shared<ov::op::v1::Transpose>(param3, input0_order);
const auto input3_shape_pattern = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, input2_shape);
const auto input3_reshape = std::make_shared<ov::op::v1::Reshape>(input3_transpose, input3_shape_pattern, false);

// target matmul
const auto matmul2 = std::make_shared<ov::op::v0::MatMul>(softmax, input3_reshape, false, false);
const auto order = ov::op::v0::Constant::create(ov::element::i32, Shape{4}, {2, 0, 1, 3});
const auto transpose = std::make_shared<ov::op::v1::Transpose>(matmul2, order);

function = std::make_shared<ov::Model>(transpose, ov::ParameterVector{param1, param2, param3});
}
};


} // namespace test
} // namespace ov


namespace {
INSTANTIATE_TEST_SUITE_P(smoke_MatMulTransposeFusion, MatMulTransposeFusionOnGPU,
::testing::Values(
MatMulTransposeFusionParams({3, 8, 16, 1}, {2, 4, 3, 16}, {3, 8, 16, 1})),
MatMulTransposeFusionOnGPU::getTestCaseName);

TEST_P(MatMulTransposeFusionOnGPU, CompareWithRefs){
run();
};
} // namespace

0 comments on commit e9a9996

Please sign in to comment.