Skip to content

Commit

Permalink
Use epath in c4.py to fix problems with no passing around an epath wh…
Browse files Browse the repository at this point in the history
…en expected

PiperOrigin-RevId: 700632162
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Nov 27, 2024
1 parent 14f2854 commit a8c87b5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
ExtractPath = epath.PathLike | resource_lib.Resource


def get_downloader(*args: Any, **kwargs: Any):
def get_downloader(*args: Any, **kwargs: Any) -> downloader._Downloader:
return downloader.get_downloader(*args, **kwargs)


Expand Down
4 changes: 3 additions & 1 deletion tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Async download API with checksum verification. No business logic."""

from __future__ import annotations

from collections.abc import Iterable, Iterator
import concurrent.futures
import contextlib
Expand Down Expand Up @@ -57,7 +59,7 @@ class DownloadResult:


@utils.memoize()
def get_downloader(*args: Any, **kwargs: Any) -> '_Downloader':
def get_downloader(*args: Any, **kwargs: Any) -> _Downloader:
return _Downloader(*args, **kwargs)


Expand Down
28 changes: 15 additions & 13 deletions tensorflow_datasets/text/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from absl import logging
from etils import epath
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
import tensorflow_datasets.public_api as tfds
from tensorflow_datasets.text import c4_utils
import tree
Expand Down Expand Up @@ -499,7 +498,9 @@ def _info(self):
homepage="https://github.com/google-research/text-to-text-transfer-transformer#datasets",
)

def _split_generators(self, dl_manager, pipeline):
def _split_generators(
self, dl_manager: tfds.download.DownloadManager, pipeline
):
# We will automatically download the first default CC version, but others
# need to be manually downloaded.
cc_versions = set(self.builder_config.cc_versions)
Expand All @@ -521,8 +522,8 @@ def _split_generators(self, dl_manager, pipeline):
file_paths = dl_manager.download_and_extract(files_to_download)

if self.builder_config.webtextlike:
owt_path = os.path.join(dl_manager.manual_dir, _OPENWEBTEXT_URLS_ZIP)
if not tf.io.gfile.exists(owt_path):
owt_path = dl_manager.manual_dir / _OPENWEBTEXT_URLS_ZIP
if not owt_path.exists():
raise AssertionError(
"For the WebText-like config, you must manually download the "
"following file from {0} and place it in {1}: {2}".format(
Expand Down Expand Up @@ -606,24 +607,25 @@ def _get_pages_pcollection(self, pipeline, file_paths, dl_manager):

def download_wet_file(path, dl_dir):
url = f"{_DOWNLOAD_HOST}/{path}"
out_path = f"{dl_dir}/{path}"
out_path = epath.Path(dl_dir) / path

if tf.io.gfile.exists(out_path):
if out_path.exists():
c4_utils.get_counter_inc_fn("download_wet_url")("exists")
return out_path

tmp_dir = f"{out_path}.incomplete{uuid.uuid4().hex}"
tmp_dir = epath.Path(
f"{os.fspath(out_path)}.incomplete{uuid.uuid4().hex}"
)
try:
tf.io.gfile.makedirs(tmp_dir)
tmp_dir.mkdir(parents=True, exist_ok=True)
downloader = tfds.download.download_manager.get_downloader()
with downloader.tqdm():
# TODO(slebedev): Investigate why pytype infers Promise[Future[...]].
dl_path = downloader.download(url, tmp_dir).get().path # type: ignore
tf.io.gfile.rename(os.fspath(dl_path), out_path, overwrite=True)
dl_path = epath.Path(dl_path)
dl_path.rename(out_path)
finally:
if tf.io.gfile.exists(tmp_dir):
tf.io.gfile.rmtree(tmp_dir)

tmp_dir.rmtree(missing_ok=True)
c4_utils.get_counter_inc_fn("download_wet_url")("downloaded")
return out_path

Expand Down Expand Up @@ -654,7 +656,7 @@ def download_wet_file(path, dl_dir):
# Optionally filter for RealNews domains.
# Output: [PageFeatures]
if self.builder_config.realnewslike:
with tf.io.gfile.GFile(file_paths["realnews_domains"]) as f:
with epath.Path(file_paths["realnews_domains"]).open() as f:
realnews_domains = json.load(f)
pages |= beam.Filter(c4_utils.is_realnews_domain, realnews_domains)

Expand Down

0 comments on commit a8c87b5

Please sign in to comment.