diff --git a/.gitignore b/.gitignore index fc7af979..d8aebf8d 100644 --- a/.gitignore +++ b/.gitignore @@ -157,3 +157,5 @@ cython_debug/ # ruff code style .ruff_cache/ + +playground/* \ No newline at end of file diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 0443cec3..e4b48977 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -27,6 +27,7 @@ def __init__( mem_per_cpu_gb: int = 2, workers: int = -1, job_name: str = "data_processing", + env_command: str = None, condaenv: str = None, venv_path: str = None, sbatch_args: dict | None = None, @@ -54,6 +55,7 @@ def __init__( self.logging_dir = logging_dir self.time = time self.job_name = job_name + self.env_command = env_command self.condaenv = condaenv self.venv_path = venv_path self.depends = depends @@ -158,9 +160,10 @@ def launch_file(self): def get_launch_file(self, sbatch_args: dict, run_script: str): args = "\n".join([f"#SBATCH --{k}={v}" for k, v in sbatch_args.items()]) - env_command = ( + env_command = self.env_command if self.env_command else ( f"""conda init bash - conda activate {self.condaenv}""" + conda activate {self.condaenv} + source ~/.bashrc""" if self.condaenv else (f"source {self.venv_path}" if self.venv_path else "") ) @@ -172,7 +175,6 @@ def get_launch_file(self, sbatch_args: dict, run_script: str): f""" echo "Starting data processing job {self.job_name}" {env_command} - source ~/.bashrc set -xe {run_script} """ diff --git a/src/datatrove/io/base.py b/src/datatrove/io/base.py index da00fae4..608e371a 100644 --- a/src/datatrove/io/base.py +++ b/src/datatrove/io/base.py @@ -66,6 +66,15 @@ def open(self, binary=False, compression: Literal["gzip", "zst"] | None = None): @dataclass class BaseInputDataFolder(ABC): + """An input data folder + + Args: + path (str): path to the folder + extension (str | list[str], optional): file extensions to filter. Defaults to None. + recursive (bool, optional): whether to search recursively. Defaults to True. + match_pattern (str, optional): pattern to match file names. Defaults to None. + """ + path: str extension: str | list[str] = None recursive: bool = True diff --git a/src/datatrove/io/fsspec.py b/src/datatrove/io/fsspec.py index f7af5a39..7160da6d 100644 --- a/src/datatrove/io/fsspec.py +++ b/src/datatrove/io/fsspec.py @@ -50,7 +50,10 @@ class FSSpecInputDataFolder(BaseInputDataFolder): def __post_init__(self): super().__post_init__() - protocol, self.path = self.path.split("://") + if "://" in self.path: + protocol, self.path = self.path.split("://") + else: + protocol = "file" self._fs = fsspec.filesystem(protocol, **(self.storage_options if self.storage_options else {})) def list_files(self, extension: str | list[str] = None, suffix: str = "") -> list[InputDataFile]: diff --git a/src/datatrove/pipeline/readers/jsonl.py b/src/datatrove/pipeline/readers/jsonl.py index b66801ea..17bca9f2 100644 --- a/src/datatrove/pipeline/readers/jsonl.py +++ b/src/datatrove/pipeline/readers/jsonl.py @@ -17,10 +17,12 @@ def __init__( data_folder: BaseInputDataFolder, compression: Literal["gzip", "zst"] | None = None, adapter: Callable = None, + content_key: str = "content", **kwargs, ): super().__init__(data_folder, **kwargs) self.compression = compression + self.content_key = content_key self.adapter = adapter if adapter else lambda d, path, li: d self.empty_warning = False @@ -30,7 +32,7 @@ def read_file(self, datafile: InputDataFile): with self.stats.time_manager: try: d = json.loads(line) - if not d.get("content", None): + if not d.get(self.content_key, None): if not self.empty_warning: self.empty_warning = True logger.warning("Found document without content, skipping.")