diff --git a/xla/tests/collective_pipeline_parallelism_test.cc b/xla/tests/collective_pipeline_parallelism_test.cc index ebdd21cb0c0cb..c431ee439fdf4 100644 --- a/xla/tests/collective_pipeline_parallelism_test.cc +++ b/xla/tests/collective_pipeline_parallelism_test.cc @@ -117,26 +117,59 @@ XLA_TEST_F(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } -// Naive implementation of pipeline parallelism: -// - 4 devices -// - 4 microbatches -// - no circular repeat -// - no disabled collectives -// - no collective pipelining -// -// Every stage of the pipeline is a single linear layer. -XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test +std::string GetModuleStrWithCommonComputations( + const std::string name, const std::string more_computations) { + static constexpr char kCommonComputationsStr[] = R"( + read_buffer_mb4 { + buffer = f32[4,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c4) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) + read_buffer_mb5 { + buffer = f32[5,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + update_buffer_mb4 { + buffer = f32[4,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c4) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0) + } + + update_buffer_mb5 { + buffer = f32[5,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) } is_input_replica { @@ -147,10 +180,40 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { is_output_replica { replica_id = u32[] replica-id() - c1 = u32[] constant(1) - ROOT predicate = pred[] compare(replica_id, c1), direction=EQ + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ } + is_read_input_mb4 { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c4 = u32[] constant(4) + is_input_iteration = pred[] compare(i, c4), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + + is_read_input_mb5 { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c5 = u32[] constant(5) + is_input_iteration = pred[] compare(i, c5), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + )"; + return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" + + more_computations; +} + +// Naive implementation of pipeline parallelism: +// - 4 devices +// - 4 microbatches +// - no circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) i = u32[] get-tuple-element(tuple), index=4 @@ -163,36 +226,34 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { weights = f32[16,16] get-tuple-element(tuple), index=0 input = f32[4,16] get-tuple-element(tuple), index=1 output = f32[4,16] get-tuple-element(tuple), index=2 - tmp = f32[16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3 i = u32[] get-tuple-element(tuple), index=4 - c1 = u32[] constant(1) c0 = u32[] constant(0) + c1 = u32[] constant(1) c4 = u32[] constant(4) - input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4 - prev_stage_slice = f32[16] collective-permute(tmp), + // Shift data to the next stage in the pipeline. + prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + // Select compute argument from previous stage or from input and perform + // compute. read_input = pred[] call(), to_apply=is_input_replica - compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index - output_slice = f32[1,16] reshape(compute_out) - output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index, - c0) + // Update buffers. + output_ = call(output, compute_res, c1, i), to_apply=update_buffer_mb4 i_ = add(i, c1) ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple( - weights, input, output_, compute_out, i_) + weights, input, output_, compute_res, i_) } ENTRY main { @@ -201,11 +262,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { cf0 = f32[] constant(0) output = f32[4,16] broadcast(cf0), dimensions={} - tmp = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights, - input, output, tmp, c0) + input, output, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -218,8 +279,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of 4 layers, each of which is a single linear layer. // We assign the weights to the replicas such that the layers scale the input @@ -257,93 +321,6 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { ErrorSpec{1e-5, 1e-5})); } -std::string GetModuleStrWithCommonComputations( - const std::string name, const std::string more_computations) { - static constexpr char kCommonComputationsStr[] = R"( - read_buffer_mb4 { - buffer = f32[4,16] parameter(0) - offset = u32[] parameter(1) - index = u32[] parameter(2) - c0 = u32[] constant(0) - c4 = u32[] constant(4) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c4) - slice = f32[1,16] dynamic-slice(buffer, index__, c0), - dynamic_slice_sizes={1,16} - ROOT slice_ = f32[16] reshape(slice) - } - - read_buffer_mb5 { - buffer = f32[5,16] parameter(0) - offset = u32[] parameter(1) - index = u32[] parameter(2) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - slice = f32[1,16] dynamic-slice(buffer, index__, c0), - dynamic_slice_sizes={1,16} - ROOT slice_ = f32[16] reshape(slice) - } - - update_buffer_mb4 { - buffer = f32[4,16] parameter(0) - update = f32[16] parameter(1) - offset = u32[] parameter(2) - index = u32[] parameter(3) - c0 = u32[] constant(0) - c4 = u32[] constant(4) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c4) - update_ = f32[1,16] reshape(update) - ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0) - } - - update_buffer_mb5 { - buffer = f32[5,16] parameter(0) - update = f32[16] parameter(1) - offset = u32[] parameter(2) - index = u32[] parameter(3) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - update_ = f32[1,16] reshape(update) - ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) - } - - is_input_replica { - replica_id = u32[] replica-id() - c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - - is_output_replica { - replica_id = u32[] replica-id() - c3 = u32[] constant(3) - ROOT predicate = pred[] compare(replica_id, c3), direction=EQ - } - - is_read_input_mb4 { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c4 = u32[] constant(4) - is_input_iteration = pred[] compare(i, c4), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) - } - - is_read_input_mb5 { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c5 = u32[] constant(5) - is_input_iteration = pred[] compare(i, c5), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) - } - )"; - return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" + - more_computations; -} - // Naive implementation of pipeline parallelism: // - 4 devices // - 5 microbatches