diff --git a/install.py b/install.py index 4e2b7ab9..f48721c0 100644 --- a/install.py +++ b/install.py @@ -77,14 +77,6 @@ def install_liger(): subprocess.check_call(cmd) -def install_xformers(): - os_env = os.environ.copy() - os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a" - XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers") - cmd = ["pip", "install", "-e", XFORMERS_PATH] - subprocess.check_call(cmd, env=os_env) - - if __name__ == "__main__": parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU") @@ -145,6 +137,7 @@ def install_xformers(): install_liger() if args.xformers or args.all: logger.info("[tritonbench] installing xformers...") + from tools.xformers.install import install_xformers install_xformers() if args.hstu or args.all: logger.info("[tritonbench] installing hstu...") diff --git a/tools/xformers/install.py b/tools/xformers/install.py new file mode 100644 index 00000000..4130c38e --- /dev/null +++ b/tools/xformers/install.py @@ -0,0 +1,42 @@ +import os +import subprocess +import sys +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +PATCH_DIR = str( + REPO_PATH.joinpath("submodules", "xformers") + .absolute() +) +PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xformers.patch") + + +def patch_xformers(): + try: + subprocess.check_output( + [ + "patch", + "-p1", + "--forward", + "-i", + PATCH_FILE, + "-r", + "/tmp/rej", + ], + cwd=PATCH_DIR, + ) + except subprocess.SubprocessError as e: + output_str = str(e.output) + if "previously applied" in output_str: + return + else: + print(str(output_str)) + sys.exit(1) + +def install_xformers(): + patch_xformers() + os_env = os.environ.copy() + os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a" + XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers") + cmd = ["pip", "install", "-e", XFORMERS_PATH] + subprocess.check_call(cmd, env=os_env) diff --git a/tools/xformers/xformers.patch b/tools/xformers/xformers.patch new file mode 100644 index 00000000..359f0355 --- /dev/null +++ b/tools/xformers/xformers.patch @@ -0,0 +1,22 @@ +From 1056e56f873fa6a097de3a7c1ceeeed66676ae82 Mon Sep 17 00:00:00 2001 +From: Xu Zhao +Date: Wed, 20 Nov 2024 19:19:46 -0500 +Subject: [PATCH] Link to cuda library + +--- + setup.py | 2 ++ + 1 file changed, 2 insertions(+) + +diff --git a/setup.py b/setup.py +index 6eaa50904..c804b4817 100644 +--- a/setup.py ++++ b/setup.py +@@ -356,6 +356,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): + Path(flash_root) / "hopper", + ] + ], ++ # Without this we get and error about cuTensorMapEncodeTiled not defined ++ libraries=["cuda"], + ) + ] + \ No newline at end of file