diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index ce7016c672..bb3b85c3e6 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -37,7 +37,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -46,9 +46,7 @@ jobs: - name: Install minimal dependencies run: | - # uv pip install . is not yet supported, only `-e .` - # https://github.com/astral-sh/uv/issues/1896 - uv pip install --system -e . + uv pip install --system . uv pip list # make sure all modules are still importable with only the minimal dependencies available modules=$( @@ -61,7 +59,7 @@ jobs: - name: Install all dependencies run: | - uv pip install --system -e '.[all,test]' 'lm_eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8b' + uv pip install --system '.[all,test]' 'lm_eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8b' uv pip list - name: Run tests diff --git a/README.md b/README.md index cfea1a2155..c06e792578 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ ✅  Optimized and efficient code: Flash Attention v2, multi-GPU support via fully-sharded data parallelism, [optional CPU offloading](tutorials/oom.md#do-sharding-across-multiple-gpus), and [TPU and XLA support](extensions/xla). -✅  [Pretraining](tutorials/pretraining.md), [finetuning](tutorials/finetune.md), and [inference](tutorials/inference.md) in various precision settings: FP32, FP16, BF16, and FP16/FP32 mixed. +✅  [Pretraining](tutorials/pretrain.md), [finetuning](tutorials/finetune.md), and [inference](tutorials/inference.md) in various precision settings: FP32, FP16, BF16, and FP16/FP32 mixed. ✅  [Configuration files](config_hub) for great out-of-the-box performance. @@ -35,13 +35,12 @@ ✅  [Quantization](tutorials/quantize.md): 4-bit floats, 8-bit integers, and double quantization. -✅  [Exporting](https://github.com/Lightning-AI/litgpt/blob/wip/tutorials/convert_lit_models.md) to other popular model weight formats. +✅  [Exporting](tutorials/convert_lit_models.md) to other popular model weight formats. -✅  Many popular datasets for [pretraining](tutorials/pretrain_tinyllama.md) and [finetuning](tutorials/prepare_dataset.md), and [support for custom datasets](tutorials/prepare_dataset.md#preparing-custom-datasets-for-instruction-finetuning). +✅  Many popular datasets for [pretraining](tutorials/pretrain.md) and [finetuning](tutorials/prepare_dataset.md), and [support for custom datasets](tutorials/prepare_dataset.md#preparing-custom-datasets-for-instruction-finetuning). ✅  Readable and easy-to-modify code to experiment with the latest research ideas. -  
  @@ -59,8 +58,6 @@ The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) temp - -  
  @@ -107,9 +104,17 @@ For more information, refer to the [download](tutorials/download_model_weights.m   + +> [!NOTE] +> We recommend starting with the **[Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs](tutorials/0_to_litgpt.md)** if you are looking to get started with using LitGPT. + + + +  + ## Finetuning and pretraining -LitGPT supports [pretraining](tutorials/pretrain_tinyllama.md) and [finetuning](tutorials/finetune.md) to optimize models on excisting or custom datasets. Below is an example showing how to finetune a model with LoRA: +LitGPT supports [pretraining](tutorials/pretrain.md) and [finetuning](tutorials/finetune.md) to optimize models on excisting or custom datasets. Below is an example showing how to finetune a model with LoRA: ```bash # 1) Download a pretrained model @@ -134,7 +139,7 @@ LitGPT also allows users to use configuration files in YAML format instead of sp ```bash litgpt finetune lora \ - --config https://github.com/Lightning-AI/litgpt/blob/wip/config_hub/finetune/llama-2-7b/lora.yaml + --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/llama-2-7b/lora.yaml ``` For added convenience, you can also manually override config file setting via the CLI: @@ -146,7 +151,7 @@ litgpt finetune lora \ --lora_r 4 ``` -You can browse the available configuration files [here](https://github.com/Lightning-AI/litgpt/tree/main/config_hub). +You can browse the available configuration files [here](config_hub).   @@ -324,8 +329,14 @@ If you have general questions about building with LitGPT, please [join our Disco ## Tutorials, how-to guides, and docs + +> [!NOTE] +> We recommend starting with the **[Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs](tutorials/0_to_litgpt.md)** if you are looking to get started with using LitGPT. + +Tutorials and in-depth feature documentation can be found below: + - Finetuning, incl. LoRA, QLoRA, and Adapters ([tutorials/finetune.md](tutorials/finetune.md)) -- Pretraining ([tutorials/pretrain_tinyllama.md](tutorials/pretrain_tinyllama.md)) +- Pretraining ([tutorials/pretrain.md](tutorials/pretrain.md)) - Model evaluation ([tutorials/evaluation.md](tutorials/evaluation.md)) - Supported and custom datasets ([tutorials/prepare_dataset.md](tutorials/prepare_dataset.md)) - Quantization ([tutorials/quantize.md](tutorials/quantize.md)) @@ -401,4 +412,3 @@ If you use LitGPT in your research, please cite the following work: ## License LitGPT is released under the [Apache 2.0](https://github.com/Lightning-AI/litgpt/blob/main/LICENSE) license. - diff --git a/config_hub/finetune/README.md b/config_hub/finetune/README.md index 8aed153105..fc82e0854b 100644 --- a/config_hub/finetune/README.md +++ b/config_hub/finetune/README.md @@ -1,6 +1,6 @@ ## Config files -The table below lists the performances you can expect from the provided config files. Note that you can achieve lower memory consumption by lowering the micro batch size as needed. In addition, you can lower the rank (`lora_r`) in the LoRA configuration files and disable LoRA for certain layers (for example, setting `lora_projection` and other LoRA layer-specific parameters to `false`). +The table below lists the performances you can expect from the provided config files. Note that you can achieve lower memory consumption by lowering the micro batch size as needed. In addition, you can lower the rank (`lora_r`) in the LoRA configuration files and disable LoRA for certain layers (for example, setting `lora_projection` and other LoRA layer-specific parameters to `false`). For more information, see the [Dealing with out-of-memory (OOM) errors](../../tutorials/oom.md) on lowering the memory requirements.   @@ -11,29 +11,56 @@ For more information, see the [Dealing with out-of-memory (OOM) errors](../../tu | falcon-7b/lora.yaml | 7B | Alpaca 2k | 4 | 0.945 | 16.69 GB | 512 | 2 | bfloat16 | 24.88 min (1xA10G) | | falcon-7b/qlora.yaml | 7B | Alpaca 2k | 4 | 0.993 | 9.44 GB | 512 | 2 | bfloat16 | 50.76 min (1xA10G) | | | | | | | | | | | | -| gemma-2b/lora.yaml | 2B | Alpaca 2k | 3 | 1.476 | 12.62 GB | 512 | 2 | bfloat16 | 18.31 min (1xA10G) | -| gemma-2b/qlora.yaml | 2B | Alpaca 2k | 3 | 1.626 | 11.51 GB | 512 | 2 | bfloat16 | 25.29 min (1xA10G) | -| gemma-2b/full.yaml | 2B | Alpaca 2k | 0.35 | 1.046 | 18.47 GB | 512 | 2 | bfloat16 | 16.79 min (2xA10G) | +| gemma-2b/lora.yaml | 2B | Alpaca 2k | 2 | 1.476 | 12.62 GB | 512 | 2 | bfloat16 | 9.29 min (1xA10G) | +| gemma-2b/qlora.yaml | 2B | Alpaca 2k | 2 | 0.981 | 11.59 GB | 512 | 2 | bfloat16 | 12.90 min (1xA10G) | +| gemma-2b/full.yaml | 2B | Alpaca 2k | 0.35 | 0.990 | 17.43 GB | 512 | 1 | bfloat16 | 13.61 min (4xA10G) | +| | | | | | | | | | | +| gemma-7b/lora.yaml | 7B | Alpaca 2k | 2 | 0.903 | 25.30 GB | 512 | 1 | bfloat16 | 11.47 min (1xA100) | +| gemma-7b/qlora.yaml | 7B | Alpaca 2k | 2 | 0.951 | 17.31 GB | 512 | 1 | bfloat16 | 23.46 min (1xA100) | | | | | | | | | | | | | llama-2-7b/lora.yaml | 7B | Alpaca 2k | 4 | 0.802 | 19.77 GB | 512 | 2 | bfloat16 | 32.75 min (A10G) | | llama-2-7b/qlora.yaml | 7B | Alpaca 2k | 4 | 0.814 | 13.68 GB | 512 | 2 | bfloat16 | 45.68 min (A10G) | | llama-2-7b/full.yaml | 7B | Alpaca 2k | 1 | 0.941 | 26.81 GB | 512 | 4 | bfloat16 | 1.78 min (4xA100) | | | | | | | | | | | | -| mistral-7b/lora.yaml | 7B | Alpaca 2k | 4 | 0.796 | 20.65 GB | 512 | 2 | bfloat16 | 31.04 min (1xA10G) | -| mistral-7b/qlora.yaml | 7B | Alpaca 2k | 4 | 0.803 | 14.29 GB | 512 | 2 | bfloat16 | 44.69 min (1xA10G) | +| mistral-7b/lora.yaml (v0.1) | 7B | Alpaca 2k | 4 | 0.796 | 20.65 GB | 512 | 2 | bfloat16 | 31.04 min (1xA10G) | +| mistral-7b/qlora.yaml (v0.1) | 7B | Alpaca 2k | 4 | 0.803 | 14.29 GB | 512 | 2 | bfloat16 | 44.69 min (1xA10G) | +| | | | | | | | | | | +| mistral-7b-v0.2/lora.yaml | 7B | Alpaca 2k | 4 | 0.801 | 20.65 GB | 512 | 2 | bfloat16 | 30.96 min (1xA10G) | +| mistral-7b-v0.2/qlora.yaml | 7B | Alpaca 2k | 4 | 0.813 | 14.29 GB | 512 | 2 | bfloat16 | 44.68 min (1xA10G) | | | | | | | | | | | | | phi-2/lora.yaml | 2B | Alpaca 2k | 1 | 0.832 | 13.98 GB | 512 | 4 | bfloat16 | 3.82 min (1xA10G) | | phi-2/qlora.yaml | 2B | Alpaca 2k | 1 | 0.846 | 14.27 GB | 512 | 4 | bfloat16 | 4.55 min (1xA10G) | | phi-2/full.yaml | 2B | Alpaca 2k | 1 | 0.937 | 14.44 GB | 512 | 4 | bfloat16 | 13.00 min (1xA10G) | | | | | | | | | | | | -| stablelm-base-alpha-3b/lora.yaml | 7B | Alpaca 2k | 4 | 1.367 | 8.58 GB | 512 | 2 | bfloat16 | 13.02 min (1xA10G) | -| stablelm-base-alpha-3b/qlora.yaml | 7B | Alpaca 2k | 4 | 1.392 | 5.24 GB | 512 | 2 | bfloat16 | 25.71 min (1xA10G) | -| stablelm-base-alpha-3b/full.yaml | 7B | Alpaca 2k | 1 | 1.494 | 21.23 GB | 512 | 1 | bfloat16 | 72.72 min (2xA10G) | +| stablelm-base-alpha-3b/lora.yaml | 3B | Alpaca 2k | 4 | 1.367 | 8.58 GB | 512 | 2 | bfloat16 | 13.02 min (1xA10G) | +| stablelm-base-alpha-3b/qlora.yaml | 3B | Alpaca 2k | 4 | 1.392 | 5.24 GB | 512 | 2 | bfloat16 | 25.71 min (1xA10G) | +| stablelm-base-alpha-3b/full.yaml | 3B | Alpaca 2k | 1 | 1.494 | 21.23 GB | 512 | 1 | bfloat16 | 72.72 min (2xA10G) | | | | | | | | | | | | | tiny-llama/lora.yaml | 1.1B | Alpaca 2k | 3 | 1.038 | 13.50 GB | 512 | 8 | bfloat16 | 8.06 min (1xA10G) | | tiny-llama/qlora.yaml | 1.1B | Alpaca 2k | 3 | 1.056 | 16.24 GB | 512 | 8 | bfloat16 | 8.74 min (1xA10G) | | tiny-llama/full.yaml | 1.1B | Alpaca 2k | 1 | 1.105 | 14.10 GB | 512 | 4 | bfloat16 | 2.59 min (1xA10G) |   +## Extending the context length If you require a longer sequence length than the one used in a given config file, you can either edit the `max_seq_length` in the config file or pass an additional argument when running the finetuning command, for example, `--max_seq_length 4096` to override the sequence length provided in the config file. + +  +## Training on GPUs without bfloat16 support + +If you are training on GPUs without bfloat-16 support, you need to change the `precision` option to `16-true` (16-bit floating point precision) or `16-mixed` (16/32-bit mixed precision) training: + +```bash +litgpt finetune lora \ + --config config_hub/finetune/phi-2/lora.yaml \ + --precision 16-true +``` +or + +```bash +litgpt finetune lora \ + --config config_hub/finetune/phi-2/lora.yaml \ + --precision 16-mixed +``` + +Note that `16-true` is more compute and memory-efficient, but it can sometimes lead to training convergence issues. In this case, it's recommended to use `16-mixed`. diff --git a/config_hub/finetune/gemma-2b/full.yaml b/config_hub/finetune/gemma-2b/full.yaml index 509a2675e4..77f20658ca 100644 --- a/config_hub/finetune/gemma-2b/full.yaml +++ b/config_hub/finetune/gemma-2b/full.yaml @@ -9,7 +9,7 @@ out_dir: out/finetune/full-gemma-2b precision: bf16-true # How many devices/GPUs to use. (type: Union[int, str], default: 1) -devices: 1 +devices: 4 # Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. data: @@ -32,7 +32,7 @@ train: log_interval: 1 # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) - global_batch_size: 6 + global_batch_size: 16 # Number of samples per data-parallel rank (type: int, default: 4) micro_batch_size: 1 @@ -41,13 +41,13 @@ train: lr_warmup_steps: 100 # Number of epochs to train on (type: Optional[int], default: 5) - epochs: 3 + epochs: 1 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: # Limits the number of optimizer steps to run. (type: Optional[int], default: null) - max_steps: + max_steps: 50 # Limits the length of samples. Off by default (type: Optional[int], default: null) max_seq_length: 512 diff --git a/config_hub/finetune/gemma-2b/lora.yaml b/config_hub/finetune/gemma-2b/lora.yaml index 72d56fc22b..c9f912a47c 100644 --- a/config_hub/finetune/gemma-2b/lora.yaml +++ b/config_hub/finetune/gemma-2b/lora.yaml @@ -15,7 +15,7 @@ quantize: devices: 1 # The LoRA rank. (type: int, default: 8) -lora_r: 16 +lora_r: 8 # The LoRA alpha. (type: int, default: 16) lora_alpha: 16 @@ -71,7 +71,7 @@ train: lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) - epochs: 4 + epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: diff --git a/config_hub/finetune/gemma-2b/qlora.yaml b/config_hub/finetune/gemma-2b/qlora.yaml index 4c26c9cee8..dc15fe90d3 100644 --- a/config_hub/finetune/gemma-2b/qlora.yaml +++ b/config_hub/finetune/gemma-2b/qlora.yaml @@ -71,7 +71,7 @@ train: lr_warmup_steps: 200 # Number of epochs to train on (type: Optional[int], default: 5) - epochs: 4 + epochs: 2 # Total number of tokens to train on (type: Optional[int], default: null) max_tokens: diff --git a/config_hub/finetune/gemma-7b/lora.yaml b/config_hub/finetune/gemma-7b/lora.yaml new file mode 100644 index 0000000000..d7d56f5b5c --- /dev/null +++ b/config_hub/finetune/gemma-7b/lora.yaml @@ -0,0 +1,122 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/google/gemma-7b + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/qlora-gemma-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 16 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.1 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: true + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: true + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + val_split_fraction: 0.03847 + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 800 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 6 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 1 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 200 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 2 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 25 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/gemma-7b/qlora.yaml b/config_hub/finetune/gemma-7b/qlora.yaml new file mode 100644 index 0000000000..7d4a2c634c --- /dev/null +++ b/config_hub/finetune/gemma-7b/qlora.yaml @@ -0,0 +1,122 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/google/gemma-7b + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/qlora-gemma-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: bnb.nf4 + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 16 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.1 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: true + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: true + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: true + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: true + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + val_split_fraction: 0.03847 + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 800 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 6 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 1 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 200 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 2 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 25 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/mistral-7b-v0.2/lora.yaml b/config_hub/finetune/mistral-7b-v0.2/lora.yaml new file mode 100644 index 0000000000..aad8f7c986 --- /dev/null +++ b/config_hub/finetune/mistral-7b-v0.2/lora.yaml @@ -0,0 +1,121 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2 + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/lora-mistral-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 32 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.05 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: false + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: false + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run. (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples. Off by default (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/finetune/mistral-7b-v0.2/qlora.yaml b/config_hub/finetune/mistral-7b-v0.2/qlora.yaml new file mode 100644 index 0000000000..e2f5c3aafc --- /dev/null +++ b/config_hub/finetune/mistral-7b-v0.2/qlora.yaml @@ -0,0 +1,123 @@ + +# The path to the base model's checkpoint directory to load for finetuning. (type: , default: checkpoints/stabilityai/stablelm-base-alpha-3b) +checkpoint_dir: checkpoints/unsloth/Mistral-7B-v0.2 + +# Directory in which to save checkpoints and logs. (type: , default: out/lora) +out_dir: out/finetune/qlora-mistral-7b + +# The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) +precision: bf16-true + +# If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information. (type: Optional[Literal['nf4', 'nf4-dq', 'fp4', 'fp4-dq', 'int8-training']], default: null) +quantize: bnb.nf4 + +# How many devices/GPUs to use. (type: Union[int, str], default: 1) +devices: 1 + +# The LoRA rank. (type: int, default: 8) +lora_r: 32 + +# The LoRA alpha. (type: int, default: 16) +lora_alpha: 16 + +# The LoRA dropout value. (type: float, default: 0.05) +lora_dropout: 0.05 + +# Whether to apply LoRA to the query weights in attention. (type: bool, default: True) +lora_query: true + +# Whether to apply LoRA to the key weights in attention. (type: bool, default: False) +lora_key: false + +# Whether to apply LoRA to the value weights in attention. (type: bool, default: True) +lora_value: true + +# Whether to apply LoRA to the output projection in the attention block. (type: bool, default: False) +lora_projection: false + +# Whether to apply LoRA to the weights of the MLP in the attention block. (type: bool, default: False) +lora_mlp: false + +# Whether to apply LoRA to output head in GPT. (type: bool, default: False) +lora_head: false + +# Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``. +data: + class_path: litgpt.data.Alpaca2k + init_args: + mask_prompt: false + val_split_fraction: 0.05 + prompt_style: alpaca + ignore_index: -100 + seed: 42 + num_workers: 4 + download_dir: data/alpaca2k + +# Training-related arguments. See ``litgpt.args.TrainArgs`` for details +train: + + # Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) + save_interval: 200 + + # Number of iterations between logging calls (type: int, default: 1) + log_interval: 1 + + # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 128) + global_batch_size: 8 + + # Number of samples per data-parallel rank (type: int, default: 4) + micro_batch_size: 2 + + # Number of iterations with learning rate warmup active (type: int, default: 100) + lr_warmup_steps: 10 + + # Number of epochs to train on (type: Optional[int], default: 5) + epochs: 4 + + # Total number of tokens to train on (type: Optional[int], default: null) + max_tokens: + + # Limits the number of optimizer steps to run (type: Optional[int], default: null) + max_steps: + + # Limits the length of samples (type: Optional[int], default: null) + max_seq_length: 512 + + # Whether to tie the embedding weights with the language modeling head weights (type: Optional[bool], default: null) + tie_embeddings: + + # (type: float, default: 0.0003) + learning_rate: 0.0002 + + # (type: float, default: 0.02) + weight_decay: 0.0 + + # (type: float, default: 0.9) + beta1: 0.9 + + # (type: float, default: 0.95) + beta2: 0.95 + + # (type: Optional[float], default: null) + max_norm: + + # (type: float, default: 6e-05) + min_lr: 6.0e-05 + +# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details +eval: + + # Number of optimizer steps between evaluation calls (type: int, default: 100) + interval: 100 + + # Number of tokens to generate (type: Optional[int], default: 100) + max_new_tokens: 100 + + # Number of iterations (type: int, default: 100) + max_iters: 100 + +# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: csv) +logger_name: csv + +# The random seed to use for reproducibility. (type: int, default: 1337) +seed: 1337 diff --git a/config_hub/pretrain/debug.yaml b/config_hub/pretrain/debug.yaml index 346f4111b8..77ad6b13ad 100644 --- a/config_hub/pretrain/debug.yaml +++ b/config_hub/pretrain/debug.yaml @@ -20,7 +20,7 @@ initial_checkpoint_dir: resume: false # Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. -data: OpenWebText +data: TinyStories # Training-related arguments. See ``litgpt.args.TrainArgs`` for details train: diff --git a/extensions/thunder/README.md b/extensions/thunder/README.md index e1e5f7bdb7..df2d0461a7 100644 --- a/extensions/thunder/README.md +++ b/extensions/thunder/README.md @@ -40,28 +40,9 @@ print(forward_trace) @torch.no_grad() @no_autocast() def augmented_forward_fn(*args): - # args: "Collection" - t0, \ - t1, \ - t2, \ - t3, \ - t4, \ - t5, \ - t6, \ - t7, \ - t8, \ - t9, \ - t10, \ - t11, \ - t12, \ - t13, \ - t14, \ - t15, \ - t16, \ - t17, \ - t18, \ - t19, \ - = args + # args: "Collection" + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, \ + t18, t19, = args del args t24 = torch.nn.functional.embedding(t0, t19, None, None, 2.0, False, False) # t24: "cuda:0 f32[2, 5, 4096]" t20 = torch_slice_prim_impl(t1, [0, 0], [5, 128], [1, 1]) # t20: "cuda:0 f32[5, 128]" @@ -245,92 +226,21 @@ print(backward_trace) @torch.no_grad() @no_autocast() def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, \ - C1, \ - = saved_for_backward + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, C1, = saved_for_backward clear_collection(saved_for_backward) del saved_for_backward - t178, \ - = cotangents + t178, = cotangents clear_collection(cotangents) del cotangents - t0, \ - t101, \ - t104, \ - t105, \ - t114, \ - t136, \ - t138, \ - t139, \ - t140, \ - t141, \ - t142, \ - t144, \ - t146, \ - t15, \ - t152, \ - t155, \ - t156, \ - t157, \ - t158, \ - t16, \ - t164, \ - t166, \ - t17, \ - t172, \ - t175, \ - t176, \ - t18, \ - t24, \ - t3, \ - t30, \ - t33, \ - t34, \ - t4, \ - t43, \ - t49, \ - t5, \ - t51, \ - t6, \ - t65, \ - t67, \ - t68, \ - t69, \ - t7, \ - t70, \ - t71, \ - t73, \ - t75, \ - t8, \ - t81, \ - t84, \ - t85, \ - t86, \ - t87, \ - t9, \ - t93, \ - t95, \ - = C0 + t0, t101, t104, t105, t114, t136, t138, t139, t140, t141, t142, t144, t146, \ + t15, t152, t155, t156, t157, t158, t16, t164, t166, t17, t172, t175, t176, t18, \ + t24, t3, t30, t33, t34, t4, t43, t49, t5, t51, t6, t65, t67, t68, t69, t7, t70, \ + t71, t73, t75, t8, t81, t84, t85, t86, t87, t9, t93, t95, = C0 clear_collection(C0) del C0 - b1, \ - b2, \ - b41, \ - b91, \ - f101, \ - f106, \ - f40, \ - f42, \ - f51, \ - f56, \ - f6, \ - f90, \ - f92, \ - i0, \ - i23, \ - i73, \ + b1, b2, b41, b91, f101, f106, f40, f42, f51, f56, f6, f90, f92, i0, i23, i73, \ = C1 clear_collection(C1) del C1 @@ -528,7 +438,7 @@ We provide ready-to-use Fabric strategies that integrate Thunder DDP|FSDP. Under ```python model = thunder.distributed.ddp(model) -# or +# or # model = thunder.distributed.fsdp(model) model = thunder.jit(model) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index 2afa7290e1..4efbe27c60 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -45,17 +45,35 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, + jit: bool = True, executors: Optional[Tuple[Union["Executor", str], ...]] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, **kwargs: Any, ): + r"""Strategy for Replicated Data Parallel provided by Lightning Thunder. + + .. warning:: This is an :ref:`experimental ` feature. + + Arguments: + jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually + jitting a function that includes the model. + + executors: The list of Thunder executors to enable. They can be either string aliases for the executors + or the actual executor instances. + + \**kwargs: See available parameters in :func:`thunder.distributed.ddp`. + + """ if not _THUNDER_AVAILABLE: raise ModuleNotFoundError(str(_THUNDER_AVAILABLE)) super().__init__(accelerator=accelerator, checkpoint_io=checkpoint_io, precision=precision) self.parallel_devices = parallel_devices self.cluster_environment: Optional[ClusterEnvironment] = cluster_environment + if not jit and executors is not None: + raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") + self.jit = jit self.executors = _validate_executors(executors) self._num_nodes = 1 self._process_group_backend: Optional[str] = process_group_backend @@ -111,8 +129,25 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> Module: import thunder - module = thunder.distributed.ddp(module, **self._ddp_kwargs) - + if (cd := thunder.compile_data(module)) is not None: + # the module was already jitted + if thunder.compile_stats(module).last_traces is not None: + raise RuntimeError( + "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" + " DDP transform. Remove the `forward` call before `fabric.setup()`" + ) + assert cd.is_module # sanity check + ddp_module = thunder.distributed.ddp(cd.fn, **self._ddp_kwargs) + # update the compile data state + cd.fn = ddp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = ddp_module + cd.process_group_for_ddp = ddp_module.process_group_for_ddp + return module + else: + module = thunder.distributed.ddp(module, **self._ddp_kwargs) + if not self.jit: + return module return thunder.jit(module, executors=self.executors) @override diff --git a/extensions/thunder/strategies/thunder_fsdp.py b/extensions/thunder/strategies/thunder_fsdp.py index 6fd2200d70..d4e60c0085 100644 --- a/extensions/thunder/strategies/thunder_fsdp.py +++ b/extensions/thunder/strategies/thunder_fsdp.py @@ -54,12 +54,54 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, precision: Optional[Precision] = None, + jit: bool = True, + executors: Optional[Tuple[Union["Executor", str], ...]] = None, sharding_strategy: "_FSDP_TYPE" = "ZERO3", bucketing_strategy: "_BUCKETING_STRATEGY" = "NONE", - executors: Optional[Tuple[Union["Executor", str], ...]] = None, state_dict_type: Literal["full", "sharded"] = "sharded", **kwargs: Any, ): + r"""Strategy for Fully Sharded Data Parallel provided by Lightning Thunder. + + .. warning:: This is an :ref:`experimental ` feature. + + Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. + + Arguments: + jit: Whether to automatically call ``thunder.jit(model)`` if necessary. Disable this if you are manually + jitting a function that includes the model. + + executors: The list of Thunder executors to enable. They can be either string aliases for the executors + or the actual executor instances. + + sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination + of them: + + - ``"ZERO3"``: Shards model parameters, gradients, and optimizer states (default). + - ``"ZERO2"``: Shards gradients and optimizer states only. Model parameters get replicated. + + Also accepts a :class:`thunder.distributed.FSDPType` enum value. + + bucketing_strategy: Enables combining the collective operations for sets of layers. + + - ``"NONE"``: No bucketing (default). + - ``"LAYER"``: Create buckets per layer class. + - ``"BLOCK"``: Create buckets per layer block. + + Also accepts a :class:`thunder.distributed.FSDPBucketingStrategy` enum value. + + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. + + - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file + (default). + - ``"sharded"``: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is + a folder with as many files as the world size. + + \**kwargs: See available parameters in :func:`thunder.distributed.fsdp`. + + """ if not _TORCH_GREATER_EQUAL_2_2: raise ImportError("Thunder's FSDP strategy requires PyTorch 2.2 or higher.") if not _THUNDER_AVAILABLE: @@ -77,6 +119,9 @@ def __init__( if isinstance(bucketing_strategy, str) else bucketing_strategy ) + if not jit and executors is not None: + raise ValueError(f"Passing executors={executors} doesn't have an effect with `jit={jit}`") + self.jit = jit self.executors = _validate_executors(executors) self._state_dict_type = state_dict_type self._fsdp_kwargs = kwargs @@ -115,16 +160,37 @@ def setup_environment(self) -> None: def setup_module(self, module: Module) -> Module: import thunder - module = thunder.distributed.fsdp( - module, - device=self.root_device, - sharding_strategy=self.sharding_strategy, - bucketing_strategy=self.bucketing_strategy, - **self._fsdp_kwargs, - ) - - # NOTE @IvanYaschuck says that `fsdp(jit(model))` could be supported in the future so that the user owns the `jit` call. - # we would still `jit(fsdp(undo_jit(jit(model))))` internally + if (cd := thunder.compile_data(module)) is not None: + # the module was already jitted + if thunder.compile_stats(module).last_traces is not None: + raise RuntimeError( + "You already called `thunder.jit()` and generated an execution trace. It's too late to apply the" + " FSDP transform. Remove the `forward` call before `fabric.setup()`" + ) + assert cd.is_module # sanity check + fsdp_module = thunder.distributed.fsdp( + cd.fn, + device=self.root_device, + sharding_strategy=self.sharding_strategy, + bucketing_strategy=self.bucketing_strategy, + **self._fsdp_kwargs, + ) + # update the compile data state + cd.fn = fsdp_module + assert hasattr(cd, "_processed_function") # sanity check + cd._processed_function = fsdp_module + cd.process_group_for_ddp = fsdp_module.process_group_for_ddp + return module + else: + module = thunder.distributed.fsdp( + module, + device=self.root_device, + sharding_strategy=self.sharding_strategy, + bucketing_strategy=self.bucketing_strategy, + **self._fsdp_kwargs, + ) + if not self.jit: + return module return thunder.jit(module, executors=self.executors) @override diff --git a/extensions/thunder/unsloth/executor.py b/extensions/thunder/unsloth/executor.py index a638af079f..5b13c4dee2 100644 --- a/extensions/thunder/unsloth/executor.py +++ b/extensions/thunder/unsloth/executor.py @@ -48,8 +48,7 @@ def unsloth_cross_entropy_meta(logits: TensorProxy, labels: TensorProxy) -> Tupl def unsloth_cross_entropy_backward_impl(dlosses: Tensor, logits: Tensor, labels: Tensor, logsumexp: Tensor) -> Tensor: - # clone() because the kernel writes the grads in the logits. - # If it works, we can remove this it, but it's not a thing we generally anticipate and support right now. + # clone() because the kernel writes the grads in the logits return kernels.cross_entropy_loss._cross_entropy_backward_impl(dlosses, logits.clone(), logsumexp, labels) @@ -152,17 +151,10 @@ def unsloth_cross_entropy_grad( """ -def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy: - return TensorProxy(like=e) - - -def swiglu_forward(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor: +def swiglu(e: torch.Tensor, g: torch.Tensor) -> torch.Tensor: return torch.nn.functional.silu(e) * g -swiglu = unsloth_ex.register_operator("swiglu", meta=swiglu_forward_meta, fn=swiglu_forward) - - from litgpt.model import LLaMAMLP as OriginalLLaMAMLP @@ -170,16 +162,20 @@ class ThunderLLaMAMLP(OriginalLLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - # There's no `register_operator` for Modules and `swiglu_forward` is not a torch symbol that we can register to - # For now, some duplication and monkey patching is required - fn = swiglu if thunder.core.interpreter.is_jitting() else swiglu_forward - x = fn(x_fc_1, x_fc_2) + x = swiglu(x_fc_1, x_fc_2) return self.proj(x) litgpt.model.LLaMAMLP = ThunderLLaMAMLP +def swiglu_forward_meta(e: TensorProxy, g: TensorProxy) -> TensorProxy: + return TensorProxy(like=e) + + +litgpt_swiglu = unsloth_ex.register_operator("litgpt_swiglu", meta=swiglu_forward_meta, fn=swiglu, replaces=swiglu) + + unsloth_swiglu_forward = unsloth_ex.register_operator( "unsloth_swiglu_forward", meta=swiglu_forward_meta, fn=lambda *args: kernels.swiglu_fg_kernel(*args) ) @@ -217,7 +213,7 @@ def unsloth_swiglu_grad(e: TensorProxy, g: TensorProxy) -> TensorProxy: unsloth_ex.register_implementation( - swiglu, + litgpt_swiglu, checker=swiglu_to_unsloth_checker, execution_transform=unsloth_swiglu_forward, grad_transform=unsloth_swiglu_grad, diff --git a/extensions/thunder/unsloth/pretrain.py b/extensions/thunder/unsloth/pretrain.py index 3526dd498b..3bb0166d38 100644 --- a/extensions/thunder/unsloth/pretrain.py +++ b/extensions/thunder/unsloth/pretrain.py @@ -439,11 +439,10 @@ def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resu def jit(fn: Callable) -> Any: import thunder + from executor import unsloth_ex from thunder.executors.sdpaex import sdpa_ex from thunder.executors.torch_compile import torch_compile_executor - from executor import unsloth_ex - return thunder.jit( fn, executors=[sdpa_ex, unsloth_ex, torch_compile_executor, thunder.nvfuser_executor, thunder.pytorch_executor] ) diff --git a/extensions/xla/README.md b/extensions/xla/README.md index d71a0e0f2c..6182f24d54 100644 --- a/extensions/xla/README.md +++ b/extensions/xla/README.md @@ -78,7 +78,7 @@ export PJRT_DEVICE=TPU > An extensive guide on setup and available options can be found [here](https://cloud.google.com/tpu/docs/v4-users-guide). Since a new machine was created, you may need to download pretrained weights. -They can be copied to the machine using `gcloud compute tpus tpu-vm scp`, or you can follow the steps described in our [downloading guide](download_model_weights.md). +They can be copied to the machine using `gcloud compute tpus tpu-vm scp`, or you can follow the steps described in our [downloading guide](../../tutorials/download_model_weights.md). It is also recommended to set up a persistent disk from which to load checkpoints. Follow [this guide](https://cloud.google.com/tpu/docs/setup-persistent-disk#setting_up_a_tpu_vm_and_a_persistent_disk) to do so. diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 3bed9100c6..295470c932 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -66,6 +66,8 @@ def forward( mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) @@ -149,7 +151,8 @@ def scaled_dot_product_attention( return y + self.gating_factor * ay def reset_parameters(self) -> None: - torch.nn.init.zeros_(self.gating_factor) + if hasattr(self, "gating_factor"): + torch.nn.init.zeros_(self.gating_factor) def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with older checkpoints.""" diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index d7161e3b4f..665527f053 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -80,6 +80,35 @@ def __init__(self, config: Config) -> None: self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] + return self.lm_head(x) # (b, t, vocab_size) + @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: return cls(Config.from_name(name, **kwargs)) @@ -181,6 +210,8 @@ def __init__(self, config: Config) -> None: self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.config = config + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { @@ -199,7 +230,7 @@ class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) diff --git a/litgpt/args.py b/litgpt/args.py index d6ce527d36..b227ffe3f6 100644 --- a/litgpt/args.py +++ b/litgpt/args.py @@ -1,5 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. - +import math from dataclasses import dataclass from typing import Optional @@ -16,8 +16,10 @@ class TrainArgs: """Number of samples between optimizer steps across data-parallel ranks""" micro_batch_size: int = 4 """Number of samples per data-parallel rank""" - lr_warmup_steps: int = 100 + lr_warmup_steps: Optional[int] = 100 """Number of iterations with learning rate warmup active""" + lr_warmup_fraction: Optional[float] = None + """The fraction of an epoch to use for learning rate warmup""" epochs: Optional[int] = None """Number of epochs to train on""" # TODO: `pretrain` is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs? @@ -38,6 +40,14 @@ class TrainArgs: max_norm: Optional[float] = None min_lr: float = 6e-5 + def __post_init__(self) -> None: + if self.lr_warmup_fraction and self.lr_warmup_steps: + raise ValueError( + "Can't provide both `--train.lr_warmup_fraction` and `--train.lr_warmup_steps`. Choose one." + ) + if self.lr_warmup_fraction and not (0 <= self.lr_warmup_fraction <= 1): + raise ValueError("`--train.lr_warmup_fraction` must be between 0 and 1.") + def gradient_accumulation_iters(self, devices: int) -> int: """Number of iterations between gradient synchronizations""" gradient_accumulation_iters = self.batch_size(devices) // self.micro_batch_size @@ -50,6 +60,14 @@ def batch_size(self, devices: int) -> int: assert batch_size > 0 return batch_size + def warmup_iters(self, devices: int, max_iters: int, train_dataloader) -> int: + """Number of iterations to warm up the learning rate.""" + if self.lr_warmup_fraction: + return min(max_iters, math.ceil(self.lr_warmup_fraction * len(train_dataloader))) + if self.lr_warmup_steps: + return min(max_iters, self.lr_warmup_steps * self.gradient_accumulation_iters(devices)) + return 0 + @dataclass class EvalArgs: diff --git a/litgpt/chat/base.py b/litgpt/chat/base.py index eb31205d5d..81229ddd2a 100644 --- a/litgpt/chat/base.py +++ b/litgpt/chat/base.py @@ -132,7 +132,6 @@ def main( fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" diff --git a/litgpt/config.py b/litgpt/config.py index e188d0feff..caad1454b9 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -1387,6 +1387,24 @@ def norm_class(self) -> Type: copy["name"] = c["name"].format(kind) copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) +configs.append( + # https://huggingface.co/unsloth/mistral-7b-v0.2/blob/main/config.json + dict( + name="Mistral-7B-v0.2", + hf_config=dict(org="unsloth", name="Mistral-7B-v0.2"), + padded_vocab_size=32000, + block_size=32768, + n_layer=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-05, + mlp_class_name="LLaMAMLP", + intermediate_size=14336, + ) +) configs.append( # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json dict( diff --git a/litgpt/data/json_data.py b/litgpt/data/json_data.py index 8ce5b63368..541678b93f 100644 --- a/litgpt/data/json_data.py +++ b/litgpt/data/json_data.py @@ -18,8 +18,8 @@ class JSON(DataModule): """Loads JSON or JSONL data for supervised finetuning.""" json_path: Path - """A path to a JSON file or a directory with `train.json` and `val.json` containing the data. - The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output', + """A path to a JSON file or a directory with `train.json` and `val.json` containing the data. + The file(s) should contain a list of samples (dicts). Each dict must have the keys 'instruction' and 'output', and can optionally have a key 'input' (see Alpaca).""" mask_prompt: bool = False """Whether to mask the prompt section from the label (with ``ignore_index``).""" diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index 40ab0a40ff..90fce42341 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -1,23 +1,18 @@ -"""https://github.com/karpathy/llama2.c/blob/b3c4b6/tinystories.py""" - +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import glob import json import os -import random -from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, field from functools import partial from pathlib import Path from typing import Optional -import numpy as np -import torch -from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.data import DataLoader from tqdm import tqdm +from litgpt import Tokenizer +from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing -from litgpt.data.base import DataModule -from litgpt.tokenizer import Tokenizer @dataclass @@ -27,155 +22,119 @@ class TinyStories(DataModule): Provides training and validation dataloaders that return batches of tokens. Every sample is set to a fixed length. """ - path: Path = Path("data/") - """Path to the data directory where data will be downloaded and preprocessed""" - num_workers: int = 0 - """How many DataLoader processes to use for loading.""" + data_path: Path = Path("data/tinystories") + """The path to the data directory, containing two folders 'train' and 'val' + which are the output of the preprocessing step.""" seed: int = 42 - """The random seed for creating the train/val splits and shuffling the dataset.""" + """The seed to use for shuffling the dataset.""" + num_workers: int = 8 + """The number of workers to use for the dataloaders.""" tokenizer: Optional[Tokenizer] = field(default=None, init=False, repr=False) batch_size: int = field(default=1, init=False, repr=False) max_seq_length: int = field(default=-1, init=False, repr=False) - train_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False) - test_dataset: Optional[torch.utils.data.Dataset] = field(default=None, init=False, repr=False) + + def __post_init__(self) -> None: + self.data_path_train = self.data_path / "train" + self.data_path_val = self.data_path / "val" def connect(self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: int = -1) -> None: self.tokenizer = tokenizer self.batch_size = batch_size - self.max_seq_length = max_seq_length + self.max_seq_length = max_seq_length + 1 # Increase by one because we need the next token as well def prepare_data(self) -> None: - download(self.path) - assert self.tokenizer is not None - pretokenize(self.path, self.tokenizer) - - def setup(self, stage: str = "") -> None: - # the .bin files are right along the .json files - bin_dir = self.path / "TinyStories_all_data" - shard_filenames = sorted(glob.glob(str(bin_dir / "*.bin"))) - assert len(shard_filenames) > 0, f"No bin files found in {bin_dir}" - assert len(shard_filenames) > 1, f"Expected at least two bins in {bin_dir}" + from litdata import optimize + + download(self.data_path) + + files = sorted(glob.glob(str(self.data_path / "TinyStories_all_data" / "*.json"))) + assert len(files) > 0, f"No json files found in {files}" + assert len(files) > 1, f"Expected at least two json files in {files}" # train/test split. let's use only shard 0 for test split, rest train - va_files, *train_files = shard_filenames - # shuffle the training files - random.Random(self.seed).shuffle(train_files) - self.train_dataset = ConcatDataset([PretokDataset(f, self.max_seq_length) for f in train_files]) - self.val_dataset = PretokDataset(shard_filenames[0], self.max_seq_length) + val_file, *train_files = files + num_workers = os.cpu_count() - 1 + + if not Path(self.data_path_train).is_dir(): + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=train_files, + output_dir=str(self.data_path_train), + num_workers=num_workers, + chunk_bytes="200MB", + ) + if not Path(self.data_path_val).is_dir(): + optimize( + fn=partial(tokenize, tokenizer=self.tokenizer), + inputs=[val_file], + output_dir=str(self.data_path_val), + num_workers=1, # there's only 1 file + chunk_bytes="200MB", + ) def train_dataloader(self) -> DataLoader: - return DataLoader( - self.train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True, num_workers=self.num_workers + from litdata.streaming import StreamingDataLoader, StreamingDataset, TokensLoader + + train_dataset = StreamingDataset( + input_dir=str(self.data_path_train), + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + drop_last=True, + ) + train_dataloader = StreamingDataLoader( + train_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True ) + return train_dataloader def val_dataloader(self) -> DataLoader: - return DataLoader( - self.val_dataset, - batch_size=self.batch_size, - pin_memory=True, - shuffle=True, # llama2.c shuffles validation too - num_workers=self.num_workers, + from litdata.streaming import StreamingDataset, TokensLoader + + val_dataset = StreamingDataset( + input_dir=str(self.data_path_val), + item_loader=TokensLoader(block_size=self.max_seq_length), + shuffle=True, + # Consider setting to False, but we would lose some samples due to truncation when world size > 1 + drop_last=True, ) + val_dataloader = DataLoader( + val_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True + ) + return val_dataloader + + +def tokenize(filename: str, tokenizer: Tokenizer): + with open(filename, "r") as f: + data = json.load(f) + global_rank = int(os.environ["DATA_OPTIMIZER_GLOBAL_RANK"]) + num_workers = int(os.environ["DATA_OPTIMIZER_NUM_WORKERS"]) + local_rank = global_rank % num_workers + for example in tqdm(data, position=local_rank): + text = example["story"] + text = text.strip() # get rid of leading/trailing whitespace + tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS + yield tokens _URL = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" def download(data_dir: Path): - data_dir.mkdir(exist_ok=True) - - # download the TinyStories dataset, unless it's already downloaded - data_filename = data_dir / "TinyStories_all_data.tar.gz" - download_if_missing(data_filename, _URL, stream=True, mode="wb") - print("Download done.") + data_dir.mkdir(exist_ok=True, parents=True) - # unpack the tar.gz file into all the data shards (json files) + data_tar = data_dir / "TinyStories_all_data.tar.gz" data_dir = data_dir / "TinyStories_all_data" shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) if shard_filenames: print(f"{data_dir} already exists, skipping unpacking...") - else: - data_dir.mkdir(exist_ok=True) - print(f"Unpacking {data_filename}...") - os.system(f"tar -xzf {data_filename} -C {data_dir}") - shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) - - print(f"Number of shards: {len(shard_filenames)}") - # print a single example just for debugging and such - # with open(shard_filenames[0], "r") as f: - # data = json.load(f) - # print(f"Example story:\n{data[0]}") + return + # download the TinyStories dataset, unless it's already downloaded + download_if_missing(data_tar, _URL, stream=True, mode="wb") -def process_shard(args, tokenizer): - shard_id, shard = args - with open(shard, "r") as f: - data = json.load(f) - all_tokens = [] - for example in tqdm(data, position=shard_id): - text = example["story"] - text = text.strip() # get rid of leading/trailing whitespace - tokens = tokenizer.encode(text, bos=True, eos=False) # encode the text, use BOS - all_tokens.extend(tokens) - # convert to uint16 nparray - all_tokens = np.array(all_tokens, dtype=np.uint16) - # just save the tokenized file in the same dir - tokenized_filename = shard.replace(".json", ".bin") - # write the bytes - with open(tokenized_filename, "wb") as f: - f.write(all_tokens.tobytes()) - # calculate the average sequence length (they are separated by BOS=1) - bos_id = tokenizer.bos_id - assert bos_id >= 0 # uint16 is unsigned - bos_tokens = (all_tokens == tokenizer.bos_id).sum() - assert bos_tokens > 0 - avg_seq_len = all_tokens.size / bos_tokens - print( - f"Saved {tokenized_filename}, tokens: {all_tokens.size}, bos: {bos_tokens}, average seqlen: {avg_seq_len:.2f}" - ) - - -def pretokenize(data_dir: Path, tokenizer: Tokenizer): - bins_path = str(data_dir / "TinyStories_all_data" / "*.bin") - shard_filenames = sorted(glob.glob(bins_path)) - if shard_filenames: - print("Already pretokenized.") - return - # iterate the shards and tokenize all of them one by one - jsons_path = str(data_dir / "TinyStories_all_data" / "*.json") - shard_filenames = sorted(glob.glob(jsons_path)) - if not shard_filenames: - raise ValueError(f"No json files found in {jsons_path!r}. Did you run `python tinystories.py download`?") - # process all the shards in a process pool - fun = partial(process_shard, tokenizer=tokenizer) - with ProcessPoolExecutor() as executor: - executor.map(fun, enumerate(shard_filenames)) - print("Done.") - - -class PretokDataset(torch.utils.data.Dataset): - """Loads a pre-tokenized array from disk and returns chunks of `max_seq_length` length.""" - - def __init__(self, filepath: str, max_seq_len: int): - super().__init__() - self.filepath = filepath - # open the dataset for reading but keep it on disk with memmap - self.shard = np.memmap(filepath, dtype=np.uint16, mode="r") - self.shard_length = len(self.shard) - self.length = self.shard_length // max_seq_len - assert max_seq_len > 1 - self.max_seq_len = max_seq_len - - def __len__(self) -> int: - return self.length - - def __getitem__(self, ix: int) -> torch.Tensor: - if ix < 0: - raise NotImplementedError - start = ix * self.max_seq_len - end = start + self.max_seq_len + 1 - if end > self.shard_length: - raise IndexError - # calling .astype will copy the data into a new numpy array, now in RAM - chunk = torch.from_numpy((self.shard[start:end]).astype(np.int64)) - return chunk + # unpack the tar.gz file into all the data shards (json files) + data_dir.mkdir(exist_ok=False) + tar_command = f"tar -xzf {data_tar} -C {data_dir}" + print(tar_command) + os.system(tar_command) + shard_filenames = sorted(glob.glob(str(data_dir / "*.json"))) + print(f"Number of shards: {len(shard_filenames)}") diff --git a/litgpt/finetune/adapter.py b/litgpt/finetune/adapter.py index a88fd8e2c3..9326793e2b 100644 --- a/litgpt/finetune/adapter.py +++ b/litgpt/finetune/adapter.py @@ -46,8 +46,8 @@ def setup( train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, - global_batch_size=128, - micro_batch_size=4, + global_batch_size=16, + micro_batch_size=1, lr_warmup_steps=100, epochs=5, learning_rate=1e-3, @@ -75,6 +75,8 @@ def setup( pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) + + check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) @@ -120,7 +122,6 @@ def main( eval: EvalArgs, ) -> None: validate_args(train, eval) - check_valid_checkpoint_dir(checkpoint_dir) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) @@ -365,7 +366,7 @@ def save_adapter_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] - unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings"])] + unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index 2b29ecc228..3c4634e354 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -46,8 +46,8 @@ def setup( train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, - global_batch_size=128, - micro_batch_size=4, + global_batch_size=16, + micro_batch_size=1, lr_warmup_steps=100, epochs=5, learning_rate=1e-3, @@ -75,6 +75,8 @@ def setup( pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) + + check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) @@ -120,7 +122,6 @@ def main( eval: EvalArgs, ) -> None: validate_args(train, eval) - check_valid_checkpoint_dir(checkpoint_dir) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) @@ -365,7 +366,7 @@ def save_adapter_v2_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_pa def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] - unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings"])] + unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 086756e6b5..3a2e2a7176 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -44,7 +44,7 @@ def setup( train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, - global_batch_size=64, + global_batch_size=16, micro_batch_size=1, lr_warmup_steps=100, epochs=5, @@ -74,6 +74,8 @@ def setup( pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) + + check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file(checkpoint_dir / "model_config.yaml") precision = precision or get_default_supported_precision(training=True) @@ -109,7 +111,6 @@ def main( eval: EvalArgs, ) -> None: validate_args(train, eval) - check_valid_checkpoint_dir(checkpoint_dir) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) @@ -340,7 +341,7 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] - unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings"])] + unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index 22f6b54ef9..bb60b2d180 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -56,8 +56,8 @@ def setup( train: TrainArgs = TrainArgs( save_interval=1000, log_interval=1, - global_batch_size=128, - micro_batch_size=4, + global_batch_size=16, + micro_batch_size=1, lr_warmup_steps=100, epochs=5, learning_rate=3e-4, @@ -94,6 +94,8 @@ def setup( pprint(locals()) data = Alpaca() if data is None else data devices = parse_devices(devices) + + check_valid_checkpoint_dir(checkpoint_dir) config = Config.from_file( checkpoint_dir / "model_config.yaml", lora_r=lora_r, @@ -150,7 +152,6 @@ def main( eval: EvalArgs, ) -> None: validate_args(train, eval) - check_valid_checkpoint_dir(checkpoint_dir) tokenizer = Tokenizer(checkpoint_dir) train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train) @@ -396,7 +397,7 @@ def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Pa def validate_args(train: TrainArgs, eval: EvalArgs) -> None: issues = [] - unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings"])] + unsupported = [(train, ["max_tokens", "max_norm", "tie_embeddings", "lr_warmup_fraction"])] for args, names in unsupported: for name in names: if getattr(args, name) is not None: diff --git a/litgpt/generate/adapter.py b/litgpt/generate/adapter.py index e6f35be4bc..104b3e20b0 100644 --- a/litgpt/generate/adapter.py +++ b/litgpt/generate/adapter.py @@ -60,7 +60,6 @@ def main( fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" diff --git a/litgpt/generate/adapter_v2.py b/litgpt/generate/adapter_v2.py index b8e6eff0c8..c7aeee8a91 100644 --- a/litgpt/generate/adapter_v2.py +++ b/litgpt/generate/adapter_v2.py @@ -60,7 +60,6 @@ def main( fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index d76a5bbb84..6488717429 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -133,7 +133,6 @@ def main( fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" diff --git a/litgpt/generate/full.py b/litgpt/generate/full.py index e602e6ef51..608115a5e1 100644 --- a/litgpt/generate/full.py +++ b/litgpt/generate/full.py @@ -59,7 +59,6 @@ def main( fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = finetuned_path diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index cce1d8d3c9..f804c4cffc 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -158,7 +158,6 @@ def main( print(f"Using {total_devices} devices", file=sys.stderr) check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") checkpoint_path = checkpoint_dir / "lit_model.pth" diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index 3c6c8daf2c..5c56dd1c09 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -137,7 +137,6 @@ def main( fabric.launch() check_valid_checkpoint_dir(checkpoint_dir) - config = Config.from_file(checkpoint_dir / "model_config.yaml") model_file = "lit_model.pth" diff --git a/litgpt/lora.py b/litgpt/lora.py index fd54d6f771..69e83fcbc6 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -122,8 +122,8 @@ def __init__( # Actual trainable parameters if r > 0: - self.lora_A = nn.Parameter(torch.zeros((r, in_features))) - self.lora_B = nn.Parameter(torch.zeros((out_features, r))) + self.lora_A = nn.Parameter(torch.empty((r, in_features))) + self.lora_B = nn.Parameter(torch.empty((out_features, r))) self.scaling = self.lora_alpha / self.r self.reset_parameters() @@ -144,11 +144,8 @@ def merge(self) -> None: if self.r > 0 and not self.merged: pretrained_dtype = self.linear.weight.data.dtype lora_data = self.get_lora_AB() - # if the pretrained weights and LoRA weights are of the same dtype - simply sum them - if pretrained_dtype == lora_data.dtype: - self.linear.weight.data += lora_data # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result - elif pretrained_dtype == torch.uint8: + if pretrained_dtype == torch.uint8: import bitsandbytes as bnb weight = self.linear.weight @@ -159,6 +156,10 @@ def merge(self) -> None: # assign updated weights and quantize by moving to CUDA device self.linear.weight = bnb.nn.Params4bit(weight_data, requires_grad=False, **weight.__dict__) self.linear.weight.cuda(weight.device) + # if the pretrained weights and LoRA weights are of compatible dtypes - simply sum them + elif torch.finfo(pretrained_dtype).max >= torch.finfo(lora_data.dtype).max: + # self.linear might be on CPU and lora_data on CUDA + self.linear.weight.data += lora_data.to(device=self.linear.weight.data.device) else: raise NotImplementedError( f"Cannot merge the pretrained weights of type {pretrained_dtype}" @@ -185,6 +186,7 @@ def __init__( in_features: int, out_features: int, # ↓ the remaining part is for LoRA + head_size: int, n_head: int, n_query_groups: int, r: int = 0, @@ -204,6 +206,7 @@ def __init__( Args: in_features: number of input features of the pretrained weights out_features: number of output features of the pretrained weights + head_size: size of a single attention head n_head: number of attention heads n_query_groups: number of query groups (see diagram in `litgpt/config.py`) r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of @@ -232,17 +235,18 @@ def __init__( # ⚬ r: 2 # ⚬ enable_lora: [True, False, True] if r > 0 and any(enable_lora): - self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) # (4, 128) + self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128) enable_q, enable_k, enable_v = enable_lora - self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) # qkv_shapes will be used to split a tensor with weights correctly qkv_shapes = ( - self.linear.in_features * enable_q, - self.kv_embd_size * enable_k, - self.kv_embd_size * enable_v, + # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`) + # might not be equal to `head_size * n_head`, thus we use it directly here + head_size * n_head * enable_q, + head_size * n_query_groups * enable_k, + head_size * n_query_groups * enable_v, ) self.qkv_shapes = [s for s in qkv_shapes if s] - self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) # (256, 2)) + self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2)) # Notes about shapes above # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in @@ -541,6 +545,8 @@ def forward( mask = None x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) @@ -594,6 +600,7 @@ def __init__(self, config: Config) -> None: enable_lora=(config.lora_query, config.lora_key, config.lora_value), bias=config.bias, # for MQA/GQA support + head_size=config.head_size, n_head=config.n_head, n_query_groups=config.n_query_groups, ) @@ -686,6 +693,8 @@ def __init__(self, config: Config) -> None: lora_dropout=config.lora_dropout, ) + self.config = config + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base checkpoints.""" mapping = { @@ -704,7 +713,7 @@ class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) diff --git a/litgpt/model.py b/litgpt/model.py index f2626b0e88..fe71c60b80 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -139,6 +139,12 @@ def clear_kv_cache(self) -> None: class Block(nn.Module): def __init__(self, config: Config) -> None: super().__init__() + if not config.parallel_residual and config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) self.attn = CausalSelfAttention(config) self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) @@ -154,18 +160,30 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - n_1 = self.norm_1(x) - h = self.attn(n_1, cos, sin, mask, input_pos) + """ + Non-parallel residual Parallel residual + ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True, + │ ↓ │ ↓ ↓ the output from `norm_1` is reused + │ norm_1 │ norm_1 ───► norm_2 + │ ↓ │ ↓ ↓ + │ attn │ attn mlp + │ ↓ │ ↓ │ + ┌─ └► + └► + ◄───────────┘ + │ norm_2 + │ ↓ + │ mlp + │ ↓ + └───► + + """ + + x_normed = self.norm_1(x) + attention_output = self.attn(x_normed, cos, sin, mask, input_pos) + if self.config.parallel_residual: - n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) - x = self.mlp(n_2) + h + x + x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(x_normed) + attention_output + x else: - if self.config.shared_attention_norm: - raise NotImplementedError( - "No checkpoint amongst the ones we support uses this configuration" - " (non-parallel residual and shared attention norm)." - ) - x = h + x + x = attention_output + x x = self.mlp(self.norm_2(x)) + x return x diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index b6f8560861..eef48bdc0c 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -20,6 +20,7 @@ from litgpt import Tokenizer from litgpt.args import EvalArgs, TrainArgs +from litgpt.config import name_to_config from litgpt.data import DataModule, TinyLlama from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP from litgpt.utils import ( @@ -91,7 +92,8 @@ def setup( if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") elif model_config is None and model_name is None: - model_name = "tiny-llama-1.1b" + available_models = "\n".join(sorted(name_to_config)) + raise ValueError(f"Please specify --model_name . Available values:\n{available_models}") config = Config.from_name(model_name) if model_config is None else model_config devices = parse_devices(devices) out_dir = init_out_dir(out_dir) @@ -242,7 +244,8 @@ def fit( total_t0 = time.perf_counter() val_loss = "n/a" - warmup_iters = train.lr_warmup_steps * train.gradient_accumulation_iters(devices) + warmup_iters = train.warmup_iters(devices, max_iters, train_dataloader) + for train_data in train_iterator: if state["iter_num"] >= max_iters: break diff --git a/litgpt/scripts/merge_lora.py b/litgpt/scripts/merge_lora.py index 6f2e5ea588..2bedfa743e 100644 --- a/litgpt/scripts/merge_lora.py +++ b/litgpt/scripts/merge_lora.py @@ -9,7 +9,7 @@ import yaml from litgpt.lora import GPT, Config, lora_filter, merge_lora_weights -from litgpt.utils import CLI, check_valid_checkpoint_dir, lazy_load +from litgpt.utils import CLI, check_valid_checkpoint_dir def merge_lora( @@ -22,7 +22,7 @@ def merge_lora( Args: checkpoint_dir: Path to the checkpoint directory with trained LoRA weights, which is the output of - ``litgpt finetune --method lora``. + ``litgpt finetune lora``. pretrained_checkpoint_dir: Optional path to the checkpoint directory with the weights of the base model corresponding to the LoRA checkpoint. By default, this will automatically be inferred from the metadata in the given `checkpoint_dir` directory. Only set this if the base model's checkpoint directory @@ -30,7 +30,7 @@ def merge_lora( precision: Optional precision setting to instantiate the model weights in. By default, this will automatically be inferred from the metadata in the given ``checkpoint_dir`` directory. """ - check_valid_checkpoint_dir(checkpoint_dir, lora=True) + check_valid_checkpoint_dir(checkpoint_dir, model_filename="lit_model.pth.lora") if pretrained_checkpoint_dir is not None: check_valid_checkpoint_dir(pretrained_checkpoint_dir) if (checkpoint_dir / "lit_model.pth").is_file(): @@ -43,16 +43,16 @@ def merge_lora( fabric = L.Fabric(devices=1, precision=precision, accelerator="cpu") config = Config.from_file(checkpoint_dir / "model_config.yaml", **lora_params) - with fabric.init_module(empty_init=True): + with fabric.init_module(), torch.device("meta"): model = GPT(config) lora_path = checkpoint_dir / "lit_model.pth.lora" - pretrained_checkpoint = lazy_load(pretrained_checkpoint_dir / "lit_model.pth") - lora_checkpoint = lazy_load(lora_path) + pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / "lit_model.pth"), mmap=True) + lora_checkpoint = torch.load(str(lora_path), mmap=True) # Merge LoRA weights into the base model pretrained_checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) - model.load_state_dict(pretrained_checkpoint) + model.load_state_dict(pretrained_checkpoint, assign=True) merge_lora_weights(model) # Remove LoRA parameters and the LoRA linear substring diff --git a/litgpt/utils.py b/litgpt/utils.py index 3f145da105..fb6a86c107 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -52,8 +52,7 @@ def reset_parameters(module: nn.Module) -> None: mod.reset_parameters() -def check_valid_checkpoint_dir(checkpoint_dir: Path, lora: bool = False) -> None: - model_filename = "lit_model.pth.lora" if lora else "lit_model.pth" +def check_valid_checkpoint_dir(checkpoint_dir: Path, model_filename: str = "lit_model.pth") -> None: files = { model_filename: (checkpoint_dir / model_filename).is_file(), "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(), diff --git a/pyproject.toml b/pyproject.toml index 66b4b8f4a9..9aec513f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ license = { file = "LICENSE" } dependencies = [ "torch>=2.2.0", - "lightning==2.3.0.dev20240318", + "lightning==2.3.0.dev20240324", "jsonargparse[signatures]>=4.27.6", ] diff --git a/tests/data/fixtures/alpaca.json b/tests/data/fixtures/alpaca.json index 10f50be6cb..a2f3b56649 100644 --- a/tests/data/fixtures/alpaca.json +++ b/tests/data/fixtures/alpaca.json @@ -114,4 +114,4 @@ "input": "Recording about the history of Waterloo", "output": "The Battle of Waterloo, fought on June 18, 1815, was the last major battle of the Napoleonic Wars. It marked the end of Napoleon Bonaparte's rule. The battle was fought between Napoleon Bonaparte's French forces and an allied force of British, Dutch, Belgian and Prussian troops, led by the Duke of Wellington. It was fought near the town of Waterloo in the former Austrian Netherlands, now present-day Belgium. The battle was a defeat for Napoleon's forces and ended Napoleon's rule over France. Napoleon himself was exiled to St. Helena, where he would die in 1821. The battle marked the end of French domination of Europe and the beginning of a period of peace in the region. It was also the last major battle of the Napoleonic Wars and is remembered as one of the most important battles in history. The victory of the allied forces at Waterloo helped to bring an end to French imperial ambitions in Europe. It also established the British as one of the most powerful nations in Europe. The battle was fought with a combination of infantry, cavalry, and artillery tactics, and showed the beginning of the changing nature of warfare. Although the French forces greatly outnumbered the allies, their strategic and tactical mistake led to the loss of the battle. This defeat signaled the end of French imperial power in Europe. The legacy of Waterloo still stands today, and it cemented the British Empire's position for the next one hundred years." } -] \ No newline at end of file +] diff --git a/tests/data/fixtures/longform_train.json b/tests/data/fixtures/longform_train.json index 457fb8173a..7c67fb3972 100644 --- a/tests/data/fixtures/longform_train.json +++ b/tests/data/fixtures/longform_train.json @@ -101,4 +101,4 @@ "source": "Wikipedia", "subset": "chatbot" } -] \ No newline at end of file +] diff --git a/tests/data/fixtures/longform_val.json b/tests/data/fixtures/longform_val.json index dd2f94e1ec..f08c02f3e7 100644 --- a/tests/data/fixtures/longform_val.json +++ b/tests/data/fixtures/longform_val.json @@ -53,4 +53,4 @@ "source": "Natural Instructions", "subset": "task1658 billsum summarization" } -] \ No newline at end of file +] diff --git a/tests/data/test_tinystories.py b/tests/data/test_tinystories.py index 8e9a94dd0a..bfb009c1a9 100644 --- a/tests/data/test_tinystories.py +++ b/tests/data/test_tinystories.py @@ -1,29 +1,26 @@ import json -from contextlib import redirect_stdout -from io import StringIO -import numpy as np import pytest import torch +from litdata import optimize +from litdata.streaming import StreamingDataset, TokensLoader from torch.utils._pytree import tree_map -from torch.utils.data import ConcatDataset -from litgpt.data.tinystories import PretokDataset, TinyStories, process_shard +def tokenize(data): + for story in data: + yield torch.tensor(story) -def fake_bin(tmp_path, data, name): - all_tokens = np.array(data, dtype=np.uint16) - data_path = tmp_path / f"{name}.bin" - with open(data_path, "wb") as f: - f.write(all_tokens.tobytes()) - return data_path + +def fake_chunk(path, data): + optimize(fn=tokenize, inputs=[data] * len(data), output_dir=str(path), num_workers=1, chunk_bytes="200MB") @pytest.mark.parametrize( ("max_seq_len", "expected"), [ - (2, [[0, 23, 15], [15, 63, 0], [0, 73, 5], [5, 0, 1], [1, 1999, 0]]), - (5, [[0, 23, 15, 63, 0, 73], [73, 5, 0, 1, 1999, 0]]), + (2, [[0, 23, 15], [63, 0, 73], [5, 0, 1], [1999, 0, 13]]), + (5, [[0, 23, 15, 63, 0, 73], [5, 0, 1, 1999, 0, 13]]), (6, [[0, 23, 15, 63, 0, 73, 5]]), (7, [[0, 23, 15, 63, 0, 73, 5, 0]]), ], @@ -31,14 +28,18 @@ def fake_bin(tmp_path, data, name): def test_pretok_dataset(tmp_path, max_seq_len, expected): fake_data = [0, 23, 15, 63, 0, 73, 5, 0, 1, 1999, 0, 13] assert len(fake_data) == 12 - bin_path = fake_bin(tmp_path, fake_data, "data") + fake_chunk(tmp_path, [fake_data]) - dataset = PretokDataset(str(bin_path), max_seq_len) + dataset = StreamingDataset( + input_dir=str(tmp_path), item_loader=TokensLoader(block_size=max_seq_len + 1), shuffle=False, drop_last=False + ) actual = tree_map(torch.Tensor.tolist, list(dataset)) assert actual == expected -def test_process_shard(tmp_path): +def test_tokenize(tmp_path, monkeypatch): + from litgpt.data.tinystories import tokenize + story1, story2 = "foo bar", " fun " data = [{"story": story1}, {"story": story2}] shard_path = tmp_path / "data.json" @@ -53,38 +54,43 @@ def encode(self, text, bos, eos): assert not eos return [self.bos_id] + [ord(c) for c in text] - out = StringIO() - with redirect_stdout(out): - process_shard((0, str(shard_path)), Tokenizer()) - - text = out.getvalue() - assert text.endswith("data.bin, tokens: 12, bos: 2, average seqlen: 6.00\n") - assert shard_path.with_suffix(".bin").exists() + monkeypatch.setenv("DATA_OPTIMIZER_GLOBAL_RANK", "0") + monkeypatch.setenv("DATA_OPTIMIZER_NUM_WORKERS", "1") + data = tokenize(str(shard_path), Tokenizer()) + assert list(data) == [[0, 102, 111, 111, 32, 98, 97, 114], [0, 102, 117, 110]] def test_tinystories_datamodule(tmp_path): - datamodule = TinyStories(tmp_path, seed=42) - datamodule.connect(max_seq_length=2) + from litgpt.data.tinystories import TinyStories - data_dir = tmp_path / "TinyStories_all_data" - data_dir.mkdir() - fake_bin(data_dir, [12], "0") - fake_bin(data_dir, [0, 23, 15, 63, 0], "1") - fake_bin(data_dir, [73, 5, 0, 1, 1999, 0, 13], "2") + data_dir = tmp_path / "tinystories" - datamodule.setup() + datamodule = TinyStories(data_dir, seed=42) + datamodule.connect(max_seq_length=2) - assert isinstance(datamodule.train_dataset, ConcatDataset) - assert len(datamodule.train_dataset.datasets) == 2 - assert isinstance(datamodule.train_dataset.datasets[0], PretokDataset) - # unordered because it shuffled - assert datamodule.train_dataset.datasets[0].filepath == str(data_dir / "2.bin") - assert datamodule.train_dataset.datasets[1].filepath == str(data_dir / "1.bin") + # simulate `datamodule.prepare_data` + train_data_dir = data_dir / "train" + train_data_dir.mkdir(parents=True) + fake_chunk(train_data_dir, [[12], [0, 23, 15, 63, 0], [73, 5, 0, 1, 1999, 0, 13]]) - assert isinstance(datamodule.val_dataset, PretokDataset) - assert datamodule.val_dataset.filepath == str(data_dir / "0.bin") + datamodule.setup() tr_dataloader = datamodule.train_dataloader() torch.manual_seed(0) actual = tree_map(torch.Tensor.tolist, list(tr_dataloader)) - assert actual == [[[0, 1, 1999]], [[15, 63, 0]], [[1999, 0, 13]], [[0, 23, 15]], [[73, 5, 0]]] + # there is 1 sample per index in the data (13) + assert actual == [ + [[1999, 0, 13]], + [[0, 13, 12]], + [[1, 1999, 0]], + [[63, 0, 73]], + [[5, 0, 1]], + [[0, 73, 5]], + [[0, 23, 15]], + [[0, 1, 1999]], + [[15, 63, 0]], + [[73, 5, 0]], + [[12, 0, 23]], + [[23, 15, 63]], + [[13, 12, 0]], + ] diff --git a/tests/test_adapter.py b/tests/test_adapter.py index ab1a918ec9..cb9ac7b019 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -14,6 +14,7 @@ from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging +from transformers.models.gemma import GemmaConfig, GemmaForCausalLM import litgpt.adapter as gpt_adapter import litgpt.finetune.adapter as module @@ -21,6 +22,7 @@ from litgpt.adapter import GPT, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca +from litgpt.scripts.convert_hf_checkpoint import copy_weights_hf_llama def test_config_identical(): @@ -232,3 +234,44 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca logs = stdout.getvalue() assert "of trainable parameters: 168" in logs assert "of non-trainable parameters: 1,888" in logs + + +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) +def test_against_hf_gemma(model_name): + device = torch.device("cpu") + dtype = torch.float32 + T = 5 + ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + theirs_config = GemmaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = GemmaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + # Gemma weights are shipped without `lm_head.weight` + theirs_state_dict.pop("lm_head.weight") + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 716428bf24..67f0689c05 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -13,12 +13,13 @@ from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging +from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import litgpt.config as config_module import litgpt.finetune.adapter_v2 as module -from litgpt.adapter_v2 import GPT, Config, adapter_filter from litgpt.adapter_v2 import GPT as AdapterV2GPT +from litgpt.adapter_v2 import Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.model import GPT as BaseGPT @@ -195,7 +196,7 @@ def test_against_hf_mixtral(): theirs_state_dict = theirs_model.state_dict() state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) - ours_model = GPT(ours_config).to(device) + ours_model = AdapterV2GPT(ours_config).to(device) # strict=False because missing keys due to adapter weights not contained in state dict ours_model.load_state_dict(state_dict, strict=False) @@ -207,6 +208,47 @@ def test_against_hf_mixtral(): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) +def test_against_hf_gemma(model_name): + device = torch.device("cpu") + dtype = torch.float32 + T = 5 + ours_config = Config.from_name(model_name, n_layer=2, n_head=16, n_embd=32, intermediate_size=86) + theirs_config = GemmaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = GemmaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + # Gemma weights are shipped without `lm_head.weight` + theirs_state_dict.pop("lm_head.weight") + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = AdapterV2GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict, strict=False) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(min_cuda_gpus=1) def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): if not _BITSANDBYTES_AVAILABLE: diff --git a/tests/test_args.py b/tests/test_args.py new file mode 100644 index 0000000000..0b13c83976 --- /dev/null +++ b/tests/test_args.py @@ -0,0 +1,36 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import pytest + +from litgpt.args import TrainArgs + + +def test_compute_warmup_iters(): + # warmup disabled + train = TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=0) + assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 0 + + # lr_warmup_steps and lr_warmup_fraction both are not allowed + with pytest.raises(ValueError, match="Can't provide both `--train.lr_warmup_fraction`"): + TrainArgs(lr_warmup_steps=1, lr_warmup_fraction=0.2) + + # lr_warmup_fraction invalid range + with pytest.raises(ValueError, match=" must be between 0 and 1"): + TrainArgs(lr_warmup_steps=0, lr_warmup_fraction=1.1) + + # lr_warmup_steps + train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=0) + assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 100 + # lr_warmup_steps multiplied by accumulation factor + train.global_batch_size = 4 + assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(10)) == 400 + assert train.warmup_iters(devices=2, max_iters=1000, train_dataloader=range(10)) == 200 + # lr_warmup_steps truncated by max iters + assert train.warmup_iters(devices=1, max_iters=120, train_dataloader=range(10)) == 120 + + # lr_warmup_fraction + train = TrainArgs(global_batch_size=1, micro_batch_size=1, lr_warmup_steps=0, lr_warmup_fraction=0.3) + assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(100)) == 30 + # lr_warmup_fraction truncated by max iters + assert train.warmup_iters(devices=1, max_iters=20, train_dataloader=range(100)) == 20 + # lr_warmup_fraction rounds up + assert train.warmup_iters(devices=1, max_iters=1000, train_dataloader=range(5)) == 2 diff --git a/tests/test_config_hub.py b/tests/test_config_hub.py index 163a6531dc..4ad634ca9b 100644 --- a/tests/test_config_hub.py +++ b/tests/test_config_hub.py @@ -36,7 +36,7 @@ @pytest.mark.parametrize(("script_file", "config_file"), all_pairs) -def test_config_help(script_file, config_file): +def test_config_help(script_file, config_file, monkeypatch): """Test that configs validate against the signature in the scripts.""" script_file = Path(__file__).parent.parent / script_file assert script_file.is_file() @@ -48,10 +48,11 @@ def test_config_help(script_file, config_file): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - module.main = Mock() - module.Tokenizer = Mock() - module.BitsandbytesPrecision = Mock(return_value=Precision()) - module.Config = Mock(return_value=Config.from_name("pythia-14m")) + monkeypatch.setattr(module, "main", Mock()) + monkeypatch.setattr(module, "Tokenizer", Mock()) + monkeypatch.setattr(module, "BitsandbytesPrecision", Mock(return_value=Precision()), raising=False) + monkeypatch.setattr(module, "Config", Mock(return_value=Config.from_name("pythia-14m"))) + monkeypatch.setattr(module, "check_valid_checkpoint_dir", Mock(), raising=False) with mock.patch("sys.argv", [script_file.name, "--config", str(config_file), "--devices", "1"]): CLI(module.setup) diff --git a/tests/test_lora.py b/tests/test_lora.py index c45a4f0bf8..3a6eeb8de3 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -15,6 +15,7 @@ from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo.backends import debugging from torch.nn import functional as F +from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM import litgpt.config as config_module @@ -233,7 +234,7 @@ def __init__(self, *args, **kwargs): original_linear = torch.nn.Linear # Our bnb does this sort of monkey patching torch.nn.Linear = MyLinear - layer = LoRAQKVLinear(1, 1, 1, 1) + layer = LoRAQKVLinear(1, 1, 1, 1, 1) assert isinstance(layer.linear, original_linear) torch.nn.Linear = original_linear @@ -323,6 +324,7 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap @torch.inference_mode() +@pytest.mark.parametrize("head_size", (1, 2, 4)) @pytest.mark.parametrize("n_head", (1, 2, 3, 6, 12)) @pytest.mark.parametrize( "enable_lora", @@ -336,9 +338,11 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap (True, True, True), ], ) -def test_lora_qkv_linear_compare_conv1d(n_head, enable_lora): +def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora): C = 12 - layer = LoRAQKVLinear(C, 3 * C, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora) + layer = LoRAQKVLinear( + C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora + ) x = torch.randn((1, 1, C)) a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A b = layer.lora_B.data.unsqueeze(-1) @@ -370,7 +374,8 @@ def test_lora_linear_weights_merged_status(rank, expected_merged): ((0, True, False), (1, True, True), (0, False, False), (1, False, False)), ) def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged): - layer = LoRAQKVLinear(10, 3 * 10, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora) + C = 10 + layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora) assert not layer.merged layer.merge() assert layer.merged == expected_merged @@ -523,6 +528,9 @@ def test_against_hf_mixtral(): n_query_groups=2, intermediate_size=86, n_expert=4, + lora_r=1, + lora_key=True, + lora_value=True, ) T = 5 theirs_config = MixtralConfig( @@ -544,7 +552,10 @@ def test_against_hf_mixtral(): state_dict = {} copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) - ours_model.load_state_dict(state_dict) + keys = ours_model.load_state_dict(state_dict, strict=False) + assert not keys.unexpected_keys + for k in keys.missing_keys: + assert lora_filter(k, None) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304], [23, 345, 65, 123, 321]], dtype=torch.int32, device=device) @@ -554,6 +565,61 @@ def test_against_hf_mixtral(): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ["gemma-2b", "gemma-7b"]) +def test_against_hf_gemma(model_name): + device = torch.device("cpu") + dtype = torch.float32 + T = 5 + ours_config = Config.from_name( + model_name, + n_layer=2, + n_head=16, + n_embd=32, + head_size=4, + intermediate_size=86, + lora_r=1, + lora_query=True, + lora_key=True, + lora_value=True, + ) + theirs_config = GemmaConfig( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + max_position_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + attention_bias=ours_config.bias, + tie_word_embeddings=True, + hidden_act="gelu_pytorch_tanh", + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = GemmaForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + # Gemma weights are shipped without `lm_head.weight` + theirs_state_dict.pop("lm_head.weight") + state_dict = {} + copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict) + ours_model = LoRAGPT(ours_config).to(device) + keys = ours_model.load_state_dict(state_dict, strict=False) + assert not keys.unexpected_keys + for k in keys.missing_keys: + assert lora_filter(k, None) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @RunIf(min_cuda_gpus=1) def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_path): if not _BITSANDBYTES_AVAILABLE: diff --git a/tests/test_thunder_ddp.py b/tests/test_thunder_ddp.py index 5ccc853eea..566e883ac3 100644 --- a/tests/test_thunder_ddp.py +++ b/tests/test_thunder_ddp.py @@ -10,13 +10,19 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy + + +@RunIf(thunder=True) +def test_thunder_strategy_input_parsing(): + with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): + ThunderDDPStrategy(jit=False, executors=("python",)) + @RunIf(min_cuda_gpus=2, thunder=True, standalone=True) @pytest.mark.parametrize("strategy", ["ddp", "thunder_ddp"]) def test_no_backward_sync(strategy): if strategy == "thunder_ddp": - from extensions.thunder.strategies.thunder_ddp import ThunderDDPStrategy - strategy = ThunderDDPStrategy() fabric = Fabric(devices=2, accelerator="cuda", strategy=strategy) @@ -47,3 +53,37 @@ def test_no_backward_sync(strategy): # rank0 rank1 allreduce1 rank0 rank1 allreduce2 assert model.weight.grad.item() == (9.0 if i == 3 else 22.5) model.weight.grad = None + + +@RunIf(min_cuda_gpus=2, thunder=True, standalone=True) +@pytest.mark.parametrize("jit", (False, True)) +def test_jit_before_setup(jit): + import thunder + + fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderDDPStrategy(jit=jit)) + fabric.launch() + + x = torch.randn(1, 1, device=fabric.device) + model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) + + tmodel = thunder.jit(model) + fmodel = fabric.setup(tmodel) + fmodel(x) + + assert "all_reduce" in thunder.last_backward_traces(tmodel)[-1].python() + + +@RunIf(min_cuda_gpus=1, thunder=True) +def test_setup_already_traced(): + import thunder + + device = torch.device("cuda") + x = torch.randn(1, 1, device=device) + model = torch.nn.Linear(1, 2, bias=False, device=device) + + strategy = ThunderDDPStrategy() + + tmodel = thunder.jit(model) + tmodel(x) + with pytest.raises(RuntimeError, match="already called"): + strategy.setup_module(tmodel) diff --git a/tests/test_thunder_fsdp.py b/tests/test_thunder_fsdp.py index fed938aba6..8b9c0f4340 100644 --- a/tests/test_thunder_fsdp.py +++ b/tests/test_thunder_fsdp.py @@ -6,11 +6,10 @@ import pytest import torch +from conftest import RunIf from lightning.fabric import Fabric from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 -from conftest import RunIf - # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) @@ -29,6 +28,9 @@ def test_thunder_strategy_input_parsing(): assert strategy.executors == (pythonex,) assert strategy.sharding_strategy is FSDPType.ZERO3 + with pytest.raises(ValueError, match="doesn't have an effect with `jit=False"): + ThunderFSDPStrategy(jit=False, executors=("python",)) + @RunIf(thunder=True) def test_validate_executors(): @@ -310,3 +312,37 @@ def test_save_load_sharded_checkpoint(tmp_path): actual["buf"] = actual["buf"].to(device="cpu") torch.testing.assert_close(actual, expected) assert state["primitive"] == 123 + + +@RunIf(min_cuda_gpus=2, thunder=True, standalone=True) +@pytest.mark.parametrize("jit", (False, True)) +def test_jit_before_setup(jit): + import thunder + + fabric = Fabric(devices=2, accelerator="cuda", strategy=ThunderFSDPStrategy(jit=jit)) + fabric.launch() + + x = torch.randn(1, 1, device=fabric.device) + model = torch.nn.Linear(1, 2, bias=False, device=fabric.device) + + tmodel = thunder.jit(model) + fmodel = fabric.setup(tmodel) + fmodel(x) + + assert "all_gather" in thunder.last_traces(tmodel)[-1].python() + + +@RunIf(min_cuda_gpus=1, thunder=True) +def test_setup_already_traced(): + import thunder + + device = torch.device("cuda") + x = torch.randn(1, 1, device=device) + model = torch.nn.Linear(1, 2, bias=False, device=device) + + strategy = ThunderFSDPStrategy() + + tmodel = thunder.jit(model) + tmodel(x) + with pytest.raises(RuntimeError, match="already called"): + strategy.setup_module(tmodel) diff --git a/tests/test_unsloth_executor.py b/tests/test_unsloth_executor.py index 3ce6907ffd..797d1f6f53 100644 --- a/tests/test_unsloth_executor.py +++ b/tests/test_unsloth_executor.py @@ -3,8 +3,8 @@ from conftest import RunIf from litgpt import GPT, Config -from litgpt.utils import chunked_cross_entropy from litgpt.model import apply_rope, build_rope_cache +from litgpt.utils import chunked_cross_entropy @RunIf(min_cuda_gpus=1, thunder=True) @@ -85,9 +85,9 @@ def test_unsloth_swiglu(): import thunder from thunder.core.transforms import grad - from extensions.thunder.unsloth.executor import unsloth_ex, ThunderLLaMAMLP - from litgpt.model import LLaMAMLP + from extensions.thunder.unsloth.executor import ThunderLLaMAMLP, unsloth_ex from litgpt import Config + from litgpt.model import LLaMAMLP config = Config.from_name("Llama-2-7b-hf") with torch.device("cuda"): diff --git a/tutorials/0_to_litgpt.md b/tutorials/0_to_litgpt.md new file mode 100644 index 0000000000..65d19cef10 --- /dev/null +++ b/tutorials/0_to_litgpt.md @@ -0,0 +1,550 @@ +# Zero to LitGPT: Getting Started with Pretraining, Finetuning, and Using LLMs + + + +This tutorial walks you through the main features and usage patterns for ⚡️LitGPT, a library for pretraining, finetuning, and using LLMs that focuses on an efficient user experience while being developer-friendly. + +The topics, following the installation of LitGPT, are in chronological order, reflecting the steps in an LLM lifecycle: Pretraining → Finetuning → Inference. + +  + + + +  + + + +  + +However, it is also possible, and even common, to use and deploy models with LitGPT without pretraining and finetuning. So, if you are not interested in pretraining and finetuning, please feel free to skip these sections. + + + + + +  +## Install LitGPT + +LitGPT is available as a Python library from the PyPI package repository, and we recommend installing it using Python's `pip` installer module, including all required package dependencies: + +```bash +pip install 'litgpt[all]' +``` + +Alternatively, if you are a researcher or developer planning to make changes to LitGPT, you can clone the GitHub repository and install it from a local folder as follows: + +``` +git clone https://github.com/Lightning-AI/litgpt.git +cd litgpt +pip install -e '.[all]' +``` + + +  +## Pretrain LLMs + +Finetuning LLMs requires substantial compute resources and time commitment. For that reason, most researchers and practitioners prefer to skip this step and continue with the Download pretrained model weights section instead. + +However, if you feel adventurous and want to pretrain your own LLM, here's how. + +First, we have to decide which type of model architecture we want to use. We list the available architectures by using the `pretrain` command without any additional arguments: + +```bash +litgpt pretrain +``` + +This prints a list of all available model architectures in alphabetical order: + +``` +Camel-Platypus2-13B +Camel-Platypus2-70B +CodeLlama-13b-Python-hf +... +tiny-llama-1.1b +vicuna-13b-v1.3 +vicuna-13b-v1.5 +vicuna-13b-v1.5-16k +vicuna-33b-v1.3 +vicuna-7b-v1.3 +vicuna-7b-v1.5 +vicuna-7b-v1.5-16k +``` + +Suppose we want to pretraining the 1.1B parameter small `tiny-llama-1.1b` model. Before starting finetuning, we must also choose and download a tokenizer. + +We can download a tokenizer via the `download` command. Note that running `litgpt download` without any additional arguments will also print a list of all available models and tokenizers to download. + +To filter for specific models, e.g., TinyLlama, we can use the `grep` command in our terminal: + +```bash +litgpt download | grep TinyLlama +``` + +This prints + +``` +TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T +TinyLlama/TinyLlama-1.1B-Chat-v1.0 +``` + +Let's now download the tokenizer corresponding to `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T` that we can then use to pretrain the TinyLlama model, which saves the download tokenizer to a `checkpoints/` folder by default: + +``` +litgpt download \ + --repo_id TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \ + --tokenizer_only true +``` + +  + + + +  + +Next, we can pretrain the model on the OpenWebText dataset with the default setting as follows: + +```bash +litgpt pretrain \ + --model_name tiny-llama-1.1b \ + --data OpenWebText \ + --tokenizer_dir checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T +``` + +If you are interested in additional settings, you can use the help command as follows: + +``` +litgpt pretrain --help +``` + +  + +> [!TIP] +> Above, we only covered the most basic commands for pretraining a model using LitGPT. We highly recommend checking the resources below if you are interested in pretraining a model. + +  + +**More information and additional resources** + +- [tutorials/pretraimd](./pretrain.md): General information about pretraining in LitGPT +- [tutorials/pretrain_tinyllama](./pretrain_tinyllama.md): A tutorial for finetuning a 1.1B TinyLlama model on 3 trillion tokens +- [config_hub/pretrain](../config_hub/pretrain): Pre-made config files for pretraining that work well out of the box +- Project templates in reproducible environments with multi-GPU and multi-node support: + - [Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) + - [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) + - [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) + + +  +## Download pretrained model weights + +Most practical use cases, like LLM inference (/chat) or finetuning, involve using pretrained model weights. LitGPT supports a large number of model weights, which can be listed by executing the `download` command without any additional arguments: + +```bash +litgpt download +``` + +This will print a (long) list of all supported pretrained models (abbreviated for readability below): + +``` +.. +google/gemma-2b +... +meta-llama/Llama-2-7b-hf +... +microsoft/phi-2 +... +mistralai/Mixtral-8x7B-Instruct-v0.1 +... +``` + +To download the model weights, provide one of the model strings above as a `--repo_id` argument: + +```bash +litgpt download --repo_id microsoft/phi-2 +``` + +``` +model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s] +model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s] +tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s] +... +Converting checkpoint files to LitGPT format. +Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin +... +Saving converted checkpoint to checkpoints/microsoft/phi-2 +``` + + +  + +> [!TIP] +> Note that some models, such as Llama 2, require that you accept Meta AI's terms of service for this model, and you need to use a special access token via the `litgpt download ... --access_token ...` option. For more information, visit the respective Model Hub website, e.g., [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf). The access token can be created under your Model Hub in the `Profile > Access Tokens` menu. + +  + + +By default, the weights are going to be stored in a `./checkpoints` subdirectory: + +```bash +ls -lh checkpoints/microsoft/phi-2/ +``` + +``` +total 11G +-rw-r--r-- 1 sebastian sebastian 863 Mar 19 21:14 config.json +-rw-r--r-- 1 sebastian sebastian 124 Mar 19 21:14 generation_config.json +-rw-r--r-- 1 sebastian sebastian 5.2G Mar 19 21:15 lit_model.pth +-rw-r--r-- 1 sebastian sebastian 4.7G Mar 19 21:15 model-00001-of-00002.bin +-rw-r--r-- 1 sebastian sebastian 538M Mar 19 21:15 model-00002-of-00002.bin +-rw-r--r-- 1 sebastian sebastian 528 Mar 19 21:15 model_config.yaml +-rw-r--r-- 1 sebastian sebastian 2.1M Mar 19 21:14 tokenizer.json +-rw-r--r-- 1 sebastian sebastian 7.2K Mar 19 21:14 tokenizer_config.json +``` + +The model is now ready for inference and chat, for example, using the `chat` command on the checkpoint directory: + +```bash +litgpt chat --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +``` +Now chatting with phi-2. +To exit, press 'Enter' on an empty prompt. + +Seed set to 1234 +>> Prompt: Why are LLMs so useful? +>> Reply: When building applications or operating systems, you can use LLMs to know how a computer should respond to your commands. This can make your programs run faster and more efficiently. + +Time for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens + +>> Prompt: +``` + + +  +**More information and additional resources** + +- [tutorials/download_model_weights](download_model_weights.md): A more comprehensive download tutorial, tips for GPU memory limitations, and more + + +  +## Finetune LLMs + +LitGPT supports several methods of supervised instruction finetuning, which allows you to finetune models to follow instructions. + +Datasets for Instruction-finetuning are usually formatted in the following way: + +  + + + +  + +Alternatively, datasets for instruction finetuning can also contain an `'input'` field: + +In an instruction-finetuning context, "full" finetuning means updating all model parameters as opposed to only a subset. Adapter and LoRA (short for low-rank adaptation) are methods for parameter-efficient finetuning that only require updating a small fraction of the model weights. + +  + + + +  + +Parameter-efficient finetuning is much more resource-efficient and cheaper than full finetuning, and it often results in the same good performance on downstream tasks. + +In the following example, we will use LoRA for finetuning, which is one of the most popular LLM finetuning methods. (For more information on how LoRA works, please see [Code LoRA from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).) + +Before we start, we have to download a model as explained in the previous "Download pretrained model" section above: + +```bash +litgpt download --repo_id microsoft/phi-2 +``` + +The LitGPT interface can be used via command line arguments and configuration files. We recommend starting with the configuration files from the [config_hub](../config_hub) and either modifying them directly or overriding specific settings via the command line. For example, we can use the following setting to train the downloaded 2.7B parameter `microsoft/phi-2` model, where we set `--max_steps 5` for a quick test run. + +If you have downloaded or cloned the LitGPT repository, you can provide the `config` file via a relative path: + +```bash +litgpt finetune lora \ + --config config_hub/finetune/phi-2/lora.yaml \ + --train.max_steps 5 +``` + +Alternatively, you can provide a URL: + +```bash +litgpt finetune lora \ + --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/finetune/phi-2/lora.yaml \ + --train.max_steps 5 +``` + + +  + + +> [!TIP] +> Note that the config file above will finetune the model on the `Alpaca2k` dataset on 1 GPU and save the resulting files in an `out/finetune/lora-phi-2` directory. All of these settings can be changed via a respective command line argument or by changing the config file. +> To see more options, execute `litgpt finetune lora --help`. + +  + +Running the previous finetuning command will initiate the finetuning process, which should only take about a minute on a GPU due to the `--train.max_steps 5` setting. + +``` +{'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'), + 'data': Alpaca2k(mask_prompt=False, + val_split_fraction=0.03847, + prompt_style=, + ignore_index=-100, + seed=42, + num_workers=4, + download_dir=PosixPath('data/alpaca2k')), + 'devices': 1, + 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100), + 'logger_name': 'csv', + 'lora_alpha': 16, + 'lora_dropout': 0.05, + 'lora_head': True, + 'lora_key': True, + 'lora_mlp': True, + 'lora_projection': True, + 'lora_query': True, + 'lora_r': 8, + 'lora_value': True, + 'out_dir': PosixPath('out/finetune/lora-phi-2'), + 'precision': 'bf16-true', + 'quantize': None, + 'seed': 1337, + 'train': TrainArgs(save_interval=800, + log_interval=1, + global_batch_size=8, + micro_batch_size=4, + lr_warmup_steps=10, + epochs=1, + max_tokens=None, + max_steps=5, + max_seq_length=512, + tie_embeddings=None, + learning_rate=0.0002, + weight_decay=0.0, + beta1=0.9, + beta2=0.95, + max_norm=None, + min_lr=6e-05)} +Seed set to 1337 +Number of trainable parameters: 12,226,560 +Number of non-trainable parameters: 2,779,683,840 +The longest sequence length in the train data is 512, the model's maximum sequence length is 512 and context length is 2048 +Validating ... +Recommend a movie for me to watch during the weekend and explain the reason. +Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +Recommend a movie for me to watch during the weekend and explain the reason. + +### Response: +I recommend you watch "Parasite" because it's a critically acclaimed movie that won multiple awards, including the Academy Award for Best Picture. It's a thought-provoking and suspenseful film that will keep you on the edge of your seat. The movie also tackles social and economic inequalities, making it a must-watch for anyone interested in meaningful storytelling. + +/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MeanMetric was called before the ``update`` method which may lead to errors, as metric states have not yet been updated. + warnings.warn(*args, **kwargs) # noqa: B028 +Missing logger folder: out/finetune/lora-phi-2/logs/csv +Epoch 1 | iter 1 step 0 | loss train: 1.646, val: n/a | iter time: 820.31 ms +Epoch 1 | iter 2 step 1 | loss train: 1.660, val: n/a | iter time: 548.72 ms (step) +Epoch 1 | iter 3 step 1 | loss train: 1.687, val: n/a | iter time: 300.07 ms +Epoch 1 | iter 4 step 2 | loss train: 1.597, val: n/a | iter time: 595.27 ms (step) +Epoch 1 | iter 5 step 2 | loss train: 1.640, val: n/a | iter time: 260.75 ms +Epoch 1 | iter 6 step 3 | loss train: 1.703, val: n/a | iter time: 568.22 ms (step) +Epoch 1 | iter 7 step 3 | loss train: 1.678, val: n/a | iter time: 511.70 ms +Epoch 1 | iter 8 step 4 | loss train: 1.741, val: n/a | iter time: 514.14 ms (step) +Epoch 1 | iter 9 step 4 | loss train: 1.689, val: n/a | iter time: 423.59 ms +Epoch 1 | iter 10 step 5 | loss train: 1.524, val: n/a | iter time: 603.03 ms (step) +Training time: 11.20s +Memory used: 13.90 GB +Saving LoRA weights to 'out/finetune/lora-phi-2/final/lit_model.pth.lora' +Saved merged weights to 'out/finetune/lora-phi-2/final/lit_model.pth' +``` + +Notice that the LoRA script saves both the LoRA weights (`'out/finetune/lora-phi-2/final/lit_model.pth.lora'`) and the LoRA weight merged back into the original model (`'out/finetune/lora-phi-2/final/lit_model.pth'`) for convenience. This allows us to use the finetuned model via the `chat` function directly: + +```bash +litgpt chat --checkpoint_dir out/finetune/lora-phi-2/final/ +``` + +``` +Now chatting with phi-2. +To exit, press 'Enter' on an empty prompt. + +Seed set to 1234 +>> Prompt: Why are LLMs so useful? +>> Reply: LLMs are useful because they can be trained to perform various natural language tasks, such as language translation, text generation, and question-answering. They are also able to understand the context of the input data, which makes them particularly useful for tasks such as sentiment analysis and text summarization. Additionally, because LLMs can learn from large amounts of data, they are able to generalize well and perform well on new data. + +Time for inference: 2.15 sec total, 39.57 tokens/sec, 85 tokens + +>> Prompt: +``` + + + +  + +**More information and additional resources** + +- [tutorials/prepare_dataset](prepare_dataset.md): A summary of all out-of-the-box supported datasets in LitGPT and utilities for preparing custom datasets +- [tutorials/finetune](finetune.md): An overview of the different finetuning methods supported in LitGPT +- [tutorials/finetune_full](finetune_full.md): A tutorial on full-parameter finetuning +- [tutorials/finetune_lora](finetune_lora.md): Options for parameter-efficient finetuning with LoRA and QLoRA +- [tutorials/finetune_adapter](finetune_adapter.md): A description of the parameter-efficient Llama-Adapter methods supported in LitGPT +- [tutorials/oom](oom.md): Tips for dealing with out-of-memory (OOM) errors +- [config_hub/finetune](../config_hub/finetune): Pre-made config files for finetuning that work well out of the box + +  +## LLM inference + +To use a downloaded or finetuned model for chat, you only need to provide the corresponding checkpoint directory containing the model and tokenizer files. For example, to chat with the phi-2 model from Microsoft, download it as follows, as described in the "Download pretrained model" section: + +```bash +litgpt download --repo_id microsoft/phi-2 +``` + +``` +model-00001-of-00002.safetensors: 100%|████████████████████████████████| 5.00G/5.00G [00:40<00:00, 124MB/s] +model-00002-of-00002.safetensors: 100%|████████████████████████████████| 564M/564M [00:01<00:00, 330MB/s] +tokenizer.json: 100%|██████████████████████████████████████████████████| 2.11M/2.11M [00:00<00:00, 54.0MB/s] +... +Converting checkpoint files to LitGPT format. +Processing checkpoints/microsoft/phi-2/model-00001-of-00002.bin +... +Saving converted checkpoint to checkpoints/microsoft/phi-2 +``` + + + + + +Then, chat with the model using the following command: + +```bash +litgpt chat --checkpoint_dir checkpoints/microsoft/phi-2 +``` + +``` +Now chatting with phi-2. +To exit, press 'Enter' on an empty prompt. + +Seed set to 1234 +>> Prompt: What is the main difference between a large language model and a traditional search engine? +>> Reply: A large language model uses deep learning algorithms to analyze and generate natural language, while a traditional search engine uses algorithms to retrieve information from web pages. + +Time for inference: 1.14 sec total, 26.26 tokens/sec, 30 tokens +``` + +> [!TIP] +> Most model weights are already represented in an efficient bfloat16 format. However, if the model currently exceeds your GPU memory, you can try to pass the `--precision bf16-true` option. In addition, you can check the quantization documentation for further optimization, which is linked below. + + +  +**More information and additional resources** + +- [tutorials/inference](inference.md): Chat and inference tutorial +- [tutorials/quantize](quantize.md): Quantizing models to reduce GPU memory requirements + + + + +  +## Converting LitGPT model weights to `safetensors` format + +Sometimes, it can be useful to convert LitGPT model weights for third-party and external tools. For example, we can convert a LitGPT model to the Hugging Face format and save it via `.safetensors` files. + +The `--checkpoint_dir` argument provided below points to a directory corresponding to a downloaded or finetuned model (see the *Download pretrained model* or *Finetune LLMs* sections above for more information): + + +```bash +litgpt convert from_litgpt \ + --checkpoint_dir checkpoints/microsoft/phi-2 \ + --output_dir out/converted_model/ +``` + +Certain tools like the `.from_pretrained` method in Hugging Face `transformers` also require the original `config.json` file that originally came with the downloaded model: + +```bash +cp checkpoints/microsoft/phi-2/config.json out/converted_model/config.json +``` + +You can now load the model into a Hugging Face transformers model and safe it in a `.safetensors` format as follows: + +```bash +import torch +from transformers import AutoModel + +# Load model +state_dict = torch.load('out/converted_model/model.pth') +model = AutoModel.from_pretrained( + "microsoft/phi-2", state_dict=state_dict +) + +# Save .safetensors files +model.save_pretrained("out/converted_model/") +``` + +``` +⚡ ~/litgpt ls -lh out/converted_model +total 16G +-rwxr--r-- 1 sebastian sebastian 891 Mar 20 17:08 config.json +-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:08 model-00001-of-00003.safetensors +-rw-r--r-- 1 sebastian sebastian 4.7G Mar 20 17:09 model-00002-of-00003.safetensors +-rw-r--r-- 1 sebastian sebastian 601M Mar 20 17:09 model-00003-of-00003.safetensors +-rw-r--r-- 1 sebastian sebastian 5.2G Mar 20 16:30 model.pth +-rw-r--r-- 1 sebastian sebastian 33K Mar 20 17:09 model.safetensors.index.json +``` + +You can then use the model with external tools, for example, Eleuther AI's [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) (see the `lm_eval` installation instructions [here](https://github.com/EleutherAI/lm-evaluation-harness?tab=readme-ov-file#install)). + +The LM Evaluation Harness requires a tokenizer to be present in the model checkpoint folder, which we can copy from the original download checkpoint: + +```bash +# Copy the tokenizer needed by the Eval Harness +cp checkpoints/microsoft/phi-2/tokenizer* +out/converted_model +``` + +Then, we can run the Evaluation Harness as follows: + +```bash +lm_eval --model hf \ + --model_args pretrained="out/converted_model" \ + --tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \ + --device "cuda:0" \ + --batch_size 4 +``` + +  + +> [!TIP] +> The Evaluation Harness tasks above are those used in Open LLM Leaderboard. You can find a list all supported tasks [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md). + + + +  +**More information and additional resources** + +- [tutorials/convert_lit_models](./convert_lit_models.md): Tutorial on converting LitGPT weights + + + +  + +## Get involved! + +We appreciate your feedback and contributions. If you have feature requests, questions, or want to contribute code or config files, please don't hesitate to use the [GitHub Issue](https://github.com/Lightning-AI/litgpt/issues) tracker. + +We welcome all individual contributors, regardless of their level of experience or hardware. Your contributions are valuable, and we are excited to see what you can accomplish in this collaborative and supportive environment. + +  + +> [!TIP] +> Unsure about contributing? Check out our [How to Contribute to LitGPT](https://lightning.ai/pages/community/tutorial/how-to-contribute-to-litgpt/) guide. + +  + +If you have general questions about building with LitGPT, please [join our Discord](https://discord.gg/VptPCZkGNa). diff --git a/tutorials/convert_hf_checkpoint.md b/tutorials/convert_hf_checkpoint.md index 39bbc9823d..7081ae7c46 100644 --- a/tutorials/convert_hf_checkpoint.md +++ b/tutorials/convert_hf_checkpoint.md @@ -26,13 +26,13 @@ checkpoints/ To disable the automatic conversion, which is useful for development and debugging purposes, you can run the `litgpt/scripts/download.py` with the `--convert_checkpoint false` flag. This will only download the checkpoint files but do not convert them for use in LitGPT: ```bash -rm -rf checkpoints/EleutherAI/pythia-14m +rm -rf checkpoints/EleutherAI/pythia-14m litgpt download \ --repo_id EleutherAI/pythia-14m \ --convert_checkpoint false - -ls checkpoints/EleutherAI/pythia-14m + +ls checkpoints/EleutherAI/pythia-14m ``` ``` @@ -52,4 +52,3 @@ The required files `model_config.yaml` and `lit_model.pth` files can then be man litgpt convert to_litgpt \ --checkpoint_dir checkpoints/EleutherAI/pythia-14m ``` - diff --git a/tutorials/convert_lit_models.md b/tutorials/convert_lit_models.md index 3525e8a40e..301abfc161 100644 --- a/tutorials/convert_lit_models.md +++ b/tutorials/convert_lit_models.md @@ -66,7 +66,7 @@ For convenience, we first specify an environment variable (optional) to avoid co export repo_id=TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T ``` -Instead of using TinyLlama, you can replace the `repo_id` target with any other model repository +Instead of using TinyLlama, you can replace the `repo_id` target with any other model repository specifier that is currently supported by LitGPT. You can get a list of supported repository specifier by running `litgpt/scripts/download.py` without any additional arguments. @@ -147,4 +147,4 @@ lm_eval --model hf \ --tasks "hellaswag,gsm8k,truthfulqa_mc2,mmlu,winogrande,arc_challenge" \ --device "cuda:0" \ --batch_size 4 -``` \ No newline at end of file +``` diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index a058619046..55c214a01c 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -146,8 +146,25 @@ togethercomputer/RedPajama-INCITE-Chat-7B-v0.1 togethercomputer/RedPajama-INCITE-Instruct-3B-v1 togethercomputer/RedPajama-INCITE-Instruct-7B-v0.1 Trelis/Llama-2-7b-chat-hf-function-calling-v2 +unsloth/Mistral-7B-v0.2 ``` +  + +> [!TIP] +> To sort the list above by model name after the `/`, use `litgpt download | sort -f -t'/' -k2`. + +  + +> [!NOTE] +> If you want to adopt a model variant that is not listed in the table above but has a similar architecture as one of the supported models, you can use this model by by using the `--model_name` argument as shown below: +> ```bash +> litgpt download \ +> --repo_id NousResearch/Hermes-2-Pro-Mistral-7B \ +> --model_name Mistral-7B-v0.1 +> ``` + +   ### 2. Download Model Weights @@ -209,7 +226,7 @@ litgpt chat --checkpoint_dir checkpoints/$repo_id   ## Specific Models -Note that certain models require that you've been granted access to the weights on the Hugging Face Hub. +Note that certain models require that you've been granted access to the weights on the Hugging Face Hub. For example, to get access to the Gemma 2B model, you can do so by following the steps at https://huggingface.co/google/gemma-2b. After access is granted, you can find your HF hub token in https://huggingface.co/settings/tokens. @@ -249,7 +266,7 @@ litgpt download \   ## Converting Checkpoints Manually -For development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps. +For development purposes, for example, when adding or experimenting with new model configurations, it may be beneficial to split the weight download and model conversion into two separate steps. You can do this by passing the `--convert_checkpoint false` option to the download script: diff --git a/tutorials/images/0_to_litgpt/4-commands.webp b/tutorials/images/0_to_litgpt/4-commands.webp new file mode 100644 index 0000000000..aac24a13b3 Binary files /dev/null and b/tutorials/images/0_to_litgpt/4-commands.webp differ diff --git a/tutorials/images/0_to_litgpt/finetune.webp b/tutorials/images/0_to_litgpt/finetune.webp new file mode 100644 index 0000000000..b61b3fa698 Binary files /dev/null and b/tutorials/images/0_to_litgpt/finetune.webp differ diff --git a/tutorials/images/0_to_litgpt/instruction-1.webp b/tutorials/images/0_to_litgpt/instruction-1.webp new file mode 100644 index 0000000000..bf21f6adf7 Binary files /dev/null and b/tutorials/images/0_to_litgpt/instruction-1.webp differ diff --git a/tutorials/images/0_to_litgpt/instruction-2.webp b/tutorials/images/0_to_litgpt/instruction-2.webp new file mode 100644 index 0000000000..64f9f6bf21 Binary files /dev/null and b/tutorials/images/0_to_litgpt/instruction-2.webp differ diff --git a/tutorials/images/0_to_litgpt/pretrain.webp b/tutorials/images/0_to_litgpt/pretrain.webp new file mode 100644 index 0000000000..838a077f1c Binary files /dev/null and b/tutorials/images/0_to_litgpt/pretrain.webp differ diff --git a/tutorials/images/0_to_litgpt/usage.webp b/tutorials/images/0_to_litgpt/usage.webp new file mode 100644 index 0000000000..5b555cee43 Binary files /dev/null and b/tutorials/images/0_to_litgpt/usage.webp differ diff --git a/tutorials/inference.md b/tutorials/inference.md index 81cefe6816..4675624149 100644 --- a/tutorials/inference.md +++ b/tutorials/inference.md @@ -1,6 +1,6 @@ # Inference -We demonstrate how to run inference (next token prediction) with the GPT base model in the [`generate.py`](generate.py) script: +We demonstrate how to run inference (next token prediction) with the GPT base model in the [`generate.py`](../litgpt/generate/base.py) script: ```bash litgpt generate base --prompt "Hello, my name is" --checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b diff --git a/tutorials/oom.md b/tutorials/oom.md index c12573da10..c02ee5b2fd 100644 --- a/tutorials/oom.md +++ b/tutorials/oom.md @@ -34,7 +34,7 @@ However, your hardware may not support such large context lengths. Here's what y * For the finetuning scripts, you can trim the length of the samples in your dataset. All the finetuning scripts expose a `--data.max_seq_length=...` argument. This might also be useful in cases where sample lengths are highly unbalanced, as the presence of a single very long sample would incur a larger memory usage for all other - shorter samples. For example, the median length of the samples in Alpaca is 110 tokens. Truncating the Alpaca dataset to 256 max tokens reduces the memory requirements of a Falcon 7B model from 23.52 GB to 15.73 GB. For more information about the dataset truncation, please see the *Truncating datasets* section in the [prepare_datasets.md](prepare_datasets.md) tutorial. + shorter samples. For example, the median length of the samples in Alpaca is 110 tokens. Truncating the Alpaca dataset to 256 max tokens reduces the memory requirements of a Falcon 7B model from 23.52 GB to 15.73 GB. For more information about the dataset truncation, please see the *Truncating datasets* section in the [prepare_dataset.md](prepare_dataset.md) tutorial. Keep in mind that reducing the context length will affect the modelling performance on text sequences longer than the limit. diff --git a/tutorials/prepare_dataset.md b/tutorials/prepare_dataset.md index 51df8b020c..7f7cf238ae 100644 --- a/tutorials/prepare_dataset.md +++ b/tutorials/prepare_dataset.md @@ -79,8 +79,7 @@ For comparison, the Falcon 7B model requires 23.52 GB of memory for the original ### Alpaca-GPT4 - -The Alpaca-GPT4 was built by using the prompts of the original Alpaca dataset and generate the responses via GPT 4. The +The Alpaca-GPT4 was built by using the prompts of the original Alpaca dataset and generate the responses via GPT 4. The dataset consists of 52,000 instructions and responses. The original [Alpaca-GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) dataset can be used as follows: @@ -126,12 +125,11 @@ litgpt finetune lora \ --train.max_seq_length 256 ``` -   ### Deita -The Deita dataset (short for Data-Efficient Instruction Tuning for Alignment) is a collection of 9500 prompts and responses, as described in the [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) paper. +The Deita dataset (short for Data-Efficient Instruction Tuning for Alignment) is a collection of 9500 prompts and responses, as described in the [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) paper. Using Falcon 7b as an example, we can use the dataset as follows: ```bash @@ -162,7 +160,6 @@ litgpt finetune lora \ --train.max_seq_length 512 ``` -   ### Dolly @@ -281,7 +278,6 @@ litgpt finetune lora \ However, you can also select individual subsets via comma-separated strings as follows: - ```bash litgpt finetune lora \ --data FLAN \ @@ -385,5 +381,4 @@ Note that you only need to modify a small fraction of the code file, namely the In addition to the finetuning dataset described above, LitGPT also supports several datasets for pretraining. The pretraining datasets are described in more detail in the following separate tutorial documents: -- [Pretrain Llama 2 on OpenWebText](./pretrain_openwebtext.md) - [Pretrain TinyLlama on Slimpajama and Starcoder](./pretrain_tinyllama.md) diff --git a/tutorials/pretrain.md b/tutorials/pretrain.md new file mode 100644 index 0000000000..4a8db678e1 --- /dev/null +++ b/tutorials/pretrain.md @@ -0,0 +1,65 @@ +# Pretrain LLMs with LitGPT + + +This document explains how to pretrain LLMs using LitGPT. + +  +## The Pretraining API + +You can pretrain models in LitGPT using the `litgpt pretrain` API starting with any of the available architectures listed by calling `litgpt pretrain` without any additional arguments: + +```bash +litgpt pretrain +``` + +Shown below is an abbreviated list: + +``` +ValueError: Please specify --model_name . Available values: +Camel-Platypus2-13B +... +Gemma-2b +... +Llama-2-7b-hf +... +Mixtral-8x7B-v0.1 +... +pythia-14m +``` + +For demonstration purposes, we can pretrain a small 14 million-parameter Pythia model on the small TinyStories dataset using the [debug.yaml config file](https://github.com/Lightning-AI/litgpt/blob/main/config_hub/pretrain/debug.yaml) as follows: + +```bash +litgpt pretrain \ + --model_name pythia-14m \ + --config https://raw.githubusercontent.com/Lightning-AI/litgpt/main/config_hub/pretrain/debug.yaml +``` + + + + +  +## Pretrain a 1.1B TinyLlama model + +You can find an end-to-end LitGPT tutorial for pretraining a TinyLlama model using LitGPT [here](pretrain_tinyllama.md). + + +  +## Optimize LitGPT pretraining with Lightning Thunder + +[Lightning Thunder](https://github.com/Lightning-AI/lightning-thunder) is a source-to-source compiler for PyTorch, which is fully compatible with LitGPT. In experiments, Thunder resulted in a 40% speed-up compared to using regular PyTorch when finetuning a 7B Llama 2 model. + +For more information, see the [Lightning Thunder extension README](https://github.com/Lightning-AI/lightning-thunder). + + +  +## Project templates + +The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support: +  + +| | | +|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +|

[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[

](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) | +| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | | +| | \ No newline at end of file diff --git a/tutorials/pretrain_tinyllama.md b/tutorials/pretrain_tinyllama.md index 245ec48ab7..f4976ee097 100644 --- a/tutorials/pretrain_tinyllama.md +++ b/tutorials/pretrain_tinyllama.md @@ -5,6 +5,7 @@ This tutorial will walk you through pretraining [TinyLlama](https://github.com/j > [!TIP] > To get started with zero setup, clone the [TinyLlama studio on Lightning AI](https://lightning.ai/lightning-ai/studios/llm-pretrain-tinyllama-1-1b). +  ## What's TinyLlama? [TinyLlama](https://github.com/jzhang38/TinyLlama/) is architecturally the same as Meta AI's LLama 2, but only has 1.1B parameters and is instead trained on multiple epochs on a mix of [SlimPajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) and [Starcoder](https://huggingface.co/datasets/bigcode/starcoderdata) datasets. @@ -26,6 +27,7 @@ Here is a quick fact sheet: (this table was sourced from the author's [README](https://github.com/jzhang38/TinyLlama/)) +  ## Download datasets You can download the data using git lfs: @@ -42,6 +44,7 @@ git clone https://huggingface.co/datasets/bigcode/starcoderdata data/starcoderda Around 1.2 TB of disk space is required to store both datasets. +  ## Prepare the datasets for training In order to start pretraining litgpt on it, you need to read, tokenize, and write the data in binary chunks. This will leverage the `litdata` optimization pipeline and streaming dataset. @@ -95,6 +98,7 @@ python litgpt/data/prepare_slimpajama.py \ If you want to run on a small slice of the datasets first, pass the flag `--fast_dev_run=true` to the commands above. In the above we are assuming that you will be using the same tokenizer as used in LlaMA/TinyLlama, but any trained [SentencePiece](https://github.com/google/sentencepiece) tokenizer with a 32000 vocabulary size will do here. +  ## Pretraining Running the pretraining script with its default settings requires at least 8 A100 GPUs. @@ -139,6 +143,7 @@ Last, logging is kept minimal in the script, but for long-running experiments we As an example, we included WandB (set `--logger_name=wandb`) to show how you can integrate any experiment tracking framework. For reference, [here are the loss curves for our reproduction](https://api.wandb.ai/links/awaelchli/y7pzdpwy). +  ## Resume training The checkpoints saved during pretraining contain all the information to resume if needed. @@ -151,6 +156,7 @@ litgpt pretrain \ ``` **Important:** Each checkpoint is a directory. Point to the directory, not the 'lit_model.pth' file inside of it. +  ## Export checkpoints After training is completed, you can convert the checkpoint to a format that can be loaded for evaluation, inference, finetuning etc. @@ -172,3 +178,16 @@ checkpoints/tiny-llama/final ``` You can then use this checkpoint folder to run [evaluation](evaluation.md), [inference](inference.md), [finetuning](finetune_lora.md) or [process the checkpoint further](convert_lit_models.md). + + +  +## Project templates + +The following [Lightning Studio](https://lightning.ai/lightning-ai/studios) templates provide LitGPT pretraining projects in reproducible environments with multi-GPU and multi-node support: +  + +| | | +|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +|

[Prepare the TinyLlama 1T token dataset](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset)
[

](https://lightning.ai/lightning-ai/studios/prepare-the-tinyllama-1t-token-dataset) | [Pretrain LLMs - TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/pretrain-llms-tinyllama-1-1b) | +| [Continued Pretraining with TinyLlama 1.1B](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b)

[

](https://lightning.ai/lightning-ai/studios/continued-pretraining-with-tinyllama-1-1b) | | +| | \ No newline at end of file