Skip to content

Commit

Permalink
Patch xformers
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 21, 2024
1 parent 9f7e919 commit 4f60cb2
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
9 changes: 1 addition & 8 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,6 @@ def install_liger():
subprocess.check_call(cmd)


def install_xformers():
os_env = os.environ.copy()
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers")
cmd = ["pip", "install", "-e", XFORMERS_PATH]
subprocess.check_call(cmd, env=os_env)


if __name__ == "__main__":
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--fbgemm", action="store_true", help="Install FBGEMM GPU")
Expand Down Expand Up @@ -145,6 +137,7 @@ def install_xformers():
install_liger()
if args.xformers or args.all:
logger.info("[tritonbench] installing xformers...")
from tools.xformers.install import install_xformers
install_xformers()
if args.hstu or args.all:
logger.info("[tritonbench] installing hstu...")
Expand Down
42 changes: 42 additions & 0 deletions tools/xformers/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import subprocess
import sys
from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
PATCH_DIR = str(
REPO_PATH.joinpath("submodules", "xformers")
.absolute()
)
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xformers.patch")


def patch_xformers():
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
PATCH_FILE,
"-r",
"/tmp/rej",
],
cwd=PATCH_DIR,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)

def install_xformers():
patch_xformers()
os_env = os.environ.copy()
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a"
XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers")
cmd = ["pip", "install", "-e", XFORMERS_PATH]
subprocess.check_call(cmd, env=os_env)
22 changes: 22 additions & 0 deletions tools/xformers/xformers.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
From 1056e56f873fa6a097de3a7c1ceeeed66676ae82 Mon Sep 17 00:00:00 2001
From: Xu Zhao <[email protected]>
Date: Wed, 20 Nov 2024 19:19:46 -0500
Subject: [PATCH] Link to cuda library

---
setup.py | 2 ++
1 file changed, 2 insertions(+)

diff --git a/setup.py b/setup.py
index 6eaa50904..c804b4817 100644
--- a/setup.py
+++ b/setup.py
@@ -356,6 +356,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args):
Path(flash_root) / "hopper",
]
],
+ # Without this we get and error about cuTensorMapEncodeTiled not defined
+ libraries=["cuda"],
)
]

0 comments on commit 4f60cb2

Please sign in to comment.