Skip to content

Commit

Permalink
[XPU][TritonGEN] Revamp SIMD block memory access operations (#2756)
Browse files Browse the repository at this point in the history
Revamp `tritongen.simdblock[read|write]` operations:

- Rename to `tritongen.sub_group_block_[read|write]`
- Implement type verification in signature
- Represent scalar block memory accesses with a scalar type instead of `vector<1xty>`
- Revamp ASM format

---------

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds authored Nov 20, 2024
1 parent 61f5381 commit 7551a90
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 121 deletions.
16 changes: 0 additions & 16 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,3 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
llvm.return
}

// -----

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}}
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
llvm.return
}

// -----

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) {
// expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}}
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
llvm.return
}
21 changes: 15 additions & 6 deletions test/TritonGEN/tritongen-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -241,20 +241,29 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b

// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, will_return}

llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
llvm.func @triton_gen.sub_group_block_read(%ptr: !llvm.ptr<3>) {
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<3>) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16>
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16>
%ret = triton_gen.sub_group_block_read %ptr : !llvm.ptr<3> -> vector<2xi16>
llvm.return
}

// -----

// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, will_return}

llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> ()
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>)
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, vector<2xi16>
llvm.return
}

// -----

llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<1>, %val : i32) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<1>, %arg1: i32) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS1jj(%arg0, %arg1) {{.*}} : (!llvm.ptr<1>, i32) -> ()
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, i32
llvm.return
}
16 changes: 8 additions & 8 deletions test/TritonGEN/tritongen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base
llvm.return
}

llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16>
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16>
llvm.func @triton_gen.sub_group_block_read(%ptr : !llvm.ptr<1>) {
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<1>) {
// CHECK-NEXT: triton_gen.sub_group_block_read %arg0 : !llvm.ptr<1> -> vector<2xi16>
triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<2xi16>
llvm.return
}

llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) {
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) {
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>)
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>)
llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) {
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: i32) {
// CHECK-NEXT: triton_gen.sub_group_block_write %arg0, %arg1 : !llvm.ptr<3>, i32
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32
llvm.return
}
102 changes: 76 additions & 26 deletions third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -314,46 +314,96 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
let hasVerifier = 1;
}

def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
)> {

let summary = "simd block read";
def TritonGEN_SubGroupBlockMemoryAccessElementType
: AnyTypeOf<[I8, I16, I32, I64],
"Valid sub-group block memory access element type">;

def TritonGEN_SubGroupBlockMemoryAccessType
: AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType,
FixedVectorOfLengthAndType<[2, 4, 8],
[TritonGEN_SubGroupBlockMemoryAccessElementType]>,
// Vectors of length 16 only allowed for i8 for now.
FixedVectorOfLengthAndType<[16], [I8]>],
"Valid sub-group block memory access type">;

def TritonGEN_SubGroupBlockMemoryAccessPointerType
: Type<And<[LLVM_AnyPointer.predicate,
Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
".getAddressSpace() == " #
"static_cast<unsigned>(kCrossWorkgroup)">,
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
".getAddressSpace() == " #
"static_cast<unsigned>(kWorkgroup)">]>]>,
"LLVM pointer in local or global OpenCL address space",
"::mlir::LLVM::LLVMPointerType">;

def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> {
let summary = "Sub-group block read.";

let description = [{
The `triton_gen.simdblockread` operation performs simd block read from
a start address without laneId offset. The parameters are:
$ptr - the base address to read data
The `triton_gen.sub_group_block_read` reads a scalar or vector for each
work-item in the sub-group from pointer `ptr` as a block operation.
The data is read strided, so the first value is read from:
```
ptr[sub_group_local_id]
```
and the second one is:
```
ptr[sub_group_local_id + sub_group_size]
```
etc.

`ptr` must be aligned to the size of the element type of `res`.

Example:
```mlir
%0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32>
```
}];

let arguments = (ins
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr);

let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res);

let assemblyFormat = [{
operands ` ` attr-dict `:` functional-type(operands, results)
$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)
}];

let hasVerifier = 1;
}

def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
Arguments<(ins
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
)> {

def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
let summary = "simd block write";

let description = [{
The `triton_gen.simdblockwrite` operation performs simd block write to
a start address without laneId offset. The parameters are:
$ptr - the base address to be written
$val - the value vector to write
The `triton_gen.sub_group_block_write` writes a scalar or vector for each
work-item in the sub-group from pointer `ptr` as a block operation.
The data is read strided, so the first value is written to:
```
ptr[sub_group_local_id]
```
and the second one is:
```
ptr[sub_group_local_id + sub_group_size]
```
etc.

`ptr` must be aligned to the size of the element type of `res`.

Example:
```mlir
%0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32>
```
}];

let arguments = (ins
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr,
TritonGEN_SubGroupBlockMemoryAccessType:$val);

let results = (outs);

let assemblyFormat = [{
operands ` ` attr-dict `:` `(` type(operands) `)`
$ptr `,` $val attr-dict `:` qualified(type($ptr)) `,` type($val)
}];

let hasVerifier = 1;
}

#endif // TRITONGEN_OPS
28 changes: 0 additions & 28 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
return success();
}

static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
unsigned numElems = vecTy.getNumElements();
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());

// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
(elemTy.getWidth() != 8 || numElems != 16))
return op->emitOpError("unsupported vector type");

return success();
}

//===----------------------------------------------------------------------===//
// gen.sub_group_reduce
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {

return success();
}

//===----------------------------------------------------------------------===//
// gen.simdblockread
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
return verifySIMDBlockTy(*this, getRes().getType());
}

//===----------------------------------------------------------------------===//
// gen.simdblockwrite
//===----------------------------------------------------------------------===//

LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
return verifySIMDBlockTy(*this, getVal().getType());
}
75 changes: 43 additions & 32 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "mlir/Target/LLVMIR/TypeToLLVM.h"

#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/identity.h"
#include "llvm/IR/Attributes.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/ModRef.h"
Expand Down Expand Up @@ -935,69 +937,77 @@ struct TritonMatrix2DBlockPrefetchLowering
};

template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
OpType, TritonGEN::SIMDBlockReadOp,
TritonGEN::SIMDBlockWriteOp>::value>>
static std::string getSIMDBlockManglingName(OpType op, VectorType vecTy) {
OpType, TritonGEN::SubGroupBlockReadOp,
TritonGEN::SubGroupBlockWriteOp>::value>>
static std::string getSubGroupBlockManglingName(OpType op, Type type) {
constexpr bool isWrite =
std::is_same<OpType, TritonGEN::SIMDBlockWriteOp>::value;
std::is_same<OpType, TritonGEN::SubGroupBlockWriteOp>::value;
const LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
const unsigned numElems = vecTy.getNumElements();
// Note: OCL builtin name here differs from regular mangling.
std::string funcName = "intel_sub_group_block_";
if constexpr (isWrite)
funcName += "write";
else
funcName += "read";
funcName += "_u" + intel::getTypeMangling(vecTy.getElementType()) +
(numElems == 1 ? "" : std::to_string(numElems));
funcName =
"_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(vecTy.getElementType(), /*isUnsigned=*/true);
Type elementType =
TypeSwitch<Type, Type>(type)
.Case([](VectorType vecType) { return vecType.getElementType(); })
// Scalar case
.Default(llvm::identity<Type>());
const unsigned numElems =
TypeSwitch<Type, unsigned>(type)
.Case([](VectorType vecType) { return vecType.getNumElements(); })
// Scalar case
.Default(0u);
funcName += "_u" + intel::getTypeMangling(elementType) +
(numElems ? std::to_string(numElems) : "");
funcName = "_Z" + std::to_string(funcName.size()) + funcName + "PU3AS" +
std::to_string(ptrTy.getAddressSpace()) +
intel::getTypeMangling(elementType, /*isUnsigned=*/true);
if constexpr (isWrite)
funcName += intel::getTypeMangling(vecTy, /*isUnsigned=*/true);
funcName += intel::getTypeMangling(type, /*isUnsigned=*/true);
return funcName;
}

struct TritonSIMDBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockReadOp> {
struct TritonSubGroupBlockReadLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockReadOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockReadOp>::ConvertOpToLLVMPattern;
TritonGEN::SubGroupBlockReadOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockReadOp op, OpAdaptor adaptor,
matchAndRewrite(TritonGEN::SubGroupBlockReadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getRes().getType();
Type type = op.getRes().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSubGroupBlockManglingName(op, type);
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
/*argMem=*/LLVM::ModRefInfo::Ref,
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, vecTy, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});
rewriter, funcName, type, {ptrTy}, {op.getPtr()}, {}, funcAttrs, {});

rewriter.replaceOp(op, call.getResult());
return success();
}
};

struct TritonSIMDBlockWriteLowering
: public ConvertOpToLLVMPattern<TritonGEN::SIMDBlockWriteOp> {
struct TritonSubGroupBlockWriteLowering
: public ConvertOpToLLVMPattern<TritonGEN::SubGroupBlockWriteOp> {
using ConvertOpToLLVMPattern<
TritonGEN::SIMDBlockWriteOp>::ConvertOpToLLVMPattern;
TritonGEN::SubGroupBlockWriteOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(TritonGEN::SIMDBlockWriteOp op, OpAdaptor adaptor,
matchAndRewrite(TritonGEN::SubGroupBlockWriteOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = rewriter.getContext();
LLVM::LLVMPointerType ptrTy = op.getPtr().getType();
VectorType vecTy = op.getVal().getType();
Type type = op.getVal().getType();

std::string funcName = getSIMDBlockManglingName(op, vecTy);
std::string funcName = getSubGroupBlockManglingName(op, type);

auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/LLVM::ModRefInfo::NoModRef,
Expand All @@ -1006,7 +1016,7 @@ struct TritonSIMDBlockWriteLowering
auto funcAttrs = noUnwindWillReturnAttrs;
funcAttrs.memEffectsAttr = memAttr;
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, void_ty(ctx), {ptrTy, vecTy},
rewriter, funcName, void_ty(ctx), {ptrTy, type},
{op.getPtr(), op.getVal()}, {}, funcAttrs);

rewriter.replaceOp(op, call);
Expand Down Expand Up @@ -1071,12 +1081,13 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {

void mlir::triton::populateTritonGENToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter);
patterns
.add<TritonGENSplitBarrierSignalLowering,
TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
TritonSubGroupBlockWriteLowering>(converter);
}

void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {
Expand Down
Loading

0 comments on commit 7551a90

Please sign in to comment.