Skip to content

Commit

Permalink
Merge commit '69ba2b749e754febf21302794ca6d57382cbc6f0'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Aug 1, 2024
2 parents f85beb2 + 69ba2b7 commit 3bee61a
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 100 deletions.
4 changes: 0 additions & 4 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,6 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape, ArrayRef<int64_t> dstShape);
// If shape is empty, it means no shared memory is needed.
unsigned getNumScratchElements(ArrayRef<unsigned> shape);

bool maybeSharedAllocationOp(Operation *op);

bool supportMFMA(triton::DotOp op);

bool supportWMMA(triton::DotOp op);
Expand All @@ -191,8 +189,6 @@ bool supportMMA(triton::DotOp op, int version);

bool supportMMA(Value value, int version);

bool isSingleValue(Value value);

bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);

bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
Expand Down
47 changes: 8 additions & 39 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,18 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
return repShape;
}

// TODO: extend beyond scalars
static SmallVector<unsigned> getRepShapeForAtomicRMW(triton::AtomicRMWOp op) {
// Both `atomic_cas` and `atomic_rmw need a single scratch element if returning
// a scalar value because Triton's block-based programming model ensures that
// all threads in each block see the same return value, even those threads that
// do not participate in the atomic operation
static SmallVector<unsigned> getRepShapeForAtomic(Type retTy) {
SmallVector<unsigned> smemShape;
if (isa<RankedTensorType>(op.getPtr().getType())) {
// do nothing or just assert because shared memory is not used in tensor up
// to now
} else {
// need only bytes for scalar
// always vec = 1 and elemsPerThread = 1 for scalar?
if (!isa<RankedTensorType>(retTy)) {
smemShape.push_back(1);
}
return smemShape;
}

static SmallVector<unsigned> getRepShapeForAtomicCAS(triton::AtomicCASOp op) {
return SmallVector<unsigned>{1};
}

ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
RankedTensorType dstTy) {
// Initialize vector sizes and stride
Expand Down Expand Up @@ -185,14 +179,6 @@ class AllocationAnalysis {

/// Initializes explicitly defined shared memory values for a given operation.
void getExplicitValueSize(Operation *op) {
// Values returned from scf.yield will not be allocated even though they
// have the shared encoding.
// For example: %a = scf.if -> yield
// %a must be allocated elsewhere by other operations.
// FIXME(Keren): extract and insert are always alias for now
if (!maybeSharedAllocationOp(op))
return;

// XXX(Keren): Why this hard-coded alignment?
size_t kAlignment = 8;
for (Value result : op->getResults()) {
Expand Down Expand Up @@ -274,14 +260,14 @@ class AllocationAnalysis {
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
} else if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (dyn_cast<RankedTensorType>(value.getType())) {
// nothing to do
} else {
auto smemShape = getRepShapeForAtomicRMW(atomicRMWOp);
auto smemShape = getRepShapeForAtomic(op->getResult(0).getType());
auto elems = getNumScratchElements(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
Expand All @@ -292,23 +278,6 @@ class AllocationAnalysis {
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto atomicCASOp = dyn_cast<triton::AtomicCASOp>(op)) {
// only scalar requires scratch memory
// make it explicit for readability
auto value = op->getOperand(0);
if (dyn_cast<RankedTensorType>(value.getType())) {
// nothing to do
} else {
auto smemShape = getRepShapeForAtomicCAS(atomicCASOp);
auto elems = getNumScratchElements(smemShape);
auto elemTy =
cast<triton::PointerType>(value.getType()).getPointeeType();
auto bytes = isa<triton::PointerType>(elemTy)
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
Expand Down
14 changes: 0 additions & 14 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,20 +408,6 @@ unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
return product<unsigned>(shape);
}

bool maybeSharedAllocationOp(Operation *op) {
// TODO(Keren): This function can be replaced by adding
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
// query the memory effects of the op.
auto *dialect = op->getDialect();
return dialect &&
(dialect->getTypeID() == TypeID::get<TritonGPUDialect>() ||
dialect->getTypeID() ==
TypeID::get<triton::nvidia_gpu::TritonNvidiaGPUDialect>() ||
dialect->getTypeID() == TypeID::get<triton::TritonDialect>() ||
dialect->getTypeID() == TypeID::get<arith::ArithDialect>() ||
dialect->getTypeID() == TypeID::get<tensor::TensorDialect>());
}

static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
Expand Down
26 changes: 13 additions & 13 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,19 @@ static std::optional<Attribute> inferSrcEncoding(triton::ReshapeOp op,
op.getAllowReorder());
}

static bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}

std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding) {
if (isa<triton::ScanOp>(op)) {
// Scan only supports blocked encoding at the moment.
Expand Down Expand Up @@ -490,19 +503,6 @@ std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding) {
return std::nullopt;
}

bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}

bool isExpensiveLoadOrStore(Operation *op) {
// Case 1: Pointer of tensor is always expensive
auto operandType = op->getOperand(0).getType();
Expand Down
13 changes: 10 additions & 3 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ def kernel_nospec(X, i, BLOCK: tl.constexpr):
tl.store(X, i)


@triton.jit(do_not_specialize_on_alignment=["i"])
def kernel_nospec_on_alignment(X, i, BLOCK: tl.constexpr):
i = i + 1
i = function_1(i)
tl.store(X, i)


@triton.jit
def kernel_with_combine_fn(X, BLOCK: tl.constexpr):
i = tl.arange(0, BLOCK)
Expand Down Expand Up @@ -162,7 +169,7 @@ def inc_counter(*args, **kwargs):
assert counter == 1


@pytest.mark.parametrize('mode', ['enable', 'disable'])
@pytest.mark.parametrize('mode', ['enable', 'disable', 'disable_on_alignment'])
def test_specialize(mode, device, fresh_triton_cache):
counter = 0

Expand All @@ -172,8 +179,8 @@ def inc_counter(*args, **kwargs):

JITFunction.cache_hook = inc_counter
x = torch.empty(1, dtype=torch.int32, device=device)
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 3, 'disable': 1}[mode]
function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode]
target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1, )](x, i, BLOCK=512)
assert counter == target
Expand Down
2 changes: 1 addition & 1 deletion python/triton/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def _load_module(name, path):
spec = importlib.util.spec_from_file_location(name[:-3], path)
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
Expand Down
34 changes: 23 additions & 11 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ def _normalize_ty(ty) -> str:
class KernelParam:
"""Represents a parameter (name plus metadata) to a @jit'ed function."""

def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool):
def __init__(self, num: int, param: inspect.Parameter, do_not_specialize: bool,
do_not_specialize_on_alignment: bool):
self.num = num
self._param = param
self.do_not_specialize = do_not_specialize
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment

@cached_property
def name(self):
Expand Down Expand Up @@ -273,13 +275,13 @@ def has_default(self):
return self._param.default != inspect.Parameter.empty


def compute_spec_key(v):
def compute_spec_key(v, align):

if hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
if align and hasattr(v, "data_ptr") and (v.data_ptr() % 16 == 0):
return "D"
elif isinstance(v, int):
# bool is a subclass of int, so we don't check explicitly above.
if (v % 16 == 0):
if align and (v % 16 == 0):
return "D"
elif v == 1:
return "1"
Expand Down Expand Up @@ -368,7 +370,10 @@ def create_function_from_signature(sig, kparams):
else:
non_constexpr_vals.append(name)
if not kp.do_not_specialize:
specialisations.append('compute_spec_key(%s)' % name)
if not kp.do_not_specialize_on_alignment:
specialisations.append('compute_spec_key(%s, align=True)' % name)
else:
specialisations.append('compute_spec_key(%s, align=False)' % name)
if kp.annotation_type:
signature_types.append('"%s"' % kp.annotation_type)
else:
Expand Down Expand Up @@ -480,7 +485,7 @@ def is_divisible_by_16(x):
divisible_by_16 = {
param.num
for param, arg in zip(self.params, args)
if is_divisible_by_16(arg) and not param.do_not_specialize
if is_divisible_by_16(arg) and not param.do_not_specialize and not param.do_not_specialize_on_alignment
}
equal_to_1 = {
param.num
Expand Down Expand Up @@ -673,15 +678,17 @@ def run(self, *args, grid, warmup, **kwargs):
self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
return kernel

def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None, repr=None,
launch_metadata=None):
def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_on_alignment=None, debug=None,
noinline=None, repr=None, launch_metadata=None):
do_not_specialize = do_not_specialize if do_not_specialize else []
do_not_specialize_on_alignment = do_not_specialize_on_alignment if do_not_specialize_on_alignment else []

self.fn = fn
self.module = fn.__module__
self.version = version
self.signature = inspect.signature(fn)
self.do_not_specialize = do_not_specialize
self.do_not_specialize_on_alignment = do_not_specialize_on_alignment
self.starting_line_number = inspect.getsourcelines(fn)[1]
self.repr = lambda _: fn.__name__ if repr is None else repr(_)
self.launch_metadata = launch_metadata
Expand All @@ -690,8 +697,9 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin

self.params = []
for i, param in enumerate(self.signature.parameters.values()):
dns = do_not_specialize and (i in do_not_specialize or param.name in do_not_specialize)
self.params.append(KernelParam(i, param, dns))
dns = i in do_not_specialize or param.name in do_not_specialize
dns_oa = i in do_not_specialize_on_alignment or param.name in do_not_specialize_on_alignment
self.params.append(KernelParam(i, param, dns, dns_oa))

# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
Expand Down Expand Up @@ -809,6 +817,7 @@ def jit(
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
Expand All @@ -822,6 +831,7 @@ def jit(
repr: Optional[Callable] = None,
launch_metadata: Optional[Callable] = None,
do_not_specialize: Optional[Iterable[int]] = None,
do_not_specialize_on_alignment: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
Expand All @@ -847,13 +857,15 @@ def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
from .interpreter import InterpretedFunction
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize, debug=debug,
return InterpretedFunction(fn, version=version, do_not_specialize=do_not_specialize,
do_not_specialize_on_alignment=do_not_specialize_on_alignment, debug=debug,
noinline=noinline, repr=repr, launch_metadata=launch_metadata)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
do_not_specialize_on_alignment=do_not_specialize_on_alignment,
debug=debug,
noinline=noinline,
repr=repr,
Expand Down
1 change: 0 additions & 1 deletion test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.br
// CHECK: rocdl.barrier
// CHECK: llvm.load
// CHECK: rocdl.barrier
// CHECK: llvm.store
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr<f32>, f32, i1) -> f32
tt.store %arg0, %0 : !tt.ptr<f32>
Expand Down
2 changes: 0 additions & 2 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ struct AtomicCASOpConversion
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
barrier();
Value ret = load(valueElemTy, atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
Expand Down Expand Up @@ -637,7 +636,6 @@ struct AtomicRMWOpConversion
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
barrier();
Value ret = load(valueElemTy, atomPtr);
barrier();
rewriter.replaceOp(op, {ret});
}
}
Expand Down
13 changes: 13 additions & 0 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ namespace ttgi = mlir::triton::gpu::intel;

namespace mlir::triton::gpu::intel {

static bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = dyn_cast<RankedTensorType>(value.getType()))
return tensorTy.getNumElements() == 1;
// TODO: Handle other cases.
// For example, when ptr is a tensor of single value.
// It means that ptr is a resultant of broadcast or generated through
// a chain of broadcast and other operations.
// Rematerialize it without considering contiguous memory access pattern is
// fine.
return true;
}

std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding) {
if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op))
return encoding;
Expand Down
Loading

0 comments on commit 3bee61a

Please sign in to comment.