From b7267d13a8ee9559c0d201d9b4546047f0df360e Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 4 Dec 2024 02:19:38 -0800 Subject: [PATCH] [XLA:GPU] Swap dot operands in certain cases. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On H100, the `wgmma` instruction supports lhs operand in registers but not rhs operand. That means that there is any prologue for one of the sides (e.g. conversion), it's better to put it into the lhs -- as elements will be in registers after computing the prologue. Additionally, minimal lhs non-contracting dimension size (M dimension) is 64, while for rhs smaller sizes are supported. Because of that, after the Triton fusion is known, swap lhs and rhs ops in the following cases: - If rhs has prologue, and lhs doesn't, and N-size (rhs non-contracting dim) is ≥64, swap. - If M<64 and N≥64, swap. PiperOrigin-RevId: 702643867 --- xla/service/gpu/BUILD | 1 + xla/service/gpu/gpu_compiler.cc | 2 + xla/service/gpu/transforms/BUILD | 36 +++ .../transforms/gemm_fusion_swap_operands.cc | 214 +++++++++++++++++ .../transforms/gemm_fusion_swap_operands.h | 44 ++++ .../gemm_fusion_swap_operands_test.cc | 217 ++++++++++++++++++ xla/shape_util.cc | 20 ++ xla/shape_util.h | 5 + xla/shape_util_test.cc | 9 + 9 files changed, 548 insertions(+) create mode 100644 xla/service/gpu/transforms/gemm_fusion_swap_operands.cc create mode 100644 xla/service/gpu/transforms/gemm_fusion_swap_operands.h create mode 100644 xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c306a207e8c25..6546dcdb5190c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1544,6 +1544,7 @@ cc_library( "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter", "//xla/service/gpu/transforms:gemm_fusion", + "//xla/service/gpu/transforms:gemm_fusion_swap_operands", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:gemv_rewriter", "//xla/service/gpu/transforms:layout_assignment", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 9e1d33bbb09f8..4444096fc83cc 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -203,6 +203,7 @@ limitations under the License. #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" #include "xla/service/gpu/transforms/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/gemv_rewriter.h" #include "xla/service/gpu/transforms/layout_assignment.h" @@ -1545,6 +1546,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( rocm_cc != nullptr)) { pipeline.AddPass(); pipeline.AddPass(gpu_version); + pipeline.AddPass(); } else if (cuda_cc != nullptr && cuda_cc->major == se::CudaComputeCapability::VOLTA) { // Greedy pattern matching for custom kernel fusions. diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index a71dc15127748..c0231d5287526 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1738,6 +1738,42 @@ cc_library( ], ) +cc_library( + name = "gemm_fusion_swap_operands", + srcs = ["gemm_fusion_swap_operands.cc"], + hdrs = ["gemm_fusion_swap_operands.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu/fusions/triton:triton_support", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemm_fusion_swap_operands_test", + srcs = ["gemm_fusion_swap_operands_test.cc"], + deps = [ + ":gemm_fusion_swap_operands", + "//xla/hlo/testlib:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "gemm_fusion_test", srcs = ["gemm_fusion_test.cc"], diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc b/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc new file mode 100644 index 0000000000000..1423ca1c78816 --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc @@ -0,0 +1,214 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// Swaps operands of a dot instruction, while keeping the same physical output +// layout. Logically, dot output shape has lhs non-contracting dimensions +// followed by rhs non-contracting dimensions that have to be swapped, but +// that's countered by the layout. +HloDotInstruction* MakeDotWithSwappedOperands(HloInstruction* dot) { + HloComputation* computation = dot->parent(); + + const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers(); + const size_t num_batch_dims = dot_dims.lhs_batch_dimensions_size(); + const size_t num_lhs_noncontracting_dims = + dot->operand(0)->shape().rank() - num_batch_dims - + dot_dims.lhs_contracting_dimensions_size(); + const size_t num_rhs_noncontracting_dims = + dot->operand(1)->shape().rank() - num_batch_dims - + dot_dims.rhs_contracting_dimensions_size(); + + std::vector out_shape_permutation; + out_shape_permutation.reserve(dot->shape().rank()); + auto fill_permutation = [&](int64_t count, int64_t start) { + while (count--) out_shape_permutation.push_back(start++); + }; + // The output shape of a dot is batch dimensions, then lhs non-contracting, + // then rhs non-contracting. Batch dimensions stay where they were. and + // contracting dimensions of lhs and rhs swapped. + fill_permutation(num_batch_dims, 0); + fill_permutation(num_rhs_noncontracting_dims, + num_batch_dims + num_lhs_noncontracting_dims); + fill_permutation(num_lhs_noncontracting_dims, num_batch_dims); + const Shape new_dot_shape = + ShapeUtil::ReorderLogicalDimensions(dot->shape(), out_shape_permutation); + + DotDimensionNumbers new_dot_dims = dot_dims; + std::swap(*new_dot_dims.mutable_lhs_batch_dimensions(), + *new_dot_dims.mutable_rhs_batch_dimensions()); + std::swap(*new_dot_dims.mutable_lhs_contracting_dimensions(), + *new_dot_dims.mutable_rhs_contracting_dimensions()); + + return DynCast(computation->AddInstruction( + HloInstruction::CreateDot(new_dot_shape, dot->mutable_operand(1), + dot->mutable_operand(0), new_dot_dims, + dot->precision_config()), + &dot->metadata())); +} + +// Swaps operands of a dot instruction in a fusion. This is done by swapping the +// operands of the dot instruction, which keeps the same physical output layout, +// and then bitcasting the result to the original logical shape. +absl::Status SwapDotOperandsInFusion(HloComputation* computation) { + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + HloDotInstruction* new_dot = MakeDotWithSwappedOperands(dot); + HloInstruction* new_bitcast = computation->AddInstruction( + HloInstruction::CreateBitcast(dot->shape(), new_dot), &dot->metadata()); + TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWith(new_bitcast)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(dot)); + return absl::OkStatus(); +} + +bool HasCodeGeneratingInstructions(const HloInstruction* instruction) { + while (!instruction->operands().empty()) { + // Skip instruction that are likely to just affect the address computation + // rather than result in actual computation. + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kGetTupleElement: + case HloOpcode::kParameter: + case HloOpcode::kReshape: + case HloOpcode::kTranspose: + break; + default: + // Any other instruction is considered code generating. + return true; + } + instruction = instruction->operand(0); + } + return false; +} + +absl::StatusOr GetNonContractingDimsNumElements( + const HloInstruction* dot, size_t operand_index) { + const Shape& shape = dot->operand(operand_index)->shape(); + const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers(); + const absl::Span batch_dim_indices = + operand_index == 0 ? dot_dims.lhs_batch_dimensions() + : dot_dims.rhs_batch_dimensions(); + const absl::Span contracting_dim_indices = + operand_index == 0 ? dot_dims.lhs_contracting_dimensions() + : dot_dims.rhs_contracting_dimensions(); + const DimensionVector noncontracting_dim_indices = GetNonContractingDims( + shape.rank(), batch_dim_indices, contracting_dim_indices); + return absl::c_accumulate( + noncontracting_dim_indices, int64_t{1}, + [&](int64_t acc, int64_t dim) { return acc * shape.dimensions(dim); }); +} + +// There are two reasons to swap operands: +// 1. If one side performs computation and the other doesn't, we want the +// "computing" side to be the lhs. wgmma supports lhs in registers, and +// computation would happen in registers too, so putting it to lhs avoids an +// extra roundtrip to a shared memory. +// 2. wgmma instruction only supports 64 for the M (lhs non-contracting) +// dimension, so if it's smaller, move it to the rhs that supports smaller +// powers of two. +absl::StatusOr ShouldSwapOperands(const HloInstruction* instr) { + const HloDotInstruction* dot = DynCast(instr); + if (dot == nullptr) return false; + // Sparsity is generally not symmetric, so we cannot swap operands. + if (dot->sparse_operands()) return false; + const bool lhs_has_code = HasCodeGeneratingInstructions(dot->operand(0)); + const bool rhs_has_code = HasCodeGeneratingInstructions(dot->operand(1)); + TF_ASSIGN_OR_RETURN(const int64_t lhs_size, GetNonContractingDimsNumElements( + dot, /*operand_index=*/0)); + TF_ASSIGN_OR_RETURN(const int64_t rhs_size, GetNonContractingDimsNumElements( + dot, /*operand_index=*/1)); + if (lhs_size < 64 && rhs_size >= 64) return true; + if (!lhs_has_code && rhs_has_code && rhs_size >= 64) return true; + return false; +} + +// Triton emitter is not fully symmetric, so it's not possible to emit all +// fusions with swapped dot operands. This function checks if the emitter could +// handle such a fusion. +absl::StatusOr EmitterCanHandleSwappedOperands( + const HloInstruction* dot) { + auto tmp_module = HloModule("tmp", dot->parent()->parent()->config()); + HloCloneContext clone_context(&tmp_module); + std::unique_ptr cloned_computation = + dot->parent()->CloneInContext(clone_context); + TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation.get())); + return TritonFusionAnalysis::Execute(*cloned_computation).ok(); +} + +absl::StatusOr MaybeSwapOperands(HloComputation* computation) { + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + if (dot == nullptr) return false; + TF_ASSIGN_OR_RETURN(const bool should_swap_operands, ShouldSwapOperands(dot)); + if (!should_swap_operands) return false; + TF_ASSIGN_OR_RETURN(const bool can_handle_swapped_operands, + EmitterCanHandleSwappedOperands(dot)); + if (!can_handle_swapped_operands) return false; + TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(computation)); + return true; +} + +} // namespace + +absl::StatusOr GemmFusionSwapOperands::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool any_changed = false; + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { + if (!IsTritonFusedComputation(*computation)) continue; + TF_ASSIGN_OR_RETURN(const bool changed, MaybeSwapOperands(computation)); + any_changed |= changed; + } + return any_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands.h b/xla/service/gpu/transforms/gemm_fusion_swap_operands.h new file mode 100644 index 0000000000000..1eeedef74063e --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class GemmFusionSwapOperands : public HloModulePass { + public: + absl::string_view name() const override { + return "gemm-fusion-swap-operands"; + } + + public: + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc b/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc new file mode 100644 index 0000000000000..9306513885ddf --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc @@ -0,0 +1,217 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" + +#include +#include "xla/hlo/testlib/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class SwapOperandsTest : public HloTestBase {}; + +TEST_F(SwapOperandsTest, CodeGeneratingMovesToLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + + p1 = s4[64,448,320]{2,1,0} parameter(1) + p1.c = bf16[64,448,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,768,448]{2,1,0} dot(p0, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,768,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,768]{1,2,0} dot +CHECK-NEXT: bf16[64,768,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, CodeGeneratingMovesToLhsMultipleNoncontracting) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[768,96,320]{2,1,0} parameter(0) + p0.r = bf16[73728,320]{1,0} reshape(p0) + + p1 = s4[448,320]{1,0} parameter(1) + p1.c = bf16[448,320]{1,0} convert(p1) + + dot = bf16[73728,448]{1,0} dot(p0.r, p1.c), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + + ROOT res = bf16[768,96,448]{2,1,0} bitcast(dot) +} + +ENTRY main { + p0 = bf16[768,96,320]{2,1,0} parameter(0) + p1 = s4[448,320]{1,0} parameter(1) + ROOT fusion = bf16[768,96,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[448,73728]{0,1} dot +CHECK-NEXT: bf16[73728,448]{1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, SplitNoncontractingIsKeptInLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[768,320,96]{2,1,0} parameter(0) + + p1 = s4[448,320]{1,0} parameter(1) + p1.c = bf16[448,320]{1,0} convert(p1) + + ROOT dot = bf16[768,96,448]{2,1,0} dot(p0, p1.c), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} +} + +ENTRY main { + p0 = bf16[768,320,96]{2,1,0} parameter(0) + p1 = s4[448,320]{1,0} parameter(1) + ROOT fusion = bf16[768,96,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +TEST_F(SwapOperandsTest, DoNotSwapSmallRhsNoncontracting) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + + p1 = s4[64,32,320]{2,1,0} parameter(1) + p1.c = bf16[64,32,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,768,32]{2,1,0} dot(p0, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + p1 = s4[64,32,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,768,32]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +TEST_F(SwapOperandsTest, BothNonCodeGeneratingSwapSmallLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,448,320]{2,1,0} parameter(1) + + ROOT dot = bf16[64,32,448]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,32]{1,2,0} dot +CHECK-NEXT: bf16[64,32,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, BothCodeGeneratingSwapSmallLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = s4[64,32,320]{2,1,0} parameter(0) + p0.c = bf16[64,32,320]{2,1,0} convert(p0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + p1.c = bf16[64,448,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,32,448]{2,1,0} dot(p0.c, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = s4[64,32,320]{2,1,0} parameter(0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,32]{1,2,0} dot +CHECK-NEXT: bf16[64,32,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, BothNonCodeGeneratingDoNotSwapIfBothSmall) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,48,320]{2,1,0} parameter(1) + + ROOT dot = bf16[64,32,48]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,48,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,48]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 5e99565cc15c5..523b0d460b168 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -1757,6 +1757,26 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, return output_shape_with_layout; } +/* static */ Shape ShapeUtil::ReorderLogicalDimensions( + const Shape& shape, absl::Span permutation) { + CHECK(shape.IsArray()); + const std::vector dynamic_dimensions = + Permute(shape.dynamic_dimensions(), permutation); + + Shape new_shape(shape.element_type(), + Permute(shape.dimensions(), permutation), + absl::InlinedVector(dynamic_dimensions.begin(), + dynamic_dimensions.end()), + {}); + if (shape.has_layout()) { + *new_shape.mutable_layout() = shape.layout(); + for (int64_t& dim : *new_shape.mutable_layout()->mutable_minor_to_major()) { + dim = permutation[dim]; + } + } + return new_shape; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete, Shape shape) { CHECK(shape.IsArray()); diff --git a/xla/shape_util.h b/xla/shape_util.h index 73d245d80f072..76a0217484165 100644 --- a/xla/shape_util.h +++ b/xla/shape_util.h @@ -883,6 +883,11 @@ class ShapeUtil { static std::optional AlignLayouts(const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given logical dimensions reordered, updating the + // layout so that physical dimensions are preserved. + static Shape ReorderLogicalDimensions(const Shape& shape, + absl::Span permutation); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` diff --git a/xla/shape_util_test.cc b/xla/shape_util_test.cc index 629fde1a2a91d..71a0c2cf5ff69 100644 --- a/xla/shape_util_test.cc +++ b/xla/shape_util_test.cc @@ -1415,6 +1415,15 @@ TEST(ShapeUtilTest, DecomposeBitcastToTrt) { EXPECT_FALSE(decomposition_trt.IsTranspose2Identity()); } +TEST(ShapeUtilTest, ReorderDimensionsTest) { + EXPECT_EQ(ShapeUtil::ReorderLogicalDimensions( + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 3, 12, 17}, + {1, 2, 0, 3}), + {0, 2, 1, 3}) + .ToString(true), + "f32[16,12,3,17]{2,1,0,3}"); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithDenseLayout(F32, {3, 2, 2}, {0, 1, 2}),