Skip to content

Commit

Permalink
tokenizer is only required if optimization has to run
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Apr 8, 2024
1 parent 2219c2d commit 12332a1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
16 changes: 10 additions & 6 deletions litgpt/data/text_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,11 @@ def prepare_data(self) -> None:
val_files, *train_files = train_files
val_files = [val_files]

if self.tokenizer is None:
raise ValueError(
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
"please provide a valid `--tokenizer_dir` path."
)

# It's ok to use almost all CPUs here because this runs in a single process
num_workers = os.cpu_count() - 1
use_workers = min(num_workers, len(train_files))
if not Path(self.out_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
Expand All @@ -83,6 +78,7 @@ def prepare_data(self) -> None:
)
use_workers = min(num_workers, len(val_files))
if not Path(self.out_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=val_files,
Expand Down Expand Up @@ -127,3 +123,11 @@ def tokenize(filename: str, tokenizer: Tokenizer):
text = file.read()
text = text.strip()
yield tokenizer.encode(text, bos=True, eos=False)


def validate_tokenizer(tokenizer: Tokenizer) -> None:
if tokenizer is None:
raise ValueError(
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
"please provide a valid `--tokenizer_dir` path."
)
9 changes: 3 additions & 6 deletions litgpt/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from litgpt import Tokenizer
from litgpt.data import DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.data.text_files import validate_tokenizer


@dataclass
Expand Down Expand Up @@ -46,12 +47,6 @@ def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, ma
def prepare_data(self) -> None:
from litdata import optimize

if self.tokenizer is None:
raise ValueError(
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
"please provide a valid `--tokenizer_dir` path."
)

download(self.data_path)

files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json")))
Expand All @@ -62,6 +57,7 @@ def prepare_data(self) -> None:
num_workers = os.cpu_count() - 1

if not Path(self.data_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
Expand All @@ -70,6 +66,7 @@ def prepare_data(self) -> None:
chunk_bytes="200MB",
)
if not Path(self.data_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=[val_file],
Expand Down

0 comments on commit 12332a1

Please sign in to comment.