From 5214d1b56edac4b171e48edf8833cbf3ab038d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 24 Mar 2024 21:36:12 +0100 Subject: [PATCH 1/5] litdata backed tinystories --- litgpt/data/tinystories.py | 219 ++++++++++++++------------------- tests/data/test_tinystories.py | 81 ++++++------ 2 files changed, 128 insertions(+), 172 deletions(-) diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 40ab0a40ff..d3641a94c5 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -1,23 +1,20 @@ -"""https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py""" - +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import glob import json import os -import random -from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Optional -import numpy as np -import torch -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import DataLoader from tqdm import tqdm +from litgpt import Tokenizer +from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing -from litgpt.data.base import DataModule -from litgpt.tokenizer import Tokenizer + +_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" @dataclass @@ -27,155 +24,119 @@ class TinyStories(DataModule): Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ - path: Path = Path("data/") - """Path to the data directory where data will be downloaded and preprocessed""" - num_workers: int = 0 - """How many DataLoader processes to use for loading.""" + data_path: Path = Path("data/tinystories") + """The path to the data directory, containing two folders 'train' and 'val' + which are the output of the preprocessing step. The path can also be a remote path (e.g., s3://).""" seed: int = 42 - """The random seed for creating the train/val splits and shuffling the dataset.""" + """The seed to use for shuffling the dataset.""" + num_workers: int = 8 + """The number of workers to use for the dataloaders.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) - train_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False) - test_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False) - def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + def __post_init__(self) -> None: + # Could be a remote path (s3://) or a local path + self.data_path_train = str(self.data_path).rstrip("/") + "/train" + self.data_path_val = str(self.data_path).rstrip("/") + "/val" + + def connect( + self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1 + ) -> None: self.tokenizer = tokenizer self.batch_size = batch_size - self.max_seq_length = max_seq_length + self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: - download(self.path) - assert self.tokenizer is not None - pretokenize(self.path, self.tokenizer) - - def setup(self, stage: str = "") -> None: - # the .bin files are right along the .json files - bin_dir = self.path / "TinyStories_all_data" - shard_filenames = sorted(glob.glob(str(bin_dir / "*.bin"))) - assert len(shard_filenames) > 0, f"No bin files found in {bin_dir}" - assert len(shard_filenames) > 1, f"Expected at least two bins in {bin_dir}" + from litdata import optimize + + download(self.data_path) + + files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json"))) + assert len(files) > 0, f"No json files found in {files}" + assert len(files) > 1, f"Expected at least two json files in {files}" # train/test split. let's use only shard 0 for test split, rest train - va_files, *train_files = shard_filenames - # shuffle the training files - random.Random(self.seed).shuffle(train_files) - self.train_dataset = ConcatDataset([PretokDataset(f, self.max_seq_length) for f in train_files]) - self.val_dataset = PretokDataset(shard_filenames[0], self.max_seq_length) + val_files, *train_files = files + num_workers = os.cpu_count() - 1 + + if not Path(self.data_path_train).is_dir(): + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=train_files, + output_dir=self.data_path_train, + num_workers=num_workers, + chunk_bytes="200MB", + ) + if not Path(self.data_path_val).is_dir(): + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=val_files, + output_dir=self.data_path_val, + num_workers=num_workers, + chunk_bytes="200MB", + ) def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True, num_workers=self.num_workers + from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader + + train_dataset = StreamingDataset( + input_dir=self.data_path_train, + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + drop_last=True, ) + train_dataloader = StreamingDataLoader( + train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return train_dataloader def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - pin_memory=True, - shuffle=True, # llama2.c shuffles validation too - num_workers=self.num_workers, + from litdata.streaming import StreamingDataset, TokensLoader + + val_dataset = StreamingDataset( + input_dir=self.data_path_val, + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + # Consider setting to False, but we would lose some samples due to truncation when world size > 1 + drop_last=True, + ) + val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) + return val_dataloader -_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" +def tokenize(filename: str, tokenizer: Tokenizer): + with open(filename, "r") as f: + data = json.load(f) + global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"]) + num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"]) + local_rank = global_rank % num_workers + for example in tqdm(data, position=local_rank): + text = example["story"] + text = text.strip() # get rid of leading/trailing whitespace + tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS + yield tokens def download(data_dir: Path): data_dir.mkdir(exist_ok=True) + data_dir = data_dir / "TinyStories_all_data" + shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) + if shard_filenames: + print(f"{data_dir} already exists, skipping unpacking...") + return + # download the TinyStories dataset, unless it's already downloaded data_filename = data_dir / "TinyStories_all_data.tar.gz" download_if_missing(data_filename, _URL, stream=True, mode="wb") print("Download done.") # unpack the tar.gz file into all the data shards (json files) - data_dir = data_dir / "TinyStories_all_data" + data_dir.mkdir(exist_ok=True) + print(f"Unpacking {data_filename}...") + os.system(f"tar -xzf {data_filename} -C {data_dir}") shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) - if shard_filenames: - print(f"{data_dir} already exists, skipping unpacking...") - else: - data_dir.mkdir(exist_ok=True) - print(f"Unpacking {data_filename}...") - os.system(f"tar -xzf {data_filename} -C {data_dir}") - shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) - print(f"Number of shards: {len(shard_filenames)}") - # print a single example just for debugging and such - # with open(shard_filenames[0], "r") as f: - # data = json.load(f) - # print(f"Example story:\n{data[0]}") - - -def process_shard(args, tokenizer): - shard_id, shard = args - with open(shard, "r") as f: - data = json.load(f) - all_tokens = [] - for example in tqdm(data, position=shard_id): - text = example["story"] - text = text.strip() # get rid of leading/trailing whitespace - tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS - all_tokens.extend(tokens) - # convert to uint16 nparray - all_tokens = np.array(all_tokens, dtype=np.uint16) - # just save the tokenized file in the same dir - tokenized_filename = shard.replace(".json", ".bin") - # write the bytes - with open(tokenized_filename, "wb") as f: - f.write(all_tokens.tobytes()) - # calculate the average sequence length (they are separated by BOS=1) - bos_id = tokenizer.bos_id - assert bos_id >= 0 # uint16 is unsigned - bos_tokens = (all_tokens == tokenizer.bos_id).sum() - assert bos_tokens > 0 - avg_seq_len = all_tokens.size / bos_tokens - print( - f"Saved {tokenized_filename}, tokens: {all_tokens.size}, bos: {bos_tokens}, average seqlen: {avg_seq_len:.2f}" - ) - - -def pretokenize(data_dir: Path, tokenizer: Tokenizer): - bins_path = str(data_dir / "TinyStories_all_data" / "*.bin") - shard_filenames = sorted(glob.glob(bins_path)) - if shard_filenames: - print("Already pretokenized.") - return - # iterate the shards and tokenize all of them one by one - jsons_path = str(data_dir / "TinyStories_all_data" / "*.json") - shard_filenames = sorted(glob.glob(jsons_path)) - if not shard_filenames: - raise ValueError(f"No json files found in {jsons_path!r}. Did you run `python tinystories.py download`?") - # process all the shards in a process pool - fun = partial(process_shard, tokenizer=tokenizer) - with ProcessPoolExecutor() as executor: - executor.map(fun, enumerate(shard_filenames)) - print("Done.") - - -class PretokDataset(torch.utils.data.Dataset): - """Loads a pre-tokenized array from disk and returns chunks of `max_seq_length` length.""" - - def __init__(self, filepath: str, max_seq_len: int): - super().__init__() - self.filepath = filepath - # open the dataset for reading but keep it on disk with memmap - self.shard = np.memmap(filepath, dtype=np.uint16, mode="r") - self.shard_length = len(self.shard) - self.length = self.shard_length // max_seq_len - assert max_seq_len > 1 - self.max_seq_len = max_seq_len - - def __len__(self) -> int: - return self.length - - def __getitem__(self, ix: int) -> torch.Tensor: - if ix < 0: - raise NotImplementedError - start = ix * self.max_seq_len - end = start + self.max_seq_len + 1 - if end > self.shard_length: - raise IndexError - # calling .astype will copy the data into a new numpy array, now in RAM - chunk = torch.from_numpy((self.shard[start:end]).astype(np.int64)) - return chunk diff --git a/tests/data/test_tinystories.py b/tests/data/test_tinystories.py index bce34a797e..e78dc1c050 100644 --- a/tests/data/test_tinystories.py +++ b/tests/data/test_tinystories.py @@ -1,45 +1,51 @@ import json -from contextlib import redirect_stdout -from io import StringIO -import numpy as np import pytest import torch +from litdata import optimize +from litdata.streaming import StreamingDataset, TokensLoader from torch.utils._pytree import tree_map -from torch.utils.data import ConcatDataset -def fake_bin(tmp_path, data, name): - all_tokens = np.array(data, dtype=np.uint16) - data_path = tmp_path / f"{name}.bin" - with open(data_path, "wb") as f: - f.write(all_tokens.tobytes()) - return data_path +def fake_chunk(path, data): + def fn(_): + for story in data: + yield torch.tensor(story) + optimize( + fn=fn, + inputs=[None] * len(data), + output_dir=str(path), + num_workers=1, + chunk_bytes="200MB", + ) @pytest.mark.parametrize( ("max_seq_len", "expected"), [ - (2, [[0, 23, 15], [15, 63, 0], [0, 73, 5], [5, 0, 1], [1, 1999, 0]]), - (5, [[0, 23, 15, 63, 0, 73], [73, 5, 0, 1, 1999, 0]]), + (2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]), + (5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]), (6, [[0, 23, 15, 63, 0, 73, 5]]), (7, [[0, 23, 15, 63, 0, 73, 5, 0]]), ], ) def test_pretok_dataset(tmp_path, max_seq_len, expected): - from litgpt.data.tinystories import PretokDataset - fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13] assert len(fake_data) == 12 - bin_path = fake_bin(tmp_path, fake_data, "data") - - dataset = PretokDataset(str(bin_path), max_seq_len) + fake_chunk(tmp_path, [fake_data]) + + dataset = StreamingDataset( + input_dir=str(tmp_path), + item_loader=TokensLoader(block_size=max_seq_len + 1), + shuffle=False, + drop_last=False, + ) actual = tree_map(torch.Tensor.tolist, list(dataset)) assert actual == expected -def test_process_shard(tmp_path): - from litgpt.data.tinystories import process_shard +def test_tokenize(tmp_path, monkeypatch): + from litgpt.data.tinystories import tokenize story1, story2 = "foo bar", " fun " data = [{"story": story1}, {"story": story2}] @@ -55,40 +61,29 @@ def encode(self, text, bos, eos): assert not eos return [self.bos_id] + [ord(c) for c in text] - out = StringIO() - with redirect_stdout(out): - process_shard((0, str(shard_path)), Tokenizer()) + monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", "0") + monkeypatch.setenv("DATA_OPTIMIZER_NUM_WORKERS", "1") + data = tokenize(str(shard_path), Tokenizer()) + assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]] - text = out.getvalue() - assert text.endswith("data.bin, tokens: 12, bos: 2, average seqlen: 6.00\n") - assert shard_path.with_suffix(".bin").exists() def test_tinystories_datamodule(tmp_path): - from litgpt.data.tinystories import PretokDataset, TinyStories + from litgpt.data.tinystories import TinyStories + data_dir = tmp_path / "tinystories" - datamodule = TinyStories(tmp_path, seed=42) + datamodule = TinyStories(data_dir, seed=42) datamodule.connect(max_seq_length=2) - data_dir = tmp_path / "TinyStories_all_data" - data_dir.mkdir() - fake_bin(data_dir, [12], "0") - fake_bin(data_dir, [0, 23, 15, 63, 0], "1") - fake_bin(data_dir, [73, 5, 0, 1, 1999, 0, 13], "2") + # simulate `datamodule.prepare_data` + train_data_dir = data_dir / "train" + train_data_dir.mkdir(parents=True) + fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) datamodule.setup() - assert isinstance(datamodule.train_dataset, ConcatDataset) - assert len(datamodule.train_dataset.datasets) == 2 - assert isinstance(datamodule.train_dataset.datasets[0], PretokDataset) - # unordered because it shuffled - assert datamodule.train_dataset.datasets[0].filepath == str(data_dir / "2.bin") - assert datamodule.train_dataset.datasets[1].filepath == str(data_dir / "1.bin") - - assert isinstance(datamodule.val_dataset, PretokDataset) - assert datamodule.val_dataset.filepath == str(data_dir / "0.bin") - tr_dataloader = datamodule.train_dataloader() torch.manual_seed(0) actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) - assert actual == [[[0, 1, 1999]], [[15, 63, 0]], [[1999, 0, 13]], [[0, 23, 15]], [[73, 5, 0]]] + # there is 1 sample per index in the data (13) + assert actual == [[[1999, 0, 13]], [[0, 13, 12]], [[1, 1999, 0]], [[63, 0, 73]], [[5, 0, 1]], [[0, 73, 5]], [[0, 23, 15]], [[0, 1, 1999]], [[15, 63, 0]], [[73, 5, 0]], [[12, 0, 23]], [[23, 15, 63]], [[13, 12, 0]]] From dc4fcc2d76e16dd66980d7ccde5879a8de62cb38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 24 Mar 2024 21:40:56 +0100 Subject: [PATCH 2/5] Fmt --- litgpt/data/tinystories.py | 4 +--- tests/data/test_tinystories.py | 32 +++++++++++++++++++------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index d3641a94c5..8527f55992 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -41,9 +41,7 @@ def __post_init__(self) -> None: self.data_path_train = str(self.data_path).rstrip("/") + "/train" self.data_path_val = str(self.data_path).rstrip("/") + "/val" - def connect( - self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1 - ) -> None: + def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: self.tokenizer = tokenizer self.batch_size = batch_size self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well diff --git a/tests/data/test_tinystories.py b/tests/data/test_tinystories.py index e78dc1c050..3ddf8ccb18 100644 --- a/tests/data/test_tinystories.py +++ b/tests/data/test_tinystories.py @@ -11,13 +11,8 @@ def fake_chunk(path, data): def fn(_): for story in data: yield torch.tensor(story) - optimize( - fn=fn, - inputs=[None] * len(data), - output_dir=str(path), - num_workers=1, - chunk_bytes="200MB", - ) + + optimize(fn=fn, inputs=[None] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB") @pytest.mark.parametrize( @@ -35,10 +30,7 @@ def test_pretok_dataset(tmp_path, max_seq_len, expected): fake_chunk(tmp_path, [fake_data]) dataset = StreamingDataset( - input_dir=str(tmp_path), - item_loader=TokensLoader(block_size=max_seq_len + 1), - shuffle=False, - drop_last=False, + input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False ) actual = tree_map(torch.Tensor.tolist, list(dataset)) assert actual == expected @@ -67,9 +59,9 @@ def encode(self, text, bos, eos): assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]] - def test_tinystories_datamodule(tmp_path): from litgpt.data.tinystories import TinyStories + data_dir = tmp_path / "tinystories" datamodule = TinyStories(data_dir, seed=42) @@ -86,4 +78,18 @@ def test_tinystories_datamodule(tmp_path): torch.manual_seed(0) actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) # there is 1 sample per index in the data (13) - assert actual == [[[1999, 0, 13]], [[0, 13, 12]], [[1, 1999, 0]], [[63, 0, 73]], [[5, 0, 1]], [[0, 73, 5]], [[0, 23, 15]], [[0, 1, 1999]], [[15, 63, 0]], [[73, 5, 0]], [[12, 0, 23]], [[23, 15, 63]], [[13, 12, 0]]] + assert actual == [ + [[1999, 0, 13]], + [[0, 13, 12]], + [[1, 1999, 0]], + [[63, 0, 73]], + [[5, 0, 1]], + [[0, 73, 5]], + [[0, 23, 15]], + [[0, 1, 1999]], + [[15, 63, 0]], + [[73, 5, 0]], + [[12, 0, 23]], + [[23, 15, 63]], + [[13, 12, 0]], + ] From be629a6d811f7ffbcf787f2cbc4f4a51a9ea8701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 24 Mar 2024 21:57:48 +0100 Subject: [PATCH 3/5] Bad merge --- tests/data/test_tinystories.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/data/test_tinystories.py b/tests/data/test_tinystories.py index 9db8eb2a34..3ddf8ccb18 100644 --- a/tests/data/test_tinystories.py +++ b/tests/data/test_tinystories.py @@ -6,8 +6,6 @@ from litdata.streaming import StreamingDataset, TokensLoader from torch.utils._pytree import tree_map -from litgpt.data.tinystories import PretokDataset, TinyStories, process_shard - def fake_chunk(path, data): def fn(_): @@ -38,7 +36,7 @@ def test_pretok_dataset(tmp_path, max_seq_len, expected): assert actual == expected -def test_tokenize(tmp_path): +def test_tokenize(tmp_path, monkeypatch): from litgpt.data.tinystories import tokenize story1, story2 = "foo bar", " fun " From 85a667991a8d575be0c47b755674df62e96307de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 25 Mar 2024 16:20:16 +0100 Subject: [PATCH 4/5] Begone s3 --- litgpt/data/tinystories.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 8527f55992..b494f3e9ef 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -14,8 +14,6 @@ from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing -_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" - @dataclass class TinyStories(DataModule): @@ -26,7 +24,7 @@ class TinyStories(DataModule): data_path: Path = Path("data/tinystories") """The path to the data directory, containing two folders 'train' and 'val' - which are the output of the preprocessing step. The path can also be a remote path (e.g., s3://).""" + which are the output of the preprocessing step.""" seed: int = 42 """The seed to use for shuffling the dataset.""" num_workers: int = 8 @@ -37,9 +35,8 @@ class TinyStories(DataModule): max_seq_length: int = field(default=-1, init=False, repr=False) def __post_init__(self) -> None: - # Could be a remote path (s3://) or a local path - self.data_path_train = str(self.data_path).rstrip("/") + "/train" - self.data_path_val = str(self.data_path).rstrip("/") + "/val" + self.data_path_train = self.data_path / "train" + self.data_path_val = self.data_path / "val" def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: self.tokenizer = tokenizer @@ -62,7 +59,7 @@ def prepare_data(self) -> None: optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=train_files, - output_dir=self.data_path_train, + output_dir=str(self.data_path_train), num_workers=num_workers, chunk_bytes="200MB", ) @@ -70,7 +67,7 @@ def prepare_data(self) -> None: optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=val_files, - output_dir=self.data_path_val, + output_dir=str(self.data_path_val), num_workers=num_workers, chunk_bytes="200MB", ) @@ -79,7 +76,7 @@ def train_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader train_dataset = StreamingDataset( - input_dir=self.data_path_train, + input_dir=str(self.data_path_train), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, drop_last=True, @@ -93,7 +90,7 @@ def val_dataloader(self) -> DataLoader: from litdata.streaming import StreamingDataset, TokensLoader val_dataset = StreamingDataset( - input_dir=self.data_path_val, + input_dir=str(self.data_path_val), item_loader=TokensLoader(block_size=self.max_seq_length), shuffle=True, # Consider setting to False, but we would lose some samples due to truncation when world size > 1 @@ -118,6 +115,9 @@ def tokenize(filename: str, tokenizer: Tokenizer): yield tokens +_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" + + def download(data_dir: Path): data_dir.mkdir(exist_ok=True) From 7a9789743e980af579bf4324086c0608af954482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Mar 2024 02:54:43 +0100 Subject: [PATCH 5/5] xfail --- tests/data/test_tinystories.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/data/test_tinystories.py b/tests/data/test_tinystories.py index 3ddf8ccb18..c67cc8e081 100644 --- a/tests/data/test_tinystories.py +++ b/tests/data/test_tinystories.py @@ -15,6 +15,7 @@ def fn(_): optimize(fn=fn, inputs=[None] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB") +@pytest.mark.xfail(raises=IndexError, strict=True) # requires https://github.com/Lightning-AI/litdata/pull/77 @pytest.mark.parametrize( ("max_seq_len", "expected"), [