Skip to content

Commit

Permalink
build small pip wheels for CUDA 11.8 (pytorch#114620)
Browse files Browse the repository at this point in the history
As discussed, we would like to start building all wheels using the CUDA PyPI dependencies.
Adding the "small wheel" workflow for CUDA 11.8 as it's already used for 12.1U1.

CC @malfet @atalman

Pull Request resolved: pytorch#114620
Approved by: https://github.com/atalman, https://github.com/malfet
  • Loading branch information
ptrblck authored and pytorchmergebot committed Nov 30, 2023
1 parent 2ab2e8e commit 386b9c2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 24 deletions.
66 changes: 42 additions & 24 deletions .github/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,34 @@

CPU_AARCH64_ARCH = ["cpu-aarch64"]

PYTORCH_EXTRA_INSTALL_REQUIREMENTS = (
"nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
"nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
)
PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
"11.8": (
"nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
"nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu11==8.7.0.84; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu11==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
"12.1": (
"nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
"nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.9.2.26; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.19.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
}


def get_nccl_submodule_version() -> str:
Expand Down Expand Up @@ -65,15 +80,17 @@ def get_nccl_submodule_version() -> str:
return f"{d['NCCL_MAJOR']}.{d['NCCL_MINOR']}.{d['NCCL_PATCH']}"


def get_nccl_wheel_version() -> str:
def get_nccl_wheel_version(arch_version: str) -> str:
import re

requrements = map(str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS))
return [x for x in requrements if x.startswith("nvidia-nccl-cu")][0].split("==")[1]
requirements = map(
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
)
return [x for x in requirements if x.startswith("nvidia-nccl-cu")][0].split("==")[1]


def validate_nccl_dep_consistency() -> None:
wheel_ver = get_nccl_wheel_version()
def validate_nccl_dep_consistency(arch_version: str) -> None:
wheel_ver = get_nccl_wheel_version(arch_version)
submodule_ver = get_nccl_submodule_version()
if wheel_ver != submodule_ver:
raise RuntimeError(
Expand Down Expand Up @@ -298,7 +315,7 @@ def generate_wheels_matrix(
)

# 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
if arch_version == "12.1" and os == "linux":
if arch_version in ["12.1", "11.8"] and os == "linux":
ret.append(
{
"python_version": python_version,
Expand All @@ -310,7 +327,7 @@ def generate_wheels_matrix(
"devtoolset": "",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS,
"pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version], # fmt: skip
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace( # noqa: B950
".", "_"
),
Expand All @@ -333,12 +350,13 @@ def generate_wheels_matrix(
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(
".", "_"
),
"pytorch_extra_install_requirements": PYTORCH_EXTRA_INSTALL_REQUIREMENTS
if os != "linux"
else "",
"pytorch_extra_install_requirements":
PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip
if os != "linux" else "",
}
)
return ret


validate_nccl_dep_consistency()
validate_nccl_dep_consistency("12.1")
validate_nccl_dep_consistency("11.8")

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 386b9c2

Please sign in to comment.