Skip to content

Commit

Permalink
Reduce NVCC threads
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 19, 2024
1 parent 481a5ac commit f5ef4d6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 9 deletions.
11 changes: 2 additions & 9 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,6 @@ def install_fa2(compile=False):
subprocess.check_call(cmd)


def install_fa3():
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
env = os.environ.copy()
# nvcc will now spawn cicc and will cost ~1G memory
env["MAX_JOBS"] = "8"
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)


def install_liger():
# Liger-kernel has a conflict dependency `triton` with pytorch,
# so we need to install it without dependencies
Expand Down Expand Up @@ -130,6 +121,8 @@ def setup_hip(args: argparse.Namespace):
install_fa2(compile=True)
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
from tools.flash_attn.install import install_fa3

install_fa3()
if args.colfax:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
Expand Down
14 changes: 14 additions & 0 deletions tools/flash_attn/hopper.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/hopper/setup.py b/hopper/setup.py
index f9f3cfd..132ce07 100644
--- a/hopper/setup.py
+++ b/hopper/setup.py
@@ -78,7 +78,8 @@ def check_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
- return nvcc_extra_args + ["--threads", "4"]
+ nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+ return nvcc_extra_args + ["--threads", NVCC_THREADS]


cmdclass = {}
44 changes: 44 additions & 0 deletions tools/flash_attn/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import subprocess
import sys

from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
CUR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))

def patch_fa3():
patches = ["hopper.patch"]
for patch_file in patches:
patch_file_path = os.path.join(CUR_DIR, patch_file)
submodule_path = str(REPO_PATH.joinpath("submodules", "flash-attention").absolute())
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
patch_file_path,
"-r",
"/tmp/rej",
],
cwd=submodule_path,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)

def install_fa3():
patch_fa3()
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
env = os.environ.copy()
# nvcc will spawn cicc process and will cost ~1G memory
env["MAX_JOBS"] = "8"
env["NVCC_THREADS"] = "2"
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)

0 comments on commit f5ef4d6

Please sign in to comment.