Skip to content

Commit

Permalink
Update the documents based on review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Nov 21, 2024
1 parent 1a8a0a7 commit bfccedd
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ The encoding is characterized by parameters:
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
- `sugGroupSize` Currently only sub group size 16 is supported.
- `threadsPerWarp_` AKA threadsPerWarp. It conflicts with the getThreadsPerWarp in DistributedLayout interface .
We use the name threadsPerWarp_ here. Currently only 16 is supported.

The values of the matrix is distributed across the threads in the subgroup as row-major order.
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.

Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.

The layout for A operand:
K = 16 (K = systolic depth * opsPerChan)
Expand Down Expand Up @@ -83,7 +84,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v

Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.

The layout for A operand:
K = 8 (K = systolic depth * opsPerChan)
Expand All @@ -102,7 +103,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
The layouts for C and D operands are same as the one of opsPerChan=2.

Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.

The layout for A operand:
K = 32 (K = systolic depth * opsPerChan)
Expand All @@ -121,15 +122,21 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
The layouts for C and D operands are same as the one of opsPerChan=2.

The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.

Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
The DPAS repetitions are distributed as follows:
Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
```
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>

warp[:0] warp[:1] warp[:0] warp[:1]
%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
```
The semantic of this `tt.dot` includes GEMM tiling configuration as:

warp[:,0] warp[:,1] warp[:,0] warp[:,1]
|----^----|----^----|----^----|----^----|
repCluster[1]
repCluster[1]
<--------->
┌────┬────┬────┬────┬────┬────┬────┬────┐
│R0 │R1 │ │ │R4 │R5 │ │ │
Expand All @@ -142,25 +149,25 @@ The DPAS repetitions are distributed as follows:
- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
| | │ │ │ │ │ │ │ │ │ │ │ │
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
warp[0,:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
| | │ │ │ │ │ │ │ │ │ │ │ │
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
warp[0,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
warp[1,:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
Expand All @@ -175,7 +182,7 @@ The DPAS repetitions are distributed as follows:
"unsigned":$opsPerChannel,
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
ArrayRefParameter<"unsigned">:$repCluster,
"unsigned":$subGroupSize
"unsigned":$threadsPerWarp_
);

let extraClassDeclaration = extraDistributedDeclaration # [{
Expand Down
14 changes: 7 additions & 7 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto shapeC = getDPASInstShapeC();
unsigned elemsNum = product<unsigned>(shapeC);
unsigned elemsPerThread = elemsNum / threadsPerWarp;
Expand Down Expand Up @@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
auto threadsPerWar = getSubGroupSize();
auto threadsPerWar = getThreadsPerWarp_();
size_t rank = shape.size();
if (opIdx == 0) {
auto shapeA = getShapeA();
Expand Down Expand Up @@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
auto executionSize = getExecutionSize();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
"smaller than the execution size");
Expand Down Expand Up @@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
if (opIdx == 0) {
SmallVector<unsigned> shapeA = getDPASInstShapeA();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp_();
unsigned opsPerChannel = getOpsPerChannel();

// pack the value to i16 for scalar bit width <=16.
Expand All @@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {

if (opIdx == 1) {
auto shapeB = getShapeB();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
auto executionSize = getExecutionSize();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
Expand Down Expand Up @@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);

unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto instShapeC = getDPASInstShapeC();
// The software vectorization vectorized the value as C array: int a[N] -> int
// a[N][threadsPerWarp]
Expand Down Expand Up @@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], "
<< "A = [" << rA << "], "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
size_t totalElems = elems.size();
auto numElemsPerOperand =
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
dpasLayout.getSubGroupSize();
dpasLayout.getThreadsPerWarp_();
Type elemTy =
this->getTypeConverter()->convertType(srcType.getElementType());
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DotOpDPASConversionHelper {
Type i16Ty = type::i16Ty(ctx);
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);

unsigned threadsPerWarp = layout.getSubGroupSize();
unsigned threadsPerWarp = layout.getThreadsPerWarp_();
unsigned opsPerChannel = layout.getOpsPerChannel();
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;
Expand Down
2 changes: 1 addition & 1 deletion third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
sizePerThreads[rank - 2] / repCluster[rank - 2],
sizePerThreads[rank - 1] / repCluster[rank - 1]};

unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
unsigned rowsPerElem = dpasLayout.getThreadsPerWarp_() / instShapeC[1];
unsigned colsPerElem = 1;

unsigned repNumber = product<unsigned>(repCluster);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
// We want to transpose matrices of N*threads_per_warpxthreads_per_warp
// shape.
if ( // X axis condition
encoding.getExecutionSize() != encoding.getSubGroupSize() ||
encoding.getExecutionSize() != encoding.getThreadsPerWarp_() ||
// Y axis conditions
(encoding.getRepeatCount() * encoding.getRepCluster()[0]) %
encoding.getSubGroupSize() !=
encoding.getThreadsPerWarp_() !=
0)
return failure();

Expand Down

0 comments on commit bfccedd

Please sign in to comment.