Skip to content

Commit

Permalink
Fine-tuning with label mask
Browse files Browse the repository at this point in the history
- Add support for fine-tuning with a label mask.
- Add a script for preparing Tulu V2 for fine-tuning.
- Add fine-tuning instructions to README.
  • Loading branch information
epwalsh committed Jan 17, 2024
1 parent 45ed078 commit 321c499
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 32 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,19 @@
```
pip install ai2-olmo
```

## Fine-tuning

To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing and saving it to a numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets.

Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory. Make sure the model parameters match up with the model your fine-tuning. To be safe you can always start from the config that comes with the model checkpoint.

Then launch the training job:

```
torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \
--data.paths=[{path_to_data}/input_ids.npy] \
--data.label_mask_paths=[{path_to_data}/label_mask.npy] \
--load_path={path_to_checkpoint} \
--reset_trainer_state
```
19 changes: 7 additions & 12 deletions configs/mcli/mitchish-instruct.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ gpu_type: a100_40gb
integrations:
- integration_type: git_repo
git_repo: allenai/LLM
git_branch: epwalsh/tulu-fine-tune
git_branch: epwalsh/fine-tune-with-label-masking
pip_install: -e .
ssh_clone: true
command: |-
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
learning_rate=2e-6
run_name=mitchish-mcli-2.5T-instruct-${learning_rate}
# NOTE: For some reason getting S3 and R2 authentication working both from the command line and
# from Python proved to be challenging, maybe because Mosaic's server are in Australia.
# In the end I had to use separate methods to get everything working:
Expand Down Expand Up @@ -54,7 +58,6 @@ command: |-
export LOG_FILTER_TYPE=local_rank0_only
# Download checkpoint (everything except optimizer state).
checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded
echo "Downloading checkpoint '${checkpoint}'..."
# Download config.
Expand All @@ -75,24 +78,16 @@ command: |-
# Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3.
rm -rf /root/.aws
# Download data (it's small enough so might as well).
echo "Downloading data..."
aws s3 cp \
s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \
/root/data/data.npy
torchrun \
--master_addr "$MASTER_ADDR" \
--master_port "$MASTER_PORT" \
--nnodes "$NUM_NODES" \
--node_rank "$NODE_RANK" \
--nproc_per_node 8 \
scripts/train.py configs/mitchish-instruct.yaml \
--run_name=mitchish-mcli-2.5T-instruct-2e-6 \
--optimizer.learning_rate=2e-6 \
--run_name=${run_name} \
--optimizer.learning_rate=${learning_rate} \
--save_overwrite \
--time_limit=169200 \
--data.paths=[/root/data/data.npy] \
--save_interval_unsharded=10000 \
--load_path=/root/checkpoint-unsharded \
--reset_optimizer_state \
Expand Down
4 changes: 3 additions & 1 deletion configs/mitchish-instruct.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,6 @@ data:
timeout: 0
generate_attention_mask: true
paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/input_ids.npy
label_mask_paths:
- s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/label_mask.npy
1 change: 1 addition & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class PaddingDirection(StrEnum):
class DataConfig(BaseConfig):
paths: Optional[List[str]] = None
datasets: Optional[Dict[str, List[str]]] = None
label_mask_paths: Optional[List[str]] = None
pad_direction: PaddingDirection = PaddingDirection.right
generate_attention_mask: bool = False
num_workers: int = 0
Expand Down
4 changes: 3 additions & 1 deletion olmo/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, cast

from torch.utils.data import DataLoader, DistributedSampler

from ..aliases import PathOrStr
from ..config import DataConfig, TrainConfig
from ..exceptions import OlmoConfigurationError
from ..torch_util import barrier, get_global_rank, get_world_size
Expand Down Expand Up @@ -39,6 +40,7 @@ def build_memmap_dataset(
include_instance_metadata=include_instance_metadata,
pad_token_id=train_config.model.pad_token_id,
generate_attention_mask=data_config.generate_attention_mask,
label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths),
)


Expand Down
17 changes: 17 additions & 0 deletions olmo/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
all_input_ids = []
all_attention_mask = []
all_attention_bias = []
all_label_mask = []
all_indices = []
all_metadata = []
for x in items:
Expand Down Expand Up @@ -78,6 +79,19 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
)
)

# Pad label mask.
label_mask = x.get("label_mask") if isinstance(x, dict) else None
if label_mask is not None:
if not isinstance(label_mask, torch.Tensor):
label_mask = torch.tensor(label_mask)
all_label_mask.append(
F.pad(
label_mask.to(dtype=torch.bool),
pad_shape,
value=False,
)
)

# Indices.
index = x.get("index") if isinstance(x, dict) else None
if index is not None:
Expand All @@ -93,8 +107,11 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["attention_mask"] = torch.stack(all_attention_mask)
if all_attention_bias:
out["attention_bias"] = torch.stack(all_attention_bias)
if all_label_mask:
out["label_mask"] = torch.stack(all_label_mask)
if all_indices:
out["index"] = torch.stack(all_indices)
if all_metadata:
out["metadata"] = all_metadata

return out
74 changes: 58 additions & 16 deletions olmo/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class MemMapDataset(Dataset[Dict[str, Any]]):
with the same number of items as there are paths.
:param include_instance_metadata: If ``True`` (the default), each instance returned from `__getitem__` will
include the metadata from its source.
:param generate_attention_mask: If ``True``, each instance returned from ``__getitem__`` will include an
attention mask generated by masking each padding token.
:param pad_token_id: The ID of the padding token. Required if ``generate_attention_mask`` is ``True``.
:param label_mask_paths: Optional paths to ``np.bool_`` memory-mapped arrays of label masks.
"""

def __init__(
Expand All @@ -46,21 +50,30 @@ def __init__(
include_instance_metadata: bool = True,
generate_attention_mask: bool = False,
pad_token_id: Optional[int] = None,
label_mask_paths: Optional[List[PathOrStr]] = None,
):
if not paths:
raise ValueError("At least one path is required")

if generate_attention_mask and not pad_token_id:
raise ValueError("'pad_token_id' is required for 'generate_attention_mask'")

if label_mask_paths and len(label_mask_paths) != len(paths):
raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'")

if isinstance(metadata, list):
if len(metadata) != len(paths):
raise ValueError("'metadata' should have the same length as the number of file paths")
else:
metadata = [metadata or {}] * len(paths)

self._memmap_paths = paths
self._metadata = metadata
self._label_mask_paths = label_mask_paths
self._chunk_size = chunk_size
self._mmap_offsets: Optional[List[Tuple[int, int]]] = None
self._num_instances: Optional[int] = None
self.dtype = memmap_dtype
self._item_size = self.dtype(0).itemsize
self._include_instance_metadata = include_instance_metadata
self._generate_attention_mask = generate_attention_mask
self._pad_token_id = pad_token_id
Expand Down Expand Up @@ -89,34 +102,57 @@ def offsets(self) -> List[Tuple[int, int]]:
import concurrent.futures

self._mmap_offsets = []
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for path in self._memmap_paths:
future = executor.submit(self._get_file_length, path)
futures.append(future)

path_to_length: Dict[PathOrStr, int] = {}
for future in concurrent.futures.as_completed(futures):
path_to_length: Dict[PathOrStr, int] = {}
path_to_mask_path: Dict[PathOrStr, PathOrStr] = {}
mask_path_to_length: Dict[PathOrStr, int] = {}

with concurrent.futures.ThreadPoolExecutor() as executor:
path_futures = []
mask_path_futures = []
for i, path in enumerate(self._memmap_paths):
path_futures.append(executor.submit(self._get_file_length, path))
if self._label_mask_paths is not None:
mask_path = self._label_mask_paths[i]
path_to_mask_path[path] = mask_path
mask_path_futures.append(executor.submit(self._get_file_length, mask_path, np.bool_))

for future in concurrent.futures.as_completed(path_futures):
path, length = future.result()
path_to_length[path] = length

for future in concurrent.futures.as_completed(mask_path_futures):
path, length = future.result()
mask_path_to_length[path] = length

start_offset = 0
for path in self._memmap_paths:
length = path_to_length[path]
if mask_path_to_length:
mask_path = path_to_mask_path[path]
if length != mask_path_to_length[mask_path]:
raise ValueError(f"masking file '{mask_path}' should be the same size as '{path}'")
end_offset = start_offset + length
self._mmap_offsets.append((start_offset, end_offset))
start_offset += length
return self._mmap_offsets

def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor:
bytes_start = index * self._item_size * self._chunk_size
num_bytes = self._item_size * self._chunk_size
def _read_chunk_from_memmap(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor:
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
bytes_start = index * item_size * self._chunk_size
num_bytes = item_size * self._chunk_size
buffer = get_bytes_range(path, bytes_start, num_bytes)
array = np.frombuffer(buffer, dtype=self.dtype)
return torch.tensor(array.astype(np.int_), dtype=torch.long)
array = np.frombuffer(buffer, dtype=dtype)
if dtype == np.bool_:
return torch.tensor(array)
else:
return torch.tensor(array.astype(np.int_), dtype=torch.long)

def _get_file_length(self, path) -> Tuple[PathOrStr, int]:
return path, file_size(path) // (self._item_size * self._chunk_size)
def _get_file_length(self, path, dtype=None) -> Tuple[PathOrStr, int]:
dtype = dtype or self.dtype
item_size = dtype(0).itemsize
return path, file_size(path) // (item_size * self._chunk_size)

def __len__(self) -> int:
if self._num_instances is None:
Expand All @@ -141,8 +177,14 @@ def __getitem__(self, index: int) -> Dict[str, Any]:

# Read the data from file.
input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index)

out: Dict[str, Any] = {"input_ids": input_ids}

if self._label_mask_paths is not None:
label_mask = self._read_chunk_from_memmap(
self._label_mask_paths[memmap_index], memmap_local_index, dtype=np.bool_
)
out["label_mask"] = label_mask

if self._include_instance_metadata:
metadata = self._metadata[memmap_index]
out["metadata"] = deepcopy(metadata)
Expand Down
10 changes: 8 additions & 2 deletions olmo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,15 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec

def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, attention_mask = batch["input_ids"], batch.get("attention_mask")
labels, label_mask, attention_mask = (
batch["input_ids"],
batch.get("label_mask"),
batch.get("attention_mask"),
)
if label_mask is not None:
labels.masked_fill_(~label_mask, -100)
if attention_mask is not None:
labels = labels.masked_fill(attention_mask == 0.0, -100)
labels.masked_fill_(attention_mask == 0.0, -100)
return labels[..., 1:].contiguous()

def model_forward(
Expand Down
Loading

0 comments on commit 321c499

Please sign in to comment.