Skip to content

Commit

Permalink
[XLA:GPU] Share common HLO computations between all of the pipeline t…
Browse files Browse the repository at this point in the history
…ests

PiperOrigin-RevId: 660552175
  • Loading branch information
frgossen authored and copybara-github committed Aug 7, 2024
1 parent 0a4d157 commit be3181c
Showing 1 changed file with 104 additions and 127 deletions.
231 changes: 104 additions & 127 deletions xla/tests/collective_pipeline_parallelism_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,26 +117,59 @@ XLA_TEST_F(CollectivePipelineParallelismTest,
LiteralTestUtil::ExpectR2Equal<float>({{0, 0}, {1, 1}}, results[3]);
}

// Naive implementation of pipeline parallelism:
// - 4 devices
// - 4 microbatches
// - no circular repeat
// - no disabled collectives
// - no collective pipelining
//
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
const absl::string_view kModuleStr = R"(
HloModule test
std::string GetModuleStrWithCommonComputations(
const std::string name, const std::string more_computations) {
static constexpr char kCommonComputationsStr[] = R"(
read_buffer_mb4 {
buffer = f32[4,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c4 = u32[] constant(4)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c4)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
get_circ_buffer_index {
offset = u32[] parameter(0)
index = u32[] parameter(1)
size = u32[] parameter(2)
t0 = u32[] add(offset, index)
t1 = u32[] divide(t0, size)
t2 = u32[] multiply(t1, size)
ROOT t4 = u32[] subtract(t0, t2)
read_buffer_mb5 {
buffer = f32[5,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
update_buffer_mb4 {
buffer = f32[4,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c4 = u32[] constant(4)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c4)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0)
}
update_buffer_mb5 {
buffer = f32[5,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
Expand All @@ -147,10 +180,40 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
is_output_replica {
replica_id = u32[] replica-id()
c1 = u32[] constant(1)
ROOT predicate = pred[] compare(replica_id, c1), direction=EQ
c3 = u32[] constant(3)
ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
is_read_input_mb4 {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c4 = u32[] constant(4)
is_input_iteration = pred[] compare(i, c4), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
is_read_input_mb5 {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c5 = u32[] constant(5)
is_input_iteration = pred[] compare(i, c5), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
)";
return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" +
more_computations;
}

// Naive implementation of pipeline parallelism:
// - 4 devices
// - 4 microbatches
// - no circular repeat
// - no disabled collectives
// - no collective pipelining
//
// Every stage of the pipeline is a single linear layer.
XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
constexpr char kMoreComputationsStr[] = R"(
while_condition {
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0)
i = u32[] get-tuple-element(tuple), index=4
Expand All @@ -163,36 +226,34 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
weights = f32[16,16] get-tuple-element(tuple), index=0
input = f32[4,16] get-tuple-element(tuple), index=1
output = f32[4,16] get-tuple-element(tuple), index=2
tmp = f32[16] get-tuple-element(tuple), index=3
prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3
i = u32[] get-tuple-element(tuple), index=4
c1 = u32[] constant(1)
c0 = u32[] constant(0)
c1 = u32[] constant(1)
c4 = u32[] constant(4)
input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index
input_slice = f32[1,16] dynamic-slice(input, input_idx, c0),
dynamic_slice_sizes={1,16}
input_slice_ = f32[16] reshape(input_slice)
// Read from buffers.
input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4
prev_stage_slice = f32[16] collective-permute(tmp),
// Shift data to the next stage in the pipeline.
prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res),
source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}
// Select compute argument from previous stage or from input and perform
// compute.
read_input = pred[] call(), to_apply=is_input_replica
compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice)
compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1},
compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice)
compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1},
rhs_contracting_dims={0}
output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index
output_slice = f32[1,16] reshape(compute_out)
output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index,
c0)
// Update buffers.
output_ = call(output, compute_res, c1, i), to_apply=update_buffer_mb4
i_ = add(i, c1)
ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(
weights, input, output_, compute_out, i_)
weights, input, output_, compute_res, i_)
}
ENTRY main {
Expand All @@ -201,11 +262,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
cf0 = f32[] constant(0)
output = f32[4,16] broadcast(cf0), dimensions={}
tmp = f32[16] broadcast(cf0), dimensions={}
prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={}
c0 = u32[] constant(0)
tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights,
input, output, tmp, c0)
input, output, prev_iteration_compute_res, c0)
tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple),
condition=while_condition, body=while_body
Expand All @@ -218,8 +279,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations(
/*name=*/"test", kMoreComputationsStr),
config));

// This pipeline consists of 4 layers, each of which is a single linear layer.
// We assign the weights to the replicas such that the layers scale the input
Expand Down Expand Up @@ -257,93 +321,6 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) {
ErrorSpec{1e-5, 1e-5}));
}

std::string GetModuleStrWithCommonComputations(
const std::string name, const std::string more_computations) {
static constexpr char kCommonComputationsStr[] = R"(
read_buffer_mb4 {
buffer = f32[4,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c4 = u32[] constant(4)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c4)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
read_buffer_mb5 {
buffer = f32[5,16] parameter(0)
offset = u32[] parameter(1)
index = u32[] parameter(2)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
slice = f32[1,16] dynamic-slice(buffer, index__, c0),
dynamic_slice_sizes={1,16}
ROOT slice_ = f32[16] reshape(slice)
}
update_buffer_mb4 {
buffer = f32[4,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c4 = u32[] constant(4)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c4)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0)
}
update_buffer_mb5 {
buffer = f32[5,16] parameter(0)
update = f32[16] parameter(1)
offset = u32[] parameter(2)
index = u32[] parameter(3)
c0 = u32[] constant(0)
c5 = u32[] constant(5)
index_ = u32[] add(index, offset)
index__ = u32[] remainder(index_, c5)
update_ = f32[1,16] reshape(update)
ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0)
}
is_input_replica {
replica_id = u32[] replica-id()
c0 = u32[] constant(0)
ROOT predicate = pred[] compare(replica_id, c0), direction=EQ
}
is_output_replica {
replica_id = u32[] replica-id()
c3 = u32[] constant(3)
ROOT predicate = pred[] compare(replica_id, c3), direction=EQ
}
is_read_input_mb4 {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c4 = u32[] constant(4)
is_input_iteration = pred[] compare(i, c4), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
is_read_input_mb5 {
is_input_replica = pred[] call(), to_apply=is_input_replica
i = u32[] parameter(0)
c5 = u32[] constant(5)
is_input_iteration = pred[] compare(i, c5), direction=LT
ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration)
}
)";
return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" +
more_computations;
}

// Naive implementation of pipeline parallelism:
// - 4 devices
// - 5 microbatches
Expand Down

0 comments on commit be3181c

Please sign in to comment.