Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

litdata backed tinystories #1186

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 86 additions & 127 deletions litgpt/data/tinystories.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
"""https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py"""

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import glob
import json
import os
import random
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional

import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data import DataLoader
from tqdm import tqdm

from litgpt import Tokenizer
from litgpt.data import DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.data.base import DataModule
from litgpt.tokenizer import Tokenizer


@dataclass
Expand All @@ -27,55 +22,97 @@ class TinyStories(DataModule):
Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length.
"""

path: Path = Path("data/")
"""Path to the data directory where data will be downloaded and preprocessed"""
num_workers: int = 0
"""How many DataLoader processes to use for loading."""
data_path: Path = Path("data/tinystories")
"""The path to the data directory, containing two folders 'train' and 'val'
which are the output of the preprocessing step."""
seed: int = 42
"""The random seed for creating the train/val splits and shuffling the dataset."""
"""The seed to use for shuffling the dataset."""
num_workers: int = 8
"""The number of workers to use for the dataloaders."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)
train_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False)
test_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False)

def __post_init__(self) -> None:
self.data_path_train = self.data_path / "train"
self.data_path_val = self.data_path / "val"

def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = max_seq_length
self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
download(self.path)
assert self.tokenizer is not None
pretokenize(self.path, self.tokenizer)

def setup(self, stage: str = "") -> None:
# the .bin files are right along the .json files
bin_dir = self.path / "TinyStories_all_data"
shard_filenames = sorted(glob.glob(str(bin_dir / "*.bin")))
assert len(shard_filenames) > 0, f"No bin files found in {bin_dir}"
assert len(shard_filenames) > 1, f"Expected at least two bins in {bin_dir}"
from litdata import optimize

download(self.data_path)

files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json")))
assert len(files) > 0, f"No json files found in {files}"
assert len(files) > 1, f"Expected at least two json files in {files}"
# train/test split. let's use only shard 0 for test split, rest train
va_files, *train_files = shard_filenames
# shuffle the training files
random.Random(self.seed).shuffle(train_files)
self.train_dataset = ConcatDataset([PretokDataset(f, self.max_seq_length) for f in train_files])
self.val_dataset = PretokDataset(shard_filenames[0], self.max_seq_length)
val_files, *train_files = files
num_workers = os.cpu_count() - 1

if not Path(self.data_path_train).is_dir():
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
output_dir=str(self.data_path_train),
num_workers=num_workers,
chunk_bytes="200MB",
)
if not Path(self.data_path_val).is_dir():
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=val_files,
output_dir=str(self.data_path_val),
num_workers=num_workers,
chunk_bytes="200MB",
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True, num_workers=self.num_workers
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader

train_dataset = StreamingDataset(
input_dir=str(self.data_path_train),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
drop_last=True,
)
train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
pin_memory=True,
shuffle=True, # llama2.c shuffles validation too
num_workers=self.num_workers,
from litdata.streaming import StreamingDataset, TokensLoader

val_dataset = StreamingDataset(
input_dir=str(self.data_path_val),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader


def tokenize(filename: str, tokenizer: Tokenizer):
with open(filename, "r") as f:
data = json.load(f)
global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"])
num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"])
local_rank = global_rank % num_workers
for example in tqdm(data, position=local_rank):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS
yield tokens


_URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
Expand All @@ -84,98 +121,20 @@ def val_dataloader(self) -> DataLoader:
def download(data_dir: Path):
data_dir.mkdir(exist_ok=True)

data_dir = data_dir / "TinyStories_all_data"
shard_filenames = sorted(glob.glob(str(data_dir / "*.json")))
if shard_filenames:
print(f"{data_dir} already exists, skipping unpacking...")
return

# download the TinyStories dataset, unless it's already downloaded
data_filename = data_dir / "TinyStories_all_data.tar.gz"
download_if_missing(data_filename, _URL, stream=True, mode="wb")
print("Download done.")

# unpack the tar.gz file into all the data shards (json files)
data_dir = data_dir / "TinyStories_all_data"
data_dir.mkdir(exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
shard_filenames = sorted(glob.glob(str(data_dir / "*.json")))
if shard_filenames:
print(f"{data_dir} already exists, skipping unpacking...")
else:
data_dir.mkdir(exist_ok=True)
print(f"Unpacking {data_filename}...")
os.system(f"tar -xzf {data_filename} -C {data_dir}")
shard_filenames = sorted(glob.glob(str(data_dir / "*.json")))

print(f"Number of shards: {len(shard_filenames)}")
# print a single example just for debugging and such
# with open(shard_filenames[0], "r") as f:
# data = json.load(f)
# print(f"Example story:\n{data[0]}")


def process_shard(args, tokenizer):
shard_id, shard = args
with open(shard, "r") as f:
data = json.load(f)
all_tokens = []
for example in tqdm(data, position=shard_id):
text = example["story"]
text = text.strip() # get rid of leading/trailing whitespace
tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS
all_tokens.extend(tokens)
# convert to uint16 nparray
all_tokens = np.array(all_tokens, dtype=np.uint16)
# just save the tokenized file in the same dir
tokenized_filename = shard.replace(".json", ".bin")
# write the bytes
with open(tokenized_filename, "wb") as f:
f.write(all_tokens.tobytes())
# calculate the average sequence length (they are separated by BOS=1)
bos_id = tokenizer.bos_id
assert bos_id >= 0 # uint16 is unsigned
bos_tokens = (all_tokens == tokenizer.bos_id).sum()
assert bos_tokens > 0
avg_seq_len = all_tokens.size / bos_tokens
print(
f"Saved {tokenized_filename}, tokens: {all_tokens.size}, bos: {bos_tokens}, average seqlen: {avg_seq_len:.2f}"
)


def pretokenize(data_dir: Path, tokenizer: Tokenizer):
bins_path = str(data_dir / "TinyStories_all_data" / "*.bin")
shard_filenames = sorted(glob.glob(bins_path))
if shard_filenames:
print("Already pretokenized.")
return
# iterate the shards and tokenize all of them one by one
jsons_path = str(data_dir / "TinyStories_all_data" / "*.json")
shard_filenames = sorted(glob.glob(jsons_path))
if not shard_filenames:
raise ValueError(f"No json files found in {jsons_path!r}. Did you run `python tinystories.py download`?")
# process all the shards in a process pool
fun = partial(process_shard, tokenizer=tokenizer)
with ProcessPoolExecutor() as executor:
executor.map(fun, enumerate(shard_filenames))
print("Done.")


class PretokDataset(torch.utils.data.Dataset):
"""Loads a pre-tokenized array from disk and returns chunks of `max_seq_length` length."""

def __init__(self, filepath: str, max_seq_len: int):
super().__init__()
self.filepath = filepath
# open the dataset for reading but keep it on disk with memmap
self.shard = np.memmap(filepath, dtype=np.uint16, mode="r")
self.shard_length = len(self.shard)
self.length = self.shard_length // max_seq_len
assert max_seq_len > 1
self.max_seq_len = max_seq_len

def __len__(self) -> int:
return self.length

def __getitem__(self, ix: int) -> torch.Tensor:
if ix < 0:
raise NotImplementedError
start = ix * self.max_seq_len
end = start + self.max_seq_len + 1
if end > self.shard_length:
raise IndexError
# calling .astype will copy the data into a new numpy array, now in RAM
chunk = torch.from_numpy((self.shard[start:end]).astype(np.int64))
return chunk
86 changes: 46 additions & 40 deletions tests/data/test_tinystories.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
import json
from contextlib import redirect_stdout
from io import StringIO

import numpy as np
import pytest
import torch
from litdata import optimize
from litdata.streaming import StreamingDataset, TokensLoader
from torch.utils._pytree import tree_map
from torch.utils.data import ConcatDataset

from litgpt.data.tinystories import PretokDataset, TinyStories, process_shard

def fake_chunk(path, data):
def fn(_):
for story in data:
yield torch.tensor(story)

def fake_bin(tmp_path, data, name):
all_tokens = np.array(data, dtype=np.uint16)
data_path = tmp_path / f"{name}.bin"
with open(data_path, "wb") as f:
f.write(all_tokens.tobytes())
return data_path
optimize(fn=fn, inputs=[None] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB")


@pytest.mark.xfail(raises=IndexError, strict=True) # requires https://github.com/Lightning-AI/litdata/pull/77
@pytest.mark.parametrize(
("max_seq_len", "expected"),
[
(2, [[0, 23, 15], [15, 63, 0], [0, 73, 5], [5, 0, 1], [1, 1999, 0]]),
(5, [[0, 23, 15, 63, 0, 73], [73, 5, 0, 1, 1999, 0]]),
(2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]),
(5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]),
(6, [[0, 23, 15, 63, 0, 73, 5]]),
(7, [[0, 23, 15, 63, 0, 73, 5, 0]]),
],
)
def test_pretok_dataset(tmp_path, max_seq_len, expected):
fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13]
assert len(fake_data) == 12
bin_path = fake_bin(tmp_path, fake_data, "data")
fake_chunk(tmp_path, [fake_data])

dataset = PretokDataset(str(bin_path), max_seq_len)
dataset = StreamingDataset(
input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False
)
actual = tree_map(torch.Tensor.tolist, list(dataset))
assert actual == expected


def test_process_shard(tmp_path):
def test_tokenize(tmp_path, monkeypatch):
from litgpt.data.tinystories import tokenize

story1, story2 = "foo bar", " fun "
data = [{"story": story1}, {"story": story2}]
shard_path = tmp_path / "data.json"
Expand All @@ -53,38 +54,43 @@ def encode(self, text, bos, eos):
assert not eos
return [self.bos_id] + [ord(c) for c in text]

out = StringIO()
with redirect_stdout(out):
process_shard((0, str(shard_path)), Tokenizer())

text = out.getvalue()
assert text.endswith("data.bin, tokens: 12, bos: 2, average seqlen: 6.00\n")
assert shard_path.with_suffix(".bin").exists()
monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", "0")
monkeypatch.setenv("DATA_OPTIMIZER_NUM_WORKERS", "1")
data = tokenize(str(shard_path), Tokenizer())
assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]]


def test_tinystories_datamodule(tmp_path):
datamodule = TinyStories(tmp_path, seed=42)
datamodule.connect(max_seq_length=2)
from litgpt.data.tinystories import TinyStories

data_dir = tmp_path / "TinyStories_all_data"
data_dir.mkdir()
fake_bin(data_dir, [12], "0")
fake_bin(data_dir, [0, 23, 15, 63, 0], "1")
fake_bin(data_dir, [73, 5, 0, 1, 1999, 0, 13], "2")
data_dir = tmp_path / "tinystories"

datamodule.setup()
datamodule = TinyStories(data_dir, seed=42)
datamodule.connect(max_seq_length=2)

assert isinstance(datamodule.train_dataset, ConcatDataset)
assert len(datamodule.train_dataset.datasets) == 2
assert isinstance(datamodule.train_dataset.datasets[0], PretokDataset)
# unordered because it shuffled
assert datamodule.train_dataset.datasets[0].filepath == str(data_dir / "2.bin")
assert datamodule.train_dataset.datasets[1].filepath == str(data_dir / "1.bin")
# simulate `datamodule.prepare_data`
train_data_dir = data_dir / "train"
train_data_dir.mkdir(parents=True)
fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]])

assert isinstance(datamodule.val_dataset, PretokDataset)
assert datamodule.val_dataset.filepath == str(data_dir / "0.bin")
datamodule.setup()

tr_dataloader = datamodule.train_dataloader()
torch.manual_seed(0)
actual = tree_map(torch.Tensor.tolist, list(tr_dataloader))
assert actual == [[[0, 1, 1999]], [[15, 63, 0]], [[1999, 0, 13]], [[0, 23, 15]], [[73, 5, 0]]]
# there is 1 sample per index in the data (13)
assert actual == [
[[1999, 0, 13]],
[[0, 13, 12]],
[[1, 1999, 0]],
[[63, 0, 73]],
[[5, 0, 1]],
[[0, 73, 5]],
[[0, 23, 15]],
[[0, 1, 1999]],
[[15, 63, 0]],
[[73, 5, 0]],
[[12, 0, 23]],
[[23, 15, 63]],
[[13, 12, 0]],
]
Loading