From 610560c55af6e9dbf404288f39d68a68dccf14e7 Mon Sep 17 00:00:00 2001 From: guipenedo Date: Fri, 29 Nov 2024 14:50:42 +0100 Subject: [PATCH] fixes for empty folders --- src/datatrove/io.py | 7 +++++-- src/datatrove/pipeline/readers/base.py | 7 ++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/datatrove/io.py b/src/datatrove/io.py index dd8dd329..b401a968 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -162,7 +162,7 @@ def list_files( ] ) - def get_shard(self, rank: int, world_size: int, **kwargs) -> list[str]: + def get_shard(self, rank: int, world_size: int, **kwargs) -> list[str] | None: """Fetch a shard (set of files) for a given rank, assuming there are a total of `world_size` shards. This should be deterministic to not have any overlap among different ranks. Will return files [rank, rank+world_size, rank+2*world_size, ...] @@ -175,7 +175,10 @@ def get_shard(self, rank: int, world_size: int, **kwargs) -> list[str]: Returns: a list of file paths """ - return self.list_files(**kwargs)[rank::world_size] + all_files = self.list_files(**kwargs) + if len(all_files) == 0: + return None + return all_files[rank::world_size] def resolve_paths(self, paths) -> list[str] | str: """ diff --git a/src/datatrove/pipeline/readers/base.py b/src/datatrove/pipeline/readers/base.py index 0531d6d6..66a6bfa9 100644 --- a/src/datatrove/pipeline/readers/base.py +++ b/src/datatrove/pipeline/readers/base.py @@ -222,11 +222,12 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 if not self.paths_file else list(get_shard_from_paths_file(self.paths_file, rank, world_size)) ) - if len(files_shard) == 0: - if rank == 0: - raise RuntimeError(f"No files found on {self.data_folder.path}!") + if files_shard is None: + raise RuntimeError(f"No files found on {self.data_folder.path}!") + elif len(files_shard) == 0: # otherwise just a warning logger.warning(f"No files found on {self.data_folder.path} for {rank=}") + if self.shuffle_files: random.shuffle(files_shard) for doc in self.read_files_shard(files_shard):