Skip to content

Commit

Permalink
Format files
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 21, 2024
1 parent 3c5fd4a commit 59ffcca
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
13 changes: 9 additions & 4 deletions tools/cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def prepare_cuda_env(cuda_version: str, dryrun=False):
env["CUDA_HOME"] = cuda_path_str
env["PATH"] = f"{cuda_path_str}/bin:{env['PATH']}"
env["CMAKE_CUDA_COMPILER"] = str(cuda_path.joinpath("bin", "nvcc").resolve())
env["LD_LIBRARY_PATH"] = (
f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}"
)
env[
"LD_LIBRARY_PATH"
] = f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}"
if dryrun:
print(f"CUDA_HOME is set to {env['CUDA_HOME']}")
# step 2: test call to nvcc to confirm the version is correct
Expand Down Expand Up @@ -88,6 +88,7 @@ def setup_cuda_softlink(cuda_version: str):

def install_pytorch_nightly(cuda_version: str, env, dryrun=False):
from .torch_utils import TORCH_NIGHTLY_PACKAGES

uninstall_torch_cmd = ["pip", "uninstall", "-y"]
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
if dryrun:
Expand Down Expand Up @@ -168,11 +169,15 @@ def install_torch_deps(cuda_version: str):
install_torch_deps(cuda_version=args.cudaver)
if args.install_torch_build_deps:
from .torch_utils import install_torch_build_deps

install_torch_deps(cuda_version=args.cudaver)
install_torch_build_deps()
if args.install_torch_nightly:
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
if args.check_torch_nightly_version:
from .torch_utils import check_torch_nightly_version
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."

assert (
not args.install_torch_nightly
), "Error: Can't run install torch nightly and check version in the same command."
check_torch_nightly_version(args.force_date)
1 change: 1 addition & 0 deletions tools/rocm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def install_torch_deps():

def install_pytorch_nightly(rocm_version: str, env, dryrun=False):
from .torch_utils import TORCH_NIGHTLY_PACKAGES

uninstall_torch_cmd = ["pip", "uninstall", "-y"]
uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES)
if dryrun:
Expand Down
7 changes: 5 additions & 2 deletions tools/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
CUDA/ROCM independent pytorch installation helpers.
"""

import subprocess
import importlib
import re
import subprocess
from pathlib import Path

from typing import Optional
Expand All @@ -18,9 +18,11 @@

def is_hip() -> bool:
import torch

version = torch.__version__
return "rocm" in version


def install_torch_build_deps():
# Pin cmake version to stable
# See: https://github.com/pytorch/builder/pull/1269
Expand Down Expand Up @@ -52,6 +54,7 @@ def install_torch_build_deps():
cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps
subprocess.check_call(cmd)


def get_torch_nightly_version(pkg_name: str):
pkg = importlib.import_module(pkg_name)
version = pkg.__version__
Expand All @@ -75,4 +78,4 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
force_date_str = f"User force date {force_date}" if force_date else ""
print(
f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}"
)
)

0 comments on commit 59ffcca

Please sign in to comment.