Skip to content

Commit

Permalink
fix num_workers in TinyStories
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Apr 2, 2024
1 parent 449eb29 commit 7b06c0a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions litgpt/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class TinyStories(DataModule):
which are the output of the preprocessing step."""
seed: int = 42
"""The seed to use for shuffling the dataset."""
num_workers: int = 8
"""The number of workers to use for the dataloaders."""
num_workers: Optional[int] = None,
"""The number of workers to use for the dataloaders.
Sets the number of workers equal to the number of avaialable CPUs by default."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
Expand All @@ -53,7 +54,8 @@ def prepare_data(self) -> None:
assert len(files) > 1, f"Expected at least two json files in {files}"
# train/test split. let's use only shard 0 for test split, rest train
val_file, *train_files = files
num_workers = os.cpu_count() - 1
if None:
num_workers = os.cpu_count() - 1

if not Path(self.data_path_train).is_dir():
optimize(
Expand Down

0 comments on commit 7b06c0a

Please sign in to comment.