diff --git a/litgpt/data/__init__.py b/litgpt/data/__init__.py index f3043ca301..97fb67aa05 100644 --- a/litgpt/data/__init__.py +++ b/litgpt/data/__init__.py @@ -11,6 +11,7 @@ from litgpt.data.lima import LIMA from litgpt.data.lit_data import LitData from litgpt.data.longform import LongForm +from litgpt.data.text_files import TextFiles from litgpt.data.tinyllama import TinyLlama from litgpt.data.tinystories import TinyStories from litgpt.data.openwebtext import OpenWebText @@ -30,6 +31,7 @@ "LongForm", "OpenWebText", "SFTDataset", + "TextFiles", "TinyLlama", "TinyStories", "get_sft_collate_fn", diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py new file mode 100644 index 0000000000..5989937669 --- /dev/null +++ b/litgpt/data/text_files.py @@ -0,0 +1,133 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import glob +import os +from dataclasses import dataclass, field +from functools import partial +from pathlib import Path +from tqdm import tqdm +from typing import Optional + +from torch.utils.data import DataLoader + +from litgpt import Tokenizer +from litgpt.data import DataModule + + +@dataclass +class TextFiles(DataModule): + """The TextFile data module used for pretraining. + + Reads in text data from plaintext files contained in a data folder + and provides training and validation dataloaders that return batches of tokens. + Every sample is set to a fixed length. + """ + train_data_path: Path + """The path to the data directory used for training that contains .txt files""" + val_data_path: Optional[Path] = None + """The path to the data directory used for validation that + contains .txt files. Splits off data for validation from the + training set if None.""" + seed: int = 42 + """The seed to use for shuffling the dataset.""" + num_workers: int = 4 + """The number of workers to use for data loading.""" + + tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) + batch_size: int = field(default=1, init=False, repr=False) + max_seq_length: int = field(default=-1, init=False, repr=False) + + def __post_init__(self) -> None: + self.out_path_train = self.train_data_path / "train" + if self.val_data_path is None: + self.out_path_val = self.train_data_path / "val" + else: + self.out_path_val = Path(self.val_data_path) / "val" + + def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: + self.tokenizer = tokenizer + self.batch_size = batch_size + self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well + + def prepare_data(self) -> None: + from litdata import optimize + + train_files = sorted(glob.glob(str(self.train_data_path / "*.txt"))) + assert len(train_files) > 0, f"No .txt files found in train data {train_files}" + + if self.val_data_path is not None: + self.val_data_path = Path(self.val_data_path) + val_files = sorted(glob.glob(str(self.val_data_path / "*.txt"))) + assert len(val_files) > 0, f"No .txt files found in validation data {val_files}" + # train/test split. let's use only shard 0 for test split, rest train + else: + assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}" + val_files, *train_files = train_files + val_files = [val_files] + + # It's ok to use almost all CPUs here because this runs in a single process + num_workers = os.cpu_count() - 1 + use_workers = min(num_workers, len(train_files)) + if not Path(self.out_path_train).is_dir(): + validate_tokenizer(self.tokenizer) + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=train_files, + output_dir=str(self.out_path_train), + num_workers=use_workers, + chunk_bytes="50MB", + ) + use_workers = min(num_workers, len(val_files)) + if not Path(self.out_path_val).is_dir(): + validate_tokenizer(self.tokenizer) + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=val_files, + output_dir=str(self.out_path_val), + num_workers=use_workers, + chunk_bytes="50MB", + ) + + def train_dataloader(self) -> DataLoader: + from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader + + train_dataset = StreamingDataset( + input_dir=str(self.out_path_train), + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + drop_last=True, + ) + + train_dataloader = StreamingDataLoader( + train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return train_dataloader + + def val_dataloader(self) -> DataLoader: + from litdata.streaming import StreamingDataset, TokensLoader + + val_dataset = StreamingDataset( + input_dir=str(self.out_path_val), + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + # Consider setting to False, but we would lose some samples due to truncation when world size > 1 + drop_last=True, + ) + val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return val_dataloader + + +def tokenize(filename: str, tokenizer: Tokenizer): + with open(filename, "r", encoding="utf-8") as file: + text = file.read() + text = text.strip() + yield tokenizer.encode(text, bos=True, eos=False) + + +def validate_tokenizer(tokenizer: Tokenizer) -> None: + if tokenizer is None: + raise ValueError( + "Tokenizer is None. If you are using this data module via `litgpt pretrain`, " + "please provide a valid `--tokenizer_dir` path." + ) diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 90fce42341..ec8f5d6cec 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -13,6 +13,7 @@ from litgpt import Tokenizer from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing +from litgpt.data.text_files import validate_tokenizer @dataclass @@ -56,6 +57,7 @@ def prepare_data(self) -> None: num_workers = os.cpu_count() - 1 if not Path(self.data_path_train).is_dir(): + validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=train_files, @@ -64,6 +66,7 @@ def prepare_data(self) -> None: chunk_bytes="200MB", ) if not Path(self.data_path_val).is_dir(): + validate_tokenizer(self.tokenizer) optimize( fn=partial(tokenize, tokenizer=self.tokenizer), inputs=[val_file], diff --git a/tests/data/test_textfiles.py b/tests/data/test_textfiles.py new file mode 100644 index 0000000000..54623976ed --- /dev/null +++ b/tests/data/test_textfiles.py @@ -0,0 +1,61 @@ +import random +import string +import os + +import torch + +from litdata import optimize +from torch.utils._pytree import tree_map + + +class Tokenizer: + bos_id = 0 + + def encode(self, text, bos, eos): + assert bos + assert not eos + return [self.bos_id] + [ord(c) for c in text] + + +def tokenize(data): + for story in data: + yield torch.tensor(story) + + +def fake_chunk(path, data): + optimize(fn=tokenize, inputs=[data] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB") + + +def test_textfiles_datamodule(tmp_path): + from litgpt.data.text_files import TextFiles + + data_dir = tmp_path / "textfiles" + datamodule = TextFiles(train_data_path=data_dir) + datamodule.connect(max_seq_length=2, tokenizer=Tokenizer()) + + # simulate `datamodule.prepare_data` + train_data_dir = data_dir / "train" + train_data_dir.mkdir(parents=True) + fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) + datamodule.setup() + + tr_dataloader = datamodule.train_dataloader() + torch.manual_seed(123) + + actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) + # there is 1 sample per index in the data (13) + assert actual == [ + [[1999, 0, 13]], + [[0, 13, 12]], + [[1, 1999, 0]], + [[63, 0, 73]], + [[5, 0, 1]], + [[0, 73, 5]], + [[0, 23, 15]], + [[0, 1, 1999]], + [[15, 63, 0]], + [[73, 5, 0]], + [[12, 0, 23]], + [[23, 15, 63]], + [[13, 12, 0]], + ] diff --git a/tutorials/pretrain.md b/tutorials/pretrain.md index 4a8db678e1..ce8f92b0e7 100644 --- a/tutorials/pretrain.md +++ b/tutorials/pretrain.md @@ -4,7 +4,7 @@ This document explains how to pretrain LLMs using LitGPT. -## The Pretraining API +## Using the `litgpt pretrain` command You can pretrain models in LitGPT using the `litgpt pretrain` API starting with any of the available architectures listed by calling `litgpt pretrain` without any additional arguments: @@ -36,6 +36,90 @@ litgpt pretrain \ ``` + +## Pretrain on custom data + +The simplest way to get started with pretraining on a small custom dataset is by using the `TextFiles` data module, which lets you pretrain a dataset from a folder containing plain text files. + + + +> [!NOTE] +> This approach adds a beginning-of-sequence token at the beginning of each text file. However, it otherwise assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable in case a text file conists of multiple documents. + + + +> [!WARNING] +> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size. +> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData`. + + + +For instance, assume you stored a number of text files in a `custom_pretraining_dataset` folder (we recommend avoiding small files and concatenating them to files of at least 50 Mb for efficiency): + +```bash +~ ls -lh custom_pretraining_data +total 3225M +-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_1.txt +-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_2.txt +-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_3.txt +-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_4.txt +-rw-r--r-- 1 sebastian 50M Apr 2 18:31 combined_5.txt +... +``` + +You can then use the `TextFiles` API to pretrain a model (here a small `pythia-14m` model for illustration purposes) from scratch as follows: + +```bash +litgpt download \ + --repo_id EleutherAI/pythia-14m \ + --tokenizer_only true + +litgpt pretrain \ + --model_name pythia-14m \ + --tokenizer_dir checkpoints/EleutherAI/pythia-14m \ + --data TextFiles \ + --data.train_data_path custom_pretraining_data \ + --train.learning_rate 0.005 \ + --train.lr_warmup_steps=200 +``` + + + +## Continued pretraining on custom data + +Often, it makes sense to adopt an existing pretrained model and further pretrain it on our own custom data. The existing pretrained model can be either our own pretrained model or a model downloaded from a model hub. + + + +> [!NOTE] +> This approach assumes that you have already cleaned the text files, for example, removing any unwanted characters and inserting beginning-of-sequence and end-of-sequence tokens if applicable. + + + +> [!WARNING] +> Using this approach is only recommended for small datasets. Since text data is highly compressible, it is often stored in compressed format, and often in file formats where documents can be loaded row by row without having to load entire files at once. In other words, this `TextFiles` approach is only feasible to store the data in plain text files due to the limited size. +> For datasets that take up multiple gigabytes, we recommend preprocessing it with [LitData](https://github.com/Lightning-AI/litdata) and then reading it from a local directory or S3 connection using `--data LitData --data.path path/to/your/data`. + + + +For instance, let's assume we download a Pythia model: + +```bash +litgpt download --repo_id EleutherAI/pythia-14m +``` + +Next, assume we have a custom dataset stored in text files similar to the *Pretrain on custom data* above. We can further pretrain the Pythia model via the `--initial_checkpoint_dir` setting as follows: + +```bash +litgpt pretrain \ + --model_name pythia-14m \ + --initial_checkpoint_dir checkpoints/EleutherAI/pythia-14m \ + --out_dir new_phi-2_checkpoint \ + --data TextFiles \ + --data.train_data_path custom_pretraining_data \ + --train.learning_rate 0.005 \ + --train.lr_warmup_steps=200 +``` @@ -62,4 +146,4 @@ The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) temp |---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| |
[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[
[
](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) | | [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)[
](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | | -| | \ No newline at end of file +| |