Skip to content

Commit

Permalink
Merge commit 'fe18b9b43a67392e082128cdcdc2a85e7dbdf26b'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Aug 1, 2024
2 parents 3bee61a + fe18b9b commit 27dea53
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
SameVariadicOperandSize,
TypesMatchWith<"infer pointer type from the result type",
"result", "base",
"getPointerType(getElementTypeOfTensorPointerType($_self))">]> {
"getPointerType(getElementTypeOfTensorPointerType($_self), getAddressSpace($_self))">]> {
let summary = "Make a tensor pointer type with meta information of the parent tensor and the block specified";

let description = [{
Expand Down
4 changes: 3 additions & 1 deletion include/triton/Dialect/Triton/IR/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ unsigned getPointeeBitWidth(Type type);

Type getPointeeType(Type type);

Type getPointerType(Type type);
Type getPointerType(Type type, int addressSpace = 1);

int getAddressSpace(Type type);

Type getElementTypeOfTensorPointerType(Type type);

Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/Triton/IR/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,16 @@ Type getPointerTypeSameShape(Type type) {
}
}

Type getPointerType(Type type) { return PointerType::get(type, 1); }
// upstream Triton only uses address space 1 for Pointer Type
Type getPointerType(Type type, int addressSpace) {
return PointerType::get(type, addressSpace);
}

int getAddressSpace(Type type) {
if (auto ptrType = dyn_cast<PointerType>(type))
return ptrType.getAddressSpace();
return 1;
}

bool isTensorPointerType(Type type) {
if (auto ptrType = dyn_cast<PointerType>(type))
Expand Down
4 changes: 2 additions & 2 deletions test/TritonIntelGPU/distribute-to-warps.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK: scf.for {{.*}} iter_args([[ARG10:%.*]] = [[CST]], [[ARG11:%.*]] = [[TPTR1]], [[ARG12:%.*]] = [[TPTR2]])
// CHECK-DAG: [[LOAD1:%.*]] = tt.load [[ARG11]] : !tt.ptr<tensor<32x32xf16, [[WARP2]]>>
// CHECK-DAG: [[LOAD2:%.*]] = tt.load [[ARG12]] : !tt.ptr<tensor<32x32xf16, [[WARP2]]>>
// CHECK: [[ALLOC1:%.*]] = triton_intel_gpu.alloc : <f16>
// CHECK: [[ALLOC1:%.*]] = triton_intel_gpu.alloc : <f16, 3>
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr [[ALLOC1]], {{.*}} {order = array<i32: 1, 0>} : <tensor<32x32xf16, [[WARP2]]>, 3>
// CHECK: tt.store [[PTR1]], [[LOAD1]] : !tt.ptr<tensor<32x32xf16, [[WARP2]]>, 3>
// CHECK: gpu.barrier
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr [[ALLOC1]], {{.*}} {order = array<i32: 1, 0>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[WARP1]]}>>, 3>
// CHECK: [[LOAD3:%.*]] = tt.load [[PTR2]] : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[WARP1]]}>>, 3>
// CHECK: [[ALLOC2:%.*]] = triton_intel_gpu.alloc : <f16>
// CHECK: [[ALLOC2:%.*]] = triton_intel_gpu.alloc : <f16, 3>
// CHECK: [[PTR3:%.*]] = tt.make_tensor_ptr [[ALLOC2]], {{.*}} {order = array<i32: 1, 0>} : <tensor<32x32xf16, [[WARP2]]>, 3>
// CHECK: tt.store [[PTR3]], [[LOAD2]] : !tt.ptr<tensor<32x32xf16, [[WARP2]]>, 3>
// CHECK: gpu.barrier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ void distributeConvertLayoutOp(ttg::ConvertLayoutOp op, Value warpId,

// FIXME: allocOp may carry the size info.
OpBuilder b(op);
auto baseType = tt::PointerType::get(
oldSrcType.getElementType(),
triton::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup);
auto baseType =
tt::PointerType::get(oldSrcType.getElementType(),
triton::TritonGEN::TritonGENMemorySpace::kWorkgroup);
auto base = b.create<ttgi::AllocOp>(loc, baseType);

SmallVector<Value> shape;
Expand Down

0 comments on commit 27dea53

Please sign in to comment.