diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 6b7f1e78e7..d5e003d90e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -488,17 +488,20 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion assert(to_vector(conversion->getInDimNames()) == to_vector(conversion->getOutDimNames())); auto dims = conversion->getInDimNames(); - if (llvm::is_contained(dims, str_attr("block"))) { + if (llvm::is_contained(dims, kBlock)) { // Case 1: Transfer between values in different CTAs. // This requires moving values through distributed shared memory. return rewriter.notifyMatchFailure( op, "NYI: Transfer between different CTAs"); - } else if (llvm::is_contained(dims, str_attr("warp"))) { - return rewriter.notifyMatchFailure( - op, "NYI: Transfer between different warps"); - } else if (llvm::is_contained(dims, str_attr("lane"))) { + } else if (llvm::is_contained(dims, kWarp)) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. + // TODO: Implement + return failure(); + } else if (llvm::is_contained(dims, kLane)) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory // If the operation is a supported sub-group shuffle, perform via shuffle // operations. if (intel::cvtIsSubGroupShuffle(srcTy, dstTy)) { @@ -513,15 +516,17 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } // TODO(jlebar): Implement me. return failure(); - } else if (llvm::is_contained(dims, str_attr("register"))) { + } else if (llvm::is_contained(dims, kRegister) || + dstLayout.getInDimSize(kRegister) != + srcLayout.getInDimSize(kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). return transferWithinThread( op, dstLayout.getFreeVariableMasks()[kRegister], dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); } else { - // The two layouts are equivalent. We should probably remove these in - // RemoveLayoutConversion. + // Cast 5. The two layouts are equivalent. We should probably remove + // these in RemoveLayoutConversion. rewriter.replaceOp(op, adaptor.getSrc()); return success(); }