Skip to content

Commit

Permalink
[XLA:GPU] Do not compute suggested combiner threshold if there are no…
Browse files Browse the repository at this point in the history
… pipelined collectives in IR.

PiperOrigin-RevId: 702266560
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Dec 3, 2024
1 parent f0ca2e2 commit e9947dd
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 2 deletions.
2 changes: 1 addition & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3162,9 +3162,9 @@ xla_cc_test(
deps = [
":gpu_all_gather_combiner",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:filecheck",
"//xla/service:collective_utils",
"//xla/stream_executor:device_description",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
Expand Down
7 changes: 7 additions & 0 deletions xla/service/gpu/all_gather_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ absl::StatusOr<bool> GpuAllGatherCombiner::Run(
return AllGatherCombiner::Run(module, execution_threads);
}

// If there are no pipelined instructions in the IR, the optimizations below
// do not kick in anyway.
// Exit early so we do not perform expensive scheduling dry run below.
if (!ContainsPipelinedInstruction(*module)) {
return AllGatherCombiner::Run(module, execution_threads);
}

// Combine as much as possible for pipelined collectives.
int previous_combiner_threshold = combine_threshold_in_bytes_;
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(
Expand Down
102 changes: 101 additions & 1 deletion xla/service/gpu/all_gather_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/testlib/filecheck.h"
#include "xla/service/collective_utils.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"
Expand Down Expand Up @@ -137,6 +137,106 @@ ENTRY entry {
kExpected));
}

TEST_F(GpuAllGatherCombinerTest,
CombinesNonPipelinedCollectivesWithAFallbackCombiner) {
// The IR is the minimal valid example of a while loop with AG inside.
// All collectives are not pipelined.
constexpr absl::string_view kHloString = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(8)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
param.0 = s32[] get-tuple-element(param), index=0
param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1
param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2
param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3
param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4
param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5
param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6
param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7
zero = bf16[] constant(0)
one = s32[] constant(1)
it = s32[] add(param.0, one)
ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), dimensions={0}
ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), dimensions={0}
ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), dimensions={0}
ag.nonpipelined.3 = bf16[6,8,128] all-gather(param.nonpipelined.3),
dimensions={0}
ag.nonpipelined.4 = bf16[6,8,128] all-gather(param.nonpipelined.4),
dimensions={0}
ag.nonpipelined.6 = bf16[6,8,128] all-gather(param.nonpipelined.5),
dimensions={0}
ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, ag.nonpipelined.3, ag.nonpipelined.4, ag.nonpipelined.6, param.7)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[6,8,128] parameter(0)
p1 = bf16[3,1,2,128] parameter(1)
tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1)
while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1
}
)";
auto config =
GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2);

DeviceDescription device_info;
// Combine at most 2 collectives.
int collective_size = 2 * 6 * 8 * 128;
int threshold_bytes = 2 * collective_size;
int current_peak_mem = 90604;
int pointer_size = 4;
device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4);

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString, config));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
GpuAllGatherCombiner(
device_info, /*default_combine_threshold_in_bytes=*/
kDefaultAllGatherCombineThreshold,
/*combine_threshold_in_bytes=*/kDefaultAllGatherCombineThreshold,
/*combine_threshold_count=*/256,
/*combine_by_dim=*/false,
/*combine_different_dtypes=*/true, pointer_size)
.Run(module.get()));

VLOG(1) << module->ToString();
EXPECT_TRUE(changed);
const absl::string_view kExpected = R"(
// CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1
// CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2
// CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3
// CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4
// CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5
// CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6
// CHECK: all-gather(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]
// CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]])
)";
EXPECT_TRUE(*RunFileCheck(
module->ToString(HloPrintOptions()
.set_print_operand_shape(false)
.set_print_result_shape(false)
.set_print_operand_index_annotation_interval(10)),
kExpected));
}

TEST_F(GpuAllGatherCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) {
// The IR is the minimal valid example of a while loop with AG inside. Three
// are annotated as pipelined and three are not. Various configurations of the
Expand Down
7 changes: 7 additions & 0 deletions xla/service/gpu/all_reduce_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ absl::StatusOr<bool> GpuAllReduceCombiner::Run(
return AllReduceCombiner::Run(module, execution_threads);
}

// If there are no pipelined instructions in the IR, the optimizations below
// do not kick in anyway.
// Exit early so we do not perform expensive scheduling dry run below.
if (!ContainsPipelinedInstruction(*module)) {
return AllReduceCombiner::Run(module, execution_threads);
}

// Combine as much as possible for pipelined collectives.
int previous_combiner_threshold = combine_threshold_in_bytes_;
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(
Expand Down
95 changes: 95 additions & 0 deletions xla/service/gpu/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,101 @@ ENTRY entry {
kExpected));
}

TEST_F(GpuAllReduceCombinerTest,
CombinesNonPipelinedCollectivesWithAFallbackCombiner) {
// The IR is the minimal valid example of a while loop with RS inside.
// All collectives are not pipelined.
constexpr absl::string_view kHloString = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
while_cond {
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
gte = s32[] get-tuple-element(param), index=0
constant.1 = s32[] constant(8)
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
}
while_body {
param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128],
bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0)
param.0 = s32[] get-tuple-element(param), index=0
param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1
param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2
param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3
param.nonpipelined.3 = bf16[6,8,128] get-tuple-element(param), index=4
param.nonpipelined.4 = bf16[6,8,128] get-tuple-element(param), index=5
param.nonpipelined.5 = bf16[6,8,128] get-tuple-element(param), index=6
param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7
zero = bf16[] constant(0)
one = s32[] constant(1)
it = s32[] add(param.0, one)
ar.nonpipelined.0 = bf16[6,8,128] all-reduce(param.nonpipelined.0),
to_apply=add
ar.nonpipelined.1 = bf16[6,8,128] all-reduce(param.nonpipelined.1),
to_apply=add
ar.nonpipelined.2 = bf16[6,8,128] all-reduce(param.nonpipelined.2),
to_apply=add
ar.nonpipelined.3 = bf16[6,8,128] all-reduce(param.nonpipelined.3),
to_apply=add
ar.nonpipelined.4 = bf16[6,8,128] all-reduce(param.nonpipelined.4),
to_apply=add
ar.nonpipelined.5 = bf16[6,8,128] all-reduce(param.nonpipelined.5),
to_apply=add
ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ar.nonpipelined.0, ar.nonpipelined.1, ar.nonpipelined.2, ar.nonpipelined.3, ar.nonpipelined.4, ar.nonpipelined.5, param.7)
}
ENTRY entry {
c0 = s32[] constant(0)
p0 = bf16[6,8,128] parameter(0)
p1 = bf16[3,1,2,128] parameter(1)
tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1)
while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body
ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1
}
)";
auto config =
GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2);
DeviceDescription device_info;
int pointer_size = 4;

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloString, config));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
GpuAllReduceCombiner(
device_info, /*default_combine_threshold_in_bytes=*/
kDefaultAllReduceCombineThreshold,
/*combine_threshold_in_bytes=*/kDefaultAllReduceCombineThreshold,
/*combine_threshold_count=*/256, pointer_size)
.Run(module.get()));

VLOG(1) << module->ToString();
EXPECT_TRUE(changed);
const absl::string_view kExpected = R"(
// CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=1
// CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=2
// CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=3
// CHECK-DAG: %[[NONPIPELINED_PARAM_3:.*]] = {{.*}} index=4
// CHECK-DAG: %[[NONPIPELINED_PARAM_4:.*]] = {{.*}} index=5
// CHECK-DAG: %[[NONPIPELINED_PARAM_5:.*]] = {{.*}} index=6
// CHECK: all-reduce(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]
// CHECK-SAME: %[[NONPIPELINED_PARAM_3]], %[[NONPIPELINED_PARAM_4]], %[[NONPIPELINED_PARAM_5]])
)";
EXPECT_TRUE(*RunFileCheck(
module->ToString(HloPrintOptions()
.set_print_operand_shape(false)
.set_print_result_shape(false)
.set_print_operand_index_annotation_interval(10)),
kExpected));
}

TEST_F(GpuAllReduceCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) {
// The IR is the minimal valid example of a while loop with AR inside. Three
// are annotated as pipelined and three are not. Various configurations of the
Expand Down
16 changes: 16 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,20 @@ absl::Status AppendPipelinedInstruction(HloInstruction* instr) {
return instr->set_backend_config(config);
}

bool ContainsPipelinedInstruction(const HloModule& module) {
for (const HloComputation* computation : module.computations()) {
for (const HloInstruction* instr : computation->instructions()) {
auto backend_config = instr->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {
VLOG(2) << "Cannot read backend config for: " << instr->ToString();
continue;
}
if (backend_config->collective_backend_config().is_pipelined()) {
return true;
}
}
}
return false;
}

} // namespace xla::gpu
3 changes: 3 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ int64_t ComputeSuggestedCombinerThreshold(
// this.
absl::Status AppendPipelinedInstruction(HloInstruction* instr);

// Returns true if module contains any pipelined instruction. False otherwise.
bool ContainsPipelinedInstruction(const HloModule& module);

} // namespace xla::gpu

#endif // XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_
54 changes: 54 additions & 0 deletions xla/service/gpu/gpu_collective_combiner_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,5 +504,59 @@ TEST_F(CollectiveCombinerUtilsTest,
});
}

TEST_F(CollectiveCombinerUtilsTest,
ContainsPipelinedInstructionReturnsTrueForPipelinedInstructions) {
// The IR is the minimal valid example of a while loop with AR inside. Three
// are annotated as pipelined and three are not. Various configurations of the
// combiner are tested to ensure the expected behaviour.
constexpr absl::string_view kHloText = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
ENTRY entry {
p0 = bf16[1] parameter(0)
ROOT ar.pipelined.1 = bf16[1] all-reduce(p0),
to_apply=add,
backend_config={"collective_backend_config": {"is_pipelined": true}}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_TRUE(ContainsPipelinedInstruction(*module));
}

TEST_F(CollectiveCombinerUtilsTest,
ContainsPipelinedInstructionReturnsFalseForNonPipelinedInstructions) {
// The IR is the minimal valid example of a while loop with AR inside. Three
// are annotated as pipelined and three are not. Various configurations of the
// combiner are tested to ensure the expected behaviour.
constexpr absl::string_view kHloText = R"(
HloModule module
add {
lhs = bf16[] parameter(0)
rhs = bf16[] parameter(1)
ROOT add = bf16[] add(lhs, rhs)
}
ENTRY entry {
p0 = bf16[1] parameter(0)
ar.0 = bf16[1] all-reduce(p0),
to_apply=add
ROOT ar.1 = bf16[1] all-reduce(ar.0),
to_apply=add,
backend_config={"collective_backend_config": {"is_pipelined": false}}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
EXPECT_FALSE(ContainsPipelinedInstruction(*module));
}

} // namespace
} // namespace xla::gpu
7 changes: 7 additions & 0 deletions xla/service/gpu/reduce_scatter_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ absl::StatusOr<bool> GpuReduceScatterCombiner::Run(
return ReduceScatterCombiner::Run(module, execution_threads);
}

// If there are no pipelined instructions in the IR, the optimizations below
// do not kick in anyway.
// Exit early so we do not perform expensive scheduling dry run below.
if (!ContainsPipelinedInstruction(*module)) {
return ReduceScatterCombiner::Run(module, execution_threads);
}

// Combine as much as possible for pipelined collectives.
int previous_combiner_threshold = combine_threshold_in_bytes_;
combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold(
Expand Down
Loading

0 comments on commit e9947dd

Please sign in to comment.