Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse kernels launcher from main intel driver for benchmarks #3070

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,2 @@
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401

if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
from triton.runtime import driver
from . import benchmark_driver
# replace the launcher with the profilier hook.
driver.active.launcher_cls = benchmark_driver.XPULauncher
from . import benchmark_driver # type: ignore # noqa: F401
308 changes: 3 additions & 305 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
@@ -1,308 +1,6 @@
import os

from triton._utils import parse_list_string
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER, ty_to_cpp, serialize_args
from .benchmark_testing import BENCHMARKING_METHOD

# ------------------------
# Utils
# ------------------------

COMPILATION_HELPER.inject_pytorch_dep()

# ------------------------
# Launcher
# ------------------------


def make_launcher(constants, signature, ids): # pylint: disable=unused-argument

def _extracted_type(ty):
if ty[0] == "*" or ty == "none":
return "PyObject*"
if ty[0] == "[":
if ty == "[]":
return "[]"
tys = parse_list_string(ty)
val = ",".join(map(_extracted_type, tys))
return f"[{val}]"
return ty_to_cpp(ty)

def format_of(ty):
if ty == "void*":
return "O"
if ty[0] == "[":
if ty == "[]":
return "()"
tys = parse_list_string(ty)
val = "".join(map(format_of, tys))
return f"({val})"
return {
"PyObject*": "O",
"float": "f",
"double": "d",
"long": "l",
"int8_t": "b",
"int16_t": "h",
"int32_t": "i",
"int64_t": "L",
"uint8_t": "B",
"uint16_t": "H",
"uint32_t": "I",
"uint64_t": "K",
}[ty]

signature = {k: v for k, v in signature.items() if v != "constexpr"}
args_format = "".join([format_of(_extracted_type(ty)) for ty in signature.values()])
fmt = "iiiOOOOOO" + args_format
signature = ",".join(signature.values()).replace("[", "").replace("]", "")
signature = list(filter(bool, signature.split(",")))
signature = dict(enumerate(signature))
args_list = ", " + ", ".join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""

# Record the end of regular arguments;
# subsequent arguments are architecture-specific descriptors.
arg_decls = ", ".join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items())

# generate glue code
src = f"""
#include <cstddef>
#include <string>
#include <iostream>
#include <iomanip>
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
#include <ATen/record_function.h>

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <stdio.h>
#include <numpy/arrayobject.h>

static inline void gpuAssert(ze_result_t code, const char *file, int line)
{{
if (code != ZE_RESULT_SUCCESS)
{{
const char* prefix = "Triton Error [ZE]: ";
std::string str = std::to_string(code);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str.c_str());
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}

#define ZE_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}

typedef struct _DevicePtrInfo {{
void* dev_ptr;
bool valid;
}} DevicePtrInfo;

static inline void checkDevicePointer(DevicePtrInfo *ptr_info, int idx, const sycl::queue &queue) {{
if (!ptr_info->dev_ptr || !ptr_info->valid) {{
return;
}}
auto context = queue.get_context();
auto handle = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(context);
ze_memory_allocation_properties_t prop;
prop.stype = ZE_STRUCTURE_TYPE_MEMORY_ALLOCATION_PROPERTIES;
prop.pNext = nullptr;
ze_device_handle_t device;
auto res = zeMemGetAllocProperties((ze_context_handle_t)handle, ptr_info->dev_ptr, &prop, &device);
if (res != ZE_RESULT_SUCCESS) {{
PyErr_Format(PyExc_ValueError,
"Cannot get memory properties for pointer argument (at %d, err=%d)", idx, res);
ptr_info->valid = false;
}} else if (prop.type != ZE_MEMORY_TYPE_DEVICE) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) doesn't reference XPU device memory (cpu tensor?)", idx);
ptr_info->valid = false;
}}
}}

static inline DevicePtrInfo getPointer(PyObject *obj, int idx, const sycl::queue &queue) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = PyLong_AsVoidPtr(obj);
checkDevicePointer(&ptr_info, idx, queue);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = PyLong_AsVoidPtr(ret);
if(!ptr_info.dev_ptr) {{
return ptr_info;
}}
checkDevicePointer(&ptr_info, idx, queue);
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
ptr_info.valid = false;
return ptr_info;
}}
// start sycl
template <class T>
static inline void set_scalar_arg(sycl::handler &cgh, int index, const void *value) {{
cgh.set_arg(index, *static_cast<const T *>(value));
}}
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {", " + arg_decls if len(arg_decls) > 0 else ""}) {{

std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {{}});
void *params[] = {{ {", ".join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
uint32_t num_params = sizeof(params)/sizeof(params[0]);
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
size_t global_range_x = gridX*threads_per_warp*num_warps;
size_t global_range_y = gridY;
size_t global_range_z = gridZ;
size_t local_range_x = num_warps*threads_per_warp;
size_t local_range_y = 1;
size_t local_range_z = 1;
sycl::range<3> global_range(global_range_z, global_range_y, global_range_x);
sycl::range<3> local_range(local_range_z, local_range_y, local_range_x);
sycl::nd_range<3> parallel_work_size(global_range, local_range);
if (shared_memory) {{
expected_num_params -= 1;
}}
assert(num_params == expected_num_params && "number of kernel param not matched");
// Submit the imported kernel.
auto cgf = [&](sycl::handler &cgh) {{
{" ".join(f"set_scalar_arg<{ty_to_cpp(item)}>(cgh, {idx}, params[{idx}]);" for idx, item in enumerate([signature[i] for i in signature if i not in constants and signature[i] != "none"]))} if (shared_memory) {{
using share_mem_t = sycl::local_accessor<int8_t, 1>;
share_mem_t local_buffer = share_mem_t(shared_memory, cgh);
cgh.set_arg(num_params, local_buffer);
cgh.parallel_for(parallel_work_size, kernel_ptr);
}} else {{
cgh.parallel_for(parallel_work_size, kernel_ptr);
}}
}};
auto event = stream.submit(cgf);
}}
// end sycl
static PyObject* launch(PyObject* self, PyObject* args) {{

int gridX, gridY, gridZ;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *kernel_metadata = NULL;
PyObject *launch_metadata = NULL;
PyObject *py_obj_stream;
PyObject *py_kernel;

{" ".join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{fmt}\", &gridX, &gridY, &gridZ, &py_obj_stream, &py_kernel,
&kernel_metadata, &launch_metadata,
&launch_enter_hook, &launch_exit_hook {args_list})) {{
return NULL;
}}

// extract kernel metadata
int num_warps = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_warps"));
int num_ctas = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "num_ctas"));
int shared_memory = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "shared"));
int threads_per_warp = PyLong_AsLong(PyObject_GetAttrString(kernel_metadata, "threads_per_warp"));

// extract cluster dims
PyObject *clusterDim = PyObject_GetAttrString(kernel_metadata, "cluster_dims");
if (!PyTuple_Check(kernel_metadata)) {{
PyErr_SetString(PyExc_TypeError, "kernel_metadata.cluster_dims must be a tuple");
return NULL;
}}
int clusterDimX = PyLong_AsLong(PyTuple_GetItem(clusterDim, 0));
int clusterDimY = PyLong_AsLong(PyTuple_GetItem(clusterDim, 1));
int clusterDimZ = PyLong_AsLong(PyTuple_GetItem(clusterDim, 2));
// extract launch metadata
if (launch_enter_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
PyObject* ret = PyObject_CallObject(launch_enter_hook, args);
Py_DECREF(args);
if (!ret)
return NULL;
}}

void * pStream = PyLong_AsVoidPtr(py_obj_stream);
//error check
if(pStream == nullptr || py_kernel == nullptr) return NULL;

sycl::queue stream = *(static_cast<sycl::queue*>(pStream));
sycl::kernel* kernel_ptr = reinterpret_cast<sycl::kernel*>(PyCapsule_GetPointer(py_kernel, "kernel"));
if(kernel_ptr == nullptr) return NULL;
sycl::kernel kernel = *kernel_ptr;

{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}, stream); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" or ty == "none" else "" for i, ty in signature.items()])};
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel {"," + ", ".join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" or ty == "none" else f"_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ""});

if(launch_exit_hook != Py_None){{
PyObject* args = Py_BuildValue("(O)", launch_metadata);
PyObject* ret = PyObject_CallObject(launch_exit_hook, args);
Py_DECREF(args);
if (!ret)
return NULL;
}}
if (PyErr_Occurred()) {{
return NULL;
}}

// return None
Py_INCREF(Py_None);
return Py_None;
}}

static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}},
{{NULL, NULL, 0, NULL}} // sentinel
}};

static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};

PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
"""
return src


class XPULauncher:

def __init__(self, src, metadata): # pylint: disable=unused-argument
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
constants = src.constants if hasattr(src, "constants") else {}
self.constants = dict(constants.items())
self.signature = dict(src.signature.items())
src = make_launcher(self.constants, self.signature, ids)
mod = compile_module_from_src(src, "__triton_launcher")
self.launch = mod.launch

def __call__(self, *args, **kwargs):
# Serialize KernelArguments for SPIR-V Runner
serialize_kernel_args = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS", None)
if serialize_kernel_args:
serialize_args(args, self.constants, self.signature)
self.launch(*args, **kwargs)
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
os.environ["INJECT_PYTORCH"] = "True"
15 changes: 7 additions & 8 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class CompilationHelper:
_include_dir: list[str]
libraries: list[str]

# for benchmarks
_build_with_pytorch_dep: bool = False

def __init__(self):
self._library_dir = None
self._include_dir = None
Expand All @@ -81,11 +78,9 @@ def __init__(self):
if os.name != "nt":
self.libraries += ["sycl"]

@property
def inject_pytorch_dep(self):
# must be called before any cached properties (if pytorch is needed)
if self._build_with_pytorch_dep is False:
self._build_with_pytorch_dep = True
self.libraries += ['torch']
return os.environ.get("INJECT_PYTORCH", "False") == "True"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working through environment variables is more stable, since with the previous approach this module (driver.py) is loaded by Python at least twice (somewhere in Triton driver there is a manual module loading), which recreates CompilationHelper instance and removes the effect of calling inject_pytorch_dep method. This issue only occurs when trying to reuse make_launcher function, so the previous version still works correctly.


@cached_property
def _compute_compilation_options_lazy(self):
Expand All @@ -103,7 +98,7 @@ def _compute_compilation_options_lazy(self):
include_dir += [os.path.join(dirname, "include")]
library_dir += [os.path.join(dirname, "lib")]

if self._build_with_pytorch_dep:
if self.inject_pytorch_dep:
import torch

torch_path = torch.utils.cmake_prefix_path
Expand All @@ -112,6 +107,7 @@ def _compute_compilation_options_lazy(self):
os.path.join(torch_path, "../../include/torch/csrc/api/include"),
]
library_dir += [os.path.join(torch_path, "../../lib")]
self.libraries += ['torch']

self._library_dir = library_dir
self._include_dir = include_dir
Expand Down Expand Up @@ -276,6 +272,7 @@ def format_of(ty):
#include <iomanip>
#include <level_zero/ze_api.h>
#include <sycl/sycl.hpp>
{ "#include <ATen/record_function.h>" if COMPILATION_HELPER.inject_pytorch_dep else "" }

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
Expand Down Expand Up @@ -370,6 +367,8 @@ def format_of(ty):
static void sycl_kernel_launch(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_warps, int threads_per_warp, int shared_memory, sycl::queue& stream, sycl::kernel& kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{

std::string kernel_name = kernel_ptr.get_info<sycl::info::kernel::function_name>();
{ 'RECORD_FUNCTION("XPU Triton kernel:" + kernel_name, {});' if COMPILATION_HELPER.inject_pytorch_dep else "" }

void *params[] = {{ {', '.join(f"&arg{i}" for i, ty in signature.items() if i not in constants and ty != "none")} }};
uint32_t num_params = sizeof(params)/sizeof(params[0]);
uint32_t expected_num_params = kernel_ptr.get_info<sycl::info::kernel::num_args>();
Expand Down