Skip to content

Commit

Permalink
Support v6e (apple#879)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethanlm authored Dec 10, 2024
1 parent 0f03612 commit 73625c9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion axlearn/cloud/gcp/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def infer_tpu_workers(tpu_type: str) -> int:
tpu_version, tpu_cores = match.groups()
if tpu_version in {"v3", "v4", "v5p"}:
return int(tpu_cores) // 8
if tpu_version in {"v5litepod"}:
if tpu_version in {"v5litepod", "v6e"}:
return int(tpu_cores) // 4
except Exception as e: # pylint: disable=broad-except
logging.error("Failed to parse tpu_type %s: %s", tpu_type, e)
Expand Down
2 changes: 1 addition & 1 deletion axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,4 @@ def infer_xsc_compiler_options(
return options


_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p")
_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e")

0 comments on commit 73625c9

Please sign in to comment.