From 7404f7fee758f370c58845b59bca6a081a6264d8 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 26 Dec 2024 16:43:21 +0100 Subject: [PATCH] Reuse `XPUUtils`, `ty_to_cpp`, `serialize_args` from main intel driver for benchmarks (#3063) Part of https://github.com/intel/intel-xpu-backend-for-triton/issues/2540 Signed-off-by: Anatoly Myachev --- .../benchmark_driver.py | 105 +----------------- 1 file changed, 1 insertion(+), 104 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_driver.py b/benchmarks/triton_kernels_benchmark/benchmark_driver.py index 996b28b7ea..893e84db18 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_driver.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_driver.py @@ -1,10 +1,9 @@ import os -from pathlib import Path from triton.backends.compiler import GPUTarget from triton.backends.driver import DriverBase from triton._utils import parse_list_string -from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER +from triton.backends.intel.driver import compile_module_from_src, COMPILATION_HELPER, XPUUtils, ty_to_cpp, serialize_args import torch @@ -14,58 +13,11 @@ COMPILATION_HELPER.inject_pytorch_dep() - -class XPUUtils: - - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(XPUUtils, cls).__new__(cls) - return cls.instance - - def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src( - Path(os.path.join(dirname, "driver.c")).read_text(encoding="utf-8"), "spirv_utils") - self.load_binary = mod.load_binary - self.get_device_properties = mod.get_device_properties - self.context = mod.init_context(self.get_sycl_queue()) - self.device_count = mod.init_devices(self.get_sycl_queue()) - self.current_device = 0 if self.device_count[0] > 0 else -1 - - def get_current_device(self): - return self.current_device - - def get_sycl_queue(self): - return torch.xpu.current_stream().sycl_queue - - # ------------------------ # Launcher # ------------------------ -def ty_to_cpp(ty): - if ty[0] == "*" or ty == "none": - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - def make_launcher(constants, signature, ids): # pylint: disable=unused-argument def _extracted_type(ty): @@ -341,61 +293,6 @@ def format_of(ty): return src -def serialize_kernel_metadata(arg, args_dict): - args_dict["num_warps"] = arg.num_warps - args_dict["threads_per_warp"] = arg.threads_per_warp - args_dict["shared_memory"] = arg.shared - args_dict["kernel_name"] = arg.name - args_dict["spv_name"] = f"{arg.name}.spv" - args_dict["build_flags"] = arg.build_flags - - -def serialize_args(args, constants, signature): - import numbers - dir_path = os.getenv("TRITON_XPU_DUMP_SPIRV_KERNEL_ARGS") - if not os.path.exists(dir_path): - os.makedirs(dir_path) - print(f"Path to directory consisting of SPIR-V Runner data: {dir_path}") - - cnt = 0 - args_dict = {"gridX": args[cnt], "gridY": args[cnt + 1], "gridZ": args[cnt + 2]} - args_dict["argument_list"] = [] - counts = {"tensors": 0, "scalars": 0, "karg_cnt": 0} - cnt = 4 - for arg in args[cnt:]: - if type(arg).__name__ == "KernelMetadata": - serialize_kernel_metadata(arg, args_dict) - - if isinstance(arg, torch.Tensor): - cpu_tensor = arg.cpu() - tensor_path = os.path.join(dir_path, f"tensor_{counts['tensors']}.pt") - with open(tensor_path, "wb") as f: - torch.save(cpu_tensor, f) - new_arg = { - "name": f"tensor_{counts['tensors']}", "type": "tensor", "dtype": str(arg.dtype), "ctype": - signature[counts["karg_cnt"]] - } - args_dict["argument_list"].append(new_arg) - counts["karg_cnt"] += 1 - counts["tensors"] += 1 - - if isinstance(arg, numbers.Number): - if counts["karg_cnt"] not in constants: - new_arg = { - "name": f"scalarArg_{counts['scalars']}", "type": "scalar", "value": args[cnt], "ctype": - signature[counts["karg_cnt"]] - } - args_dict["argument_list"].append(new_arg) - counts["karg_cnt"] += 1 - counts["scalars"] += 1 - cnt += 1 - # Dump argument info as a JSON file - json_path = os.path.join(dir_path, "args_data.json") - with open(json_path, "w", encoding="utf-8") as json_file: - import json - json.dump(args_dict, json_file, indent=4) - - class XPULauncher: def __init__(self, src, metadata): # pylint: disable=unused-argument