diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 25b6bc5f..718634f5 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -20,6 +20,10 @@ env: jobs: build-push-docker: if: ${{ github.repository_owner == 'pytorch-labs' }} + # spec: + # 120G RAM + # 32 logical CPU cores + # 1T disk runs-on: 32-core-ubuntu environment: docker-s3-upload steps: diff --git a/docker/tritonbench-nightly.dockerfile b/docker/tritonbench-nightly.dockerfile index 71d2c5d9..f5713c72 100644 --- a/docker/tritonbench-nightly.dockerfile +++ b/docker/tritonbench-nightly.dockerfile @@ -49,6 +49,9 @@ RUN cd /workspace/tritonbench && \ # which is from NVIDIA driver RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf patch +# Workaround: installing Ninja from setup.py hits "Failed to decode METADATA with UTF-8" error +RUN . ${SETUP_SCRIPT} && pip install ninja + # Install Tritonbench RUN cd /workspace/tritonbench && \ bash .ci/tritonbench/install.sh diff --git a/install.py b/install.py index e0e227a4..e99aa1c0 100644 --- a/install.py +++ b/install.py @@ -57,7 +57,7 @@ def install_fa2(compile=False): if compile: # compile from source (slow) FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention") - cmd = [sys.executable, "setup.py", "install"] + cmd = ["pip", "install", "-e", "."] subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve())) else: # Install the pre-built binary @@ -65,12 +65,6 @@ def install_fa2(compile=False): subprocess.check_call(cmd) -def install_fa3(): - FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") - cmd = [sys.executable, "setup.py", "install"] - subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve())) - - def install_liger(): # Liger-kernel has a conflict dependency `triton` with pytorch, # so we need to install it without dependencies @@ -119,16 +113,18 @@ def setup_hip(args: argparse.Namespace): # checkout submodules checkout_submodules(REPO_PATH) # install submodules + if args.fa3 or args.all: + # we need to install fa3 above all other dependencies + logger.info("[tritonbench] installing fa3...") + from tools.flash_attn.install import install_fa3 + install_fa3() if args.fbgemm or args.all: logger.info("[tritonbench] installing FBGEMM...") install_fbgemm() if args.fa2 or args.all: logger.info("[tritonbench] installing fa2 from source...") install_fa2(compile=True) - if args.fa3 or args.all: - logger.info("[tritonbench] installing fa3...") - install_fa3() - if args.colfax or args.all: + if args.colfax: logger.info("[tritonbench] installing colfax cutlass-kernels...") from tools.cutlass_kernels.install import install_colfax_cutlass @@ -136,7 +132,7 @@ def setup_hip(args: argparse.Namespace): if args.jax or args.all: logger.info("[tritonbench] installing jax...") install_jax() - if args.tk or args.all: + if args.tk: logger.info("[tritonbench] installing thunderkittens...") from tools.tk.install import install_tk @@ -144,7 +140,7 @@ def setup_hip(args: argparse.Namespace): if args.liger or args.all: logger.info("[tritonbench] installing liger-kernels...") install_liger() - if args.xformers or args.all: + if args.xformers: logger.info("[tritonbench] installing xformers...") from tools.xformers.install import install_xformers diff --git a/submodules/ThunderKittens b/submodules/ThunderKittens index 5d7107a1..cdfce886 160000 --- a/submodules/ThunderKittens +++ b/submodules/ThunderKittens @@ -1 +1 @@ -Subproject commit 5d7107a13b016811da38e781d7fc4f74140b2d03 +Subproject commit cdfce886b2660a0345a20bc5c7d52efd0db95fc0 diff --git a/tools/flash_attn/hopper.patch b/tools/flash_attn/hopper.patch new file mode 100644 index 00000000..b34a4025 --- /dev/null +++ b/tools/flash_attn/hopper.patch @@ -0,0 +1,14 @@ +diff --git a/hopper/setup.py b/hopper/setup.py +index f9f3cfd..132ce07 100644 +--- a/hopper/setup.py ++++ b/hopper/setup.py +@@ -78,7 +78,8 @@ def check_if_cuda_home_none(global_option: str) -> None: + + + def append_nvcc_threads(nvcc_extra_args): +- return nvcc_extra_args + ["--threads", "4"] ++ nvcc_threads = os.getenv("NVCC_THREADS") or "4" ++ return nvcc_extra_args + ["--threads", nvcc_threads] + + + cmdclass = {} diff --git a/tools/flash_attn/install.py b/tools/flash_attn/install.py index e69de29b..b764c0d2 100644 --- a/tools/flash_attn/install.py +++ b/tools/flash_attn/install.py @@ -0,0 +1,44 @@ +import os +import subprocess +import sys + +from pathlib import Path + +REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent +CUR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__))) + +def patch_fa3(): + patches = ["hopper.patch"] + for patch_file in patches: + patch_file_path = os.path.join(CUR_DIR, patch_file) + submodule_path = str(REPO_PATH.joinpath("submodules", "flash-attention").absolute()) + try: + subprocess.check_output( + [ + "patch", + "-p1", + "--forward", + "-i", + patch_file_path, + "-r", + "/tmp/rej", + ], + cwd=submodule_path, + ) + except subprocess.SubprocessError as e: + output_str = str(e.output) + if "previously applied" in output_str: + return + else: + print(str(output_str)) + sys.exit(1) + +def install_fa3(): + patch_fa3() + FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper") + env = os.environ.copy() + # limit nvcc memory usage on the CI machine + env["MAX_JOBS"] = "8" + env["NVCC_THREADS"] = "1" + cmd = ["pip", "install", "-e", "."] + subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)