Skip to content

Commit

Permalink
[mlir][nvgpu] Improve verifier of ldmatrix (llvm#77807)
Browse files Browse the repository at this point in the history
PR improves the verifier of `nvgpu.ldmatrix` Op, so `nvgpu-to-nvvm`
lowering does not crash.
  • Loading branch information
grypp authored Jan 12, 2024
1 parent 2e78c22 commit 2491867
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ LogicalResult LdMatrixOp::verify() {
if (isTranspose && !(elementBitWidth == 16))
return emitError()
<< "nvgpu.ldmatrix transpose works only at 16b granularity";
if (resShape.size() != 2) {
return emitError() << "results must be 2 dimensional vector";
}
if (!(resShape[1] == numElementsPer32b))
return emitError() << "expected vector register shape[1] = "
<< numElementsPer32b;
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/NVGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf
}
// -----

func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf32> {
%c0 = arith.constant 0 : index
// expected-error @+1 {{results must be 2 dimensional vector}}
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4xf32>
return %a : vector<4xf32>
}
// -----

func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> {
%c0 = arith.constant 0 : index
// expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}}
Expand Down

0 comments on commit 2491867

Please sign in to comment.