Skip to content

Commit

Permalink
Move pretrain into package (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 5, 2024
1 parent a2fa3c0 commit b9e57e4
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 117 deletions.
2 changes: 1 addition & 1 deletion lit_gpt/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TrainArgs:
"""Number of iterations with learning rate warmup active"""
epochs: Optional[int] = None
"""Number of epochs to run"""
# TODO: pretrain/pretrain is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
# TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs?
max_tokens: Optional[int] = None
"""Total number of tokens to train on"""
max_steps: Optional[int] = None
Expand Down
12 changes: 1 addition & 11 deletions pretrain/pretrain.py → lit_gpt/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

"""
This script is adapted from TinyLlama:
https://github.com/jzhang38/TinyLlama/blob/main/pretrain/tinyllama.py
"""

import math
import os
import sys
import time
from datetime import timedelta
from functools import partial
Expand All @@ -25,10 +19,6 @@
from torchmetrics.aggregation import RunningMean
from typing_extensions import Literal

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from lit_gpt import Tokenizer
from lit_gpt.args import EvalArgs, TrainArgs
from lit_gpt.data import LitDataModule, TinyLlama
Expand Down Expand Up @@ -330,7 +320,7 @@ def get_lr(learning_rate: float, it: int, warmup_iters: int, max_iters: int, min


def init_weights(module: nn.Module, n_layer: int, n_embd: int):
# Follows GPT-NeoX: https://arxiv.org/abs/2204.06745
# Copied from https://github.com/jzhang38/TinyLlama/blob/bf12224/lit_gpt/model.py#L40-L54
if isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / n_embd))
elif isinstance(module, nn.Linear):
Expand Down
12 changes: 6 additions & 6 deletions requirements-all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ scipy # required by bitsandbytes
sentencepiece # llama-based models
tokenizers # pythia, falcon, redpajama
datasets # eval
requests # scripts/prepare_*
zstandard # scripts/prepare_redpajama.py, scripts/prepare_starcoder.py
pandas # scripts/prepare_csv.py, scripts/prepare_starcoder.py
requests # lit_gpt.data
zstandard # scripts/prepare_starcoder.py
pandas # scripts/prepare_starcoder.py
pyarrow # scripts/prepare_starcoder.py
tensorboard # pretrain/pretrain.py
torchmetrics # pretrain/pretrain.py
tensorboard # lit_gpt.pretrain
torchmetrics # lit_gpt.pretrain
# eval
git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8beaa90051fb52db77f0a529
# scripts/prepare_slimpajama.py, scripts/prepare_starcoder.py, pretrain/tinyllama.py
# scripts/prepare_slimpajama.py, scripts/prepare_starcoder.py
lightning[data] @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af
2 changes: 1 addition & 1 deletion scripts/convert_pretrained_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def convert_checkpoint(checkpoint_file: Path, tokenizer_dir: Path, config_name:
with the tokenizer and model config, which then can be loaded by other scripts for inference, evaluation, etc.
Args:
checkpoint_file: Path to a checkpoint file scripts produced by the scripts in ``lit_gpt/pretrain/``.
checkpoint_file: Path to a checkpoint file scripts produced by ``lit_gpt.pretrain``.
tokenizer_dir: A path to the folder that holds the tokenizer configuration files that were used to train
the model. All files with a name starting with 'tokenizer' will be copied to the output folder.
config_name: The name of the model loaded with the ``lit_gpt.Config``. The configuration will be saved as a
Expand Down
6 changes: 3 additions & 3 deletions tests/test_config_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


@pytest.mark.parametrize(["script_file", "config_file"], [
("pretrain/pretrain.py", "pretrain/debug.yaml"),
("pretrain/pretrain.py", "pretrain/tinyllama.yaml"),
("pretrain/pretrain.py", "pretrain/tinystories.yaml"),
("lit_gpt/pretrain.py", "pretrain/debug.yaml"),
("lit_gpt/pretrain.py", "pretrain/tinyllama.yaml"),
("lit_gpt/pretrain.py", "pretrain/tinystories.yaml"),
])
def test_config_help(script_file, config_file, monkeypatch, tmp_path):
"""Test that configs validate against the signature in the scripts."""
Expand Down
8 changes: 4 additions & 4 deletions tests/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
@RunIf(min_cuda_gpus=2, standalone=True)
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
def test_pretrain_tiny_llama(tmp_path, monkeypatch):
import pretrain.pretrain as module
def test_pretrain(tmp_path, monkeypatch):
from lit_gpt import pretrain
from lit_gpt.args import EvalArgs, TrainArgs
from lit_gpt.config import Config

model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)

dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
dataloader = DataLoader(dataset)
module.get_dataloaders = Mock(return_value=(dataloader, dataloader))
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))

out_dir = tmp_path / "out"
stdout = StringIO()
with redirect_stdout(stdout):
module.setup(
pretrain.setup(
devices=2,
model=model_config,
out_dir=out_dir,
Expand Down
86 changes: 0 additions & 86 deletions tutorials/pretrain_openwebtext.md

This file was deleted.

10 changes: 5 additions & 5 deletions tutorials/pretrain_tinyllama.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ In the above we are assuming that you will be using the same tokenizer as used i
Running the pretraining script with its default settings requires at least 8 A100 GPUs.

```bash
python pretrain/pretrain.py
python lit_gpt/pretrain.py
```

The script will save checkpoints periodically to the folder `out/`.
By default, the `pretrain/pretrain.py` script will pretrain the model with FSDP in
By default, the `pretrain` script will pretrain the model with FSDP in
`bfloat16` mixed precision and gradient accumulation.

Note that the `pretrain/pretrain.py` is not actually a model-specific training script, so feel free to change
Note that `pretrain` is not actually a model-specific training script, so feel free to change
the configuration and size by passing a different string to the model name argument, for example:

```shell
python pretrain/pretrain.py --model.name Gemma-2b
python lit_gpt/pretrain.py --model.name Gemma-2b
```

The currently supported model names are contained in the [config.py](https://github.com/Lightning-AI/lit-gpt/lit_gpt/config.py) file.
Expand Down Expand Up @@ -145,7 +145,7 @@ The checkpoints saved during pretraining contain all the information to resume i
Simply rerun the script with the `--resume` argument:

```bash
python pretrain/pretrain.py --resume out/tiny-llama-1.1b/step-00060500.pth
python lit_gpt/pretrain.py --resume out/tiny-llama-1.1b/step-00060500.pth
```

## Export checkpoints
Expand Down

0 comments on commit b9e57e4

Please sign in to comment.