Skip to content

Commit

Permalink
Merge commit '7f2c56a51f3b309fd6c1afbb0cd53347ba17721f'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 5, 2024
2 parents 3a07c3e + 7f2c56a commit e7fb607
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1f20eee6dc367bd202895e3eedb03974a628ef16
86b69c31642e98f8357df62c09d118ad1da4e16a
4 changes: 2 additions & 2 deletions lib/Dialect/Triton/Transforms/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def CombineDotAddIPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFPattern : Pat<
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath, $denorm),
(Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand All @@ -29,7 +29,7 @@ def CombineDotAddIRevPattern : Pat<
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"res->hasOneUse()">, "dot result has a single use">)]>;
def CombineDotAddFRevPattern : Pat<
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath, $denorm),
(Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $inputPrecision, $maxNumImpreciseAcc), $d, $fastmath),
(TT_DotOp $a, $b, $d, $inputPrecision, $maxNumImpreciseAcc, (location $res)),
[(Constraint<CPred<"isZero($0)">> $c),
(Constraint<CPred<"::llvm::cast<::mlir::IntegerAttr>($0).getInt() == 0">> $maxNumImpreciseAcc),
Expand Down
2 changes: 1 addition & 1 deletion python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip

input_dtypes = ["float16", "float32", "float64"]
input_dtypes = ["bfloat16", "float16", "float32", "float64"]
if is_cuda():
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
Expand Down
29 changes: 29 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1999,3 +1999,32 @@ tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: te
}

}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {

tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
// CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
%2 = arith.sitofp %1 : tensor<16x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
tt.return
}

}

// -----

#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 3072 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
// CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0
// CHECK: llvm.sitofp %{{.*}} : i8 to f16
%2 = arith.sitofp %1 : tensor<32x16xi8, #ttg.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}

}

0 comments on commit e7fb607

Please sign in to comment.