Skip to content

Commit

Permalink
Fix xetla import when built as wheel (#2589)
Browse files Browse the repository at this point in the history
Fixes #2576.
  • Loading branch information
pbchekin authored Oct 29, 2024
1 parent 18f70d0 commit efce869
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 31 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ repos:
- --disable=line-too-long
# Disable import-error: not everything can be imported when pre-commit runs
- --disable=import-error
- --disable=no-name-in-module
# Disable unused-import: ruff has a corresponding check and supports "noqa: F401"
- --disable=unused-import
# Disable invalid_name: benchmarks use a lot of UPPER_SNAKE arguments
Expand Down
37 changes: 12 additions & 25 deletions benchmarks/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
# TODO: update once there is replacement for clean:
# https://github.com/pypa/setuptools/discussions/2838
from distutils import log # pylint: disable=[deprecated-module]
from distutils.dir_util import remove_tree # pylint: disable=[deprecated-module]
from distutils.command.clean import clean as _clean # pylint: disable=[deprecated-module]

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as _build_ext
Expand All @@ -24,10 +22,10 @@ def __init__(self, name):

class CMakeBuild():

def __init__(self, debug=False, dry_run=False):
def __init__(self, build_lib, build_temp, debug=False, dry_run=False):
self.current_dir = os.path.abspath(os.path.dirname(__file__))
self.build_temp = self.current_dir + "/build/temp"
self.extdir = self.current_dir + "/triton_kernels_benchmark"
self.build_temp = build_temp
self.extdir = build_lib + "/triton_kernels_benchmark"
self.build_type = self.get_build_type(debug)
self.cmake_prefix_paths = [torch.utils.cmake_prefix_path]
self.use_ipex = False
Expand Down Expand Up @@ -101,30 +99,20 @@ def build_extension(self):
self.check_call(["cmake"] + build_args)
self.check_call(["cmake"] + install_args)

def clean(self):
if os.path.exists(self.build_temp):
remove_tree(self.build_temp, dry_run=self.dry_run)
else:
log.warn("'%s' does not exist -- can't clean it", os.path.relpath(self.build_temp,
os.path.dirname(__file__)))


class build_ext(_build_ext):

def run(self):
cmake = CMakeBuild(debug=self.debug, dry_run=self.dry_run)
cmake = CMakeBuild(
build_lib=self.build_lib,
build_temp=self.build_temp,
debug=self.debug,
dry_run=self.dry_run,
)
cmake.run()
super().run()


class clean(_clean):

def run(self):
cmake = CMakeBuild(dry_run=self.dry_run)
cmake.clean()
super().run()


def get_git_commit_hash(length=8):
try:
cmd = ["git", "rev-parse", f"--short={length}", "HEAD"]
Expand All @@ -151,11 +139,10 @@ def get_git_commit_hash(length=8):
package_data={"triton_kernels_benchmark": ["xetla_kernel.cpython-*.so"]},
cmdclass={
"build_ext": build_ext,
"clean": clean,
},
ext_modules=[CMakeExtension("triton_kernels_benchmark")],
extra_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch=2.1.10"],
ext_modules=[CMakeExtension("triton_kernels_benchmark.xetla_kernel")],
extras_require={
"ipex": ["numpy<=2.0", "intel-extension-for-pytorch==2.1.10"],
"pytorch": ["torch>=2.6"],
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.runtime import driver

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel


@torch.jit.script
Expand Down
3 changes: 1 addition & 2 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

import triton_kernels_benchmark as benchmark_suit
from triton_kernels_benchmark.benchmark_testing import do_bench_elapsed_time, BENCHMARKING_METHOD

import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import triton.language as tl

import triton_kernels_benchmark as benchmark_suit
import xetla_kernel
from triton_kernels_benchmark import xetla_kernel

if benchmark_suit.USE_IPEX_OPTION:
import intel_extension_for_pytorch # type: ignore # noqa: F401
Expand Down

0 comments on commit efce869

Please sign in to comment.