Skip to content

Commit

Permalink
[XLA] Move hlo_traversal.h to hlo/utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701965479
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Dec 2, 2024
1 parent 28fe451 commit 695f4f8
Show file tree
Hide file tree
Showing 57 changed files with 125 additions and 130 deletions.
34 changes: 34 additions & 0 deletions xla/hlo/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ load(
"//xla:xla.bzl",
"xla_cc_test",
)
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
load("//xla/tsl/platform:rules_cc.bzl", "cc_library")

package(
Expand Down Expand Up @@ -194,3 +195,36 @@ xla_cc_test(
"@tsl//tsl/platform:statusor",
],
)

cc_library(
name = "hlo_traversal",
srcs = ["hlo_traversal.cc"],
hdrs = ["hlo_traversal.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)

xla_cc_test(
name = "hlo_traversal_test",
srcs = ["hlo_traversal_test.cc"],
deps = [
":hlo_traversal",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/hlo/utils/hlo_traversal.h"

#include <algorithm>
#include <cstdint>
Expand All @@ -36,7 +36,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"

namespace xla {
namespace gpu {
namespace {

template <typename F>
Expand Down Expand Up @@ -686,5 +685,4 @@ std::vector<HloInstructionAdaptor> HloFindUseChain(HloInstructionAdaptor parent,
return result;
}

} // namespace gpu
} // namespace xla
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_GPU_HLO_TRAVERSAL_H_
#define XLA_SERVICE_GPU_HLO_TRAVERSAL_H_
#ifndef XLA_HLO_UTILS_HLO_TRAVERSAL_H_
#define XLA_HLO_UTILS_HLO_TRAVERSAL_H_

#include <functional>
#include <memory>
Expand All @@ -30,7 +30,6 @@ limitations under the License.
#include "xla/shape.h"

namespace xla {
namespace gpu {

class HloFusionAdaptor;

Expand Down Expand Up @@ -225,7 +224,6 @@ bool HloAnyOf(const HloFusionAdaptor& fusion, Pred&& pred) {
std::vector<HloInstructionAdaptor> HloFindUseChain(HloInstructionAdaptor parent,
HloInstructionAdaptor root);

} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_HLO_TRAVERSAL_H_
#endif // XLA_HLO_UTILS_HLO_TRAVERSAL_H_
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/hlo/utils/hlo_traversal.h"

#include <optional>
#include <string>
Expand All @@ -29,7 +29,6 @@ limitations under the License.
#include "xla/tests/hlo_test_base.h"

namespace xla {
namespace gpu {
namespace {

namespace m = ::xla::match;
Expand Down Expand Up @@ -707,5 +706,4 @@ TEST_F(HloTraversalTest, HloFindUseChain) {
}

} // namespace
} // namespace gpu
} // namespace xla
43 changes: 5 additions & 38 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,6 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":hlo_traversal",
":target_util",
"//xla:literal",
"//xla:shape_util",
Expand All @@ -646,6 +645,7 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:backend_config",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/service/llvm_ir:llvm_type_conversion_util",
"//xla/service/llvm_ir:llvm_util",
Expand All @@ -672,14 +672,14 @@ xla_cc_test(
srcs = ["ir_emission_utils_test.cc"],
deps = [
":backend_configs_cc",
":hlo_traversal",
":ir_emission_utils",
"//xla:literal",
"//xla:literal_util",
"//xla:shape_util",
"//xla:types",
"//xla/hlo/ir:backend_config",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
Expand Down Expand Up @@ -2431,11 +2431,11 @@ cc_library(
compatible_with = get_compatible_with_portable(),
deps = [
":backend_configs_cc",
":hlo_traversal",
":ir_emission_utils",
":reduction_utils",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -2455,9 +2455,9 @@ xla_cc_test(
":backend_configs_cc",
":gpu_device_info_for_tests",
":hlo_fusion_analysis",
":hlo_traversal",
":ir_emission_utils",
"//xla:protobuf_util",
"//xla/hlo/utils:hlo_traversal",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/tests:hlo_test_base",
Expand Down Expand Up @@ -2574,7 +2574,6 @@ cc_library(
deps = [
":backend_configs_cc",
":hlo_fusion_analysis",
":hlo_traversal",
":ir_emission_utils",
":launch_dimensions",
":reduction_utils",
Expand All @@ -2583,6 +2582,7 @@ cc_library(
"//xla:util",
"//xla/hlo/analysis:hlo_dataflow_analysis",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:instruction_fusion",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/algorithm:container",
Expand Down Expand Up @@ -2944,39 +2944,6 @@ cc_library(
],
)

cc_library(
name = "hlo_traversal",
srcs = ["hlo_traversal.cc"],
hdrs = ["hlo_traversal.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
],
)

xla_cc_test(
name = "hlo_traversal_test",
srcs = ["hlo_traversal_test.cc"],
deps = [
":hlo_traversal",
"//xla/hlo/ir:hlo",
"//xla/service:pattern_matcher",
"//xla/service:pattern_matcher_gmock",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)

xla_test(
name = "determinism_test",
srcs = ["determinism_test.cc"],
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cc_library(
"//xla/hlo/pass:hlo_pass_pipeline",
"//xla/hlo/transforms:float_normalization",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_traversal",
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/service:algorithm_util",
"//xla/service:call_inliner",
Expand All @@ -141,7 +142,6 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:buffer_comparator",
"//xla/service/gpu:gpu_float_support",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_indexing_utils",
"//xla/service/gpu:matmul_utils",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ limitations under the License.
#include "xla/hlo/pass/hlo_pass_pipeline.h"
#include "xla/hlo/transforms/simplifiers/float_normalization.h"
#include "xla/hlo/utils/hlo_query.h"
#include "xla/hlo/utils/hlo_traversal.h"
#include "xla/pjrt/distributed/key_value_store_interface.h"
#include "xla/primitive_util.h"
#include "xla/service/algorithm_util.h"
Expand All @@ -65,7 +66,6 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/buffer_comparator.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/kernels/custom_kernel.h"
#include "xla/service/gpu/kernels/custom_kernel_fusion.h"
Expand Down
18 changes: 9 additions & 9 deletions xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ cc_library(
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
Expand All @@ -44,9 +44,9 @@ cc_library(
":fusion_emitter",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu/runtime:copy_thunk",
"//xla/service/gpu/runtime:thunk",
Expand All @@ -72,6 +72,7 @@ cc_library(
"//xla/ffi:attribute_map",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
Expand All @@ -80,7 +81,6 @@ cc_library(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:cublas_cudnn",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu:kernel_arguments",
Expand Down Expand Up @@ -222,10 +222,10 @@ cc_library(
":triton",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:buffer_assignment",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status:statusor",
Expand All @@ -242,9 +242,9 @@ cc_library(
"//xla:status_macros",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
Expand Down Expand Up @@ -338,9 +338,9 @@ cc_library(
"//xla:shape_util",
"//xla:status_macros",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu:kernel_arguments",
Expand Down Expand Up @@ -466,9 +466,9 @@ cc_library(
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:gpu_fusible",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:reduction_utils",
Expand Down Expand Up @@ -499,8 +499,8 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:reduction_utils",
Expand Down Expand Up @@ -577,8 +577,8 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:hlo_traversal",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu/fusions/ir:xla_gpu",
"//xla/service/gpu/fusions/mlir:computation_partitioner",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/utils/hlo_traversal.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/runtime/copy_thunk.h"
#include "xla/service/gpu/runtime/thunk.h"
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/fusions/custom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/utils/hlo_traversal.h"
#include "xla/literal.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/custom_call_status.h"
Expand All @@ -50,7 +51,6 @@ limitations under the License.
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/fusions/fusion_emitter.h"
#include "xla/service/gpu/hlo_fusion_analysis.h"
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/kernel_arguments.h"
Expand Down
Loading

0 comments on commit 695f4f8

Please sign in to comment.