Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 76ed94d #2543

Merged
merged 10 commits into from
Oct 23, 2024
Merged
1 change: 0 additions & 1 deletion bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUAccelerateMatmul();
mlir::registerTritonAMDGPUOptimizeEpilogue();
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipeline();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <memory>
#include <optional>
#include <string>

namespace mlir {

Expand Down
9 changes: 2 additions & 7 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
The compiler is still free to change it for better performance.
}];
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr<UnitAttr>:$efficient_layout);
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
let results = (outs TT_Tensor:$result);
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasVerifier = 1;
let builders = [
OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder),
[{
build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr());
}]>];
}

def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
}

LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();
return canonicalizeViewOrBroadcast(op, rewriter);
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2764,7 +2764,7 @@ struct CanonicalizeConvertFromReshape
return failure();
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
if (!op.getAllowReorder() || op.getEfficientLayout())
return failure();

rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
Expand Down Expand Up @@ -2885,8 +2885,7 @@ struct CanonicalizeConvertFromConvert

// cvt(reshape) -> reshape
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
if (!reshape.getAllowReorder() ||
reshape.getEfficientLayout().has_value() ||
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
isExpensiveView(reshape.getSrc().getType(), op.getType()))
return failure();

Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ class TritonGPUOptimizeThreadLocalityPass
IRMapping mapping;
for (auto operand : reduce.getOperands()) {
auto viewOp = builder.create<triton::ReshapeOp>(
reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true);
viewOp.setEfficientLayout(true);
reduce.getLoc(), viewOpTensorType, operand,
/*allowReorder=*/true, /*efficientLayout=*/true);
mapping.map(operand, viewOp);
}

Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
RankedTensorType newDstType =
RankedTensorType::get(reshapeDstType.getShape(),
reshapeDstType.getElementType(), targetEncoding);
return reshape.getAllowReorder() &&
!reshape.getEfficientLayout().has_value() &&
return reshape.getAllowReorder() && !reshape.getEfficientLayout() &&
!triton::gpu::isExpensiveView(reshape.getSrc().getType(),
newDstType);
}
Expand Down
42 changes: 22 additions & 20 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
except KeyError:
arch = platform.machine()
url = url_func(arch, version)
supported = {"Linux": "linux", "Darwin": "linux"}
url = url_func(supported[system], arch, version)
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
Expand Down Expand Up @@ -500,61 +501,62 @@ def get_platform_dependent_src_path(subdir):

download_and_copy(
name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda arch, version:
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/linux-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cuobjdump",
src_path="bin/cuobjdump",
dst_path="bin/cuobjdump",
variable="TRITON_CUOBJDUMP_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
)
download_and_copy(
name="nvdisasm",
src_path="bin/nvdisasm",
dst_path="bin/nvdisasm",
variable="TRITON_NVDISASM_PATH",
version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
url_func=lambda arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
url_func=lambda system, arch, version:
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
)
download_and_copy(
name="cudacrt", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda arch, version:
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-crt-dev_linux-{arch}/{version}/download/noarch/cuda-crt-dev_linux-{arch}-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-crt-dev_{system}-{arch}/{version}/download/noarch/cuda-crt-dev_{system}-{arch}-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cudart", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda arch, version:
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cudart-dev_linux-{arch}/{version}/download/noarch/cuda-cudart-dev_linux-{arch}-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cudart-dev_{system}-{arch}/{version}/download/noarch/cuda-cudart-dev_{system}-{arch}-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/linux-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/{system}-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
)(*version.split('.'))))
download_and_copy(
name="cupti", src_path=get_platform_dependent_src_path("include"), dst_path="include",
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
(*version.split('.'))))
download_and_copy(
name="cupti", src_path=get_platform_dependent_src_path("lib"), dst_path="lib/cupti",
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda system, arch, version:
((lambda version_major, version_minor1, version_minor2:
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
if int(version_major) >= 12 and int(version_minor1) >= 5 else
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
(*version.split('.'))))

backends = [*BackendInstaller.copy(["intel", "nvidia", "amd"]), *BackendInstaller.copy_externals()]
Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def kernel():
try:
inner = e.value.__cause__
outer = e.value
assert "/core.py" in '\n'.join(traceback.format_tb(inner.__traceback__)), "error should point inside core.py"
assert f"{os.sep}core.py" in '\n'.join(traceback.format_tb(
inner.__traceback__)), "error should point inside core.py"

assert "at 2:4:" in str(outer), "error should point to expand_dims call"
assert "<source unavailable>" not in str(outer)
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/language/test_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_print(func_type: str, data_type: str, device: str):
assert proc.stderr == b''
return

outs = [line for line in proc.stdout.decode("UTF-8").split("\n") if line]
outs = [line for line in proc.stdout.decode("UTF-8").splitlines() if line]
# The total number of elements in the 1-D tensor to print.
N = 128

Expand Down
25 changes: 12 additions & 13 deletions python/test/unit/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
for env_var in [True, False]\
])
@pytest.mark.forked
def test_device_assert(cond, opt_flag, env_var, device="cuda"):
def test_device_assert(cond, opt_flag, env_var, device):
os.environ['TRITON_DEBUG'] = str(int(env_var))
torch.zeros([1], dtype=torch.int32, device=device)

Expand All @@ -21,11 +21,11 @@ def _kernel(COND: tl.constexpr):
if not cond and (opt_flag or env_var):
with pytest.raises(RuntimeError):
_kernel[(1, )](cond, debug=opt_flag)
torch.cuda.synchronize()
getattr(torch, device).synchronize()
return

_kernel[(1, )](cond, debug=opt_flag)
torch.cuda.synchronize()
getattr(torch, device).synchronize()


@pytest.mark.parametrize("cond", [False, True])
Expand All @@ -43,19 +43,18 @@ def _kernel(COND: tl.constexpr):
_kernel[(1, )](cond)


def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func):
device = "cuda"
def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func, device):
x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device)
y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device)
z = torch.empty_like(x)
if should_overflow and debug:
with pytest.raises(RuntimeError) as exc_info:
tri_func[(1, )](x, y, z, debug=debug)
torch.cuda.synchronize()
getattr(torch, device).synchronize()
assert "device-side assert" in str(exc_info.value)
else:
tri_func[(1, )](x, y, z, debug=debug)
torch.cuda.synchronize()
getattr(torch, device).synchronize()
assert int(z) == int(ref_func(x, y))


Expand All @@ -74,13 +73,13 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref
(2**15 - 1, 1, 'int16', 'int16', True, True),
])
@pytest.mark.forked
def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):

@triton.jit
def _kernel_add(X, Y, Z):
tl.store(Z, tl.load(X) + tl.load(Y))

_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y)
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y, device)


# mul overflow
Expand All @@ -95,13 +94,13 @@ def _kernel_add(X, Y, Z):
(-2**30, 2, 'int32', 'int32', True, False),
])
@pytest.mark.forked
def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):

@triton.jit
def _kernel_mul(X, Y, Z):
tl.store(Z, tl.load(X) * tl.load(Y))

_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y)
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y, device)


# sub overflow
Expand All @@ -115,10 +114,10 @@ def _kernel_mul(X, Y, Z):
(-2**31, -1, 'int32', 'int32', True, False),
])
@pytest.mark.forked
def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):

@triton.jit
def _kernel_sub(X, Y, Z):
tl.store(Z, tl.load(X) - tl.load(Y))

_test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y)
_test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y, device)
2 changes: 1 addition & 1 deletion test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1]
// CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0]
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1]
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
// CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
// CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0]
// CHECK-NEXT: [[RES2:%.*]] = llvm.insertvalue [[T1]], [[RES1]][1]
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.mlir.undef
// CHECK: %[[T0:.*]] = llvm.extractvalue
// CHECK: %[[T1:.*]] = llvm.extractvalue
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue %[[T0]]
// CHECK: llvm.insertvalue %[[T1]]
Expand Down
12 changes: 6 additions & 6 deletions test/Triton/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ tt.func @test_canonicalize_expand_dims(%arg0: tensor<f32>, %arg1: tensor<1xf32>)

// CHECK-LABEL: @test_canonicalize_view
tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (tensor<4x2xf32>, tensor<2x2x2xf32>, tensor<8xf32>) {
%view0 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 {allow_reorder = true} : tensor<2x4xf32> -> tensor<4x2xf32>
%view0 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<2x4xf32>
// CHECK: %{{.*}} = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<4x2xf32>
%view1 = tt.reshape %view0 allow_reorder : tensor<2x4xf32> -> tensor<4x2xf32>

%splat = tt.splat %arg1 : tensor<f32> -> tensor<8xf32>
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
%view2 = tt.reshape %splat {allow_reorder = true} : tensor<8xf32> -> tensor<2x2x2xf32>
%view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>

%view3 = tt.reshape %arg0 {allow_reorder = true} : tensor<8xf32> -> tensor<8xf32>
%view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32>
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
%add = arith.addf %view3, %arg0 : tensor<8xf32>

Expand Down Expand Up @@ -329,7 +329,7 @@ tt.func @test_fold_views() -> (tensor<16x8xf32>, tensor<16x128xf32>, tensor<1x1x
%a = arith.constant dense<1.0> : tensor<1x128xf32>

// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x8xf32>
%b = tt.reshape %a {allow_reorder = true} : tensor<1x128xf32> -> tensor<16x8xf32>
%b = tt.reshape %a allow_reorder : tensor<1x128xf32> -> tensor<16x8xf32>

// CHECK-DAG: %{{.*}} = arith.constant dense<1.{{.*}}> : tensor<16x128xf32>
%c = tt.broadcast %a : tensor<1x128xf32> -> tensor<16x128xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/Triton/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) {

tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) {
// expected-error @+1 {{number of src and dst elements of reshape must be the same}}
%a = tt.reshape %arg0 {allow_reorder = false} : tensor<32x128xf16> -> tensor<64x32xf16>
%a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16>
tt.return
}

Expand Down
10 changes: 8 additions & 2 deletions test/Triton/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,14 @@ tt.func @inline_asm_scalar(%0: i32) {

// CHECK-LABEL: reshape
tt.func @reshape(%0: tensor<512xi32>) {
// CHECK: tt.reshape %{{.+}} {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 {allow_reorder = false} : tensor<512xi32> -> tensor<16x32xi32>
// CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32>
%1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32>
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
%2 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
// CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
%3 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
// CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
%4 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
tt.return
}

Expand Down
Loading
Loading