diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 8a1256c5c48654..dc4dffc8b700c9 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bcad9dabe15021c53b6a88296e9d7a210044f108 +e28a256d71f3cf2bcc7b69d6bda73a9b855e385e diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 7ec1d6db408777..ccbccc3dc62631 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -2.1.0 +2.2.0 diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 2dcd6a01e6587b..693d6892ff5921 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -10,6 +10,9 @@ SCRIPT_DIR = Path(__file__).parent REPO_DIR = SCRIPT_DIR.parent.parent +# TODO: Remove me once Triton version is again in sync for vanilla and ROCm +ROCM_TRITION_VERSION = "2.1.0" + def read_triton_pin(rocm_hash: bool = False) -> str: triton_file = "triton.txt" if not rocm_hash else "triton-rocm.txt" @@ -29,25 +32,37 @@ def check_and_replace(inp: str, src: str, dst: str) -> str: return inp.replace(src, dst) -def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None: +def patch_setup_py( + path: Path, + *, + version: str, + name: str = "triton", + expected_version: Optional[str] = None, +) -> None: with open(path) as f: orig = f.read() # Replace name orig = check_and_replace(orig, 'name="triton",', f'name="{name}",') # Replace version + if not expected_version: + expected_version = read_triton_version() orig = check_and_replace( - orig, f'version="{read_triton_version()}",', f'version="{version}",' + orig, f'version="{expected_version}",', f'version="{version}",' ) with open(path, "w") as f: f.write(orig) -def patch_init_py(path: Path, *, version: str) -> None: +def patch_init_py( + path: Path, *, version: str, expected_version: Optional[str] = None +) -> None: + if not expected_version: + expected_version = read_triton_version() with open(path) as f: orig = f.read() # Replace version orig = check_and_replace( - orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"' + orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"' ) with open(path, "w") as f: f.write(orig) @@ -140,6 +155,7 @@ def build_triton( patch_init_py( triton_pythondir / "triton" / "__init__.py", version=f"{version}", + expected_version=ROCM_TRITION_VERSION if build_rocm else None, ) if build_rocm: @@ -148,6 +164,7 @@ def build_triton( triton_pythondir / "setup.py", name=triton_pkg_name, version=f"{version}", + expected_version=ROCM_TRITION_VERSION, ) check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True) print("ROCm libraries setup for triton installation...")