Skip to content

Commit

Permalink
Update the documents based on review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
chengjunlu committed Nov 22, 2024
1 parent eda4c40 commit 7fc674a
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,21 @@ The encoding is characterized by parameters:
- `repeatCount` which shall be in the range [1, 8]
- `systolicDepth` For PVC/ATSM, the size is 8.
- `executionSize` For PVC, the size is 16. For ATSM, the size is 8.
- `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type.
- `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction,
2 for 16 bit scalar type of A/B operands of DPAS instruction,
1 for 32 bit scalar type of A/B operands of DPAS instruction.
- `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2.
- `repCluster` indicates the cluster size of the repetitions of the DPAS tile.
- `sugGroupSize` Currently only sub group size 16 is supported.
- `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting
with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported.

The values of the matrix is distributed across the threads in the subgroup as row-major order.
- If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix.
- If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix.
- If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name.
- If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register.
- If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register.
- If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register.

Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16.

The layout for A operand:
K = 16 (K = systolic depth * opsPerChan)
Expand Down Expand Up @@ -83,7 +86,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15
t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v

Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16.

The layout for A operand:
K = 8 (K = systolic depth * opsPerChan)
Expand All @@ -102,7 +105,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8.
The layouts for C and D operands are same as the one of opsPerChan=2.

Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16.
The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16.

The layout for A operand:
K = 32 (K = systolic depth * opsPerChan)
Expand All @@ -121,49 +124,57 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32.
The layouts for C and D operands are same as the one of opsPerChan=2.

The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.

Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16.
The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16.
The DPAS repetitions are distributed as follows:

warp[:0] warp[:1] warp[:0] warp[:1]
|----^----|----^----|----^----|----^----|
repCluster[1]
<--------->
┌────┬────┬────┬────┬────┬────┬────┬────┐
│R0 │R1 │ │ │R4 │R5 │ │ │
│ │ │ │ │ │ │ │ │
├────┼────┼────┼────┼────┼────┼────┼────┤
│R2 │R3 │ │ │R6 │R7 │ │ │
│ │ │ │ │ │ │ │ │
└────┴────┴────┴────┴────┴────┴────┴────┘

- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
| | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │
| | │ │ │ │ │ │ │ │ │ │ │ │
warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │
| | │ │ │ │ │ │ │ │ │ │ │ │
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │ │ │ │ │ │ │ │ │ │ │ │
| │ │ │ │ │ │ │ │ │ │ │ │
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘
along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing.

Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows:
```
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>

%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas>
```
The semantic of this `tt.dot` includes GEMM tiling configuration as:

warp[:0] warp[:1] warp[:0] warp[:1]
|----^----|----^----|----^----|----^----|
repCluster[1]
<--------->
┌────┬────┬────┬────┬────┬────┬────┬────┐
│W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
│W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤
[W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
│W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
└────┴────┴────┴────┴────┴────┴────┴────┘


- ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐
| | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│
| | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │
warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│
| | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │
- v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│
| │W3R0│W3R2│ │ │ │ │ │ │ │ │ │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│
| │W3R1│W3R1│ │ │ │ │ │ │ │ │ │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │
| │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │
| │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │
- ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │
| │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │
warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤
| │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │
| │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │
- └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘


}];

Expand All @@ -175,7 +186,7 @@ The DPAS repetitions are distributed as follows:
"unsigned":$opsPerChannel,
ArrayRefParameter<"unsigned">:$warpsPerCTA__,
ArrayRefParameter<"unsigned">:$repCluster,
"unsigned":$subGroupSize
"unsigned":$threadsPerWarp_
);

let extraClassDeclaration = extraDistributedDeclaration # [{
Expand Down
14 changes: 7 additions & 7 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto shapeC = getDPASInstShapeC();
unsigned elemsNum = product<unsigned>(shapeC);
unsigned elemsPerThread = elemsNum / threadsPerWarp;
Expand Down Expand Up @@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand(
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
auto threadsPerWar = getSubGroupSize();
auto threadsPerWar = getThreadsPerWarp_();
size_t rank = shape.size();
if (opIdx == 0) {
auto shapeA = getShapeA();
Expand Down Expand Up @@ -296,7 +296,7 @@ SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
size_t rank = getWarpsPerCTA().size();
SmallVector<unsigned> res(rank, 1);
auto executionSize = getExecutionSize();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
"smaller than the execution size");
Expand Down Expand Up @@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {
assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout");
if (opIdx == 0) {
SmallVector<unsigned> shapeA = getDPASInstShapeA();
unsigned subGroupSize = getSubGroupSize();
unsigned subGroupSize = getThreadsPerWarp_();
unsigned opsPerChannel = getOpsPerChannel();

// pack the value to i16 for scalar bit width <=16.
Expand All @@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const {

if (opIdx == 1) {
auto shapeB = getShapeB();
auto subGroupSize = getSubGroupSize();
auto subGroupSize = getThreadsPerWarp_();
auto executionSize = getExecutionSize();
if (subGroupSize < executionSize) {
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
Expand Down Expand Up @@ -394,7 +394,7 @@ SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
assert(rank == 2 || rank == 3);
SmallVector<unsigned> contigPerThread(rank, 1);

unsigned threadsPerWarp = getSubGroupSize();
unsigned threadsPerWarp = getThreadsPerWarp_();
auto instShapeC = getDPASInstShapeC();
// The software vectorization vectorized the value as C array: int a[N] -> int
// a[N][threadsPerWarp]
Expand Down Expand Up @@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const {
<< "systolicDepth = " << getSystolicDepth() << ", "
<< "executionSize = " << getExecutionSize() << ", "
<< "opsPerChan = " << getOpsPerChannel() << ", "
<< "threadsPerWarp = " << getSubGroupSize() << ", "
<< "threadsPerWarp = " << getThreadsPerWarp_() << ", "
<< "warpsPerCTA = [" << llvm::ArrayRef<unsigned>(warpsPerCTA) << "], "
<< "repCluster = [" << repCluster << "], "
<< "A = [" << rA << "], "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion
size_t totalElems = elems.size();
auto numElemsPerOperand =
product<unsigned>(dpasLayout.getDPASInstShapeC()) /
dpasLayout.getSubGroupSize();
product<unsigned>(dpasLayout.getThreadsPerWarp());
Type elemTy =
this->getTypeConverter()->convertType(srcType.getElementType());
VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DotOpDPASConversionHelper {
Type i16Ty = type::i16Ty(ctx);
Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed);

unsigned threadsPerWarp = layout.getSubGroupSize();
unsigned threadsPerWarp = product<unsigned>(layout.getThreadsPerWarp());
unsigned opsPerChannel = layout.getOpsPerChannel();
SmallVector<unsigned> shapeC = layout.getDPASInstShapeC();
unsigned elemNumC = product<unsigned>(shapeC) / threadsPerWarp;
Expand Down
3 changes: 2 additions & 1 deletion third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
sizePerThreads[rank - 2] / repCluster[rank - 2],
sizePerThreads[rank - 1] / repCluster[rank - 1]};

unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
unsigned rowsPerElem =
product<unsigned>(dpasLayout.getThreadsPerWarp()) / instShapeC[1];
unsigned colsPerElem = 1;

unsigned repNumber = product<unsigned>(repCluster);
Expand Down
Loading

0 comments on commit 7fc674a

Please sign in to comment.