Skip to content

Commit

Permalink
Reuse XPUUtils, ty_to_cpp, serialize_args from main intel drive…
Browse files Browse the repository at this point in the history
…r for benchmarks (#3063)

Part of
#2540

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Dec 26, 2024
1 parent 7107b3d commit 7404f7f
Showing 1 changed file with 1 addition and 104 deletions.
105 changes: 1 addition & 104 deletions benchmarks/triton_kernels_benchmark/benchmark_driver.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
from pathlib import Path

from triton.backends.compiler import GPUTarget
from triton.backends.driver import DriverBase
from triton._utils import parse_list_string
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER
from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER, XPUUtils, ty_to_cpp, serialize_args

import torch

Expand All @@ -14,58 +13,11 @@

COMPILATION_HELPER.inject_pytorch_dep()


class XPUUtils:

def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(XPUUtils, cls).__new__(cls)
return cls.instance

def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(
Path(os.path.join(dirname, "driver.c")).read_text(encoding="utf-8"), "spirv_utils")
self.load_binary = mod.load_binary
self.get_device_properties = mod.get_device_properties
self.context = mod.init_context(self.get_sycl_queue())
self.device_count = mod.init_devices(self.get_sycl_queue())
self.current_device = 0 if self.device_count[0] > 0 else -1

def get_current_device(self):
return self.current_device

def get_sycl_queue(self):
return torch.xpu.current_stream().sycl_queue


# ------------------------
# Launcher
# ------------------------


def ty_to_cpp(ty):
if ty[0] == "*" or ty == "none":
return "void*"
return {
"i1": "int32_t",
"i8": "int8_t",
"i16": "int16_t",
"i32": "int32_t",
"i64": "int64_t",
"u1": "uint32_t",
"u8": "uint8_t",
"u16": "uint16_t",
"u32": "uint32_t",
"u64": "uint64_t",
"fp16": "float",
"bf16": "float",
"fp32": "float",
"f32": "float",
"fp64": "double",
}[ty]


def make_launcher(constants, signature, ids): # pylint: disable=unused-argument

def _extracted_type(ty):
Expand Down Expand Up @@ -341,61 +293,6 @@ def format_of(ty):
return src


def serialize_kernel_metadata(arg, args_dict):
args_dict["num_warps"] = arg.num_warps
args_dict["threads_per_warp"] = arg.threads_per_warp
args_dict["shared_memory"] = arg.shared
args_dict["kernel_name"] = arg.name
args_dict["spv_name"] = f"{arg.name}.spv"
args_dict["build_flags"] = arg.build_flags


def serialize_args(args, constants, signature):
import numbers
dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS")
if not os.path.exists(dir_path):
os.makedirs(dir_path)
print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}")

cnt = 0
args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]}
args_dict["argument_list"] = []
counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0}
cnt = 4
for arg in args[cnt:]:
if type(arg).__name__ == "KernelMetadata":
serialize_kernel_metadata(arg, args_dict)

if isinstance(arg, torch.Tensor):
cpu_tensor = arg.cpu()
tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt")
with open(tensor_path, "wb") as f:
torch.save(cpu_tensor, f)
new_arg = {
"name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype":
signature[counts["karg_cnt"]]
}
args_dict["argument_list"].append(new_arg)
counts["karg_cnt"] += 1
counts["tensors"] += 1

if isinstance(arg, numbers.Number):
if counts["karg_cnt"] not in constants:
new_arg = {
"name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype":
signature[counts["karg_cnt"]]
}
args_dict["argument_list"].append(new_arg)
counts["karg_cnt"] += 1
counts["scalars"] += 1
cnt += 1
# Dump argument info as a JSON file
json_path = os.path.join(dir_path, "args_data.json")
with open(json_path, "w", encoding="utf-8") as json_file:
import json
json.dump(args_dict, json_file, indent=4)


class XPULauncher:

def __init__(self, src, metadata): # pylint: disable=unused-argument
Expand Down

0 comments on commit 7404f7f

Please sign in to comment.