Skip to content

Commit

Permalink
PR #16183: Handle CP decomposer when channel ID is not provided
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16183

When channel ID is not provided, collective-permute is across replicas. In that case, the collective permute should be split into forward and backward edges, and then combined based on the replica-id. This is similar to how collective-permute across partitions is handled.
Copybara import of the project:

--
203888a by Shraiysh Vaishay <[email protected]>:

Handle CP decomposer when channel ID is not provided

When channel ID is not provided, collective-permute is across
replicas. In that case, the collective permute should be split into
forward and backward edges, and then combined based on the replica-id.
This is similar to how collective-permute across partitions is handled.

Merging this change closes #16183

COPYBARA_INTEGRATE_REVIEW=#16183 from shraiysh:replica_collective_permute_decomposer 203888a
PiperOrigin-RevId: 666426686
  • Loading branch information
shraiysh authored and copybara-github committed Aug 22, 2024
1 parent 6942179 commit 0b700d7
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 131 deletions.
2 changes: 2 additions & 0 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ xla_cc_test(
":collective_permute_cycle_decomposer",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_parser",
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:test_utils",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest",
"@tsl//tsl/platform:statusor",
Expand Down
64 changes: 40 additions & 24 deletions xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ enum class CycleType { kUnknown, kForward, kBackward };
CycleType ShouldDecomposeWithCycleType(
const HloCollectivePermuteInstruction& collective_permute,
int64_t threshold_in_bytes) {
if (!collective_permute.channel_id().has_value()) {
return CycleType::kUnknown;
}

if (collective_permute.operand_count() != 1) {
return CycleType::kUnknown;
}
Expand Down Expand Up @@ -157,40 +153,60 @@ absl::Status DecomposeCollectivePermuteCycle(
xla::FrontendAttributes cp1_attr, cp2_attr;
TF_RETURN_IF_ERROR(GetFrontendAttributes(cp, cycle_type, cp1_attr, cp2_attr));

TF_ASSIGN_OR_RETURN(
CollectiveOpGroupMode mode,
GetCollectiveOpGroupMode(cp->channel_id().has_value(), std::nullopt));

// Create the CollectivePermute instruction for the communication represented
// by the backedge.
HloInstruction* cp1 =
computation->AddInstruction(HloInstruction::CreateCollectivePermute(
cp->shape(), cp->mutable_operand(0), backedge,
cp->channel_id().value()));
HloInstruction* cp1 = computation->AddInstruction(
HloInstruction::CreateCollectivePermute(
cp->shape(), cp->mutable_operand(0), backedge, cp->channel_id()),
"cp.backward");
cp1->set_metadata(metadata);
cp1->set_frontend_attributes(cp1_attr);
int64_t cp1_receiver = backedge.back().second;
int64_t bwd_recv_id = backedge.back().second;

// Create the CollectivePermute instruction for the communication represented
// byt other edges.
HloInstruction* cp2 =
computation->AddInstruction(HloInstruction::CreateCollectivePermute(
cp->shape(), cp->mutable_operand(0), other_edges, next_channel_id));
bool is_cross_partition = (mode == CollectiveOpGroupMode::kCrossPartition);
HloInstruction* cp2 = computation->AddInstruction(
HloInstruction::CreateCollectivePermute(
cp->shape(), cp->mutable_operand(0), other_edges,
is_cross_partition ? std::optional(next_channel_id) : std::nullopt),
"cp.forward");

cp2->set_metadata(metadata);
cp2->set_frontend_attributes(cp2_attr);

// Calculate the received data as follows:
// partition = u32[] partition-id()
// constant = u32[] constant(cp1_receiver)
// compare0 = pred[] compare(partition, cp1_received), direction=EQ
// compare = pred[?] broadcast(compare0), dimensions={}
// %partition = u32[] partition-id()
// %bwd_recv_id = u32[] constant(bwd-recv-partition-id)
// compare = pred[] compare(%partition, %bwd_recv_id), direction=EQ
// recv-data = type[?] select(compare, cp1_done, cp2_done)
HloInstruction* partition =
computation->AddInstruction(HloInstruction::CreatePartitionId());
// If the collective is across replicas, then `partition` is replaced by
// `replica = u32[] replica-id()`.
HloInstruction* partition_or_replica = nullptr;
switch (mode) {
case CollectiveOpGroupMode::kCrossReplica:
partition_or_replica =
computation->AddInstruction(HloInstruction::CreateReplicaId());
break;
case CollectiveOpGroupMode::kCrossPartition:
partition_or_replica =
computation->AddInstruction(HloInstruction::CreatePartitionId());
break;
case CollectiveOpGroupMode::kCrossReplicaAndPartition:
case CollectiveOpGroupMode::kFlattenedID:
return absl::InternalError(absl::StrFormat(
"Unexpected collective group mode for %s", cp->name()));
};
HloInstruction* constant = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, cp1_receiver)));
HloInstruction* compare0 = computation->AddInstruction(
HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), partition,
constant, Comparison::Direction::kEq));
HloInstruction::CreateConstant(LiteralUtil::CreateR0(U32, bwd_recv_id)));
HloInstruction* compare =
computation->AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(PRED, cp1->shape().dimensions()), compare0, {}));
computation->AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::MakeShape(PRED, {}), partition_or_replica, constant,
Comparison::Direction::kEq));
HloInstruction* recv_data =
computation->AddInstruction(HloInstruction::CreateTernary(
cp1->shape(), HloOpcode::kSelect, compare, cp1, cp2));
Expand Down
Loading

0 comments on commit 0b700d7

Please sign in to comment.