From 5cc39763e97a7afa7671f1ba30a8794182e98db5 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Fri, 20 Dec 2024 09:39:58 -0800 Subject: [PATCH] Fix the nightly docker build (#120) Summary: After the submodule update, the FA3 CUTLASS kernels now cost much more memory to build and will exhaust the CI machine. Make the following changes: 1. Build FA3 with MAX_JOBS=4 and NVCC_THREADS=1. 2. Disable xformers build as it also requires FA3, we will try to enable its build later. 3. Also disable colfax and TK build for now - will fix that later. Pull Request resolved: https://github.com/pytorch-labs/tritonbench/pull/120 Test Plan: CI Reviewed By: adamomainz Differential Revision: D67524390 Pulled By: xuzhao9 fbshipit-source-id: fc889fae5d51d7d2d974d2996e33e3f31d8db98e --- docker/tritonbench-nightly.dockerfile | 3 ++ install.py | 23 ++++++------- submodules/ThunderKittens | 2 +- tools/flash_attn/hopper.patch | 14 ++++++++ tools/flash_attn/install.py | 48 +++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 14 deletions(-) create mode 100644 tools/flash_attn/hopper.patch diff --git a/docker/tritonbench-nightly.dockerfile b/docker/tritonbench-nightly.dockerfile index 71d2c5d9..f5713c72 100644 --- a/docker/tritonbench-nightly.dockerfile +++ b/docker/tritonbench-nightly.dockerfile @@ -49,6 +49,9 @@ RUN cd /workspace/tritonbench && \ # which is from NVIDIA driver RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf patch +# Workaround: installing Ninja from setup.py hits "Failed to decode METADATA with UTF-8" error +RUN . ${SETUP_SCRIPT} && pip install ninja + # Install Tritonbench RUN cd /workspace/tritonbench && \ bash .ci/tritonbench/install.sh diff --git a/install.py b/install.py index e0e227a4..f185e2f6 100644 --- a/install.py +++ b/install.py @@ -57,7 +57,7 @@ def install_fa2(compile=False): if compile: # compile from source (slow) FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention") - cmd = [sys.executable, "setup.py", "install"] + cmd = ["pip", "install", "-e", "."] subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve())) else: # Install the pre-built binary @@ -65,12 +65,6 @@ def install_fa2(compile=False): subprocess.check_call(cmd) -def install_fa3(): - FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") - cmd = [sys.executable, "setup.py", "install"] - subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve())) - - def install_liger(): # Liger-kernel has a conflict dependency `triton` with pytorch, # so we need to install it without dependencies @@ -119,16 +113,19 @@ def setup_hip(args: argparse.Namespace): # checkout submodules checkout_submodules(REPO_PATH) # install submodules + if args.fa3 or args.all: + # we need to install fa3 above all other dependencies + logger.info("[tritonbench] installing fa3...") + from tools.flash_attn.install import install_fa3 + + install_fa3() if args.fbgemm or args.all: logger.info("[tritonbench] installing FBGEMM...") install_fbgemm() if args.fa2 or args.all: logger.info("[tritonbench] installing fa2 from source...") install_fa2(compile=True) - if args.fa3 or args.all: - logger.info("[tritonbench] installing fa3...") - install_fa3() - if args.colfax or args.all: + if args.colfax: logger.info("[tritonbench] installing colfax cutlass-kernels...") from tools.cutlass_kernels.install import install_colfax_cutlass @@ -136,7 +133,7 @@ def setup_hip(args: argparse.Namespace): if args.jax or args.all: logger.info("[tritonbench] installing jax...") install_jax() - if args.tk or args.all: + if args.tk: logger.info("[tritonbench] installing thunderkittens...") from tools.tk.install import install_tk @@ -144,7 +141,7 @@ def setup_hip(args: argparse.Namespace): if args.liger or args.all: logger.info("[tritonbench] installing liger-kernels...") install_liger() - if args.xformers or args.all: + if args.xformers: logger.info("[tritonbench] installing xformers...") from tools.xformers.install import install_xformers diff --git a/submodules/ThunderKittens b/submodules/ThunderKittens index 5d7107a1..cdfce886 160000 --- a/submodules/ThunderKittens +++ b/submodules/ThunderKittens @@ -1 +1 @@ -Subproject commit 5d7107a13b016811da38e781d7fc4f74140b2d03 +Subproject commit cdfce886b2660a0345a20bc5c7d52efd0db95fc0 diff --git a/tools/flash_attn/hopper.patch b/tools/flash_attn/hopper.patch new file mode 100644 index 00000000..b34a4025 --- /dev/null +++ b/tools/flash_attn/hopper.patch @@ -0,0 +1,14 @@ +diff --git a/hopper/setup.py b/hopper/setup.py +index f9f3cfd..132ce07 100644 +--- a/hopper/setup.py ++++ b/hopper/setup.py +@@ -78,7 +78,8 @@ def check_if_cuda_home_none(global_option: str) -> None: + + + def append_nvcc_threads(nvcc_extra_args): +- return nvcc_extra_args + ["--threads", "4"] ++ nvcc_threads = os.getenv("NVCC_THREADS") or "4" ++ return nvcc_extra_args + ["--threads", nvcc_threads] + + + cmdclass = {} diff --git a/tools/flash_attn/install.py b/tools/flash_attn/install.py index e69de29b..8b026af8 100644 --- a/tools/flash_attn/install.py +++ b/tools/flash_attn/install.py @@ -0,0 +1,48 @@ +import os +import subprocess +import sys + +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +CUR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__))) + + +def patch_fa3(): + patches = ["hopper.patch"] + for patch_file in patches: + patch_file_path = os.path.join(CUR_DIR, patch_file) + submodule_path = str( + REPO_PATH.joinpath("submodules", "flash-attention").absolute() + ) + try: + subprocess.check_output( + [ + "patch", + "-p1", + "--forward", + "-i", + patch_file_path, + "-r", + "/tmp/rej", + ], + cwd=submodule_path, + ) + except subprocess.SubprocessError as e: + output_str = str(e.output) + if "previously applied" in output_str: + return + else: + print(str(output_str)) + sys.exit(1) + + +def install_fa3(): + patch_fa3() + FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") + env = os.environ.copy() + # limit nvcc memory usage on the CI machine + env["MAX_JOBS"] = "8" + env["NVCC_THREADS"] = "1" + cmd = ["pip", "install", "-e", "."] + subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)