From a8c87b57063db04e59f99017197d0bd7568b5c06 Mon Sep 17 00:00:00 2001 From: Tom van der Weide Date: Wed, 27 Nov 2024 03:05:00 -0800 Subject: [PATCH] Use epath in c4.py to fix problems with no passing around an epath when expected PiperOrigin-RevId: 700632162 --- .../core/download/download_manager.py | 2 +- .../core/download/downloader.py | 4 ++- tensorflow_datasets/text/c4.py | 28 ++++++++++--------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/tensorflow_datasets/core/download/download_manager.py b/tensorflow_datasets/core/download/download_manager.py index 82362ba068b..675787b04a8 100644 --- a/tensorflow_datasets/core/download/download_manager.py +++ b/tensorflow_datasets/core/download/download_manager.py @@ -53,7 +53,7 @@ ExtractPath = epath.PathLike | resource_lib.Resource -def get_downloader(*args: Any, **kwargs: Any): +def get_downloader(*args: Any, **kwargs: Any) -> downloader._Downloader: return downloader.get_downloader(*args, **kwargs) diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index 2e4652daec6..ff5fc60e5ed 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -15,6 +15,8 @@ """Async download API with checksum verification. No business logic.""" +from __future__ import annotations + from collections.abc import Iterable, Iterator import concurrent.futures import contextlib @@ -57,7 +59,7 @@ class DownloadResult: @utils.memoize() -def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader': +def get_downloader(*args: Any, **kwargs: Any) -> _Downloader: return _Downloader(*args, **kwargs) diff --git a/tensorflow_datasets/text/c4.py b/tensorflow_datasets/text/c4.py index d45187883d8..38d4639495b 100644 --- a/tensorflow_datasets/text/c4.py +++ b/tensorflow_datasets/text/c4.py @@ -24,7 +24,6 @@ from absl import logging from etils import epath -from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf import tensorflow_datasets.public_api as tfds from tensorflow_datasets.text import c4_utils import tree @@ -499,7 +498,9 @@ def _info(self): homepage="https://github.com/google-research/text-to-text-transfer-transformer#datasets", ) - def _split_generators(self, dl_manager, pipeline): + def _split_generators( + self, dl_manager: tfds.download.DownloadManager, pipeline + ): # We will automatically download the first default CC version, but others # need to be manually downloaded. cc_versions = set(self.builder_config.cc_versions) @@ -521,8 +522,8 @@ def _split_generators(self, dl_manager, pipeline): file_paths = dl_manager.download_and_extract(files_to_download) if self.builder_config.webtextlike: - owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP) - if not tf.io.gfile.exists(owt_path): + owt_path = dl_manager.manual_dir / _OPENWEBTEXT_URLS_ZIP + if not owt_path.exists(): raise AssertionError( "For the WebText-like config, you must manually download the " "following file from {0} and place it in {1}: {2}".format( @@ -606,24 +607,25 @@ def _get_pages_pcollection(self, pipeline, file_paths, dl_manager): def download_wet_file(path, dl_dir): url = f"{_DOWNLOAD_HOST}/{path}" - out_path = f"{dl_dir}/{path}" + out_path = epath.Path(dl_dir) / path - if tf.io.gfile.exists(out_path): + if out_path.exists(): c4_utils.get_counter_inc_fn("download_wet_url")("exists") return out_path - tmp_dir = f"{out_path}.incomplete{uuid.uuid4().hex}" + tmp_dir = epath.Path( + f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}" + ) try: - tf.io.gfile.makedirs(tmp_dir) + tmp_dir.mkdir(parents=True, exist_ok=True) downloader = tfds.download.download_manager.get_downloader() with downloader.tqdm(): # TODO(slebedev): Investigate why pytype infers Promise[Future[...]]. dl_path = downloader.download(url, tmp_dir).get().path # type: ignore - tf.io.gfile.rename(os.fspath(dl_path), out_path, overwrite=True) + dl_path = epath.Path(dl_path) + dl_path.rename(out_path) finally: - if tf.io.gfile.exists(tmp_dir): - tf.io.gfile.rmtree(tmp_dir) - + tmp_dir.rmtree(missing_ok=True) c4_utils.get_counter_inc_fn("download_wet_url")("downloaded") return out_path @@ -654,7 +656,7 @@ def download_wet_file(path, dl_dir): # Optionally filter for RealNews domains. # Output: [PageFeatures] if self.builder_config.realnewslike: - with tf.io.gfile.GFile(file_paths["realnews_domains"]) as f: + with epath.Path(file_paths["realnews_domains"]).open() as f: realnews_domains = json.load(f) pages |= beam.Filter(c4_utils.is_realnews_domain, realnews_domains)