From 28e37d4f3bc40430192b8e4e11c991ffb3784eb9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 14 Dec 2023 18:54:24 +0000 Subject: [PATCH] Update Trition pin (#115743) To include a cherry-pick of https://github.com/openai/triton/pull/2771 that should fix cuda-11.8 runtime issues Also, tweak build wheel script to update both ROCm and vanilla Trition builds version to 2.2 (even though on trunk it should probably be 3.3 already) TODO: Remove `ROCM_TRITION_VERSION` once both trunk and ROCM version are in sync again Pull Request resolved: https://github.com/pytorch/pytorch/pull/115743 Approved by: https://github.com/davidberard98 --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- .github/scripts/build_triton_wheel.py | 25 +++++++++++++++++++++---- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 8a1256c5c4865..dc4dffc8b700c 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bcad9dabe15021c53b6a88296e9d7a210044f108 +e28a256d71f3cf2bcc7b69d6bda73a9b855e385e diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 7ec1d6db40877..ccbccc3dc6263 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -2.1.0 +2.2.0 diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 2dcd6a01e6587..693d6892ff592 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -10,6 +10,9 @@ SCRIPT_DIR = Path(__file__).parent REPO_DIR = SCRIPT_DIR.parent.parent +# TODO: Remove me once Triton version is again in sync for vanilla and ROCm +ROCM_TRITION_VERSION = "2.1.0" + def read_triton_pin(rocm_hash: bool = False) -> str: triton_file = "triton.txt" if not rocm_hash else "triton-rocm.txt" @@ -29,25 +32,37 @@ def check_and_replace(inp: str, src: str, dst: str) -> str: return inp.replace(src, dst) -def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None: +def patch_setup_py( + path: Path, + *, + version: str, + name: str = "triton", + expected_version: Optional[str] = None, +) -> None: with open(path) as f: orig = f.read() # Replace name orig = check_and_replace(orig, 'name="triton",', f'name="{name}",') # Replace version + if not expected_version: + expected_version = read_triton_version() orig = check_and_replace( - orig, f'version="{read_triton_version()}",', f'version="{version}",' + orig, f'version="{expected_version}",', f'version="{version}",' ) with open(path, "w") as f: f.write(orig) -def patch_init_py(path: Path, *, version: str) -> None: +def patch_init_py( + path: Path, *, version: str, expected_version: Optional[str] = None +) -> None: + if not expected_version: + expected_version = read_triton_version() with open(path) as f: orig = f.read() # Replace version orig = check_and_replace( - orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"' + orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"' ) with open(path, "w") as f: f.write(orig) @@ -140,6 +155,7 @@ def build_triton( patch_init_py( triton_pythondir / "triton" / "__init__.py", version=f"{version}", + expected_version=ROCM_TRITION_VERSION if build_rocm else None, ) if build_rocm: @@ -148,6 +164,7 @@ def build_triton( triton_pythondir / "setup.py", name=triton_pkg_name, version=f"{version}", + expected_version=ROCM_TRITION_VERSION, ) check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True) print("ROCm libraries setup for triton installation...")