From c9ee096b77f5fd6e0c1f65118d3c2ee72d57adf8 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 21 Nov 2024 11:41:37 -0500 Subject: [PATCH] Format files --- tools/cuda_utils.py | 13 +++++++++---- tools/rocm_utils.py | 1 + tools/torch_utils.py | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tools/cuda_utils.py b/tools/cuda_utils.py index daf51e39..7186a854 100644 --- a/tools/cuda_utils.py +++ b/tools/cuda_utils.py @@ -37,9 +37,9 @@ def prepare_cuda_env(cuda_version: str, dryrun=False): env["CUDA_HOME"] = cuda_path_str env["PATH"] = f"{cuda_path_str}/bin:{env['PATH']}" env["CMAKE_CUDA_COMPILER"] = str(cuda_path.joinpath("bin", "nvcc").resolve()) - env["LD_LIBRARY_PATH"] = ( - f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}" - ) + env[ + "LD_LIBRARY_PATH" + ] = f"{cuda_path_str}/lib64:{cuda_path_str}/extras/CUPTI/lib64:{env['LD_LIBRARY_PATH']}" if dryrun: print(f"CUDA_HOME is set to {env['CUDA_HOME']}") # step 2: test call to nvcc to confirm the version is correct @@ -88,6 +88,7 @@ def setup_cuda_softlink(cuda_version: str): def install_pytorch_nightly(cuda_version: str, env, dryrun=False): from .torch_utils import TORCH_NIGHTLY_PACKAGES + uninstall_torch_cmd = ["pip", "uninstall", "-y"] uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) if dryrun: @@ -168,11 +169,15 @@ def install_torch_deps(cuda_version: str): install_torch_deps(cuda_version=args.cudaver) if args.install_torch_build_deps: from .torch_utils import install_torch_build_deps + install_torch_deps(cuda_version=args.cudaver) install_torch_build_deps() if args.install_torch_nightly: install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ) if args.check_torch_nightly_version: from .torch_utils import check_torch_nightly_version - assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command." + + assert ( + not args.install_torch_nightly + ), "Error: Can't run install torch nightly and check version in the same command." check_torch_nightly_version(args.force_date) diff --git a/tools/rocm_utils.py b/tools/rocm_utils.py index 18d705ff..0394af1a 100644 --- a/tools/rocm_utils.py +++ b/tools/rocm_utils.py @@ -34,6 +34,7 @@ def install_torch_deps(): def install_pytorch_nightly(rocm_version: str, env, dryrun=False): from .torch_utils import TORCH_NIGHTLY_PACKAGES + uninstall_torch_cmd = ["pip", "uninstall", "-y"] uninstall_torch_cmd.extend(TORCH_NIGHTLY_PACKAGES) if dryrun: diff --git a/tools/torch_utils.py b/tools/torch_utils.py index b60fbab0..d6e96ca1 100644 --- a/tools/torch_utils.py +++ b/tools/torch_utils.py @@ -2,9 +2,9 @@ CUDA/ROCM independent pytorch installation helpers. """ -import subprocess import importlib import re +import subprocess from pathlib import Path from typing import Optional @@ -18,9 +18,11 @@ def is_hip() -> bool: import torch + version = torch.__version__ return "rocm" in version + def install_torch_build_deps(): # Pin cmake version to stable # See: https://github.com/pytorch/builder/pull/1269 @@ -52,6 +54,7 @@ def install_torch_build_deps(): cmd = ["conda", "install", "-y", "-c", "conda-forge"] + conda_deps subprocess.check_call(cmd) + def get_torch_nightly_version(pkg_name: str): pkg = importlib.import_module(pkg_name) version = pkg.__version__ @@ -75,4 +78,4 @@ def check_torch_nightly_version(force_date: Optional[str] = None): force_date_str = f"User force date {force_date}" if force_date else "" print( f"Installed consistent torch nightly packages: {pkg_versions}. {force_date_str}" - ) \ No newline at end of file + )