Skip to content

Commit

Permalink
update defaults and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Mar 29, 2024
1 parent 2d5ec4b commit efe78a4
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 79 deletions.
2 changes: 1 addition & 1 deletion config_hub/finetune/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ For more information, see the [Dealing with out-of-memory (OOM) errors](../../tu
| | | | | | | | | | |
| phi-2/lora.yaml | 2B | Alpaca 2k | 1 | 0.832 | 13.98 GB | 512 | 4 | bfloat16 | 3.82 min (1xA10G) |
| phi-2/qlora.yaml | 2B | Alpaca 2k | 1 | 0.846 | 14.27 GB | 512 | 4 | bfloat16 | 4.55 min (1xA10G) |
| phi-2/full.yaml | 2B | Alpaca 2k | 1 | 0.937 | 14.44 GB | 512 | 4 | bfloat16 | 13.00 min (1xA10G) |
| phi-2/full.yaml | 2B | Alpaca 2k | 1 | 0.937 | 14.44 GB | 512 | 4 | bfloat16 | 13.00 min (2xA10G) |
| | | | | | | | | | |
| stablelm-base-alpha-3b/lora.yaml | 7B | Alpaca 2k | 4 | 1.367 | 8.58 GB | 512 | 2 | bfloat16 | 13.02 min (1xA10G) |
| stablelm-base-alpha-3b/qlora.yaml | 7B | Alpaca 2k | 4 | 1.392 | 5.24 GB | 512 | 2 | bfloat16 | 25.71 min (1xA10G) |
Expand Down
16 changes: 8 additions & 8 deletions litgpt/finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.data import Alpaca2k, DataModule
from litgpt.generate.base import generate
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
Expand Down Expand Up @@ -46,14 +46,14 @@ def setup(
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
global_batch_size=8,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
learning_rate=1e-3,
lr_warmup_steps=10,
epochs=1,
learning_rate=0.002,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
eval: EvalArgs = EvalArgs(interval=50, max_new_tokens=100, max_iters=100),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
) -> None:
Expand All @@ -65,15 +65,15 @@ def setup(
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca2k``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""

pprint(locals())
data = Alpaca() if data is None else data
data = Alpaca2k() if data is None else data
devices = parse_devices(devices)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down
16 changes: 8 additions & 8 deletions litgpt/finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from litgpt.adapter_v2 import GPT, Block, Config, adapter_filter, mark_only_adapter_v2_as_trainable
from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.data import Alpaca2k, DataModule
from litgpt.generate.base import generate
from litgpt.prompts import save_prompt_style
from litgpt.tokenizer import Tokenizer
Expand Down Expand Up @@ -46,14 +46,14 @@ def setup(
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
global_batch_size=8,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
learning_rate=1e-3,
lr_warmup_steps=10,
epochs=1,
learning_rate=0.002,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
eval: EvalArgs = EvalArgs(interval=50, max_new_tokens=100, max_iters=100),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
) -> None:
Expand All @@ -65,15 +65,15 @@ def setup(
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
devices: How many devices/GPUs to use.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca2k``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""

pprint(locals())
data = Alpaca() if data is None else data
data = Alpaca2k() if data is None else data
devices = parse_devices(devices)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down
14 changes: 7 additions & 7 deletions litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchmetrics import RunningMean

from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.data import Alpaca2k, DataModule
from litgpt.generate.base import generate
from litgpt.model import GPT, Block, Config
from litgpt.prompts import save_prompt_style
Expand Down Expand Up @@ -46,12 +46,12 @@ def setup(
log_interval=1,
global_batch_size=16,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
learning_rate=3e-3,
lr_warmup_steps=1000,
epochs=1,
learning_rate=0.0002,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
eval: EvalArgs = EvalArgs(interval=50, max_new_tokens=100, max_iters=100),
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
seed: int = 1337,
) -> None:
Expand All @@ -64,15 +64,15 @@ def setup(
devices: How many devices/GPUs to use
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
from the latest checkpoint in ``out_dir``.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca2k``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""

pprint(locals())
data = Alpaca() if data is None else data
data = Alpaca2k() if data is None else data
devices = parse_devices(devices)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down
14 changes: 7 additions & 7 deletions litgpt/finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchmetrics import RunningMean

from litgpt.args import EvalArgs, TrainArgs
from litgpt.data import Alpaca, DataModule
from litgpt.data import Alpaca2k, DataModule
from litgpt.generate.base import generate
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
from litgpt.prompts import save_prompt_style
Expand All @@ -43,7 +43,7 @@ def setup(
precision: Optional[str] = None,
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
devices: Union[int, str] = 1,
lora_r: int = 8,
lora_r: int = 32,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_query: bool = True,
Expand All @@ -56,11 +56,11 @@ def setup(
train: TrainArgs = TrainArgs(
save_interval=1000,
log_interval=1,
global_batch_size=16,
global_batch_size=8,
micro_batch_size=1,
lr_warmup_steps=100,
epochs=5,
learning_rate=3e-4,
epochs=4,
learning_rate=0.0002,
max_seq_length=None,
),
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
Expand All @@ -84,15 +84,15 @@ def setup(
lora_projection: Whether to apply LoRA to the output projection in the attention block.
lora_mlp: Whether to apply LoRA to the weights of the MLP in the attention block.
lora_head: Whether to apply LoRA to output head in GPT.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca2k``.
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
logger_name: The name of the logger to send metrics to.
seed: The random seed to use for reproducibility.
"""

pprint(locals())
data = Alpaca() if data is None else data
data = Alpaca2k() if data is None else data
devices = parse_devices(devices)

check_valid_checkpoint_dir(checkpoint_dir)
Expand Down
2 changes: 1 addition & 1 deletion litgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def get_default_supported_precision(training: bool) -> str:

if MPSAccelerator.is_available() or (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()):
return "16-mixed" if training else "16-true"
return "bf16-mixed" if training else "bf16-true"
return "bf16-true"


def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
Expand Down
90 changes: 71 additions & 19 deletions tutorials/finetune_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Adapter, first introduced for the LLaMA model as [LLaMA-Adapter](https://arxiv.org/abs/2303.16199), is a form of prefix-tuning that prepends a learnable adaption-prompt to the inputs of the attention blocks in an LLM. In total, there are only ~500k parameters to update during finetuning in StableLM 3B, which significantly reduces the memory footprint and speeds up training.

We are able to demonstrate instruction-finetuning LitGPT StableLM 3B on the [Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset on a **single RTX 3060 GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.
We are able to demonstrate instruction-finetuning LitGPT StableLM 3B on the Alpaca 2k dataset (a subset of [Alpcaca](https://github.com/tatsu-lab/stanford_alpaca)) on a **single RTX 3060 GPU**. If using 8 GPUs, finetuning can be completed in under 1 hour.

If you are new to Adapter and are interested to learn more about how it works before proceeding with the finetuning guide below, you might find our article [Understanding Parameter-Efficient Finetuning of Large Language Models: From Prefix Tuning to LLaMA-Adapters](https://lightning.ai/pages/community/article/understanding-llama-adapters/) helpful.

Expand All @@ -19,46 +19,92 @@ LitGPT provides common datasets for finetuning, such as Alpaca, LIMA, Dolly, and
You can optionally [prepare your own dataset](#tune-on-your-dataset).
For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.

## Running the finetuning
For example,

```bash
litgpt finetune adapter \
--data Alpaca \
--checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
litgpt download --repo_id stablelm-base-alpha-3b
```

or for Adapter V2
LitGPT provides common datasets for finetuning, such as Alpaca 2k, Alpaca, LIMA, Dolly, and more.
You can optionally [prepare your own dataset](#tune-on-your-dataset).
For more information about dataset preparation, also see the [prepare_dataset.md](./prepare_dataset.md) tutorial.

 

## Running the finetuning

To finetune the default `"stablelm-base-alpha-3b"` model on Alpaca2k, run the following command:

```bash
litgpt finetune adapter_v2 \
--data Alpaca \
--checkpoint_dir checkpoints/stabilityai/stablelm-base-alpha-3b
litgpt finetune adapter --data Alpaca2k
```

The finetuning requires at least one GPU with ~12 GB memory.
You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.
To fit Adapter V2 to 12GB memory set `--train.micro_batch_size 2`.

For example, the following settings will let you finetune the model in under 1 hour:
Alternatively, you can use Adapter V2 as follows:

```bash
--devices 4 --train.micro_batch_size 4
litgpt finetune adapter_v2
```

The preceeding code will initiate the training, which will print the following outputs (via an A10G GPU):

```
{'checkpoint_dir': PosixPath('checkpoints/stabilityai/stablelm-base-alpha-3b'),
'data': Alpaca2k,
'devices': 1,
'eval': EvalArgs(interval=50, max_new_tokens=100, max_iters=100),
'logger_name': 'csv',
'out_dir': PosixPath('out/finetune/adapter-v2'),
'precision': None,
'quantize': None,
'seed': 1337,
'train': TrainArgs(save_interval=1000,
log_interval=1,
global_batch_size=8,
micro_batch_size=1,
lr_warmup_steps=10,
epochs=1,
max_tokens=None,
max_steps=None,
max_seq_length=None,
tie_embeddings=None,
learning_rate=0.002,
weight_decay=0.02,
beta1=0.9,
beta2=0.95,
max_norm=None,
min_lr=6e-05)}
Seed set to 1337
Number of trainable parameters: 2,125,248
Number of non-trainable parameters: 3,637,051,392
The longest sequence length in the train data is 634, the model's maximum sequence length is 634 and context length is 4096
...
Epoch 1 | iter 1 step 0 | loss train: 1.919, val: n/a | iter time: 304.25 ms
Epoch 1 | iter 2 step 0 | loss train: 2.004, val: n/a | iter time: 88.54 ms
...
Epoch 1 | iter 1899 step 237 | loss train: 1.238, val: 1.420 | iter time: 85.90 ms
Epoch 1 | iter 1900 step 237 | loss train: 1.313, val: 1.420 | iter time: 48.38 ms
Epoch 2 | iter 1901 step 237 | loss train: 1.422, val: 1.420 | iter time: 279.63 ms
Training time: 281.17s
Memory used: 9.44 GB
Saving adapter v2 weights to 'out/finetune/adapter-v2/final/lit_model.pth.adapter_v2'
```


The finetuning requires at least one GPU with ~10 GB memory.
You can speed up training by passing the `devices` argument to the script to utilize more GPUs if available.
Depending on the available GPU memory, you can also tune the `micro_batch_size` parameter to utilize the GPU efficiently.

This script will save checkpoints periodically to the `out_dir` directory. If you are finetuning different models or on your own dataset, you can specify an output directory with your preferred name:

```bash
litgpt finetune adapter \
--data Alpaca \
--out_dir out/adapter/my-model-finetuned
```

or for Adapter V2

```bash
litgpt finetune adapter_v2 \
--data Alpaca \
--out_dir out/adapter_v2/my-model-finetuned
```

Expand All @@ -67,13 +113,15 @@ For instance, to fine-tune on MPS (the GPU on modern Macs), you can run

```bash
litgpt finetune adapter \
--data Alpaca \
--data Alpaca2k \
--out_dir out/adapter/my-model-finetuned \
--precision 32-true
```

Note that `mps` as the accelerator will be picked up automatically by Fabric when running on a modern Mac.

 

### Quantization

Optionally, finetuning using quantization can be enabled via the `--quantize` flag, for example using the 4-bit NormalFloat data type:
Expand All @@ -90,6 +138,8 @@ litgpt finetune adapter_v2 --quantize "bnb.nf4-dq"

For additional benchmarks and resource requirements, please see the [Resource Tables](resource-tables.md).

 

## Test the model

You can test the finetuned model with your own instructions by running:
Expand All @@ -116,6 +166,8 @@ A good movie to watch on the weekend would be The Lion King, since it's a classi

If your GPU supports `bfloat16`, the script will automatically use it.

 

## Tune on your dataset

You can easily train on your own instruction dataset saved in JSON format.
Expand Down
Loading

0 comments on commit efe78a4

Please sign in to comment.