Skip to content

Commit

Permalink
QoL tweaks (#32)
Browse files Browse the repository at this point in the history
* tweaks

* fix lint
  • Loading branch information
thomwolf authored Oct 31, 2023
1 parent 5c0b035 commit e21a372
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,5 @@ cython_debug/

# ruff code style
.ruff_cache/

playground/*
8 changes: 5 additions & 3 deletions src/datatrove/executor/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
mem_per_cpu_gb: int = 2,
workers: int = -1,
job_name: str = "data_processing",
env_command: str = None,
condaenv: str = None,
venv_path: str = None,
sbatch_args: dict | None = None,
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
self.logging_dir = logging_dir
self.time = time
self.job_name = job_name
self.env_command = env_command
self.condaenv = condaenv
self.venv_path = venv_path
self.depends = depends
Expand Down Expand Up @@ -158,9 +160,10 @@ def launch_file(self):
def get_launch_file(self, sbatch_args: dict, run_script: str):
args = "\n".join([f"#SBATCH --{k}={v}" for k, v in sbatch_args.items()])

env_command = (
env_command = self.env_command if self.env_command else (
f"""conda init bash
conda activate {self.condaenv}"""
conda activate {self.condaenv}
source ~/.bashrc"""
if self.condaenv
else (f"source {self.venv_path}" if self.venv_path else "")
)
Expand All @@ -172,7 +175,6 @@ def get_launch_file(self, sbatch_args: dict, run_script: str):
f"""
echo "Starting data processing job {self.job_name}"
{env_command}
source ~/.bashrc
set -xe
{run_script}
"""
Expand Down
9 changes: 9 additions & 0 deletions src/datatrove/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def open(self, binary=False, compression: Literal["gzip", "zst"] | None = None):

@dataclass
class BaseInputDataFolder(ABC):
"""An input data folder
Args:
path (str): path to the folder
extension (str | list[str], optional): file extensions to filter. Defaults to None.
recursive (bool, optional): whether to search recursively. Defaults to True.
match_pattern (str, optional): pattern to match file names. Defaults to None.
"""

path: str
extension: str | list[str] = None
recursive: bool = True
Expand Down
5 changes: 4 additions & 1 deletion src/datatrove/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class FSSpecInputDataFolder(BaseInputDataFolder):

def __post_init__(self):
super().__post_init__()
protocol, self.path = self.path.split("://")
if "://" in self.path:
protocol, self.path = self.path.split("://")
else:
protocol = "file"
self._fs = fsspec.filesystem(protocol, **(self.storage_options if self.storage_options else {}))

def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]:
Expand Down
4 changes: 3 additions & 1 deletion src/datatrove/pipeline/readers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ def __init__(
data_folder: BaseInputDataFolder,
compression: Literal["gzip", "zst"] | None = None,
adapter: Callable = None,
content_key: str = "content",
**kwargs,
):
super().__init__(data_folder, **kwargs)
self.compression = compression
self.content_key = content_key
self.adapter = adapter if adapter else lambda d, path, li: d
self.empty_warning = False

Expand All @@ -30,7 +32,7 @@ def read_file(self, datafile: InputDataFile):
with self.stats.time_manager:
try:
d = json.loads(line)
if not d.get("content", None):
if not d.get(self.content_key, None):
if not self.empty_warning:
self.empty_warning = True
logger.warning("Found document without content, skipping.")
Expand Down

0 comments on commit e21a372

Please sign in to comment.