Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 20, 2024
1 parent 5600d5a commit b62f5f6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
6 changes: 5 additions & 1 deletion litgpt/data/lit_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def _dataloader(self, input_dir: str, train: bool):
from litdata.streaming import StreamingDataset, TokensLoader

dataset = StreamingDataset(
input_dir=input_dir, item_loader=TokensLoader(block_size=self.seq_length), shuffle=train, drop_last=True, seed=self.seed
input_dir=input_dir,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=train,
drop_last=True,
seed=self.seed,
)
dataloader = DataLoader(
dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
Expand Down
15 changes: 15 additions & 0 deletions tests/data/test_lit_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import sys
from unittest import mock
from unittest.mock import ANY

import pytest

Expand Down Expand Up @@ -34,3 +35,17 @@ def test_input_dir_and_splits(dl_mock, tmp_path):
dl_mock.assert_called_with(input_dir=str("s3://mydataset/data/train"), train=True)
data.val_dataloader()
dl_mock.assert_called_with(input_dir=str("s3://mydataset/data/val"), train=False)


@pytest.mark.skipif(sys.platform == "win32", reason="Needs to implement platform agnostic path/url joining")
@mock.patch("litdata.streaming.StreamingDataset")
def test_dataset_args(streaming_dataset_mock, tmp_path):
data = LitData(data_path=tmp_path, seed=1000)
data.train_dataloader()
streaming_dataset_mock.assert_called_with(
input_dir=str(tmp_path),
item_loader=ANY,
shuffle=True,
drop_last=True,
seed=1000,
)

0 comments on commit b62f5f6

Please sign in to comment.