Skip to content

Commit

Permalink
[XLA:GPU] Remove handling of copies as transposes.
Browse files Browse the repository at this point in the history
Layout normalization turns copies that change the layout into transposes.
Initially, this pass was behind a flag, but by now the flag has been removed
and it is always run. Passes that run after layout normalization don't create
additional copies with layout changes (checked manually). Therefore, the
transpose emitter does not need to handle copies with layout changes.
Update the tests that used copies instead of transposes to reflect reality.

PiperOrigin-RevId: 666280709
  • Loading branch information
akuegel authored and copybara-github committed Aug 22, 2024
1 parent c8a26f1 commit 6fabf59
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 440 deletions.
91 changes: 15 additions & 76 deletions xla/service/gpu/ir_emission_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,72 +540,23 @@ absl::StatusOr<bool> CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
return true;
}

static std::optional<TransposeDescription> FindTiledTranspose(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kCopy) {
return std::nullopt;
}

absl::InlinedVector<int64_t, 3> permutation;
auto tr = ShapeUtil::GetNormalizedTransposeShape(instr.operand(0)->shape(),
instr.shape(), permutation);
if (!tr.has_value()) {
return std::nullopt;
}
if (permutation == absl::InlinedVector<int64_t, 3>{0, 2, 1}) {
if ((tr->at(1) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(1) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{
&instr, *tr,
/*permutation=*/absl::InlinedVector<int64_t, 3>{0, 2, 1}};
}
} else if (permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}) {
if ((tr->at(0) >= kMinDimensionToTransposeTiled &&
tr->at(2) >= kMinDimensionToTransposeTiled) ||
(tr->at(0) >= kMinDimensionToTransposeTiled2 &&
tr->at(2) >= kMinDimensionToTransposeTiled2 &&
tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{
&instr, *tr,
/*permutation=*/absl::InlinedVector<int64_t, 3>{2, 1, 0}};
}
} else if (IsMlirTransposeEmitterEnabled(instr)) {
if (permutation == absl::InlinedVector<int64_t, 3>{1, 0, 2}) {
auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension &&
byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >=
kMinDimensionToTransposeTiled) {
return TransposeDescription{
&instr, *tr,
/*permutation=*/absl::InlinedVector<int64_t, 3>{1, 0, 2}};
}
}
}
return std::nullopt;
}

// Find 021, 210 or 102 transpose in logical + physical transposition.
static std::optional<TransposeDescription> FindTiledLogicalTranspose(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kTranspose) {
std::optional<TransposeDescription> GetDescriptionForTiledTransposeEmitter(
const HloInstruction& hero) {
if (hero.opcode() != HloOpcode::kTranspose) {
return std::nullopt;
}

// We can assume that TransposeDimensionGrouper pass has run, so no need to
// call GetNormalizedLogicalTransposeShape here.
absl::InlinedVector<int64_t, 3> permutation(instr.dimensions().begin(),
instr.dimensions().end());
absl::InlinedVector<int64_t, 3> permutation(hero.dimensions().begin(),
hero.dimensions().end());
// A real transpose needs at least 2 transpose dimensions.
if (permutation.size() < 2) {
return std::nullopt;
}
absl::InlinedVector<int64_t, 3> dimensions(instr.shape().dimensions().begin(),
instr.shape().dimensions().end());
int64_t operand_most_minor_dim =
instr.operand(0)->shape().dimensions().back();
absl::InlinedVector<int64_t, 3> dimensions(hero.shape().dimensions().begin(),
hero.shape().dimensions().end());
int64_t operand_most_minor_dim = hero.operand(0)->shape().dimensions().back();
if (permutation == absl::InlinedVector<int64_t, 3>{0, 2, 1} ||
permutation == absl::InlinedVector<int64_t, 3>{2, 1, 0}) {
if ((dimensions.back() >= kMinDimensionToTransposeTiled &&
Expand All @@ -614,43 +565,32 @@ static std::optional<TransposeDescription> FindTiledLogicalTranspose(
operand_most_minor_dim >= kMinDimensionToTransposeTiled2 &&
dimensions.back() * operand_most_minor_dim >=
kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{&instr, dimensions, permutation};
return TransposeDescription{&hero, dimensions, permutation};
}
} else if (IsMlirTransposeEmitterEnabled(instr)) {
} else if (IsMlirTransposeEmitterEnabled(hero)) {
if (permutation.back() == dimensions.size() - 1) {
operand_most_minor_dim =
instr.operand(0)->shape().dimensions(dimensions.size() - 2);
auto byte_width = primitive_util::ByteWidth(instr.shape().element_type());
hero.operand(0)->shape().dimensions(dimensions.size() - 2);
auto byte_width = primitive_util::ByteWidth(hero.shape().element_type());
if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension &&
byte_width * dimensions.back() *
std::min(operand_most_minor_dim,
dimensions[dimensions.size() - 2]) >=
kMinDimensionToTransposeTiled) {
return TransposeDescription{&instr, dimensions, permutation};
return TransposeDescription{&hero, dimensions, permutation};
}
} else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled &&
dimensions.back() >= kMinDimensionToTransposeTiled) ||
(operand_most_minor_dim >= kMinDimensionToTransposeTiled2 &&
dimensions.back() >= kMinDimensionToTransposeTiled2 &&
operand_most_minor_dim * dimensions.back() >=
kMinTotalDimensionsToTransposeTiled)) {
return TransposeDescription{&instr, dimensions, permutation};
return TransposeDescription{&hero, dimensions, permutation};
}
}
return std::nullopt;
}

std::optional<TransposeDescription> GetDescriptionForTiledTransposeEmitter(
const HloInstruction& hero) {
if (auto d1 = FindTiledTranspose(hero)) {
return d1;
}
if (auto d2 = FindTiledLogicalTranspose(hero)) {
return d2;
}
return std::nullopt;
}

bool IsIntermediate(const HloInstruction* instr, int allowed_operand_count) {
// Number of operands should be in range [1, allowed_operand_count].
if (instr->operand_count() == 0 ||
Expand Down Expand Up @@ -738,8 +678,7 @@ HloInstructionAdaptor FindNonTrivialHero(const HloInstructionAdaptor& instr) {
// transpose and concat emitters also work if there are elementwise ops with
// more than 1 operand on the path between root and the root op.
auto is_transpose = [](const HloInstruction& node) {
return FindTiledLogicalTranspose(node).has_value() ||
FindTiledTranspose(node).has_value();
return GetDescriptionForTiledTransposeEmitter(node).has_value();
};
if (auto transpose = FindNonTrivialHero(hero, is_transpose)) {
return *transpose;
Expand Down
131 changes: 0 additions & 131 deletions xla/service/gpu/ir_emission_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,49 +157,6 @@ ENTRY entry {
EXPECT_EQ(result->permutation, InlinedVector({1, 3, 2, 0}));
}

TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) {
const char* hlo = R"(
HloModule module
ENTRY entry {
p = s16[32,48,4]{2,1,0} parameter(0)
ROOT t = s16[32,48,4]{2,0,1} copy(p)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto& debug_options = module->mutable_config().mutable_debug_options();
debug_options.set_xla_gpu_mlir_emitter_level(3);

HloInstruction* tr = module->entry_computation()->root_instruction();

auto result = GetDescriptionForTiledTransposeEmitter(*tr);
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, tr);
EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 4}));
EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2}));
}

TEST_F(IrEmissionUtilsTest, FindTiled102TransposeTooMuchMemoryRequired) {
const char* hlo = R"(
HloModule module
ENTRY entry {
p = s8[32,48,9]{2,1,0} parameter(0)
ROOT t = s8[32,48,9]{2,0,1} copy(p)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));
auto& debug_options = module->mutable_config().mutable_debug_options();
debug_options.set_xla_gpu_mlir_emitter_level(3);

HloInstruction* tr = module->entry_computation()->root_instruction();

auto result = GetDescriptionForTiledTransposeEmitter(*tr);
EXPECT_FALSE(result.has_value());
}

TEST_F(IrEmissionUtilsTest, FindAnyTiledTranspose) {
const char* hlo = R"(
HloModule module
Expand Down Expand Up @@ -526,40 +483,6 @@ ENTRY entry {
transpose);
}

TEST_F(IrEmissionUtilsTest, FindNonTrivialCopyHeroInsideFusion) {
const char* hlo = R"(
HloModule module
f {
p0 = f32[100,200,300]{2,1,0} parameter(0)
t = f32[100,200,300]{0,1,2} copy(p0)
ROOT add = f32[100,200,300]{0,1,2} add(t, t)
}
ENTRY entry {
p0 = f32[100,200,300]{2,1,0} parameter(0)
p1 = f32[100,200,300]{0,1,2} parameter(1)
fusion = f32[100,200,300]{0,1,2} fusion(p0), kind=kLoop, calls=f
ROOT add = f32[100,200,300]{0,1,2} add(p1, fusion)
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));

HloInstruction* r = module->entry_computation()->root_instruction();
HloInstruction* copy = module->GetComputationWithName("f")
->parameter_instruction(0)
->users()
.front();
HloInstruction* fusion =
module->entry_computation()->GetInstructionWithName("fusion");
auto fusion_adaptor = HloFusionAdaptor::ForProducerConsumer(fusion, r);
EXPECT_EQ(&FindNonTrivialHero(HloInstructionAdaptor(*r, fusion_adaptor.get()))
.instruction(),
copy);
}

TEST_F(IrEmissionUtilsTest, TransposeReachableViaTrivialAndNontrivialOps) {
const char* hlo = R"(
HloModule module
Expand Down Expand Up @@ -588,33 +511,6 @@ ENTRY main {
EXPECT_EQ(&FindNonTrivialHero(*r), r);
}

TEST_F(IrEmissionUtilsTest, FindTiledTransposeOneSwapDimIsSmall) {
const char* hlo = R"(
HloModule module
fusion {
p = f32[100,11,12,8]{3,2,1,0} parameter(0)
ROOT c = f32[100,11,12,8]{1,0,2,3} copy(p)
}
ENTRY main {
param = f32[100,11,12,8]{3,2,1,0} parameter(0)
ROOT fusion = f32[100,11,12,8]{1,0,2,3} fusion(param), kind=kInput, calls=fusion
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));

HloInstruction* copy =
module->entry_computation()->root_instruction()->fused_expression_root();
auto result =
GetDescriptionForTiledTransposeEmitter(FindNonTrivialHero(*copy));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, copy);
EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100}));
EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}

TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) {
const char* hlo = R"(
HloModule module
Expand All @@ -641,33 +537,6 @@ ENTRY main {
EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}

TEST_F(IrEmissionUtilsTest, FindTiledTransposeOtherSwapDimIsSmall) {
const char* hlo = R"(
HloModule module
fusion {
p = f32[8,12,100,11]{3,2,1,0} parameter(0)
ROOT c = f32[8,12,100,11]{0,1,3,2} copy(p)
}
ENTRY main {
param = f32[8,12,100,11]{3,2,1,0} parameter(0)
ROOT fusion = f32[8,12,100,11]{0,1,3,2} fusion(param), kind=kInput, calls=fusion
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo));

HloInstruction* copy =
module->entry_computation()->root_instruction()->fused_expression_root();
auto result =
GetDescriptionForTiledTransposeEmitter(FindNonTrivialHero(*copy));
EXPECT_TRUE(result.has_value());
EXPECT_EQ(result->instr, copy);
EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8}));
EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0}));
}

TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) {
const char* hlo = R"(
HloModule module
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ xla_test(
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu/transforms:instruction_fusion",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:test_main",
],
)
Expand Down Expand Up @@ -621,7 +622,6 @@ lit_test_suite(
[
"add_preds.hlo",
"calling_convention.hlo",
"copy.hlo",
"dot_bf16.hlo",
"dynamic_update_slice_inplace.hlo",
"fused_scatter.hlo",
Expand All @@ -645,6 +645,7 @@ lit_test_suite(
"sorting.hlo",
"transpose_021.hlo",
"transpose_021_extra_output.hlo",
"transpose_10.hlo",
"transpose_210.hlo",
"transpose_210_extra_output.hlo",
"triton_naming.hlo",
Expand Down
17 changes: 9 additions & 8 deletions xla/service/gpu/tests/gpu_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <optional>
#include <vector>

#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -109,24 +110,24 @@ class TransposeFusionTest : public GpuFusionTest {
}
};

TEST_F(TransposeFusionTest, Elementary) {
TEST_F(TransposeFusionTest, ElementaryWithTranspose) {
const char* hlo = R"(
HloModule module
ENTRY main {
p = f32[16,32]{1,0} parameter(0)
s = sqrt(p)
ROOT c = f32[16,32]{0,1} copy(s)
ROOT t = f32[32,16]{1,0} transpose(s), dimensions={1,0}
}
)";

CheckGpuFusion(hlo, R"(
// CHECK: %fused_computation (param_0.1: f32[16,32]) -> f32[16,32] {
// CHECK: %fused_computation (param_0.1: f32[16,32]) -> f32[32,16] {
// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0)
// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]])
// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1_1]])
// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0}
// CHECK-NEXT: }
// CHECK: ROOT [[fusion_0:%[^ ]+]] = f32[16,32]{0,1} fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
// CHECK: ROOT [[fusion_0:%[^ ]+]] = f32[32,16]{1,0} fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
)");
}

Expand Down Expand Up @@ -161,7 +162,7 @@ ENTRY main {
p = f32[256,16]{1,0} parameter(0)
r = f32[16,16,16]{2,1,0} reshape(p)
s = sqrt(r)
ROOT c = f32[16,16,16]{1,2,0} copy(s)
ROOT t = f32[16,16,16]{2,1,0} transpose(s), dimensions={0,2,1}
}
)";

Expand All @@ -170,9 +171,9 @@ ENTRY main {
// CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[256,16]{1,0} parameter(0)
// CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[16,16,16]{2,1,0} reshape([[param_0_2_0]])
// CHECK-NEXT: [[s_1_2:%[^ ]+]] = f32[16,16,16]{2,1,0} sqrt([[r_1_1]])
// CHECK-NEXT: ROOT [[c_1_3:%[^ ]+]] = f32[16,16,16]{1,2,0} copy([[s_1_2]])
// CHECK-NEXT: ROOT [[c_1_3:%[^ ]+]] = f32[16,16,16]{2,1,0} transpose([[s_1_2]]), dimensions={0,2,1}
// CHECK-NEXT: }
// CHECK: ROOT [[fusion_0:%[^ ]+]] = f32[16,16,16]{1,2,0} fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
// CHECK: ROOT [[fusion_0:%[^ ]+]] = f32[16,16,16]{2,1,0} fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]]
)");
}

Expand Down
Loading

0 comments on commit 6fabf59

Please sign in to comment.