-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Patch when installing the xformers (#61)
Summary: This is to patch xformers as its FA3 extension build will fail due to lack of linking to libcuda.so: facebookresearch/xformers#1157 Fixes #20 Pull Request resolved: #61 Reviewed By: FindHao Differential Revision: D66273474 Pulled By: xuzhao9 fbshipit-source-id: 81898ccd005750937ac3cfd639c2303975ef1abe
- Loading branch information
1 parent
e2bbc48
commit abb7ac6
Showing
3 changed files
with
64 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
|
||
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent | ||
PATCH_DIR = str(REPO_PATH.joinpath("submodules", "xformers").absolute()) | ||
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xformers.patch") | ||
|
||
|
||
def patch_xformers(): | ||
try: | ||
subprocess.check_output( | ||
[ | ||
"patch", | ||
"-p1", | ||
"--forward", | ||
"-i", | ||
PATCH_FILE, | ||
"-r", | ||
"/tmp/rej", | ||
], | ||
cwd=PATCH_DIR, | ||
) | ||
except subprocess.SubprocessError as e: | ||
output_str = str(e.output) | ||
if "previously applied" in output_str: | ||
return | ||
else: | ||
print(str(output_str)) | ||
sys.exit(1) | ||
|
||
|
||
def install_xformers(): | ||
patch_xformers() | ||
os_env = os.environ.copy() | ||
os_env["TORCH_CUDA_ARCH_LIST"] = "8.0;9.0;9.0a" | ||
XFORMERS_PATH = REPO_PATH.joinpath("submodules", "xformers") | ||
cmd = ["pip", "install", "-e", XFORMERS_PATH] | ||
subprocess.check_call(cmd, env=os_env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
From 1056e56f873fa6a097de3a7c1ceeeed66676ae82 Mon Sep 17 00:00:00 2001 | ||
From: Xu Zhao <[email protected]> | ||
Date: Wed, 20 Nov 2024 19:19:46 -0500 | ||
Subject: [PATCH] Link to cuda library | ||
|
||
--- | ||
setup.py | 2 ++ | ||
1 file changed, 2 insertions(+) | ||
|
||
diff --git a/setup.py b/setup.py | ||
index 6eaa50904..c804b4817 100644 | ||
--- a/setup.py | ||
+++ b/setup.py | ||
@@ -356,6 +356,8 @@ def get_flash_attention3_extensions(cuda_version: int, extra_compile_args): | ||
Path(flash_root) / "hopper", | ||
] | ||
], | ||
+ # Without this we get and error about cuTensorMapEncodeTiled not defined | ||
+ libraries=["cuda"], | ||
) | ||
] | ||
|