diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index c306a207e8c25..6546dcdb5190c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1544,6 +1544,7 @@ cc_library( "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter", "//xla/service/gpu/transforms:gemm_fusion", + "//xla/service/gpu/transforms:gemm_fusion_swap_operands", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:gemv_rewriter", "//xla/service/gpu/transforms:layout_assignment", diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 9e1d33bbb09f8..4444096fc83cc 100644 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -203,6 +203,7 @@ limitations under the License. #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" #include "xla/service/gpu/transforms/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/gemv_rewriter.h" #include "xla/service/gpu/transforms/layout_assignment.h" @@ -1545,6 +1546,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( rocm_cc != nullptr)) { pipeline.AddPass(); pipeline.AddPass(gpu_version); + pipeline.AddPass(); } else if (cuda_cc != nullptr && cuda_cc->major == se::CudaComputeCapability::VOLTA) { // Greedy pattern matching for custom kernel fusions. diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index a71dc15127748..c0231d5287526 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1738,6 +1738,42 @@ cc_library( ], ) +cc_library( + name = "gemm_fusion_swap_operands", + srcs = ["gemm_fusion_swap_operands.cc"], + hdrs = ["gemm_fusion_swap_operands.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu/fusions/triton:triton_support", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemm_fusion_swap_operands_test", + srcs = ["gemm_fusion_swap_operands_test.cc"], + deps = [ + ":gemm_fusion_swap_operands", + "//xla/hlo/testlib:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + "@tsl//tsl/platform:test", + ], +) + xla_cc_test( name = "gemm_fusion_test", srcs = ["gemm_fusion_test.cc"], diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc b/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc new file mode 100644 index 0000000000000..1423ca1c78816 --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands.cc @@ -0,0 +1,214 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +// Swaps operands of a dot instruction, while keeping the same physical output +// layout. Logically, dot output shape has lhs non-contracting dimensions +// followed by rhs non-contracting dimensions that have to be swapped, but +// that's countered by the layout. +HloDotInstruction* MakeDotWithSwappedOperands(HloInstruction* dot) { + HloComputation* computation = dot->parent(); + + const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers(); + const size_t num_batch_dims = dot_dims.lhs_batch_dimensions_size(); + const size_t num_lhs_noncontracting_dims = + dot->operand(0)->shape().rank() - num_batch_dims - + dot_dims.lhs_contracting_dimensions_size(); + const size_t num_rhs_noncontracting_dims = + dot->operand(1)->shape().rank() - num_batch_dims - + dot_dims.rhs_contracting_dimensions_size(); + + std::vector out_shape_permutation; + out_shape_permutation.reserve(dot->shape().rank()); + auto fill_permutation = [&](int64_t count, int64_t start) { + while (count--) out_shape_permutation.push_back(start++); + }; + // The output shape of a dot is batch dimensions, then lhs non-contracting, + // then rhs non-contracting. Batch dimensions stay where they were. and + // contracting dimensions of lhs and rhs swapped. + fill_permutation(num_batch_dims, 0); + fill_permutation(num_rhs_noncontracting_dims, + num_batch_dims + num_lhs_noncontracting_dims); + fill_permutation(num_lhs_noncontracting_dims, num_batch_dims); + const Shape new_dot_shape = + ShapeUtil::ReorderLogicalDimensions(dot->shape(), out_shape_permutation); + + DotDimensionNumbers new_dot_dims = dot_dims; + std::swap(*new_dot_dims.mutable_lhs_batch_dimensions(), + *new_dot_dims.mutable_rhs_batch_dimensions()); + std::swap(*new_dot_dims.mutable_lhs_contracting_dimensions(), + *new_dot_dims.mutable_rhs_contracting_dimensions()); + + return DynCast(computation->AddInstruction( + HloInstruction::CreateDot(new_dot_shape, dot->mutable_operand(1), + dot->mutable_operand(0), new_dot_dims, + dot->precision_config()), + &dot->metadata())); +} + +// Swaps operands of a dot instruction in a fusion. This is done by swapping the +// operands of the dot instruction, which keeps the same physical output layout, +// and then bitcasting the result to the original logical shape. +absl::Status SwapDotOperandsInFusion(HloComputation* computation) { + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + HloDotInstruction* new_dot = MakeDotWithSwappedOperands(dot); + HloInstruction* new_bitcast = computation->AddInstruction( + HloInstruction::CreateBitcast(dot->shape(), new_dot), &dot->metadata()); + TF_RETURN_IF_ERROR(dot->ReplaceAllUsesWith(new_bitcast)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(dot)); + return absl::OkStatus(); +} + +bool HasCodeGeneratingInstructions(const HloInstruction* instruction) { + while (!instruction->operands().empty()) { + // Skip instruction that are likely to just affect the address computation + // rather than result in actual computation. + switch (instruction->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kGetTupleElement: + case HloOpcode::kParameter: + case HloOpcode::kReshape: + case HloOpcode::kTranspose: + break; + default: + // Any other instruction is considered code generating. + return true; + } + instruction = instruction->operand(0); + } + return false; +} + +absl::StatusOr GetNonContractingDimsNumElements( + const HloInstruction* dot, size_t operand_index) { + const Shape& shape = dot->operand(operand_index)->shape(); + const DotDimensionNumbers& dot_dims = dot->dot_dimension_numbers(); + const absl::Span batch_dim_indices = + operand_index == 0 ? dot_dims.lhs_batch_dimensions() + : dot_dims.rhs_batch_dimensions(); + const absl::Span contracting_dim_indices = + operand_index == 0 ? dot_dims.lhs_contracting_dimensions() + : dot_dims.rhs_contracting_dimensions(); + const DimensionVector noncontracting_dim_indices = GetNonContractingDims( + shape.rank(), batch_dim_indices, contracting_dim_indices); + return absl::c_accumulate( + noncontracting_dim_indices, int64_t{1}, + [&](int64_t acc, int64_t dim) { return acc * shape.dimensions(dim); }); +} + +// There are two reasons to swap operands: +// 1. If one side performs computation and the other doesn't, we want the +// "computing" side to be the lhs. wgmma supports lhs in registers, and +// computation would happen in registers too, so putting it to lhs avoids an +// extra roundtrip to a shared memory. +// 2. wgmma instruction only supports 64 for the M (lhs non-contracting) +// dimension, so if it's smaller, move it to the rhs that supports smaller +// powers of two. +absl::StatusOr ShouldSwapOperands(const HloInstruction* instr) { + const HloDotInstruction* dot = DynCast(instr); + if (dot == nullptr) return false; + // Sparsity is generally not symmetric, so we cannot swap operands. + if (dot->sparse_operands()) return false; + const bool lhs_has_code = HasCodeGeneratingInstructions(dot->operand(0)); + const bool rhs_has_code = HasCodeGeneratingInstructions(dot->operand(1)); + TF_ASSIGN_OR_RETURN(const int64_t lhs_size, GetNonContractingDimsNumElements( + dot, /*operand_index=*/0)); + TF_ASSIGN_OR_RETURN(const int64_t rhs_size, GetNonContractingDimsNumElements( + dot, /*operand_index=*/1)); + if (lhs_size < 64 && rhs_size >= 64) return true; + if (!lhs_has_code && rhs_has_code && rhs_size >= 64) return true; + return false; +} + +// Triton emitter is not fully symmetric, so it's not possible to emit all +// fusions with swapped dot operands. This function checks if the emitter could +// handle such a fusion. +absl::StatusOr EmitterCanHandleSwappedOperands( + const HloInstruction* dot) { + auto tmp_module = HloModule("tmp", dot->parent()->parent()->config()); + HloCloneContext clone_context(&tmp_module); + std::unique_ptr cloned_computation = + dot->parent()->CloneInContext(clone_context); + TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(cloned_computation.get())); + return TritonFusionAnalysis::Execute(*cloned_computation).ok(); +} + +absl::StatusOr MaybeSwapOperands(HloComputation* computation) { + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + if (dot == nullptr) return false; + TF_ASSIGN_OR_RETURN(const bool should_swap_operands, ShouldSwapOperands(dot)); + if (!should_swap_operands) return false; + TF_ASSIGN_OR_RETURN(const bool can_handle_swapped_operands, + EmitterCanHandleSwappedOperands(dot)); + if (!can_handle_swapped_operands) return false; + TF_RETURN_IF_ERROR(SwapDotOperandsInFusion(computation)); + return true; +} + +} // namespace + +absl::StatusOr GemmFusionSwapOperands::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool any_changed = false; + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { + if (!IsTritonFusedComputation(*computation)) continue; + TF_ASSIGN_OR_RETURN(const bool changed, MaybeSwapOperands(computation)); + any_changed |= changed; + } + return any_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands.h b/xla/service/gpu/transforms/gemm_fusion_swap_operands.h new file mode 100644 index 0000000000000..1eeedef74063e --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +class GemmFusionSwapOperands : public HloModulePass { + public: + absl::string_view name() const override { + return "gemm-fusion-swap-operands"; + } + + public: + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_SWAP_OPERANDS_H_ diff --git a/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc b/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc new file mode 100644 index 0000000000000..9306513885ddf --- /dev/null +++ b/xla/service/gpu/transforms/gemm_fusion_swap_operands_test.cc @@ -0,0 +1,217 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/gemm_fusion_swap_operands.h" + +#include +#include "xla/hlo/testlib/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class SwapOperandsTest : public HloTestBase {}; + +TEST_F(SwapOperandsTest, CodeGeneratingMovesToLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + + p1 = s4[64,448,320]{2,1,0} parameter(1) + p1.c = bf16[64,448,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,768,448]{2,1,0} dot(p0, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,768,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,768]{1,2,0} dot +CHECK-NEXT: bf16[64,768,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, CodeGeneratingMovesToLhsMultipleNoncontracting) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[768,96,320]{2,1,0} parameter(0) + p0.r = bf16[73728,320]{1,0} reshape(p0) + + p1 = s4[448,320]{1,0} parameter(1) + p1.c = bf16[448,320]{1,0} convert(p1) + + dot = bf16[73728,448]{1,0} dot(p0.r, p1.c), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + + ROOT res = bf16[768,96,448]{2,1,0} bitcast(dot) +} + +ENTRY main { + p0 = bf16[768,96,320]{2,1,0} parameter(0) + p1 = s4[448,320]{1,0} parameter(1) + ROOT fusion = bf16[768,96,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[448,73728]{0,1} dot +CHECK-NEXT: bf16[73728,448]{1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, SplitNoncontractingIsKeptInLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[768,320,96]{2,1,0} parameter(0) + + p1 = s4[448,320]{1,0} parameter(1) + p1.c = bf16[448,320]{1,0} convert(p1) + + ROOT dot = bf16[768,96,448]{2,1,0} dot(p0, p1.c), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} +} + +ENTRY main { + p0 = bf16[768,320,96]{2,1,0} parameter(0) + p1 = s4[448,320]{1,0} parameter(1) + ROOT fusion = bf16[768,96,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +TEST_F(SwapOperandsTest, DoNotSwapSmallRhsNoncontracting) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + + p1 = s4[64,32,320]{2,1,0} parameter(1) + p1.c = bf16[64,32,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,768,32]{2,1,0} dot(p0, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,768,320]{2,1,0} parameter(0) + p1 = s4[64,32,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,768,32]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +TEST_F(SwapOperandsTest, BothNonCodeGeneratingSwapSmallLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,448,320]{2,1,0} parameter(1) + + ROOT dot = bf16[64,32,448]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,32]{1,2,0} dot +CHECK-NEXT: bf16[64,32,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, BothCodeGeneratingSwapSmallLhs) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = s4[64,32,320]{2,1,0} parameter(0) + p0.c = bf16[64,32,320]{2,1,0} convert(p0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + p1.c = bf16[64,448,320]{2,1,0} convert(p1) + + ROOT dot = bf16[64,32,448]{2,1,0} dot(p0.c, p1.c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = s4[64,32,320]{2,1,0} parameter(0) + p1 = s4[64,448,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,448]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_TRUE(GemmFusionSwapOperands().Run(module->get()).value()); + EXPECT_TRUE(*RunFileCheck(module->get()->ToString(), R"( +CHECK: bf16[64,448,32]{1,2,0} dot +CHECK-NEXT: bf16[64,32,448]{2,1,0} bitcast)")); +} + +TEST_F(SwapOperandsTest, BothNonCodeGeneratingDoNotSwapIfBothSmall) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule DotLayout + +fcomp { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,48,320]{2,1,0} parameter(1) + + ROOT dot = bf16[64,32,48]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + +ENTRY main { + p0 = bf16[64,32,320]{2,1,0} parameter(0) + p1 = bf16[64,48,320]{2,1,0} parameter(1) + ROOT fusion = bf16[64,32,48]{2,1,0} fusion(p0, p1), + kind=kCustom, calls=fcomp, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} +})"); + EXPECT_FALSE(GemmFusionSwapOperands().Run(module->get()).value()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/xla/shape_util.cc b/xla/shape_util.cc index 5e99565cc15c5..523b0d460b168 100644 --- a/xla/shape_util.cc +++ b/xla/shape_util.cc @@ -1757,6 +1757,26 @@ ShapeUtil::DecomposeBitcastToTrt(const Shape& input_shape, return output_shape_with_layout; } +/* static */ Shape ShapeUtil::ReorderLogicalDimensions( + const Shape& shape, absl::Span permutation) { + CHECK(shape.IsArray()); + const std::vector dynamic_dimensions = + Permute(shape.dynamic_dimensions(), permutation); + + Shape new_shape(shape.element_type(), + Permute(shape.dimensions(), permutation), + absl::InlinedVector(dynamic_dimensions.begin(), + dynamic_dimensions.end()), + {}); + if (shape.has_layout()) { + *new_shape.mutable_layout() = shape.layout(); + for (int64_t& dim : *new_shape.mutable_layout()->mutable_minor_to_major()) { + dim = permutation[dim]; + } + } + return new_shape; +} + /* static */ Shape ShapeUtil::DeleteDimension(int64_t dim_to_delete, Shape shape) { CHECK(shape.IsArray()); diff --git a/xla/shape_util.h b/xla/shape_util.h index 73d245d80f072..76a0217484165 100644 --- a/xla/shape_util.h +++ b/xla/shape_util.h @@ -883,6 +883,11 @@ class ShapeUtil { static std::optional AlignLayouts(const Shape& input_shape, const Shape& output_shape); + // Returns a shape with the given logical dimensions reordered, updating the + // layout so that physical dimensions are preserved. + static Shape ReorderLogicalDimensions(const Shape& shape, + absl::Span permutation); + // Returns a shape with the given dimension deleted. // For example: // • `DeleteDimension(1, T[m, n, k]) = T[m, k]` diff --git a/xla/shape_util_test.cc b/xla/shape_util_test.cc index 629fde1a2a91d..71a0c2cf5ff69 100644 --- a/xla/shape_util_test.cc +++ b/xla/shape_util_test.cc @@ -1415,6 +1415,15 @@ TEST(ShapeUtilTest, DecomposeBitcastToTrt) { EXPECT_FALSE(decomposition_trt.IsTranspose2Identity()); } +TEST(ShapeUtilTest, ReorderDimensionsTest) { + EXPECT_EQ(ShapeUtil::ReorderLogicalDimensions( + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 3, 12, 17}, + {1, 2, 0, 3}), + {0, 2, 1, 3}) + .ToString(true), + "f32[16,12,3,17]{2,1,0,3}"); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithDenseLayout(F32, {3, 2, 2}, {0, 1, 2}),