Skip to content

Commit

Permalink
io improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
guipenedo committed Oct 20, 2023
1 parent a70e8c5 commit 67ba828
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 14 deletions.
11 changes: 11 additions & 0 deletions src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def set_lock(self, lock):
def get_files_shard(self, rank: int, world_size: int, extension: str | list[str] = None) -> list[InputDataFile]:
return self.list_files(extension=extension)[rank::world_size]

def get_file(self, relative_path: str) -> InputDataFile | None:
if self.file_exists(relative_path):
return self.__get_file(relative_path)

def __get_file(self, relative_path: str) -> InputDataFile:
return InputDataFile(path=os.path.join(self.path, relative_path), relative_path=relative_path)

@abstractmethod
def file_exists(self, relative_path: str) -> bool:
return True

def _match_file(self, file_path, extension=None):
extensions = (
([self.extension] if isinstance(self.extension, str) else self.extension)
Expand Down
12 changes: 10 additions & 2 deletions src/datatrove/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def __post_init__(self):

def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]:
return [
FSSpecInputDataFile(path=path, relative_path=os.path.relpath(path, self.path), _fs=self._fs)
for path in self._fs.ls(self.path, detail=False)
self.__get_file(os.path.relpath(path, self.path))
for path in self._fs.ls(os.path.join(self.path, suffix), detail=False)
if self._match_file(path, extension)
]

def __get_file(self, relative_path: str):
return FSSpecInputDataFile(
path=os.path.join(self.path, relative_path), relative_path=relative_path, _fs=self._fs
)

def file_exists(self, relative_path: str) -> bool:
return self._fs.isfile(os.path.join(self.path, relative_path))
3 changes: 3 additions & 0 deletions src/datatrove/io/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def list_files(self, extension: str | list[str] = None, suffix: str = "") -> lis
if self._match_file(path, extension)
]

def file_exists(self, relative_path: str) -> bool:
return os.path.exists(os.path.join(self.path, relative_path))


def get_local_file_list(path: str, recursive: bool = True) -> list[str]:
filelist = []
Expand Down
28 changes: 20 additions & 8 deletions src/datatrove/io/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@

from datatrove.io import BaseOutputDataFolder, InputDataFile
from datatrove.io.base import BaseInputDataFolder, OutputDataFile
from datatrove.io.utils.s3 import s3_download_file, s3_get_file_list, s3_get_file_stream, s3_upload_file
from datatrove.io.utils.s3 import (
s3_download_file,
s3_file_exists,
s3_get_file_list,
s3_get_file_stream,
s3_upload_file,
)


@dataclass
Expand Down Expand Up @@ -86,18 +92,24 @@ def __post_init__(self):
raise ValueError("S3InputDataFolder path must start with s3://")
self._tmpdir = None

def __get_file(self, relative_path: str):
return S3InputDataFile(
path=os.path.join(self.path, relative_path),
local_path=os.path.join(self.local_path, relative_path),
relative_path=relative_path,
folder=self,
stream=self.stream,
)

def file_exists(self, relative_path: str) -> bool:
return s3_file_exists(os.path.join(self.path, relative_path))

def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]:
if not self.local_path:
self._tmpdir = tempfile.TemporaryDirectory()
self.local_path = self._tmpdir.name
return [
S3InputDataFile(
path=os.path.join(self.path, suffix, path),
local_path=os.path.join(self.local_path, suffix, path),
relative_path=path,
folder=self,
stream=self.stream,
)
self.__get_file(os.path.join(suffix, path))
for path in s3_get_file_list(
os.path.join(self.path, suffix), match_pattern=self.match_pattern, recursive=self.recursive
)
Expand Down
18 changes: 18 additions & 0 deletions src/datatrove/io/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ def _get_s3_path_components(s3_path):
return bucket_name, prefix


def s3_file_exists(cloud_path):
"""
Checks if a given file to path exists. Currently only check for s3. If path is a folder path return False
@params cloud_path:
@return: bool
"""
s3_client = boto3.client("s3")
bucket_name, prefix = _get_s3_path_components(cloud_path)

try:
s3_client.head_object(Bucket=bucket_name, Key=prefix)
return True
except ClientError as exc:
if exc.response["Error"]["Code"] != "404":
raise exc
return False


def s3_upload_file(local_path, cloud_path):
"""
@param local_path:
Expand Down
10 changes: 6 additions & 4 deletions src/datatrove/pipeline/dedup/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,17 @@ def set_up_dl_locks(self, dl_lock, up_lock):
self.data_folder.set_lock(dl_lock)

def __call__(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
remove_data = self.data_folder.get_files_shard(rank, world_size, extension=".remove")
assert len(remove_data) <= 1, f"Must have exactly one .remove file per task. Found {len(remove_data)} files."
remove_file = self.data_folder.get_file(f"{rank:06d}.remove")

clusters_data = self.data_folder.get_files_shard(rank, world_size, extension=".clusters")
assert (
not self.load_cluster_ids or len(clusters_data) <= 1
), f"Must have exactly one .clusters file per task. Found {len(remove_data)} files."
), f"Must have exactly one .clusters file per task. Found {len(clusters_data)} files."

with remove_data[0].open_binary() as f:
if not remove_file:
logger.warning(f"No .remove file for {rank=}.")
return
with remove_file.open_binary() as f:
with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as exc_writer:

def get_next():
Expand Down

0 comments on commit 67ba828

Please sign in to comment.