From 224cbf7167186fdaff74c3765c93b536c6010ba7 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 29 Nov 2024 07:07:42 -0800 Subject: [PATCH] [XLA:GPU] Do not compute suggested combiner threshold if there are no pipelined collectives in IR. PiperOrigin-RevId: 701281462 --- xla/service/gpu/BUILD | 2 +- xla/service/gpu/all_gather_combiner.cc | 4 + xla/service/gpu/all_gather_combiner_test.cc | 102 +++++++++++++++++- xla/service/gpu/all_reduce_combiner.cc | 4 + xla/service/gpu/all_reduce_combiner_test.cc | 95 ++++++++++++++++ .../gpu/gpu_collective_combiner_utils.cc | 16 +++ .../gpu/gpu_collective_combiner_utils.h | 3 + .../gpu/gpu_collective_combiner_utils_test.cc | 54 ++++++++++ xla/service/gpu/reduce_scatter_combiner.cc | 4 + .../gpu/reduce_scatter_combiner_test.cc | 96 +++++++++++++++++ 10 files changed, 378 insertions(+), 2 deletions(-) diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 488f54b927f8f..da60cc32eab6d 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -3158,9 +3158,9 @@ xla_cc_test( deps = [ ":gpu_all_gather_combiner", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service:collective_utils", "//xla/stream_executor:device_description", - "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", diff --git a/xla/service/gpu/all_gather_combiner.cc b/xla/service/gpu/all_gather_combiner.cc index de32d4c40da37..d96e752783685 100644 --- a/xla/service/gpu/all_gather_combiner.cc +++ b/xla/service/gpu/all_gather_combiner.cc @@ -68,6 +68,10 @@ absl::StatusOr GpuAllGatherCombiner::Run( return AllGatherCombiner::Run(module, execution_threads); } + if (!ContainsPipelinedInstruction(*module)) { + return AllGatherCombiner::Run(module, execution_threads); + } + // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( diff --git a/xla/service/gpu/all_gather_combiner_test.cc b/xla/service/gpu/all_gather_combiner_test.cc index 7ff76b5e6ea5b..d254c4572904e 100644 --- a/xla/service/gpu/all_gather_combiner_test.cc +++ b/xla/service/gpu/all_gather_combiner_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/collective_utils.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -137,6 +137,106 @@ ENTRY entry { kExpected)); } +TEST_F(GpuAllGatherCombinerTest, + CombinesNonPipelinedCollectivesWithAFallbackCombiner) { + // The IR is the minimal valid example of a while loop with AG inside. + // All collectives are not pipelined. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), dimensions={0} + ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), dimensions={0} + ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), dimensions={0} + ag.nonpipelined.3 = bf16[6,8,128] all-gather(param.nonpipelined.3), + dimensions={0} + ag.nonpipelined.4 = bf16[6,8,128] all-gather(param.nonpipelined.4), + dimensions={0} + ag.nonpipelined.6 = bf16[6,8,128] all-gather(param.nonpipelined.5), + dimensions={0} + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, ag.nonpipelined.3, ag.nonpipelined.4, ag.nonpipelined.6, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 90604; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuAllGatherCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultAllGatherCombineThreshold, + /*combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, + /*combine_different_dtypes=*/true, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6 + // CHECK: all-gather(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]] + // CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]]) + )"; + EXPECT_TRUE(*RunFileCheck( + module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false) + .set_print_operand_index_annotation_interval(10)), + kExpected)); +} + TEST_F(GpuAllGatherCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) { // The IR is the minimal valid example of a while loop with AG inside. Three // are annotated as pipelined and three are not. Various configurations of the diff --git a/xla/service/gpu/all_reduce_combiner.cc b/xla/service/gpu/all_reduce_combiner.cc index 3c1df55abe702..a48d13b0d1dbe 100644 --- a/xla/service/gpu/all_reduce_combiner.cc +++ b/xla/service/gpu/all_reduce_combiner.cc @@ -66,6 +66,10 @@ absl::StatusOr GpuAllReduceCombiner::Run( return AllReduceCombiner::Run(module, execution_threads); } + if (!ContainsPipelinedInstruction(*module)) { + return AllReduceCombiner::Run(module, execution_threads); + } + // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( diff --git a/xla/service/gpu/all_reduce_combiner_test.cc b/xla/service/gpu/all_reduce_combiner_test.cc index b9733d35d84b4..4d821d1b84ffd 100644 --- a/xla/service/gpu/all_reduce_combiner_test.cc +++ b/xla/service/gpu/all_reduce_combiner_test.cc @@ -137,6 +137,101 @@ ENTRY entry { kExpected)); } +TEST_F(GpuAllReduceCombinerTest, + CombinesNonPipelinedCollectivesWithAFallbackCombiner) { + // The IR is the minimal valid example of a while loop with RS inside. + // All collectives are not pipelined. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ar.nonpipelined.0 = bf16[6,8,128] all-reduce(param.nonpipelined.0), + to_apply=add + ar.nonpipelined.1 = bf16[6,8,128] all-reduce(param.nonpipelined.1), + to_apply=add + ar.nonpipelined.2 = bf16[6,8,128] all-reduce(param.nonpipelined.2), + to_apply=add + ar.nonpipelined.3 = bf16[6,8,128] all-reduce(param.nonpipelined.3), + to_apply=add + ar.nonpipelined.4 = bf16[6,8,128] all-reduce(param.nonpipelined.4), + to_apply=add + ar.nonpipelined.5 = bf16[6,8,128] all-reduce(param.nonpipelined.5), + to_apply=add + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ar.nonpipelined.0, ar.nonpipelined.1, ar.nonpipelined.2, ar.nonpipelined.3, ar.nonpipelined.4, ar.nonpipelined.5, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + DeviceDescription device_info; + int pointer_size = 4; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuAllReduceCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultAllReduceCombineThreshold, + /*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold, + /*combine_threshold_count=*/256, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6 + // CHECK: all-reduce(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]] + // CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]]) + )"; + EXPECT_TRUE(*RunFileCheck( + module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false) + .set_print_operand_index_annotation_interval(10)), + kExpected)); +} + TEST_F(GpuAllReduceCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) { // The IR is the minimal valid example of a while loop with AR inside. Three // are annotated as pipelined and three are not. Various configurations of the diff --git a/xla/service/gpu/gpu_collective_combiner_utils.cc b/xla/service/gpu/gpu_collective_combiner_utils.cc index 8e1844d33c10a..d789b652df6d4 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils.cc @@ -80,4 +80,20 @@ absl::Status AppendPipelinedInstruction(HloInstruction* instr) { return instr->set_backend_config(config); } +bool ContainsPipelinedInstruction(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instr : computation->instructions()) { + auto backend_config = instr->backend_config(); + if (!backend_config.ok()) { + VLOG(2) << "Cannot read backend config for: " << instr->ToString(); + continue; + } + if (backend_config->collective_backend_config().is_pipelined()) { + return true; + } + } + } + return false; +} + } // namespace xla::gpu diff --git a/xla/service/gpu/gpu_collective_combiner_utils.h b/xla/service/gpu/gpu_collective_combiner_utils.h index 171f132dae541..38a7890decb59 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.h +++ b/xla/service/gpu/gpu_collective_combiner_utils.h @@ -46,6 +46,9 @@ int64_t ComputeSuggestedCombinerThreshold( // this. absl::Status AppendPipelinedInstruction(HloInstruction* instr); +// Returns true if module contains any pipelined instruction. False otherwise. +bool ContainsPipelinedInstruction(const HloModule& module); + } // namespace xla::gpu #endif // XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ diff --git a/xla/service/gpu/gpu_collective_combiner_utils_test.cc b/xla/service/gpu/gpu_collective_combiner_utils_test.cc index 7ebf8baf271b2..f0b213f343e58 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils_test.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils_test.cc @@ -504,5 +504,59 @@ TEST_F(CollectiveCombinerUtilsTest, }); } +TEST_F(CollectiveCombinerUtilsTest, + ContainsPipelinedInstructionReturnsTrueForPipelinedInstructions) { + // The IR is the minimal valid example of a while loop with AR inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloText = R"( + HloModule module + + add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) + } + + ENTRY entry { + p0 = bf16[1] parameter(0) + ROOT ar.pipelined.1 = bf16[1] all-reduce(p0), + to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_TRUE(ContainsPipelinedInstruction(*module)); +} + +TEST_F(CollectiveCombinerUtilsTest, + ContainsPipelinedInstructionReturnsFalseForNonPipelinedInstructions) { + // The IR is the minimal valid example of a while loop with AR inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloText = R"( + HloModule module + + add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) + } + + ENTRY entry { + p0 = bf16[1] parameter(0) + ar.0 = bf16[1] all-reduce(p0), + to_apply=add + ROOT ar.1 = bf16[1] all-reduce(ar.0), + to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": false}} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + EXPECT_FALSE(ContainsPipelinedInstruction(*module)); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/gpu/reduce_scatter_combiner.cc b/xla/service/gpu/reduce_scatter_combiner.cc index 7a5d509f6df62..969a68319356a 100644 --- a/xla/service/gpu/reduce_scatter_combiner.cc +++ b/xla/service/gpu/reduce_scatter_combiner.cc @@ -66,6 +66,10 @@ absl::StatusOr GpuReduceScatterCombiner::Run( return ReduceScatterCombiner::Run(module, execution_threads); } + if (!ContainsPipelinedInstruction(*module)) { + return ReduceScatterCombiner::Run(module, execution_threads); + } + // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( diff --git a/xla/service/gpu/reduce_scatter_combiner_test.cc b/xla/service/gpu/reduce_scatter_combiner_test.cc index d03701ec810ed..8e73cd4e6f41b 100644 --- a/xla/service/gpu/reduce_scatter_combiner_test.cc +++ b/xla/service/gpu/reduce_scatter_combiner_test.cc @@ -138,6 +138,102 @@ ENTRY entry { kExpected)); } +TEST_F(GpuReduceScatterCombinerTest, + CombinesNonPipelinedCollectivesWithAFallbackCombiner) { + // The IR is the minimal valid example of a while loop with RS inside. + // All collectives are not pipelined. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.nonpipelined.0 = bf16[6,8,128] reduce-scatter(param.nonpipelined.0), + dimensions={0}, to_apply=add + ag.nonpipelined.1 = bf16[6,8,128] reduce-scatter(param.nonpipelined.1), + dimensions={0}, to_apply=add + ag.nonpipelined.2 = bf16[6,8,128] reduce-scatter(param.nonpipelined.2), + to_apply=add, dimensions={0} + ag.nonpipelined.3 = bf16[6,8,128] reduce-scatter(param.nonpipelined.3), + dimensions={0}, to_apply=add + ag.nonpipelined.4 = bf16[6,8,128] reduce-scatter(param.nonpipelined.4), + dimensions={0}, to_apply=add + ag.nonpipelined.5 = bf16[6,8,128] reduce-scatter(param.nonpipelined.5), + dimensions={0}, to_apply=add + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, ag.nonpipelined.3, ag.nonpipelined.4, ag.nonpipelined.5, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + DeviceDescription device_info; + int pointer_size = 4; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuReduceScatterCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultReduceScatterCombineThreshold, + /*combine_threshold_in_bytes=*/kDefaultReduceScatterCombineThreshold, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6 + // CHECK: reduce-scatter(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]] + // CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]]) + )"; + EXPECT_TRUE(*RunFileCheck( + module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false) + .set_print_operand_index_annotation_interval(10)), + kExpected)); +} + TEST_F(GpuReduceScatterCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) { // The IR is the minimal valid example of a while loop with RS inside. Three