Skip to content

Commit

Permalink
Fix GpuCompilerPassTest.GpuCompilerRunsTritonGemmRewriterByDefaultFro…
Browse files Browse the repository at this point in the history
…mAmpere on ROCm

The gemm rewriter also runs on ROCm hardware, so the assumption "AtLeastAmpere" is only true for NVIDIA.

This change extends it to ROCm.

PiperOrigin-RevId: 666320928
  • Loading branch information
beckerhe authored and copybara-github committed Aug 22, 2024
1 parent 1aaac40 commit b1526c2
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include <gmock/gmock.h>
Expand Down Expand Up @@ -971,13 +972,20 @@ TEST_F(GpuCompilerTest, TestFlag_xla_gpu_unsafe_pipelined_loop_annotator) {
using GpuCompilerPassTest = GpuCompilerTest;

TEST_F(GpuCompilerPassTest,
GpuCompilerRunsTritonGemmRewriterByDefaultFromAmpere) {
auto cc = backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();

bool expect_triton_gemm_rewriter_has_run = cc.IsAtLeastAmpere();
GpuCompilerRunsTritonGemmRewriterByDefaultOnSupportedGPUs) {
auto cuda_cc = std::get_if<stream_executor::CudaComputeCapability>(
&backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability());
auto rocm_cc = std::get_if<stream_executor::RocmComputeCapability>(
&backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability());

bool expect_triton_gemm_rewriter_has_run =
(cuda_cc && cuda_cc->IsAtLeastAmpere()) || (rocm_cc);

constexpr absl::string_view constant_module = R"(
HloModule noop
Expand Down

0 comments on commit b1526c2

Please sign in to comment.