Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang committed Sep 12, 2024
1 parent 0836791 commit 0cf05a9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
15 changes: 7 additions & 8 deletions test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,19 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK-COUNT-2: [[SUBGROUP_SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK: [[READ1:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ2:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ3:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ4:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ5:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ6:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[READ7:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE1]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE2:%.*]] = llvm.getelementptr [[BASE1]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2:%.*]]) {{.*}} -> vector<8xi16>
// CHECK: [[BASE1:%.*]] = llvm.getelementptr [[BASE2]][[[SUBGROUP_SIZE]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, vector<8xi16>
// CHECK: [[READ8:%.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS3t([[BASE2]]) {{.*}} -> vector<8xi16>
// CHECK: [[GLUE1:%.*]] = llvm.shufflevector [[READ1]], [[READ2]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE2:%.*]] = llvm.shufflevector [[READ3]], [[READ4]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
// CHECK: [[GLUE3:%.*]] = llvm.shufflevector [[READ5]], [[READ6]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,17 @@ class LoadStorePrefetchOpConversion
rewriter.restoreInsertionPoint(insertPoint);

constexpr unsigned maxBlockLoadi16Width = 8;
VectorType v8i16Ty = VectorType::get(maxBlockLoadi16Width, i16_ty);
VectorType decomposedVecTy =
VectorType::get(maxBlockLoadi16Width, i16_ty);
auto mod = op->template getParentOfType<mlir::ModuleOp>();
Value offset =
i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
SmallVector<Value> values;
for (int i = 0; i < 64 / maxBlockLoadi16Width; ++i) {
auto simdRead =
rewriter.create<TritonGEN::SIMDBlockReadOp>(loc, v8i16Ty, base);
auto simdRead = rewriter.create<TritonGEN::SIMDBlockReadOp>(
loc, decomposedVecTy, base);
values.push_back(simdRead.getRes());
base = gep(ptrToSharedMemTy, v8i16Ty, base, offset);
base = gep(ptrToSharedMemTy, decomposedVecTy, base, offset);
}
auto simdRead = rewriter.create<GlueOp>(loc, v64i16Ty, values);

Expand Down

0 comments on commit 0cf05a9

Please sign in to comment.