Skip to content

Commit

Permalink
[intel] Fix getTotalElemsPerThread failures
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang committed Nov 17, 2024
1 parent ddc8f87 commit 1efa822
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,10 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
return amdWmmaParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
return dpasParent.getTotalElemsPerThreadForOperand(
shape, eltTy, getKWidth(), getOpIdx());
}
}
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
auto shapePerCTA = getShapePerCTA(*this, shape);
Expand Down

0 comments on commit 1efa822

Please sign in to comment.