Skip to content

Commit

Permalink
Merge branch 'main' into amyachev/autotuner
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang authored Dec 21, 2024
2 parents 5710fd1 + 3e1165f commit b67dc9a
Show file tree
Hide file tree
Showing 16 changed files with 681 additions and 403 deletions.
68 changes: 34 additions & 34 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from triton.backends.driver import DriverBase
from triton.runtime.cache import get_cache_manager
from triton.runtime.build import _build, quiet
from triton._utils import parse_list_string

import torch

Expand Down Expand Up @@ -84,7 +85,7 @@ def get_sycl_queue(self):


def ty_to_cpp(ty):
if ty[0] == "*":
if ty[0] == "*" or ty == "none":
return "void*"
return {
"i1": "int32_t",
Expand All @@ -106,16 +107,27 @@ def ty_to_cpp(ty):


def make_launcher(constants, signature, ids): # pylint: disable=unused-argument
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())

def _extracted_type(ty):
if ty[0] == "*":
if ty[0] == "*" or ty == "none":
return "PyObject*"
if ty[0] == "[":
if ty == "[]":
return "[]"
tys = parse_list_string(ty)
val = ",".join(map(_extracted_type, tys))
return f"[{val}]"
return ty_to_cpp(ty)

def format_of(ty):
if ty == "void*":
return "O"
if ty[0] == "[":
if ty == "[]":
return "()"
tys = parse_list_string(ty)
val = "".join(map(format_of, tys))
return f"({val})"
return {
"PyObject*": "O",
"float": "f",
Expand All @@ -131,10 +143,18 @@ def format_of(ty):
"uint64_t": "K",
}[ty]

signature = {k: v for k, v in signature.items() if v != "constexpr"}
args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()])
fmt = "iiiOOOOOO" + args_format
signature = ",".join(signature.values()).replace("[", "").replace("]", "")
signature = list(filter(bool, signature.split(",")))
signature = dict(enumerate(signature))
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""

# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())

# generate glue code
src = f"""
#include <cstddef>
Expand Down Expand Up @@ -229,33 +249,15 @@ def format_of(ty):
return ptr_info;
}}
// start sycl
static void set_scalar_arg(
sycl::handler& cgh,
int index,
size_t size,
const void* value) {{
switch (size) {{
case sizeof(uint8_t):
cgh.set_arg(index, *static_cast<const uint8_t*>(value));
break;
case sizeof(uint16_t):
cgh.set_arg(index, *static_cast<const uint16_t*>(value));
break;
case sizeof(uint32_t):
cgh.set_arg(index, *static_cast<const uint32_t*>(value));
break;
case sizeof(uint64_t):
cgh.set_arg(index, *static_cast<const uint64_t*>(value));
break;
default:
assert(false && "wrong scalar size in sycl gen.");
}}
template <class T>
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
cgh.set_arg(index, *static_cast<const T *>(value));
}}
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{
std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
uint32_t num_params = sizeof(params)/sizeof(params[0]);
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
size_t global_range_x = gridX*threads_per_warp*num_warps;
Expand All @@ -273,8 +275,7 @@ def format_of(ty):
assert(num_params == expected_num_params && "number of kernel param not matched");
// Submit the imported kernel.
auto cgf = [&](sycl::handler &cgh) {{
{" ".join(f"set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants]))}
if (shared_memory) {{
{" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{
using share_mem_t = sycl::local_accessor<int8_t, 1>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
cgh.set_arg(num_params, local_buffer);
Expand Down Expand Up @@ -336,8 +337,8 @@ def format_of(ty):
if(kernel_ptr == nullptr) return NULL;
sycl::kernel kernel = *kernel_ptr;
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
if(launch_exit_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
Expand Down Expand Up @@ -440,9 +441,8 @@ class XPULauncher:
def __init__(self, src, metadata): # pylint: disable=unused-argument
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
constants = src.constants if hasattr(src, "constants") else {}
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
self.constants = {cst_key(key): value for key, value in constants.items()}
self.signature = {cst_key(key): value for key, value in src.signature.items()}
self.constants = dict(constants.items())
self.signature = dict(src.signature.items())
src = make_launcher(self.constants, self.signature, ids)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch
Expand Down
24 changes: 24 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,12 +348,18 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);

SmallVector<unsigned> delinearize(unsigned linear, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);

Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);

Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);

Expand Down Expand Up @@ -496,6 +502,24 @@ inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
return ret;
}

/// Extend 2d shared object to 3d.
///
/// If tensor has 3 dimensions, returns original shared object.
/// If tensor shape is [M, N], return shared object describing shape [1, M, N]
///
/// This Function is used to simplify processing of 2d and 3d dot operands,
/// particularly in the conversion of local_load operation.
///
/// \param rewriter
/// \param loc
/// \param smemObj
/// \param shape shape of a tensor represented by smemObj
/// \returns shared object describing 3d tensor
SharedMemoryObject
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
SharedMemoryObject smemObj,
ArrayRef<int64_t> shape);

// -----------------------------------------------------------------------
// Blocked layout indices
// -----------------------------------------------------------------------
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s);

llvm::SmallVector<unsigned>
expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o);

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
Loading

0 comments on commit b67dc9a

Please sign in to comment.