Skip to content

Commit

Permalink
Reduce RAM usage in CI environment
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Dec 9, 2024
1 parent 95add4d commit 3878064
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .ci/tritonbench/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ tritonbench_dir=$(dirname "$(readlink -f "$0")")/../..
cd ${tritonbench_dir}

# Install Tritonbench and all its customized packages
python install.py --all
python install.py --all --ci
2 changes: 1 addition & 1 deletion .ci/tritonbench/test-install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ tritonbench_dir=$(dirname "$(readlink -f "$0")")/../..
cd ${tritonbench_dir}

# Install Tritonbench and all its customized packages
python install.py --all --test
python install.py --all --test --ci
4 changes: 2 additions & 2 deletions .github/workflows/_linux-test-h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ jobs:
if: steps.submodules_changes.outputs.fa == 'true'
run: |
. "${SETUP_SCRIPT}"
python install.py --fa2
python install.py --fa3
python install.py --fa2 --ci
python install.py --fa3 --ci
- name: Reinstall xformers (optional)
if: steps.submodules_changes.outputs.xformers == 'true'
run: |
Expand Down
21 changes: 15 additions & 6 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,30 @@ def test_fbgemm():
print("OK")


def install_fa2(compile=False):
def install_fa2(compile=False, ci=False):
if compile:
# compile from source (slow)
env = os.environ.copy()
# limit max jobs to save memory in CI
if ci:
env["MAX_JOBS"] = 4
FA2_PATH = REPO_PATH.joinpath("submodules", "flash-attention")
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()))
subprocess.check_call(cmd, cwd=str(FA2_PATH.resolve()), env=env)
else:
# Install the pre-built binary
cmd = ["pip", "install", "flash-attn", "--no-build-isolation"]
subprocess.check_call(cmd)


def install_fa3():
def install_fa3(ci=False):
FA3_PATH = REPO_PATH.joinpath("submodules", "flash-attention", "hopper")
env = os.environ.copy()
# limit max jobs to save memory in CI
if ci:
env["MAX_JOBS"] = 4
cmd = [sys.executable, "setup.py", "install"]
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()))
subprocess.check_call(cmd, cwd=str(FA3_PATH.resolve()), env=env)


def install_liger():
Expand Down Expand Up @@ -109,6 +117,7 @@ def setup_hip(args: argparse.Namespace):
"--all", action="store_true", help="Install all custom kernel repos"
)
parser.add_argument("--test", action="store_true", help="Run tests")
parser.add_argument("--ci", action="store_true", help="Indicate running in CI environment.")
args = parser.parse_args()

if args.all and is_hip():
Expand All @@ -124,10 +133,10 @@ def setup_hip(args: argparse.Namespace):
install_fbgemm()
if args.fa2 or args.all:
logger.info("[tritonbench] installing fa2 from source...")
install_fa2(compile=True)
install_fa2(compile=True, ci=args.ci)
if args.fa3 or args.all:
logger.info("[tritonbench] installing fa3...")
install_fa3()
install_fa3(ci=args.ci)
if args.colfax or args.all:
logger.info("[tritonbench] installing colfax cutlass-kernels...")
from tools.cutlass_kernels.install import install_colfax_cutlass
Expand Down

0 comments on commit 3878064

Please sign in to comment.