Skip to content

Commit

Permalink
[TritonGEN]: Add operation for subgroup_scan_[ex|in]clusive (#1506)
Browse files Browse the repository at this point in the history
Signed-off-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
etiotto authored Jun 28, 2024
1 parent 3000054 commit 9d2c1bf
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 46 deletions.
58 changes: 58 additions & 0 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,64 @@ module attributes {

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_addi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_muli(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_maxi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_mini(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_andi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z27sub_group_scan_exclusive_ori(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_exclusive_xori(i32) -> i32 attributes {passthrough = ["convergent"]}

// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_addi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_muli(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_maxi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_mini(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_andi(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z27sub_group_scan_inclusive_ori(i32) -> i32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z28sub_group_scan_inclusive_xori(i32) -> i32 attributes {passthrough = ["convergent"]}

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_scan() {
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_addi([[VAL]]) {{.*}} : (i32) -> i32
%1 = triton_gen.sub_group_scan add %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_muli([[VAL]]) {{.*}} : (i32) -> i32
%2 = triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_mini([[VAL]]) {{.*}} : (i32) -> i32
%3 = triton_gen.sub_group_scan min %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_maxi([[VAL]]) {{.*}} : (i32) -> i32
%4 = triton_gen.sub_group_scan max %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_andi([[VAL]]) {{.*}} : (i32) -> i32
%5 = triton_gen.sub_group_scan and %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z27sub_group_scan_exclusive_ori([[VAL]]) {{.*}} : (i32) -> i32
%6 = triton_gen.sub_group_scan or %0 {kind = exclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_exclusive_xori([[VAL]]) {{.*}} : (i32) -> i32
%7 = triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32

// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_addi([[VAL]]) {{.*}} : (i32) -> i32
%8 = triton_gen.sub_group_scan add %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_muli([[VAL]]) {{.*}} : (i32) -> i32
%9 = triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_mini([[VAL]]) {{.*}} : (i32) -> i32
%10 = triton_gen.sub_group_scan min %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_maxi([[VAL]]) {{.*}} : (i32) -> i32
%11 = triton_gen.sub_group_scan max %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_andi([[VAL]]) {{.*}} : (i32) -> i32
%12 = triton_gen.sub_group_scan and %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z27sub_group_scan_inclusive_ori([[VAL]]) {{.*}} : (i32) -> i32
%13 = triton_gen.sub_group_scan or %0 {kind = inclusive} : i32
// CHECK: llvm.call spir_funccc @_Z28sub_group_scan_inclusive_xori([[VAL]]) {{.*}} : (i32) -> i32
%14 = triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32

llvm.return
}
}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xordj(f64, i32) -> f64 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorfj(f32, i32) -> f32 attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z21sub_group_shuffle_xorDhj(f16, i32) -> f16 attributes {passthrough = ["convergent"]}
Expand Down
42 changes: 42 additions & 0 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,48 @@ module attributes {

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 32>>
} {
llvm.func @triton_gen.sub_group_scan() {
// CHECK-LABEL: triton_gen.sub_group_scan
%0 = llvm.mlir.constant(0 : i32) : i32
// CHECK: triton_gen.sub_group_scan add %0 {kind = exclusive} : i32
%1 = triton_gen.sub_group_scan add %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32
%2 = triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan min %0 {kind = exclusive} : i32
%3 = triton_gen.sub_group_scan min %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan max %0 {kind = exclusive} : i32
%4 = triton_gen.sub_group_scan max %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan and %0 {kind = exclusive} : i32
%5 = triton_gen.sub_group_scan and %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan or %0 {kind = exclusive} : i32
%6 = triton_gen.sub_group_scan or %0 {kind = exclusive} : i32
// CHECK: triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32
%7 = triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32

// CHECK: triton_gen.sub_group_scan add %0 {kind = inclusive} : i32
%8 = triton_gen.sub_group_scan add %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32
%9 = triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan min %0 {kind = inclusive} : i32
%10 = triton_gen.sub_group_scan min %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan max %0 {kind = inclusive} : i32
%11 = triton_gen.sub_group_scan max %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan and %0 {kind = inclusive} : i32
%12 = triton_gen.sub_group_scan and %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan or %0 {kind = inclusive} : i32
%13 = triton_gen.sub_group_scan or %0 {kind = inclusive} : i32
// CHECK: triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32
%14 = triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32

llvm.return
}
}

// -----

llvm.func @triton_gen.sub_group_shuffle() {
// CHECK-LABEL: triton_gen.sub_group_shuffle
%0 = llvm.mlir.constant(0 : i32) : i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class TritonGEN_Attr<string name, string attrMnemonic, list<Trait> traits = []>
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different reduce kinds.
def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind",
/// Enum attribute of the different subgroup reduce kinds.
def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN subgroup reduce kind",
[
I32EnumAttrCase<"ADD", 0, "add">,
I32EnumAttrCase<"MUL", 1, "mul">,
Expand All @@ -34,8 +34,17 @@ def TritonGEN_ReduceKindAttr : I32EnumAttr<"ReduceKind", "TritonGEN reduce kind"
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN shuffle kind",
/// Enum attribute of the different subgroup scan kinds.
def TritonGEN_ScanKindAttr : I32EnumAttr<"ScanKind", "TritonGEN subgroup scan kind",
[
I32EnumAttrCase<"INCLUSIVE", 0, "inclusive">,
I32EnumAttrCase<"EXCLUSIVE", 1, "exclusive">,
]> {
let cppNamespace = "::mlir::triton::TritonGEN";
}

/// Enum attribute of the different subgroup shuffle kinds.
def TritonGEN_ShflKindAttr : I32EnumAttr<"ShflKind", "TritonGEN subgroup shuffle kind",
[
I32EnumAttrCase<"XOR", 0, "xor">,
I32EnumAttrCase<"UP", 1, "up">,
Expand Down
23 changes: 23 additions & 0 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,29 @@ def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [
let hasVerifier = 1;
}

def TritonGEN_SubGroupScanOp : TritonGEN_Op<"sub_group_scan", [
AllTypesMatch<["res", "value"]>]>,
Results<(outs SignlessIntegerOrFloatLike:$res)>,
Arguments<(ins SignlessIntegerOrFloatLike:$value,
TritonGEN_ReduceKindAttr:$reduce_kind,
TritonGEN_ScanKindAttr:$scan_kind)> {
let summary = "Subgroup scan";

let description = [{
The `triton_gen.sub_group_scan` operation is invoked by all work items in
a subgroup, each of them providing a $value. Each work item performs the
reduction operation identified by $reduce_kind. The $scan_kind attribute
indicates whether to perform an inclusive or exclusive scan. The result
of the scan operation is returned for each work item.
Note: The scan order is defined by increasing sub-group local ID within
the sub-group.
}];

let assemblyFormat = [{
$reduce_kind $value ` ` `{` `kind` `=` $scan_kind `}` attr-dict `:` type($value)
}];
}

def TritonGEN_SubGroupShuffleOp : TritonGEN_Op<"sub_group_shuffle", [
AllTypesMatch<["res", "value"]>]>,
Results<(outs SignlessIntegerOrFloatLike:$res)>,
Expand Down
5 changes: 1 addition & 4 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"

#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"

#include "llvm/ADT/STLExtras.h"
#include <cstdint>

Expand Down
Loading

0 comments on commit 9d2c1bf

Please sign in to comment.