Skip to content

Commit

Permalink
Merge pull request #410 from allenai/epwalsh/fine-tune-with-label-mas…
Browse files Browse the repository at this point in the history
…king

Fine-tuning with label mask
  • Loading branch information
epwalsh authored Jan 19, 2024
2 parents dcae8e8 + f3298c6 commit f36ac42
Show file tree
Hide file tree
Showing 12 changed files with 351 additions and 82 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,27 @@
```
pip install ai2-olmo
```

## Fine-tuning

To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets.

Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line:

- Update `load_path` to point to the checkpoint you want to start from.
- Set `reset_trainer_state` to `true`.
- Update `data.paths` to point to the `token_ids.npy` file you generated.
- Optionally update `data.label_mask_paths` to point to the `label_mask.npy` file you generated, unless you don't need special masking for the loss.
- Update `evaluators` to add/remove in-loop evaluations.

Once you're satisfied with your training config, you can launch the training job via `torchrun`. For example:

```
torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \
--data.paths=[{path_to_data}/input_ids.npy] \
--data.label_mask_paths=[{path_to_data}/label_mask.npy] \
--load_path={path_to_checkpoint} \
--reset_trainer_state
```

Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config.
47 changes: 24 additions & 23 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
run_name: olmo-7b-instruct
name: olmo-7b-instruct
image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04
gpu_num: 64
#gpu_num: 8
#cluster: r12z3
cluster: r7z2
gpu_type: a100_40gb
compute:
#cluster: r12z3
cluster: r7z2
gpus: 64
gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: epwalsh/tulu-fine-tune
git_branch: main
pip_install: -e .
ssh_clone: true
command: |-
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
learning_rate=2e-6
run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-5ep-v2
# NOTE: For some reason getting S3 and R2 authentication working both from the command line and
# from Python proved to be challenging, maybe because Mosaic's server are in Australia.
# In the end I had to use separate methods to get everything working:
Expand All @@ -34,7 +38,6 @@ command: |-
# Prepare environment including AWS config files for both S3 and R2 access.
mkdir -p /root/.cache/torch
mkdir /root/checkpoint-unsharded
mkdir /root/data
mkdir /root/.aws
touch /root/.aws/credentials /root/.aws/config
echo '[s3]' >> /root/.aws/credentials
Expand All @@ -54,7 +57,6 @@ command: |-
export LOG_FILTER_TYPE=local_rank0_only
# Download checkpoint (everything except optimizer state).
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
echo "Downloading checkpoint '${checkpoint}'..."
# Download config.
Expand All @@ -72,31 +74,30 @@ command: |-
--endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
"${checkpoint}/model.pt" /root/checkpoint-unsharded/
# Download optimizer state.
#aws s3 cp --profile=r2 --region=auto \
# --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \
# "${checkpoint}/optim.pt" /root/checkpoint-unsharded/
# Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
rm -rf /root/.aws
# Download data (it's small enough so might as well).
echo "Downloading data..."
aws s3 cp \
s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \
/root/data/data.npy
torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish-instruct.yaml \
--run_name=mitchish-mcli-2.5T-instruct-2e-6 \
--optimizer.learning_rate=2e-6 \
--run_name=${run_name} \
--optimizer.learning_rate=${learning_rate} \
--scheduler.grad_clip_warmup_steps=400 \
--save_overwrite \
--time_limit=169200 \
--data.paths=[/root/data/data.npy] \
--save_interval_unsharded=10000 \
--save_interval_unsharded=100000 \
--load_path=/root/checkpoint-unsharded \
--reset_optimizer_state \
--reset_trainer_state \
--reset_optimizer_state \
--compile=null \
--activation_checkpointing=fine_grained \
--fsdp.wrapping_strategy=size_based
--activation_checkpointing=whole_layer \
--fsdp.wrapping_strategy=size_based \
--max_duration=5ep
46 changes: 6 additions & 40 deletions configs/mitchish-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ compile:
optimizer:
name: adamw
learning_rate: 2e-5
weight_decay: 0.0
weight_decay: 0.1
betas:
- 0.9
- 0.999
- 0.95
metrics_log_interval: 10

scheduler:
name: linear_with_warmup
t_warmup: 100
t_warmup: 200
alpha_f: 0.001

tokenizer:
Expand Down Expand Up @@ -91,42 +91,6 @@ eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
- label: all-small-ppl-validation
data:
num_workers: 0
drop_last: true
# pin_memory: true
# prefetch_factor: 1
# persistent_workers: false
# timeout: 0
datasets:
4chan-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy
c4_100_domains-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy
c4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy
gab-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy
ice-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy
m2d2_s2orc-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy
m2d2_wiki-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy
manosphere-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy
mc4_en-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy
pile-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy
ptb-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy
twitterAEE-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy
wikitext_103-validation:
- s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy

##########################
# Downstream evaluations #
##########################
Expand Down Expand Up @@ -179,4 +143,6 @@ data:
timeout: 0
generate_attention_mask: true
paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/input_ids.npy
label_mask_paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/label_mask.npy
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ class PaddingDirection(StrEnum):
class DataConfig(BaseConfig):
paths: Optional[List[str]] = None
datasets: Optional[Dict[str, List[str]]] = None
label_mask_paths: Optional[List[str]] = None
pad_direction: PaddingDirection = PaddingDirection.right
generate_attention_mask: bool = False
num_workers: int = 0
Expand Down
4 changes: 3 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, cast

from torch.utils.data import DataLoader, DistributedSampler

from ..aliases import PathOrStr
from ..config import DataConfig, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size
Expand Down Expand Up @@ -39,6 +40,7 @@ def build_memmap_dataset(
include_instance_metadata=include_instance_metadata,
pad_token_id=train_config.model.pad_token_id,
generate_attention_mask=data_config.generate_attention_mask,
label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
)


Expand Down
17 changes: 17 additions & 0 deletions olmo/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
all_input_ids = []
all_attention_mask = []
all_attention_bias = []
all_label_mask = []
all_indices = []
all_metadata = []
for x in items:
Expand Down Expand Up @@ -78,6 +79,19 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
)
)

# Pad label mask.
label_mask = x.get("label_mask") if isinstance(x, dict) else None
if label_mask is not None:
if not isinstance(label_mask, torch.Tensor):
label_mask = torch.tensor(label_mask)
all_label_mask.append(
F.pad(
label_mask.to(dtype=torch.bool),
pad_shape,
value=False,
)
)

# Indices.
index = x.get("index") if isinstance(x, dict) else None
if index is not None:
Expand All @@ -93,8 +107,11 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["attention_mask"] = torch.stack(all_attention_mask)
if all_attention_bias:
out["attention_bias"] = torch.stack(all_attention_bias)
if all_label_mask:
out["label_mask"] = torch.stack(all_label_mask)
if all_indices:
out["index"] = torch.stack(all_indices)
if all_metadata:
out["metadata"] = all_metadata

return out
74 changes: 58 additions & 16 deletions olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class MemMapDataset(Dataset[Dict[str, Any]]):
with the same number of items as there are paths.
:param include_instance_metadata: If ``True`` (the default), each instance returned from `__getitem__` will
include the metadata from its source.
:param generate_attention_mask: If ``True``, each instance returned from ``__getitem__`` will include an
attention mask generated by masking each padding token.
:param pad_token_id: The ID of the padding token. Required if ``generate_attention_mask`` is ``True``.
:param label_mask_paths: Optional paths to ``np.bool_`` memory-mapped arrays of label masks.
"""

def __init__(
Expand All @@ -46,21 +50,30 @@ def __init__(
include_instance_metadata: bool = True,
generate_attention_mask: bool = False,
pad_token_id: Optional[int] = None,
label_mask_paths: Optional[List[PathOrStr]] = None,
):
if not paths:
raise ValueError("At least one path is required")

if generate_attention_mask and not pad_token_id:
raise ValueError("'pad_token_id' is required for 'generate_attention_mask'")

if label_mask_paths and len(label_mask_paths) != len(paths):
raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'")

if isinstance(metadata, list):
if len(metadata) != len(paths):
raise ValueError("'metadata' should have the same length as the number of file paths")
else:
metadata = [metadata or {}] * len(paths)

self._memmap_paths = paths
self._metadata = metadata
self._label_mask_paths = label_mask_paths
self._chunk_size = chunk_size
self._mmap_offsets: Optional[List[Tuple[int, int]]] = None
self._num_instances: Optional[int] = None
self.dtype = memmap_dtype
self._item_size = self.dtype(0).itemsize
self._include_instance_metadata = include_instance_metadata
self._generate_attention_mask = generate_attention_mask
self._pad_token_id = pad_token_id
Expand Down Expand Up @@ -89,34 +102,57 @@ def offsets(self) -> List[Tuple[int, int]]:
import concurrent.futures

self._mmap_offsets = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for path in self._memmap_paths:
future = executor.submit(self._get_file_length, path)
futures.append(future)

path_to_length: Dict[PathOrStr, int] = {}
for future in concurrent.futures.as_completed(futures):
path_to_length: Dict[PathOrStr, int] = {}
path_to_mask_path: Dict[PathOrStr, PathOrStr] = {}
mask_path_to_length: Dict[PathOrStr, int] = {}

with concurrent.futures.ThreadPoolExecutor() as executor:
path_futures = []
mask_path_futures = []
for i, path in enumerate(self._memmap_paths):
path_futures.append(executor.submit(self._get_file_length, path))
if self._label_mask_paths is not None:
mask_path = self._label_mask_paths[i]
path_to_mask_path[path] = mask_path
mask_path_futures.append(executor.submit(self._get_file_length, mask_path, np.bool_))

for future in concurrent.futures.as_completed(path_futures):
path, length = future.result()
path_to_length[path] = length

for future in concurrent.futures.as_completed(mask_path_futures):
path, length = future.result()
mask_path_to_length[path] = length

start_offset = 0
for path in self._memmap_paths:
length = path_to_length[path]
if mask_path_to_length:
mask_path = path_to_mask_path[path]
if length != mask_path_to_length[mask_path]:
raise ValueError(f"masking file '{mask_path}' should be the same size as '{path}'")
end_offset = start_offset + length
self._mmap_offsets.append((start_offset, end_offset))
start_offset += length
return self._mmap_offsets

def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor:
bytes_start = index * self._item_size * self._chunk_size
num_bytes = self._item_size * self._chunk_size
def _read_chunk_from_memmap(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor:
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
bytes_start = index * item_size * self._chunk_size
num_bytes = item_size * self._chunk_size
buffer = get_bytes_range(path, bytes_start, num_bytes)
array = np.frombuffer(buffer, dtype=self.dtype)
return torch.tensor(array.astype(np.int_), dtype=torch.long)
array = np.frombuffer(buffer, dtype=dtype)
if dtype == np.bool_:
return torch.tensor(array)
else:
return torch.tensor(array.astype(np.int_), dtype=torch.long)

def _get_file_length(self, path) -> Tuple[PathOrStr, int]:
return path, file_size(path) // (self._item_size * self._chunk_size)
def _get_file_length(self, path, dtype=None) -> Tuple[PathOrStr, int]:
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
return path, file_size(path) // (item_size * self._chunk_size)

def __len__(self) -> int:
if self._num_instances is None:
Expand All @@ -141,8 +177,14 @@ def __getitem__(self, index: int) -> Dict[str, Any]:

# Read the data from file.
input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index)

out: Dict[str, Any] = {"input_ids": input_ids}

if self._label_mask_paths is not None:
label_mask = self._read_chunk_from_memmap(
self._label_mask_paths[memmap_index], memmap_local_index, dtype=np.bool_
)
out["label_mask"] = label_mask

if self._include_instance_metadata:
metadata = self._metadata[memmap_index]
out["metadata"] = deepcopy(metadata)
Expand Down
Loading

0 comments on commit f36ac42

Please sign in to comment.