Skip to content

Commit

Permalink
Merge branch 'mitchish65-2' of https://github.com/allenai/LLM into mi…
Browse files Browse the repository at this point in the history
…tchish65-2
  • Loading branch information
dirkgr committed Jan 25, 2024
2 parents 98e93e4 + 1690d8b commit 0348696
Show file tree
Hide file tree
Showing 13 changed files with 374 additions and 95 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,42 @@

## Installation

First install [PyTorch](https://pytorch.org) according to the instructions specific to your operating system.

To install from source (recommended for training/fine-tuning) run:

```bash
git clone https://github.com/allenai/OLMo.git
cd OLMo
pip install -e .
```

Otherwise you can install the model code by itself directly from PyPI with:

```bash
pip install ai2-olmo
```

## Fine-tuning

To fine-tune an OLMo model using our trainer you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets.

Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line:

- Update `load_path` to point to the checkpoint you want to start from.
- Set `reset_trainer_state` to `true`.
- Update `data.paths` to point to the `token_ids.npy` file you generated.
- Optionally update `data.label_mask_paths` to point to the `label_mask.npy` file you generated, unless you don't need special masking for the loss.
- Update `evaluators` to add/remove in-loop evaluations.

Once you're satisfied with your training config, you can launch the training job via `torchrun`. For example:

```
torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \
--data.paths=[{path_to_data}/input_ids.npy] \
--data.label_mask_paths=[{path_to_data}/label_mask.npy] \
--load_path={path_to_checkpoint} \
--reset_trainer_state
```

Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config.
47 changes: 24 additions & 23 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
run_name: olmo-7b-instruct
name: olmo-7b-instruct
image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
gpu_num: 64
#gpu_num: 8
#cluster: r12z3
cluster: r7z2
gpu_type: a100_40gb
compute:
#cluster: r12z3
cluster: r7z2
gpus: 64
gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: epwalsh/tulu-fine-tune
git_branch: main
pip_install: -e .
ssh_clone: true
command: |-
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
learning_rate=2e-6
run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-5ep-v2
# NOTE: For some reason getting S3 and R2 authentication working both from the command line and
# from Python proved to be challenging, maybe because Mosaic's server are in Australia.
# In the end I had to use separate methods to get everything working:
Expand All @@ -34,7 +38,6 @@ command: |-
# Prepare environment including AWS config files for both S3 and R2 access.
mkdir -p /root/.cache/torch
mkdir /root/checkpoint-unsharded
mkdir /root/data
mkdir /root/.aws
touch /root/.aws/credentials /root/.aws/config
echo '[s3]' >> /root/.aws/credentials
Expand All @@ -54,7 +57,6 @@ command: |-
export LOG_FILTER_TYPE=local_rank0_only
# Download checkpoint (everything except optimizer state).
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
echo "Downloading checkpoint '${checkpoint}'..."
# Download config.
Expand All @@ -72,31 +74,30 @@ command: |-
--endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
"${checkpoint}/model.pt" /root/checkpoint-unsharded/
# Download optimizer state.
#aws s3 cp --profile=r2 --region=auto \
# --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
# "${checkpoint}/optim.pt" /root/checkpoint-unsharded/
# Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
rm -rf /root/.aws
# Download data (it's small enough so might as well).
echo "Downloading data..."
aws s3 cp \
s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \
/root/data/data.npy
torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish-instruct.yaml \
--run_name=mitchish-mcli-2.5T-instruct-2e-6 \
--optimizer.learning_rate=2e-6 \
--run_name=${run_name} \
--optimizer.learning_rate=${learning_rate} \
--scheduler.grad_clip_warmup_steps=400 \
--save_overwrite \
--time_limit=169200 \
--data.paths=[/root/data/data.npy] \
--save_interval_unsharded=10000 \
--save_interval_unsharded=100000 \
--load_path=/root/checkpoint-unsharded \
--reset_optimizer_state \
--reset_trainer_state \
--reset_optimizer_state \
--compile=null \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based
--activation_checkpointing=whole_layer \
--fsdp.wrapping_strategy=size_based \
--max_duration=5ep
46 changes: 6 additions & 40 deletions configs/mitchish-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ compile:
optimizer:
name: adamw
learning_rate: 2e-5
weight_decay: 0.0
weight_decay: 0.1
betas:
- 0.9
- 0.999
- 0.95
metrics_log_interval: 10

scheduler:
name: linear_with_warmup
t_warmup: 100
t_warmup: 200
alpha_f: 0.001

tokenizer:
Expand Down Expand Up @@ -91,42 +91,6 @@ eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
- label: all-small-ppl-validation
data:
num_workers: 0
drop_last: true
# pin_memory: true
# prefetch_factor: 1
# persistent_workers: false
# timeout: 0
datasets:
4chan-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy
c4_100_domains-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy
c4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy
gab-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy
ice-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy
m2d2_s2orc-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy
m2d2_wiki-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy
manosphere-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy
mc4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy
pile-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy
ptb-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy
twitterAEE-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy
wikitext_103-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy

##########################
# Downstream evaluations #
##########################
Expand Down Expand Up @@ -179,4 +143,6 @@ data:
timeout: 0
generate_attention_mask: true
paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/input_ids.npy
label_mask_paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/label_mask.npy
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ class PaddingDirection(StrEnum):
class DataConfig(BaseConfig):
paths: Optional[List[str]] = None
datasets: Optional[Dict[str, List[str]]] = None
label_mask_paths: Optional[List[str]] = None
pad_direction: PaddingDirection = PaddingDirection.right
generate_attention_mask: bool = False
num_workers: int = 0
Expand Down
4 changes: 3 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, cast

from torch.utils.data import DataLoader, DistributedSampler

from ..aliases import PathOrStr
from ..config import DataConfig, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size
Expand Down Expand Up @@ -39,6 +40,7 @@ def build_memmap_dataset(
include_instance_metadata=include_instance_metadata,
pad_token_id=train_config.model.pad_token_id,
generate_attention_mask=data_config.generate_attention_mask,
label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
)


Expand Down
17 changes: 17 additions & 0 deletions olmo/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
all_input_ids = []
all_attention_mask = []
all_attention_bias = []
all_label_mask = []
all_indices = []
all_metadata = []
for x in items:
Expand Down Expand Up @@ -78,6 +79,19 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
)
)

# Pad label mask.
label_mask = x.get("label_mask") if isinstance(x, dict) else None
if label_mask is not None:
if not isinstance(label_mask, torch.Tensor):
label_mask = torch.tensor(label_mask)
all_label_mask.append(
F.pad(
label_mask.to(dtype=torch.bool),
pad_shape,
value=False,
)
)

# Indices.
index = x.get("index") if isinstance(x, dict) else None
if index is not None:
Expand All @@ -93,8 +107,11 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["attention_mask"] = torch.stack(all_attention_mask)
if all_attention_bias:
out["attention_bias"] = torch.stack(all_attention_bias)
if all_label_mask:
out["label_mask"] = torch.stack(all_label_mask)
if all_indices:
out["index"] = torch.stack(all_indices)
if all_metadata:
out["metadata"] = all_metadata

return out
Loading

0 comments on commit 0348696

Please sign in to comment.