-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pretraining with plaintext files (#1235)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: awaelchli <[email protected]>
- Loading branch information
1 parent
9475ec4
commit cbbea05
Showing
5 changed files
with
285 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
import glob | ||
import os | ||
from dataclasses import dataclass, field | ||
from functools import partial | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
from typing import Optional | ||
|
||
from torch.utils.data import DataLoader | ||
|
||
from litgpt import Tokenizer | ||
from litgpt.data import DataModule | ||
|
||
|
||
@dataclass | ||
class TextFiles(DataModule): | ||
"""The TextFile data module used for pretraining. | ||
Reads in text data from plaintext files contained in a data folder | ||
and provides training and validation dataloaders that return batches of tokens. | ||
Every sample is set to a fixed length. | ||
""" | ||
train_data_path: Path | ||
"""The path to the data directory used for training that contains .txt files""" | ||
val_data_path: Optional[Path] = None | ||
"""The path to the data directory used for validation that | ||
contains .txt files. Splits off data for validation from the | ||
training set if None.""" | ||
seed: int = 42 | ||
"""The seed to use for shuffling the dataset.""" | ||
num_workers: int = 4 | ||
"""The number of workers to use for data loading.""" | ||
|
||
tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) | ||
batch_size: int = field(default=1, init=False, repr=False) | ||
max_seq_length: int = field(default=-1, init=False, repr=False) | ||
|
||
def __post_init__(self) -> None: | ||
self.out_path_train = self.train_data_path / "train" | ||
if self.val_data_path is None: | ||
self.out_path_val = self.train_data_path / "val" | ||
else: | ||
self.out_path_val = Path(self.val_data_path) / "val" | ||
|
||
def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: | ||
self.tokenizer = tokenizer | ||
self.batch_size = batch_size | ||
self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well | ||
|
||
def prepare_data(self) -> None: | ||
from litdata import optimize | ||
|
||
train_files = sorted(glob.glob(str(self.train_data_path / "*.txt"))) | ||
assert len(train_files) > 0, f"No .txt files found in train data {train_files}" | ||
|
||
if self.val_data_path is not None: | ||
self.val_data_path = Path(self.val_data_path) | ||
val_files = sorted(glob.glob(str(self.val_data_path / "*.txt"))) | ||
assert len(val_files) > 0, f"No .txt files found in validation data {val_files}" | ||
# train/test split. let's use only shard 0 for test split, rest train | ||
else: | ||
assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}" | ||
val_files, *train_files = train_files | ||
val_files = [val_files] | ||
|
||
# It's ok to use almost all CPUs here because this runs in a single process | ||
num_workers = os.cpu_count() - 1 | ||
use_workers = min(num_workers, len(train_files)) | ||
if not Path(self.out_path_train).is_dir(): | ||
validate_tokenizer(self.tokenizer) | ||
optimize( | ||
fn=partial(tokenize, tokenizer=self.tokenizer), | ||
inputs=train_files, | ||
output_dir=str(self.out_path_train), | ||
num_workers=use_workers, | ||
chunk_bytes="50MB", | ||
) | ||
use_workers = min(num_workers, len(val_files)) | ||
if not Path(self.out_path_val).is_dir(): | ||
validate_tokenizer(self.tokenizer) | ||
optimize( | ||
fn=partial(tokenize, tokenizer=self.tokenizer), | ||
inputs=val_files, | ||
output_dir=str(self.out_path_val), | ||
num_workers=use_workers, | ||
chunk_bytes="50MB", | ||
) | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader | ||
|
||
train_dataset = StreamingDataset( | ||
input_dir=str(self.out_path_train), | ||
item_loader=TokensLoader(block_size=self.max_seq_length), | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
|
||
train_dataloader = StreamingDataLoader( | ||
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True | ||
) | ||
return train_dataloader | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
from litdata.streaming import StreamingDataset, TokensLoader | ||
|
||
val_dataset = StreamingDataset( | ||
input_dir=str(self.out_path_val), | ||
item_loader=TokensLoader(block_size=self.max_seq_length), | ||
shuffle=True, | ||
# Consider setting to False, but we would lose some samples due to truncation when world size > 1 | ||
drop_last=True, | ||
) | ||
val_dataloader = DataLoader( | ||
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True | ||
) | ||
return val_dataloader | ||
|
||
|
||
def tokenize(filename: str, tokenizer: Tokenizer): | ||
with open(filename, "r", encoding="utf-8") as file: | ||
text = file.read() | ||
text = text.strip() | ||
yield tokenizer.encode(text, bos=True, eos=False) | ||
|
||
|
||
def validate_tokenizer(tokenizer: Tokenizer) -> None: | ||
if tokenizer is None: | ||
raise ValueError( | ||
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, " | ||
"please provide a valid `--tokenizer_dir` path." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import random | ||
import string | ||
import os | ||
|
||
import torch | ||
|
||
from litdata import optimize | ||
from torch.utils._pytree import tree_map | ||
|
||
|
||
class Tokenizer: | ||
bos_id = 0 | ||
|
||
def encode(self, text, bos, eos): | ||
assert bos | ||
assert not eos | ||
return [self.bos_id] + [ord(c) for c in text] | ||
|
||
|
||
def tokenize(data): | ||
for story in data: | ||
yield torch.tensor(story) | ||
|
||
|
||
def fake_chunk(path, data): | ||
optimize(fn=tokenize, inputs=[data] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB") | ||
|
||
|
||
def test_textfiles_datamodule(tmp_path): | ||
from litgpt.data.text_files import TextFiles | ||
|
||
data_dir = tmp_path / "textfiles" | ||
datamodule = TextFiles(train_data_path=data_dir) | ||
datamodule.connect(max_seq_length=2, tokenizer=Tokenizer()) | ||
|
||
# simulate `datamodule.prepare_data` | ||
train_data_dir = data_dir / "train" | ||
train_data_dir.mkdir(parents=True) | ||
fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) | ||
datamodule.setup() | ||
|
||
tr_dataloader = datamodule.train_dataloader() | ||
torch.manual_seed(123) | ||
|
||
actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) | ||
# there is 1 sample per index in the data (13) | ||
assert actual == [ | ||
[[1999, 0, 13]], | ||
[[0, 13, 12]], | ||
[[1, 1999, 0]], | ||
[[63, 0, 73]], | ||
[[5, 0, 1]], | ||
[[0, 73, 5]], | ||
[[0, 23, 15]], | ||
[[0, 1, 1999]], | ||
[[15, 63, 0]], | ||
[[73, 5, 0]], | ||
[[12, 0, 23]], | ||
[[23, 15, 63]], | ||
[[13, 12, 0]], | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters