Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot fix for benchmark_driver.py after #3043 #3060

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from triton.backends.driver import DriverBase
from triton.runtime.cache import get_cache_manager
from triton.runtime.build import _build, quiet
from triton._utils import parse_list_string

import torch

Expand Down Expand Up @@ -84,7 +85,7 @@ def get_sycl_queue(self):


def ty_to_cpp(ty):
if ty[0] == "*":
if ty[0] == "*" or ty == "none":
return "void*"
return {
"i1": "int32_t",
Expand All @@ -106,16 +107,27 @@ def ty_to_cpp(ty):


def make_launcher(constants, signature, ids): # pylint: disable=unused-argument
# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())

def _extracted_type(ty):
if ty[0] == "*":
if ty[0] == "*" or ty == "none":
return "PyObject*"
if ty[0] == "[":
if ty == "[]":
return "[]"
tys = parse_list_string(ty)
val = ",".join(map(_extracted_type, tys))
return f"[{val}]"
return ty_to_cpp(ty)

def format_of(ty):
if ty == "void*":
return "O"
if ty[0] == "[":
if ty == "[]":
return "()"
tys = parse_list_string(ty)
val = "".join(map(format_of, tys))
return f"({val})"
return {
"PyObject*": "O",
"float": "f",
Expand All @@ -131,10 +143,18 @@ def format_of(ty):
"uint64_t": "K",
}[ty]

signature = {k: v for k, v in signature.items() if v != "constexpr"}
args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()])
fmt = "iiiOOOOOO" + args_format
signature = ",".join(signature.values()).replace("[", "").replace("]", "")
signature = list(filter(bool, signature.split(",")))
signature = dict(enumerate(signature))
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""

# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())

# generate glue code
src = f"""
#include <cstddef>
Expand Down Expand Up @@ -229,33 +249,15 @@ def format_of(ty):
return ptr_info;
}}
// start sycl
static void set_scalar_arg(
sycl::handler& cgh,
int index,
size_t size,
const void* value) {{
switch (size) {{
case sizeof(uint8_t):
cgh.set_arg(index, *static_cast<const uint8_t*>(value));
break;
case sizeof(uint16_t):
cgh.set_arg(index, *static_cast<const uint16_t*>(value));
break;
case sizeof(uint32_t):
cgh.set_arg(index, *static_cast<const uint32_t*>(value));
break;
case sizeof(uint64_t):
cgh.set_arg(index, *static_cast<const uint64_t*>(value));
break;
default:
assert(false && "wrong scalar size in sycl gen.");
}}
template <class T>
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
cgh.set_arg(index, *static_cast<const T *>(value));
}}
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{

std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
void *params[] = {{ {", ".join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
uint32_t num_params = sizeof(params)/sizeof(params[0]);
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
size_t global_range_x = gridX*threads_per_warp*num_warps;
Expand All @@ -273,8 +275,7 @@ def format_of(ty):
assert(num_params == expected_num_params && "number of kernel param not matched");
// Submit the imported kernel.
auto cgf = [&](sycl::handler &cgh) {{
{" ".join(f"set_scalar_arg(cgh, {idx}, sizeof({ty_to_cpp(item)}), params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants]))}
if (shared_memory) {{
{" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{
using share_mem_t = sycl::local_accessor<int8_t, 1>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
cgh.set_arg(num_params, local_buffer);
Expand Down Expand Up @@ -336,8 +337,8 @@ def format_of(ty):
if(kernel_ptr == nullptr) return NULL;
sycl::kernel kernel = *kernel_ptr;

{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});

if(launch_exit_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
Expand Down Expand Up @@ -440,9 +441,8 @@ class XPULauncher:
def __init__(self, src, metadata): # pylint: disable=unused-argument
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
constants = src.constants if hasattr(src, "constants") else {}
cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i
self.constants = {cst_key(key): value for key, value in constants.items()}
self.signature = {cst_key(key): value for key, value in src.signature.items()}
self.constants = dict(constants.items())
self.signature = dict(src.signature.items())
src = make_launcher(self.constants, self.signature, ids)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch
Expand Down