Skip to content

Commit

Permalink
Update Trition pin (pytorch#115743)
Browse files Browse the repository at this point in the history
To include a cherry-pick of triton-lang/triton#2771 that should fix  cuda-11.8 runtime issues

Also, tweak build wheel script to update both ROCm and vanilla Trition builds version to 2.2 (even though on trunk it should probably be 3.3 already)

TODO: Remove `ROCM_TRITION_VERSION` once both trunk and ROCM version are in sync again

Pull Request resolved: pytorch#115743
Approved by: https://github.com/davidberard98
  • Loading branch information
malfet authored and pytorchmergebot committed Dec 14, 2023
1 parent 87547a2 commit 28e37d4
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
bcad9dabe15021c53b6a88296e9d7a210044f108
e28a256d71f3cf2bcc7b69d6bda73a9b855e385e
2 changes: 1 addition & 1 deletion .ci/docker/triton_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.1.0
2.2.0
25 changes: 21 additions & 4 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
SCRIPT_DIR = Path(__file__).parent
REPO_DIR = SCRIPT_DIR.parent.parent

# TODO: Remove me once Triton version is again in sync for vanilla and ROCm
ROCM_TRITION_VERSION = "2.1.0"


def read_triton_pin(rocm_hash: bool = False) -> str:
triton_file = "triton.txt" if not rocm_hash else "triton-rocm.txt"
Expand All @@ -29,25 +32,37 @@ def check_and_replace(inp: str, src: str, dst: str) -> str:
return inp.replace(src, dst)


def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
def patch_setup_py(
path: Path,
*,
version: str,
name: str = "triton",
expected_version: Optional[str] = None,
) -> None:
with open(path) as f:
orig = f.read()
# Replace name
orig = check_and_replace(orig, 'name="triton",', f'name="{name}",')
# Replace version
if not expected_version:
expected_version = read_triton_version()
orig = check_and_replace(
orig, f'version="{read_triton_version()}",', f'version="{version}",'
orig, f'version="{expected_version}",', f'version="{version}",'
)
with open(path, "w") as f:
f.write(orig)


def patch_init_py(path: Path, *, version: str) -> None:
def patch_init_py(
path: Path, *, version: str, expected_version: Optional[str] = None
) -> None:
if not expected_version:
expected_version = read_triton_version()
with open(path) as f:
orig = f.read()
# Replace version
orig = check_and_replace(
orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"'
orig, f"__version__ = '{expected_version}'", f'__version__ = "{version}"'
)
with open(path, "w") as f:
f.write(orig)
Expand Down Expand Up @@ -140,6 +155,7 @@ def build_triton(
patch_init_py(
triton_pythondir / "triton" / "__init__.py",
version=f"{version}",
expected_version=ROCM_TRITION_VERSION if build_rocm else None,
)

if build_rocm:
Expand All @@ -148,6 +164,7 @@ def build_triton(
triton_pythondir / "setup.py",
name=triton_pkg_name,
version=f"{version}",
expected_version=ROCM_TRITION_VERSION,
)
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
print("ROCm libraries setup for triton installation...")
Expand Down

0 comments on commit 28e37d4

Please sign in to comment.