Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docker] Fix the nightly docker build #120

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ env:
jobs:
build-push-docker:
if: ${{ github.repository_owner == 'pytorch-labs' }}
# spec:
# 120G RAM
# 32 logical CPU cores
# 1T disk
runs-on: 32-core-ubuntu
environment: docker-s3-upload
steps:
Expand Down
3 changes: 3 additions & 0 deletions docker/tritonbench-nightly.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ RUN cd /workspace/tritonbench && \
# which is from NVIDIA driver
RUN sudo apt update && sudo apt-get install -y libnvidia-compute-550 patchelf patch

# Workaround: installing Ninja from setup.py hits "Failed to decode METADATA with UTF-8" error
RUN . ${SETUP_SCRIPT} && pip install ninja

# Install Tritonbench
RUN cd /workspace/tritonbench && \
bash .ci/tritonbench/install.sh
Expand Down
22 changes: 9 additions & 13 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,14 @@ def install_fa2(compile=False):
if compile:
# compile from source (slow)
FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention")
cmd = [sys.executable, "setup.py", "install"]
cmd = ["pip", "install", "-e", "."]
subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()))
else:
# Install the pre-built binary
cmd = ["pip", "install", "flash-attn", "--no-build-isolation"]
subprocess.check_call(cmd)


def install_fa3():
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))


def install_liger():
# Liger-kernel has a conflict dependency `triton` with pytorch,
# so we need to install it without dependencies
Expand Down Expand Up @@ -119,32 +113,34 @@ def setup_hip(args: argparse.Namespace):
# checkout submodules
checkout_submodules(REPO_PATH)
# install submodules
if args.fa3 or args.all:
# we need to install fa3 above all other dependencies
logger.info("[tritonbench] installing fa3...")
from tools.flash_attn.install import install_fa3
install_fa3()
if args.fbgemm or args.all:
logger.info("[tritonbench] installing FBGEMM...")
install_fbgemm()
if args.fa2 or args.all:
logger.info("[tritonbench] installing fa2 from source...")
install_fa2(compile=True)
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
if args.colfax or args.all:
if args.colfax:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass

install_colfax_cutlass()
if args.jax or args.all:
logger.info("[tritonbench] installing jax...")
install_jax()
if args.tk or args.all:
if args.tk:
logger.info("[tritonbench] installing thunderkittens...")
from tools.tk.install import install_tk

install_tk()
if args.liger or args.all:
logger.info("[tritonbench] installing liger-kernels...")
install_liger()
if args.xformers or args.all:
if args.xformers:
logger.info("[tritonbench] installing xformers...")
from tools.xformers.install import install_xformers

Expand Down
2 changes: 1 addition & 1 deletion submodules/ThunderKittens
Submodule ThunderKittens updated 383 files
14 changes: 14 additions & 0 deletions tools/flash_attn/hopper.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
diff --git a/hopper/setup.py b/hopper/setup.py
index f9f3cfd..132ce07 100644
--- a/hopper/setup.py
+++ b/hopper/setup.py
@@ -78,7 +78,8 @@ def check_if_cuda_home_none(global_option: str) -> None:


def append_nvcc_threads(nvcc_extra_args):
- return nvcc_extra_args + ["--threads", "4"]
+ nvcc_threads = os.getenv("NVCC_THREADS") or "4"
+ return nvcc_extra_args + ["--threads", nvcc_threads]


cmdclass = {}
44 changes: 44 additions & 0 deletions tools/flash_attn/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import subprocess
import sys

from pathlib import Path

REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
CUR_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)))

def patch_fa3():
patches = ["hopper.patch"]
for patch_file in patches:
patch_file_path = os.path.join(CUR_DIR, patch_file)
submodule_path = str(REPO_PATH.joinpath("submodules", "flash-attention").absolute())
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
patch_file_path,
"-r",
"/tmp/rej",
],
cwd=submodule_path,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)

def install_fa3():
patch_fa3()
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
env = os.environ.copy()
# limit nvcc memory usage on the CI machine
env["MAX_JOBS"] = "8"
env["NVCC_THREADS"] = "1"
cmd = ["pip", "install", "-e", "."]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)
Loading