Skip to content

Commit

Permalink
[RollForward][XLA][HostOffloader] Remove redundant copies to and from…
Browse files Browse the repository at this point in the history
… host for host offloaded computation outputs

The simple algorithm tracks usages of all outputs of each host offloaded computation. For each:
- If they are ONLY used on the host and they are outputs of the entry computation, it sets the memory space to Host.
- If they are ONLY used on the host, but are temporaries, no changes are made.
- For cases replaced, if a MoveToHost is found (NOTE: that the algorithm does not explicitly check that any exist nor that all paths
lead to a MoveToHost) for an output that is only used on the host, we simply replace the usage.

Reverts 4df18bd

PiperOrigin-RevId: 666487090
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Aug 22, 2024
1 parent 21f90d1 commit 11da26a
Show file tree
Hide file tree
Showing 6 changed files with 700 additions and 151 deletions.
2 changes: 2 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6473,6 +6473,7 @@ cc_library(
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -6501,6 +6502,7 @@ xla_cc_test(
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
5 changes: 5 additions & 0 deletions xla/service/host_offload_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,10 @@ std::string InstructionAndShapeIndex::ToString() const {
shape_index.ToString());
}

bool IsHostAsyncStart(const HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kAsyncStart &&
instruction->async_execution_thread() == HloInstruction::kHostThread;
}

} // namespace host_offload_utils
} // namespace xla
3 changes: 3 additions & 0 deletions xla/service/host_offload_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ std::vector<InstructionAndShapeIndex> GetPredecessors(
// middle of a pure memory offload path.
bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction);

// Returns true if the instruction is an async-start with host thread.
bool IsHostAsyncStart(const HloInstruction* instruction);

} // namespace host_offload_utils
} // namespace xla

Expand Down
210 changes: 209 additions & 1 deletion xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "xla/service/host_offloader.h"

#include <array>
#include <cstddef>
#include <cstdint>
#include <iomanip>
#include <memory>
Expand All @@ -26,15 +27,19 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/literal_util.h"
#include "xla/service/call_graph.h"
Expand Down Expand Up @@ -782,17 +787,220 @@ absl::StatusOr<bool> HostOffloader::ApplySchedulingFix(
return changed;
}

namespace {

absl::Status ValidateAsyncComputationStructure(HloComputation* computation) {
for (HloInstruction* instr : computation->instructions()) {
if (instr->opcode() == HloOpcode::kParameter || instr->IsRoot()) {
continue;
}

return absl::InternalError(
absl::StrCat("Unexpected instruction found in async computation: ",
instr->ToString()));
}

return absl::OkStatus();
}

// Updates memory space for all outputs of the host offloaded computation
// (associated with `call_start`) that are ONLY used on host. NOTE: We also
// remove redundant copies to host, if any.
absl::StatusOr<bool> UpdateMemorySpaceForHostOffloadedOutputs(
HloInstruction* call_start,
absl::flat_hash_map<size_t, std::vector<InstructionAndShapeIndex>>&
host_instr) {
// Keep track of MoveToHost instructions that need to be removed.
std::vector<InstructionAndShapeIndex> to_replace;

HloComputation* called_computation = call_start->async_wrapped_computation();
TF_RETURN_IF_ERROR(ValidateAsyncComputationStructure(called_computation));
HloInstruction* root = called_computation->root_instruction();
Shape* root_shape = root->mutable_shape();

for (auto& pair : host_instr) {
std::vector<InstructionAndShapeIndex>& instruction_and_shape_indexes =
pair.second;

for (InstructionAndShapeIndex& instr_and_shape :
instruction_and_shape_indexes) {
// If instruction is MoveToHost, we will replace usage.
if (instr_and_shape.instruction->IsCustomCall(
host_memory_offload_annotations::kMoveToHostCustomCallTarget)) {
to_replace.emplace_back(instr_and_shape);
continue;
}

SetMemorySpace(ShapeUtil::GetMutableSubshape(
instr_and_shape.instruction->mutable_shape(),
instr_and_shape.shape_index),
Layout::kHostMemorySpace);
}

// Update the memory space for the output of the computation call itself.
size_t index = pair.first;
SetMemorySpace(root_shape->mutable_tuple_shapes(index),
Layout::kHostMemorySpace);
}

// Remove MoveToHost usage.
for (InstructionAndShapeIndex& instr_and_shape : to_replace) {
HloInstruction* pred = instr_and_shape.instruction->mutable_operand(0);
TF_RETURN_IF_ERROR(instr_and_shape.instruction->ReplaceAllUsesWith(pred));
}

return !host_instr.empty();
}

// Additional checks (does not run IsValidDuringPureMemoryOffload) to determine
// if the respective tensor can be on host.
bool ExtraCheckForValidUsageOnHostForHostOffloadedOutputs(
const Shape& entry_computation_shape,
InstructionAndShapeIndex& instruction_and_shape_index) {
HloInstruction* instruction = instruction_and_shape_index.instruction;
ShapeIndex& shape_index = instruction_and_shape_index.shape_index;

// We respect entry computation layout. So for the cases where the
// outputs are not expected on host, we bail.
if (instruction->IsRoot() && instruction->parent()->IsEntryComputation()) {
if (ShapeUtil::GetSubshape(entry_computation_shape, shape_index)
.layout()
.memory_space() != Layout::kHostMemorySpace) {
return false;
}
}

// For custom calls, we conservatively only accept MoveToHost.
// For MoveToDevice, this could be re-considered, or done as part of a
// generic redundant copies removal.
if (instruction->opcode() == HloOpcode::kCustomCall &&
instruction->custom_call_target() !=
host_memory_offload_annotations::kMoveToHostCustomCallTarget) {
return false;
}

// TODO(b/347101407): To also consider host async computations, as we
// extend GetSuccessors to properly treat it.
if (instruction->opcode() == HloOpcode::kAsyncStart ||
instruction->opcode() == HloOpcode::kAsyncDone) {
return false;
}

return true;
}

} // namespace

absl::StatusOr<bool> HostOffloader::HandleRedundantCopiesBackToHost(
const HloModule* module, HloInstruction* instruction) {
HloAsyncInstruction* call_start = Cast<HloAsyncInstruction>(instruction);

CHECK_EQ(call_start->users().size(), 1);
HloInstruction* call_done = call_start->users()[0];

absl::flat_hash_map<size_t, std::vector<InstructionAndShapeIndex>>
host_instrs;
const Shape& entry_computation_shape =
module->entry_computation_layout().result_layout().shape();

// We collect all usages per output index, stopping at any non host
// instruction.
const Shape& done_shape = call_done->shape();
for (size_t index = 0; index < done_shape.tuple_shapes_size(); index++) {
ShapeIndex output_shape_index = {static_cast<int64_t>(index)};
std::queue<InstructionAndShapeIndex> queue;
queue.push(InstructionAndShapeIndex(call_done, output_shape_index));

// async-start packs the (inputs, outputs, context) in a tuple.
constexpr int64_t kShapeTupleOutputIndexInAsyncStart = 1;
ShapeIndex start_shape_index = {kShapeTupleOutputIndexInAsyncStart,
static_cast<int64_t>(index)};

// TODO(b/347101407): Start from async-start and trace through the
// computation as well in GetSuccessors instead of having to manually add
// async-done and update the async computation separately.
host_instrs[index].push_back(
InstructionAndShapeIndex(call_start, start_shape_index));
host_instrs[index].push_back(
InstructionAndShapeIndex(call_done, output_shape_index));

bool host_only = true;
// Keep track if the output of the host offloading computation is also an
// output of the entry computation. Temporaries are conservatively kept on
// HBM.
//
// TODO(b/347101407): Better use AliasAnalysis here to trace host compute
// outputs to entry compute outputs instead. NOTE: The current algorithm
// only tracks accepted host offloading operations which operate on the same
// tensor.
bool entry_compute_output = false;

while (!queue.empty() && host_only) {
InstructionAndShapeIndex instruction_and_shape_index = queue.front();
queue.pop();

TF_ASSIGN_OR_RETURN(
std::vector<InstructionAndShapeIndex> successors,
host_offload_utils::GetSuccessors(InstructionAndShapeIndex(
instruction_and_shape_index.instruction,
instruction_and_shape_index.shape_index)));

// Check if any of the successors needs to be on device.
for (InstructionAndShapeIndex& successor : successors) {
if (!host_offload_utils::IsValidDuringPureMemoryOffload(
successor.instruction) ||
!ExtraCheckForValidUsageOnHostForHostOffloadedOutputs(
entry_computation_shape, successor)) {
host_only = false;
break;
}

if (successor.instruction->IsRoot() &&
successor.instruction->parent()->IsEntryComputation()) {
entry_compute_output = true;
}

queue.push(successor);
host_instrs[index].emplace_back(successor);
}
}

if (!host_only || !entry_compute_output) {
host_instrs.erase(index);
}
}

// Update memory space for the host_offloading outputs that never get used on
// device.
return UpdateMemorySpaceForHostOffloadedOutputs(call_start, host_instrs);
}

absl::StatusOr<bool> HostOffloader::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;

// First remove redundant copies to and from host (conservatively) starting
// from the outputs of the host offloaded computations. Iterate over all
// instructions and look for XLA host offload annotations.
bool changed_in_loop;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
for (HloInstruction* instruction : computation->instructions()) {
if (host_offload_utils::IsHostAsyncStart(instruction)) {
TF_ASSIGN_OR_RETURN(changed_in_loop, HandleRedundantCopiesBackToHost(
module, instruction));
changed = changed || changed_in_loop;
}
}
}

TF_ASSIGN_OR_RETURN(const bool input_streaming_changed_module,
HandleInputStreaming(module->entry_computation()));
changed = changed || input_streaming_changed_module;

// Since we're modifying the graph as we iterate over it, any time we change
// it, we need to re-run the loop.
bool changed_in_loop;
do {
changed_in_loop = false;
for (HloComputation* computation :
Expand Down
19 changes: 17 additions & 2 deletions xla/service/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,17 @@ class HloCostAnalysis;
// tensors along each path have their memory space set as host memory space. If
// a MoveToHost custom call is paired with a DynamicUpdateSlice, the
// DynamicUpdateSlice will write into host memory space. Otherwise, a copy from
// device to host will be inserted. All MoveToHost and MoveToDevice custom calls
// are removed by the end of this pass.
// device to host will be inserted.
//
// If an output of a host offloaded computation is only used on host, the memory
// space of the usages are updated to reflect it and no copies to and from host
// are performed. Any MoveToHost instructions for outputs used only on host, are
// removed.
// TODO(b/347101407): A better approach could be to remove redundant copies in a
// generalized fashion. Should also be moved out of Host Offloader.
//
// All MoveToHost and MoveToDevice custom calls are removed by the end of this
// pass.
class HostOffloader : public HloModulePass {
public:
explicit HostOffloader(int64_t host_memory_space_color)
Expand Down Expand Up @@ -143,6 +152,12 @@ class HostOffloader : public HloModulePass {
absl::StatusOr<bool> ApplySchedulingFix(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads);

// Starting from the outputs of the host offloaded computation, track all
// their usages. For the outputs that are ONLY used on host, remove redundant
// copies to and from host, as well as update the memory space.
absl::StatusOr<bool> HandleRedundantCopiesBackToHost(
const HloModule* module, HloInstruction* instruction);
};

} // namespace xla
Expand Down
Loading

0 comments on commit 11da26a

Please sign in to comment.