Skip to content

Commit

Permalink
feat: Implement hermetic Python version matching system Python version
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 637137789
  • Loading branch information
vam-google authored and copybara-github committed May 25, 2024
1 parent 7dac65a commit 1bb87a7
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 108 deletions.
4 changes: 3 additions & 1 deletion third_party/py/python_init_repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ load("//third_party/py:python_repo.bzl", "python_repository")
def python_init_repositories(
requirements = {},
local_wheel_workspaces = [],
local_wheel_dist_folder = None):
local_wheel_dist_folder = None,
default_python_version = None):
python_repository(
name = "python_version_repo",
requirements_versions = requirements.keys(),
requirements_locks = requirements.values(),
local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder,
default_python_version = default_python_version,
)
py_repositories()
148 changes: 95 additions & 53 deletions third_party/py/python_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,13 @@ To set wheel name, add "--repo_env=WHEEL_NAME=tensorflow_cpu"
"""

DEFAULT_VERSION = "3.11"
WARNING = """
HERMETIC_PYTHON_VERSION variable was not set correctly, using default version.
Python {} will be used.
To select Python version, either set HERMETIC_PYTHON_VERSION env variable in
your shell:
export HERMETIC_PYTHON_VERSION=3.12
OR pass it as an argument to bazel command directly or inside your .bazelrc
file:
--repo_env=HERMETIC_PYTHON_VERSION=3.12
""".format(DEFAULT_VERSION)

content = """TF_PYTHON_VERSION = "{version}"
HERMETIC_PYTHON_VERSION = "{version}"
WHEEL_NAME = "{wheel_name}"
WHEEL_COLLAB = "{wheel_collab}"
REQUIREMENTS = "{requirements}"
REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}"
"""

def _python_repository_impl(ctx):
ctx.file("BUILD", "")
version_legacy = ctx.os.environ.get("TF_PYTHON_VERSION", "")
version = ctx.os.environ.get("HERMETIC_PYTHON_VERSION", "")
if not version:
version = version_legacy
else:
version_legacy = version
version = _get_python_version(ctx)

ctx.file("BUILD", "")
wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow")
wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False)
if not version:
print(WARNING) # buildifier: disable=print
version = DEFAULT_VERSION
else:
print("Using hermetic Python %s" % version) # buildifier: disable=print

requirements = None
for i in range(0, len(ctx.attr.requirements_locks)):
Expand All @@ -62,11 +34,14 @@ Please check python_init_repositories() in your WORKSPACE file.
))

requirements_with_local_wheels = str(requirements)
if ctx.attr.local_wheel_workspaces:

local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "")
if ctx.attr.local_wheel_workspaces or local_wheels_dir:
local_wheel_requirements = _get_injected_local_wheels(
ctx,
version,
ctx.attr.local_wheel_workspaces,
local_wheels_dir,
)
requirements_content = [ctx.read(requirements)] + local_wheel_requirements
merged_requirements_content = "\n".join(requirements_content)
Expand All @@ -82,7 +57,14 @@ Please check python_init_repositories() in your WORKSPACE file.

ctx.file(
"py_version.bzl",
content.format(
"""
TF_PYTHON_VERSION = "{version}"
HERMETIC_PYTHON_VERSION = "{version}"
WHEEL_NAME = "{wheel_name}"
WHEEL_COLLAB = "{wheel_collab}"
REQUIREMENTS = "{requirements}"
REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}"
""".format(
version = version,
wheel_name = wheel_name,
wheel_collab = wheel_collab,
Expand All @@ -91,32 +73,70 @@ Please check python_init_repositories() in your WORKSPACE file.
),
)

def _get_injected_local_wheels(ctx, py_version, local_wheel_workspaces):
local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "")
wheels = {}
def _get_python_version(ctx):
print_warning = False

for local_wheel_workspace in local_wheel_workspaces:
local_wheel_workspace_path = ctx.path(local_wheel_workspace)
dist_folder = ctx.attr.local_wheel_dist_folder
dist_wheels = local_wheel_workspace_path.dirname.get_child(dist_folder).readdir()
version = ctx.os.environ.get("HERMETIC_PYTHON_VERSION", "")
if not version:
version = ctx.os.environ.get("TF_PYTHON_VERSION", "")
if not version:
print_warning = True
if ctx.attr.default_python_version == "system":
python_version_result = ctx.execute(["python3", "--version"])
if python_version_result.return_code == 0:
version = python_version_result.stdout
else:
fail("""
Cannot match hermetic Python version to system Python version.
System Python was not found.""")
else:
version = ctx.attr.default_python_version

for wheel in dist_wheels:
bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0:
continue
version = _parse_python_version(version)

name_components = bn.split("-")
package_name = name_components[0]
for name_component in name_components[1:]:
if name_component[0].isdigit():
break
package_name += "-" + name_component
if print_warning:
print("""
HERMETIC_PYTHON_VERSION variable was not set correctly, using default version.
Python {} will be used.
To select Python version, either set HERMETIC_PYTHON_VERSION env variable in
your shell:
export HERMETIC_PYTHON_VERSION=3.12
OR pass it as an argument to bazel command directly or inside your .bazelrc
file:
--repo_env=HERMETIC_PYTHON_VERSION=3.12
""".format(version)) # buildifier: disable=print

print("Using hermetic Python %s" % version) # buildifier: disable=print
return version

latest_wheel = wheels.get(package_name, None)
def _parse_python_version(version_str):
if version_str.startswith("Python "):
py_ver_chunks = version_str[7:].split(".")
return "%s.%s" % (py_ver_chunks[0], py_ver_chunks[1])
return version_str

def _get_injected_local_wheels(
ctx,
py_version,
local_wheel_workspaces,
local_wheels_dir):
local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "")
wheels = {}

if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel
if local_wheel_workspaces:
for local_wheel_workspace in local_wheel_workspaces:
local_wheel_workspace_path = ctx.path(local_wheel_workspace)
dist_folder = ctx.attr.local_wheel_dist_folder
dist_folder_path = local_wheel_workspace_path.dirname.get_child(dist_folder)
if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker)
if local_wheels_dir:
dist_folder_path = ctx.path(local_wheels_dir)
if dist_folder_path.exists:
dist_wheels = dist_folder_path.readdir()
_process_dist_wheels(dist_wheels, wheels, py_ver_marker)

for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append(
Expand Down Expand Up @@ -147,6 +167,10 @@ python_repository = repository_rule(
mandatory = False,
default = "dist",
),
"default_python_version": attr.string(
mandatory = False,
default = DEFAULT_VERSION,
),
},
environ = [
"TF_PYTHON_VERSION",
Expand All @@ -156,6 +180,24 @@ python_repository = repository_rule(
],
)

def _process_dist_wheels(dist_wheels, wheels, py_ver_marker):
for wheel in dist_wheels:
bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0:
continue

name_components = bn.split("-")
package_name = name_components[0]
for name_component in name_components[1:]:
if name_component[0].isdigit():
break
package_name += "-" + name_component

latest_wheel = wheels.get(package_name, None)

if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel

def _custom_python_interpreter_impl(ctx):
version = ctx.attr.version
strip_prefix = ctx.attr.strip_prefix.format(version = version)
Expand Down
4 changes: 3 additions & 1 deletion third_party/tsl/third_party/py/python_init_repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ load("//third_party/py:python_repo.bzl", "python_repository")
def python_init_repositories(
requirements = {},
local_wheel_workspaces = [],
local_wheel_dist_folder = None):
local_wheel_dist_folder = None,
default_python_version = None):
python_repository(
name = "python_version_repo",
requirements_versions = requirements.keys(),
requirements_locks = requirements.values(),
local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder,
default_python_version = default_python_version,
)
py_repositories()
Loading

0 comments on commit 1bb87a7

Please sign in to comment.