Skip to content

Commit

Permalink
Merge branch 'main' into carmocca/no-intermediate
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Apr 16, 2024
2 parents a76682d + 5f838a4 commit bd5dfd9
Show file tree
Hide file tree
Showing 39 changed files with 755 additions and 210 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
* @awaelchli @carmocca @lantiga
/README.md @williamfalcon @lantiga
1 change: 1 addition & 0 deletions .github/azure-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
displayName: "Image info & NVIDIA"
- script: |
pip install --upgrade pip
pip install '.[all,test]'
displayName: 'Install dependencies'
Expand Down
31 changes: 31 additions & 0 deletions .github/workflows/check-links.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Check hyperlinks

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
test:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-check-links
- name: Check links
run: |
pytest --check-links README.md --check-links-ignore "http*"
pytest --check-links tutorials --check-links-ignore "http*"
16 changes: 8 additions & 8 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ defaults:

env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
UV_HTTP_TIMEOUT: 500

jobs:
cpu-tests:
Expand All @@ -40,14 +39,15 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install uv
run: pip install uv
cache: 'pip'
cache-dependency-path: |
pyproject.toml
- name: Install minimal dependencies
run: |
uv pip install --system .
uv pip list
pip install --upgrade pip
pip install .
pip list
# make sure all modules are still importable with only the minimal dependencies available
modules=$(
find litgpt -type f -name "*.py" | \
Expand All @@ -59,8 +59,8 @@ jobs:
- name: Install all dependencies
run: |
uv pip install --system '.[all,test]'
uv pip list
pip install '.[all,test]'
pip list
- name: Run tests
run: |
Expand Down
311 changes: 209 additions & 102 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion extensions/thunder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def backward_fn(saved_for_backward, cotangents):
t763 = unsloth_apply_rope_backward(t757, t21, t22, 1, 8, 4) # t763: "cuda:0 f32[2, 4, 3, 16]"
```

We provide a specific [pre-training script copy](unsloth/pretrain.py) that uses this executor.
We provide a specific [pre-training script copy](pretrain.py) that uses this executor.
Given the Unsloth results below, these hand-written kernels do not seem to be worth it, showcasing the power of automated fusion compilers like [NvFuser](https://github.com/NVIDIA/Fuser).

## Examples and benchmarks
Expand Down
29 changes: 19 additions & 10 deletions extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from litgpt.utils import (
CLI,
CycleIterator,
capture_hparams,
choose_logger,
chunked_cross_entropy,
copy_config_files,
Expand Down Expand Up @@ -97,7 +98,7 @@ def setup(
executors: If using Thunder, the executors to enable.
strategy: If desired, the strategy to use.
"""
hparams = locals()
hparams = capture_hparams()
data = TinyLlama() if data is None else data
if model_config is not None and model_name is not None:
raise ValueError("Only one of `model_name` or `model_config` can be set.")
Expand Down Expand Up @@ -232,6 +233,10 @@ def main(
train_time = time.perf_counter()
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")

# Save final checkpoint
save_checkpoint(fabric, state, tokenizer_dir, out_dir / "final" / "lit_model.pth")

if fabric.device.type == "cuda":
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")

Expand Down Expand Up @@ -357,15 +362,7 @@ def fit(
fabric.barrier()

if train.save_interval is not None and not is_accumulating and state["step_count"] % train.save_interval == 0:
checkpoint_file = out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth"
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)
save_checkpoint(fabric, state, tokenizer_dir, out_dir / f"step-{state['step_count']:08d}" / "lit_model.pth")


def forward_and_loss(model: nn.Module, input_ids: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -451,6 +448,18 @@ def init_out_dir(out_dir: Path) -> Path:
return out_dir


def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):
model = state["model"]
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fabric.print(f"Saving checkpoint to {str(checkpoint_file)!r}")
fabric.save(checkpoint_file, state)
if fabric.global_rank == 0:
save_hyperparameters(setup, checkpoint_file.parent)
if tokenizer_dir is not None:
copy_config_files(tokenizer_dir, checkpoint_file.parent)
save_config(model.config, checkpoint_file.parent)


def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:
issues = []
unsupported = [(train, ["max_steps", "epochs"]), (eval, ["max_new_tokens"])]
Expand Down
10 changes: 5 additions & 5 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def main(

fabric = L.Fabric(devices=1, precision=precision, plugins=plugins)

check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

checkpoint_path = checkpoint_dir / "lit_model.pth"

# Merge if this is a raw LoRA checkpoint
if (checkpoint_path / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
if (checkpoint_dir / "lit_model.pth.lora").is_file() and not checkpoint_path.is_file():
print("Merging LoRA weights with the base model. This won't take long and is a one-time-only thing.")
merge_lora(checkpoint_path)
merge_lora(checkpoint_dir)

check_valid_checkpoint_dir(checkpoint_dir)
config = Config.from_file(checkpoint_dir / "model_config.yaml")

with fabric.init_module(empty_init=True):
model = GPT(config)
Expand Down
26 changes: 26 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,32 @@ def norm_class(self) -> Type:
copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it"
configs.append(copy)

##################
# Google CodeGemma
##################
codegemma = [
# https://huggingface.co/google/codegemma-7b-it/blob/main/config.json
dict(
name="CodeGemma-7b-it",
hf_config=dict(org="google", name="codegemma-7b-it"),
scale_embeddings=True,
vocab_size=256000,
padding_multiple=64,
n_embd=3072,
n_layer=28,
n_head=16,
head_size=256,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="GemmaMLP",
gelu_approximate="tanh",
intermediate_size=24576,
),
]
configs.extend(codegemma)


##########################
# Stability AI FreeWilly2
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from litgpt.data.lima import LIMA
from litgpt.data.lit_data import LitData
from litgpt.data.longform import LongForm
from litgpt.data.text_files import TextFiles
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
from litgpt.data.openwebtext import OpenWebText
Expand All @@ -30,6 +31,7 @@
"LongForm",
"OpenWebText",
"SFTDataset",
"TextFiles",
"TinyLlama",
"TinyStories",
"get_sft_collate_fn",
Expand Down
133 changes: 133 additions & 0 deletions litgpt/data/text_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import glob
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from tqdm import tqdm
from typing import Optional

from torch.utils.data import DataLoader

from litgpt import Tokenizer
from litgpt.data import DataModule


@dataclass
class TextFiles(DataModule):
"""The TextFile data module used for pretraining.
Reads in text data from plaintext files contained in a data folder
and provides training and validation dataloaders that return batches of tokens.
Every sample is set to a fixed length.
"""
train_data_path: Path
"""The path to the data directory used for training that contains .txt files"""
val_data_path: Optional[Path] = None
"""The path to the data directory used for validation that
contains .txt files. Splits off data for validation from the
training set if None."""
seed: int = 42
"""The seed to use for shuffling the dataset."""
num_workers: int = 4
"""The number of workers to use for data loading."""

tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False)
batch_size: int = field(default=1, init=False, repr=False)
max_seq_length: int = field(default=-1, init=False, repr=False)

def __post_init__(self) -> None:
self.out_path_train = self.train_data_path / "train"
if self.val_data_path is None:
self.out_path_val = self.train_data_path / "val"
else:
self.out_path_val = Path(self.val_data_path) / "val"

def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None:
self.tokenizer = tokenizer
self.batch_size = batch_size
self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
from litdata import optimize

train_files = sorted(glob.glob(str(self.train_data_path / "*.txt")))
assert len(train_files) > 0, f"No .txt files found in train data {train_files}"

if self.val_data_path is not None:
self.val_data_path = Path(self.val_data_path)
val_files = sorted(glob.glob(str(self.val_data_path / "*.txt")))
assert len(val_files) > 0, f"No .txt files found in validation data {val_files}"
# train/test split. let's use only shard 0 for test split, rest train
else:
assert len(train_files) > 1, f"Expected at least two .txt files in {train_files}"
val_files, *train_files = train_files
val_files = [val_files]

# It's ok to use almost all CPUs here because this runs in a single process
num_workers = os.cpu_count() - 1
use_workers = min(num_workers, len(train_files))
if not Path(self.out_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
output_dir=str(self.out_path_train),
num_workers=use_workers,
chunk_bytes="50MB",
)
use_workers = min(num_workers, len(val_files))
if not Path(self.out_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=val_files,
output_dir=str(self.out_path_val),
num_workers=use_workers,
chunk_bytes="50MB",
)

def train_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader

train_dataset = StreamingDataset(
input_dir=str(self.out_path_train),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
drop_last=True,
)

train_dataloader = StreamingDataLoader(
train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

def val_dataloader(self) -> DataLoader:
from litdata.streaming import StreamingDataset, TokensLoader

val_dataset = StreamingDataset(
input_dir=str(self.out_path_val),
item_loader=TokensLoader(block_size=self.max_seq_length),
shuffle=True,
# Consider setting to False, but we would lose some samples due to truncation when world size > 1
drop_last=True,
)
val_dataloader = DataLoader(
val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return val_dataloader


def tokenize(filename: str, tokenizer: Tokenizer):
with open(filename, "r", encoding="utf-8") as file:
text = file.read()
text = text.strip()
yield tokenizer.encode(text, bos=True, eos=False)


def validate_tokenizer(tokenizer: Tokenizer) -> None:
if tokenizer is None:
raise ValueError(
"Tokenizer is None. If you are using this data module via `litgpt pretrain`, "
"please provide a valid `--tokenizer_dir` path."
)
5 changes: 4 additions & 1 deletion litgpt/data/tinystories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from litgpt import Tokenizer
from litgpt.data import DataModule
from litgpt.data.alpaca import download_if_missing
from litgpt.data.text_files import validate_tokenizer


@dataclass
Expand Down Expand Up @@ -56,6 +57,7 @@ def prepare_data(self) -> None:
num_workers = os.cpu_count() - 1

if not Path(self.data_path_train).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=train_files,
Expand All @@ -64,6 +66,7 @@ def prepare_data(self) -> None:
chunk_bytes="200MB",
)
if not Path(self.data_path_val).is_dir():
validate_tokenizer(self.tokenizer)
optimize(
fn=partial(tokenize, tokenizer=self.tokenizer),
inputs=[val_file],
Expand Down Expand Up @@ -103,7 +106,7 @@ def val_dataloader(self) -> DataLoader:


def tokenize(filename: str, tokenizer: Tokenizer):
with open(filename, "r") as f:
with open(filename, "r", encoding="utf-8") as f:
data = json.load(f)
global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"])
num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"])
Expand Down
Loading

0 comments on commit bd5dfd9

Please sign in to comment.