From 67ba828cc2dc72e221c3e24395b455a659eba9c6 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Fri, 20 Oct 2023 16:01:01 +0200 Subject: [PATCH] io improvements --- src/datatrove/io/base.py | 11 ++++++++++ src/datatrove/io/fsspec.py | 12 +++++++++-- src/datatrove/io/local.py | 3 +++ src/datatrove/io/s3.py | 28 ++++++++++++++++++------- src/datatrove/io/utils/s3.py | 18 ++++++++++++++++ src/datatrove/pipeline/dedup/minhash.py | 10 +++++---- 6 files changed, 68 insertions(+), 14 deletions(-) diff --git a/src/datatrove/io/base.py b/src/datatrove/io/base.py index 2d5a236e..f3cca243 100644 --- a/src/datatrove/io/base.py +++ b/src/datatrove/io/base.py @@ -99,6 +99,17 @@ def set_lock(self, lock): def get_files_shard(self, rank: int, world_size: int, extension: str | list[str] = None) -> list[InputDataFile]: return self.list_files(extension=extension)[rank::world_size] + def get_file(self, relative_path: str) -> InputDataFile | None: + if self.file_exists(relative_path): + return self.__get_file(relative_path) + + def __get_file(self, relative_path: str) -> InputDataFile: + return InputDataFile(path=os.path.join(self.path, relative_path), relative_path=relative_path) + + @abstractmethod + def file_exists(self, relative_path: str) -> bool: + return True + def _match_file(self, file_path, extension=None): extensions = ( ([self.extension] if isinstance(self.extension, str) else self.extension) diff --git a/src/datatrove/io/fsspec.py b/src/datatrove/io/fsspec.py index 378023ba..9f34eef5 100644 --- a/src/datatrove/io/fsspec.py +++ b/src/datatrove/io/fsspec.py @@ -55,7 +55,15 @@ def __post_init__(self): def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]: return [ - FSSpecInputDataFile(path=path, relative_path=os.path.relpath(path, self.path), _fs=self._fs) - for path in self._fs.ls(self.path, detail=False) + self.__get_file(os.path.relpath(path, self.path)) + for path in self._fs.ls(os.path.join(self.path, suffix), detail=False) if self._match_file(path, extension) ] + + def __get_file(self, relative_path: str): + return FSSpecInputDataFile( + path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs + ) + + def file_exists(self, relative_path: str) -> bool: + return self._fs.isfile(os.path.join(self.path, relative_path)) diff --git a/src/datatrove/io/local.py b/src/datatrove/io/local.py index 8a0f81fe..f242d64f 100644 --- a/src/datatrove/io/local.py +++ b/src/datatrove/io/local.py @@ -29,6 +29,9 @@ def list_files(self, extension: str | list[str] = None, suffix: str = "") -> lis if self._match_file(path, extension) ] + def file_exists(self, relative_path: str) -> bool: + return os.path.exists(os.path.join(self.path, relative_path)) + def get_local_file_list(path: str, recursive: bool = True) -> list[str]: filelist = [] diff --git a/src/datatrove/io/s3.py b/src/datatrove/io/s3.py index ec6482f1..745f45f1 100644 --- a/src/datatrove/io/s3.py +++ b/src/datatrove/io/s3.py @@ -7,7 +7,13 @@ from datatrove.io import BaseOutputDataFolder, InputDataFile from datatrove.io.base import BaseInputDataFolder, OutputDataFile -from datatrove.io.utils.s3 import s3_download_file, s3_get_file_list, s3_get_file_stream, s3_upload_file +from datatrove.io.utils.s3 import ( + s3_download_file, + s3_file_exists, + s3_get_file_list, + s3_get_file_stream, + s3_upload_file, +) @dataclass @@ -86,18 +92,24 @@ def __post_init__(self): raise ValueError("S3InputDataFolder path must start with s3://") self._tmpdir = None + def __get_file(self, relative_path: str): + return S3InputDataFile( + path=os.path.join(self.path, relative_path), + local_path=os.path.join(self.local_path, relative_path), + relative_path=relative_path, + folder=self, + stream=self.stream, + ) + + def file_exists(self, relative_path: str) -> bool: + return s3_file_exists(os.path.join(self.path, relative_path)) + def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]: if not self.local_path: self._tmpdir = tempfile.TemporaryDirectory() self.local_path = self._tmpdir.name return [ - S3InputDataFile( - path=os.path.join(self.path, suffix, path), - local_path=os.path.join(self.local_path, suffix, path), - relative_path=path, - folder=self, - stream=self.stream, - ) + self.__get_file(os.path.join(suffix, path)) for path in s3_get_file_list( os.path.join(self.path, suffix), match_pattern=self.match_pattern, recursive=self.recursive ) diff --git a/src/datatrove/io/utils/s3.py b/src/datatrove/io/utils/s3.py index 03cb183a..3294bd73 100644 --- a/src/datatrove/io/utils/s3.py +++ b/src/datatrove/io/utils/s3.py @@ -19,6 +19,24 @@ def _get_s3_path_components(s3_path): return bucket_name, prefix +def s3_file_exists(cloud_path): + """ + Checks if a given file to path exists. Currently only check for s3. If path is a folder path return False + @params cloud_path: + @return: bool + """ + s3_client = boto3.client("s3") + bucket_name, prefix = _get_s3_path_components(cloud_path) + + try: + s3_client.head_object(Bucket=bucket_name, Key=prefix) + return True + except ClientError as exc: + if exc.response["Error"]["Code"] != "404": + raise exc + return False + + def s3_upload_file(local_path, cloud_path): """ @param local_path: diff --git a/src/datatrove/pipeline/dedup/minhash.py b/src/datatrove/pipeline/dedup/minhash.py index c36abfa0..3823623a 100644 --- a/src/datatrove/pipeline/dedup/minhash.py +++ b/src/datatrove/pipeline/dedup/minhash.py @@ -283,15 +283,17 @@ def set_up_dl_locks(self, dl_lock, up_lock): self.data_folder.set_lock(dl_lock) def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): - remove_data = self.data_folder.get_files_shard(rank, world_size, extension=".remove") - assert len(remove_data) <= 1, f"Must have exactly one .remove file per task. Found {len(remove_data)} files." + remove_file = self.data_folder.get_file(f"{rank:06d}.remove") clusters_data = self.data_folder.get_files_shard(rank, world_size, extension=".clusters") assert ( not self.load_cluster_ids or len(clusters_data) <= 1 - ), f"Must have exactly one .clusters file per task. Found {len(remove_data)} files." + ), f"Must have exactly one .clusters file per task. Found {len(clusters_data)} files." - with remove_data[0].open_binary() as f: + if not remove_file: + logger.warning(f"No .remove file for {rank=}.") + return + with remove_file.open_binary() as f: with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as exc_writer: def get_next():