diff --git a/backend/find_tensorflow.py b/backend/find_tensorflow.py index aa75d5ecb4..6d7ce5087d 100644 --- a/backend/find_tensorflow.py +++ b/backend/find_tensorflow.py @@ -114,9 +114,9 @@ def get_tf_requirement(tf_version: str = "") -> dict: extra_requires = [] extra_select = {} - if not (tf_version == "" or tf_version in SpecifierSet(">=2.12")): + if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)): extra_requires.append("protobuf<3.20") - if tf_version == "" or tf_version in SpecifierSet(">=1.15"): + if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True): extra_select["mpi"] = [ "horovod", "mpi4py", @@ -138,9 +138,9 @@ def get_tf_requirement(tf_version: str = "") -> dict: ], **extra_select, } - elif tf_version in SpecifierSet("<1.15") or tf_version in SpecifierSet( - ">=2.0,<2.1" - ): + elif tf_version in SpecifierSet( + "<1.15", prereleases=True + ) or tf_version in SpecifierSet(">=2.0,<2.1", prereleases=True): return { "cpu": [ f"tensorflow=={tf_version}",