diff --git a/litgpt/data/lit_data.py b/litgpt/data/lit_data.py index fefe323870..8347215fbd 100644 --- a/litgpt/data/lit_data.py +++ b/litgpt/data/lit_data.py @@ -52,7 +52,11 @@ def _dataloader(self, input_dir: str, train: bool): from litdata.streaming import StreamingDataset, TokensLoader dataset = StreamingDataset( - input_dir=input_dir, item_loader=TokensLoader(block_size=self.seq_length), shuffle=train, drop_last=True, seed=self.seed + input_dir=input_dir, + item_loader=TokensLoader(block_size=self.seq_length), + shuffle=train, + drop_last=True, + seed=self.seed, ) dataloader = DataLoader( dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True diff --git a/tests/data/test_lit_data.py b/tests/data/test_lit_data.py index e5c1ea1716..a2c221c119 100644 --- a/tests/data/test_lit_data.py +++ b/tests/data/test_lit_data.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import sys from unittest import mock +from unittest.mock import ANY import pytest @@ -34,3 +35,17 @@ def test_input_dir_and_splits(dl_mock, tmp_path): dl_mock.assert_called_with(input_dir=str("s3://mydataset/data/train"), train=True) data.val_dataloader() dl_mock.assert_called_with(input_dir=str("s3://mydataset/data/val"), train=False) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining") +@mock.patch("litdata.streaming.StreamingDataset") +def test_dataset_args(streaming_dataset_mock, tmp_path): + data = LitData(data_path=tmp_path, seed=1000) + data.train_dataloader() + streaming_dataset_mock.assert_called_with( + input_dir=str(tmp_path), + item_loader=ANY, + shuffle=True, + drop_last=True, + seed=1000, + )