From 321c4994e73b7383189d25efd466669b9839cb77 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 16:25:24 -0800 Subject: [PATCH 01/13] Fine-tuning with label mask - Add support for fine-tuning with a label mask. - Add a script for preparing Tulu V2 for fine-tuning. - Add fine-tuning instructions to README. --- README.md | 16 +++++ configs/mcli/mitchish-instruct.yml | 19 ++--- configs/mitchish-instruct.yaml | 4 +- olmo/config.py | 1 + olmo/data/__init__.py | 4 +- olmo/data/collator.py | 17 +++++ olmo/data/memmap_dataset.py | 74 ++++++++++++++----- olmo/train.py | 10 ++- scripts/prepare_tulu_data.py | 111 +++++++++++++++++++++++++++++ tests/data/collator_test.py | 37 ++++++++++ tests/data/memmap_dataset_test.py | 34 +++++++++ 11 files changed, 295 insertions(+), 32 deletions(-) create mode 100644 scripts/prepare_tulu_data.py diff --git a/README.md b/README.md index dde1f8e68..d17a777eb 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,19 @@ ``` pip install ai2-olmo ``` + +## Fine-tuning + +To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing and saving it to a numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. + +Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory. Make sure the model parameters match up with the model your fine-tuning. To be safe you can always start from the config that comes with the model checkpoint. + +Then launch the training job: + +``` +torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ + --data.paths=[{path_to_data}/input_ids.npy] \ + --data.label_mask_paths=[{path_to_data}/label_mask.npy] \ + --load_path={path_to_checkpoint} \ + --reset_trainer_state +``` diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 66c8d3bd7..1e2bb268c 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -8,10 +8,14 @@ gpu_type: a100_40gb integrations: - integration_type: git_repo git_repo: allenai/LLM - git_branch: epwalsh/tulu-fine-tune + git_branch: epwalsh/fine-tune-with-label-masking pip_install: -e . ssh_clone: true command: |- + checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded + learning_rate=2e-6 + run_name=mitchish-mcli-2.5T-instruct-${learning_rate} + # NOTE: For some reason getting S3 and R2 authentication working both from the command line and # from Python proved to be challenging, maybe because Mosaic's server are in Australia. # In the end I had to use separate methods to get everything working: @@ -54,7 +58,6 @@ command: |- export LOG_FILTER_TYPE=local_rank0_only # Download checkpoint (everything except optimizer state). - checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded echo "Downloading checkpoint '${checkpoint}'..." # Download config. @@ -75,12 +78,6 @@ command: |- # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3. rm -rf /root/.aws - # Download data (it's small enough so might as well). - echo "Downloading data..." - aws s3 cp \ - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy \ - /root/data/data.npy - torchrun \ --master_addr "$MASTER_ADDR" \ --master_port "$MASTER_PORT" \ @@ -88,11 +85,9 @@ command: |- --node_rank "$NODE_RANK" \ --nproc_per_node 8 \ scripts/train.py configs/mitchish-instruct.yaml \ - --run_name=mitchish-mcli-2.5T-instruct-2e-6 \ - --optimizer.learning_rate=2e-6 \ + --run_name=${run_name} \ + --optimizer.learning_rate=${learning_rate} \ --save_overwrite \ - --time_limit=169200 \ - --data.paths=[/root/data/data.npy] \ --save_interval_unsharded=10000 \ --load_path=/root/checkpoint-unsharded \ --reset_optimizer_state \ diff --git a/configs/mitchish-instruct.yaml b/configs/mitchish-instruct.yaml index a21f1ade6..7b52c2558 100644 --- a/configs/mitchish-instruct.yaml +++ b/configs/mitchish-instruct.yaml @@ -179,4 +179,6 @@ data: timeout: 0 generate_attention_mask: true paths: - - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/data.npy + - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/input_ids.npy + label_mask_paths: + - s3://ai2-llm/preprocessed/tulu-v2-fine-tune/gpt-neox-20b-pii-special/label_mask.npy diff --git a/olmo/config.py b/olmo/config.py index b4b0576f9..387c63a2d 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -509,6 +509,7 @@ class PaddingDirection(StrEnum): class DataConfig(BaseConfig): paths: Optional[List[str]] = None datasets: Optional[Dict[str, List[str]]] = None + label_mask_paths: Optional[List[str]] = None pad_direction: PaddingDirection = PaddingDirection.right generate_attention_mask: bool = False num_workers: int = 0 diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index bc08ff863..52421b57a 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -1,8 +1,9 @@ from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, cast from torch.utils.data import DataLoader, DistributedSampler +from ..aliases import PathOrStr from ..config import DataConfig, TrainConfig from ..exceptions import OlmoConfigurationError from ..torch_util import barrier, get_global_rank, get_world_size @@ -39,6 +40,7 @@ def build_memmap_dataset( include_instance_metadata=include_instance_metadata, pad_token_id=train_config.model.pad_token_id, generate_attention_mask=data_config.generate_attention_mask, + label_mask_paths=cast(Optional[List[PathOrStr]], data_config.label_mask_paths), ) diff --git a/olmo/data/collator.py b/olmo/data/collator.py index 2d81d271e..d86a0b9af 100644 --- a/olmo/data/collator.py +++ b/olmo/data/collator.py @@ -26,6 +26,7 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di all_input_ids = [] all_attention_mask = [] all_attention_bias = [] + all_label_mask = [] all_indices = [] all_metadata = [] for x in items: @@ -78,6 +79,19 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di ) ) + # Pad label mask. + label_mask = x.get("label_mask") if isinstance(x, dict) else None + if label_mask is not None: + if not isinstance(label_mask, torch.Tensor): + label_mask = torch.tensor(label_mask) + all_label_mask.append( + F.pad( + label_mask.to(dtype=torch.bool), + pad_shape, + value=False, + ) + ) + # Indices. index = x.get("index") if isinstance(x, dict) else None if index is not None: @@ -93,8 +107,11 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di out["attention_mask"] = torch.stack(all_attention_mask) if all_attention_bias: out["attention_bias"] = torch.stack(all_attention_bias) + if all_label_mask: + out["label_mask"] = torch.stack(all_label_mask) if all_indices: out["index"] = torch.stack(all_indices) if all_metadata: out["metadata"] = all_metadata + return out diff --git a/olmo/data/memmap_dataset.py b/olmo/data/memmap_dataset.py index 69f2d85b9..5af73c277 100644 --- a/olmo/data/memmap_dataset.py +++ b/olmo/data/memmap_dataset.py @@ -35,6 +35,10 @@ class MemMapDataset(Dataset[Dict[str, Any]]): with the same number of items as there are paths. :param include_instance_metadata: If ``True`` (the default), each instance returned from `__getitem__` will include the metadata from its source. + :param generate_attention_mask: If ``True``, each instance returned from ``__getitem__`` will include an + attention mask generated by masking each padding token. + :param pad_token_id: The ID of the padding token. Required if ``generate_attention_mask`` is ``True``. + :param label_mask_paths: Optional paths to ``np.bool_`` memory-mapped arrays of label masks. """ def __init__( @@ -46,21 +50,30 @@ def __init__( include_instance_metadata: bool = True, generate_attention_mask: bool = False, pad_token_id: Optional[int] = None, + label_mask_paths: Optional[List[PathOrStr]] = None, ): if not paths: raise ValueError("At least one path is required") + + if generate_attention_mask and not pad_token_id: + raise ValueError("'pad_token_id' is required for 'generate_attention_mask'") + + if label_mask_paths and len(label_mask_paths) != len(paths): + raise ValueError("There must be the same number of 'label_mask_paths' as there are 'paths'") + if isinstance(metadata, list): if len(metadata) != len(paths): raise ValueError("'metadata' should have the same length as the number of file paths") else: metadata = [metadata or {}] * len(paths) + self._memmap_paths = paths self._metadata = metadata + self._label_mask_paths = label_mask_paths self._chunk_size = chunk_size self._mmap_offsets: Optional[List[Tuple[int, int]]] = None self._num_instances: Optional[int] = None self.dtype = memmap_dtype - self._item_size = self.dtype(0).itemsize self._include_instance_metadata = include_instance_metadata self._generate_attention_mask = generate_attention_mask self._pad_token_id = pad_token_id @@ -89,34 +102,57 @@ def offsets(self) -> List[Tuple[int, int]]: import concurrent.futures self._mmap_offsets = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for path in self._memmap_paths: - future = executor.submit(self._get_file_length, path) - futures.append(future) - path_to_length: Dict[PathOrStr, int] = {} - for future in concurrent.futures.as_completed(futures): + path_to_length: Dict[PathOrStr, int] = {} + path_to_mask_path: Dict[PathOrStr, PathOrStr] = {} + mask_path_to_length: Dict[PathOrStr, int] = {} + + with concurrent.futures.ThreadPoolExecutor() as executor: + path_futures = [] + mask_path_futures = [] + for i, path in enumerate(self._memmap_paths): + path_futures.append(executor.submit(self._get_file_length, path)) + if self._label_mask_paths is not None: + mask_path = self._label_mask_paths[i] + path_to_mask_path[path] = mask_path + mask_path_futures.append(executor.submit(self._get_file_length, mask_path, np.bool_)) + + for future in concurrent.futures.as_completed(path_futures): path, length = future.result() path_to_length[path] = length + for future in concurrent.futures.as_completed(mask_path_futures): + path, length = future.result() + mask_path_to_length[path] = length + start_offset = 0 for path in self._memmap_paths: length = path_to_length[path] + if mask_path_to_length: + mask_path = path_to_mask_path[path] + if length != mask_path_to_length[mask_path]: + raise ValueError(f"masking file '{mask_path}' should be the same size as '{path}'") end_offset = start_offset + length self._mmap_offsets.append((start_offset, end_offset)) start_offset += length return self._mmap_offsets - def _read_chunk_from_memmap(self, path: PathOrStr, index: int) -> torch.Tensor: - bytes_start = index * self._item_size * self._chunk_size - num_bytes = self._item_size * self._chunk_size + def _read_chunk_from_memmap(self, path: PathOrStr, index: int, dtype=None) -> torch.Tensor: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + bytes_start = index * item_size * self._chunk_size + num_bytes = item_size * self._chunk_size buffer = get_bytes_range(path, bytes_start, num_bytes) - array = np.frombuffer(buffer, dtype=self.dtype) - return torch.tensor(array.astype(np.int_), dtype=torch.long) + array = np.frombuffer(buffer, dtype=dtype) + if dtype == np.bool_: + return torch.tensor(array) + else: + return torch.tensor(array.astype(np.int_), dtype=torch.long) - def _get_file_length(self, path) -> Tuple[PathOrStr, int]: - return path, file_size(path) // (self._item_size * self._chunk_size) + def _get_file_length(self, path, dtype=None) -> Tuple[PathOrStr, int]: + dtype = dtype or self.dtype + item_size = dtype(0).itemsize + return path, file_size(path) // (item_size * self._chunk_size) def __len__(self) -> int: if self._num_instances is None: @@ -141,8 +177,14 @@ def __getitem__(self, index: int) -> Dict[str, Any]: # Read the data from file. input_ids = self._read_chunk_from_memmap(self._memmap_paths[memmap_index], memmap_local_index) - out: Dict[str, Any] = {"input_ids": input_ids} + + if self._label_mask_paths is not None: + label_mask = self._read_chunk_from_memmap( + self._label_mask_paths[memmap_index], memmap_local_index, dtype=np.bool_ + ) + out["label_mask"] = label_mask + if self._include_instance_metadata: metadata = self._metadata[memmap_index] out["metadata"] = deepcopy(metadata) diff --git a/olmo/train.py b/olmo/train.py index 2207b0552..2dfd48c95 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -470,9 +470,15 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). - labels, attention_mask = batch["input_ids"], batch.get("attention_mask") + labels, label_mask, attention_mask = ( + batch["input_ids"], + batch.get("label_mask"), + batch.get("attention_mask"), + ) + if label_mask is not None: + labels.masked_fill_(~label_mask, -100) if attention_mask is not None: - labels = labels.masked_fill(attention_mask == 0.0, -100) + labels.masked_fill_(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() def model_forward( diff --git a/scripts/prepare_tulu_data.py b/scripts/prepare_tulu_data.py new file mode 100644 index 000000000..65d74de2b --- /dev/null +++ b/scripts/prepare_tulu_data.py @@ -0,0 +1,111 @@ +""" +Script for preparing the Tulu V2 data for fine-tuning an OLMo model. +""" + +import logging +from argparse import ArgumentParser +from functools import partial +from pathlib import Path + +import datasets as ds +import numpy as np +from rich.progress import track + +from olmo.tokenizer import Tokenizer +from olmo.util import prepare_cli_environment + +log = logging.getLogger(__name__) + + +def main(opts) -> None: + tokenizer: Tokenizer + if Path(opts.tokenizer).is_file(): + tokenizer = Tokenizer.from_file(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) + else: + tokenizer = Tokenizer.from_pretrained(opts.tokenizer, eos_token_id=opts.eos, pad_token_id=opts.pad) + + dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") + + log.info("Tokenizing dataset...") + preprocessed = dataset.map( + partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), + batched=False, + remove_columns=["dataset", "id", "messages"], + num_proc=opts.num_proc, # type: ignore + ) + + log.info("Counting tokens...") + total_tokens = 0 + for ex in track(preprocessed): + assert len(ex["input_ids"]) == opts.seq_len # type: ignore + total_tokens += len(ex["input_ids"]) # type: ignore + log.info(f"{total_tokens:,d}") + + log.info(f"Saving results to '{opts.output_dir}'...") + output_dir = Path(opts.output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + input_ids_file = np.memmap( + str(output_dir / "input_ids.npy"), dtype=np.uint16, mode="w+", shape=(total_tokens,) + ) + label_mask_file = np.memmap( + str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,) + ) + offset = 0 + for ex in track(preprocessed): + ex_len = len(ex["input_ids"]) # type: ignore + input_ids_file[offset : offset + ex_len] = ex["input_ids"] # type: ignore + label_mask_file[offset : offset + ex_len] = ex["label_mask"] # type: ignore + offset += ex_len + input_ids_file.flush() + label_mask_file.flush() + + log.info("Done!") + + +def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): + parts = [] + for msg in example["messages"]: + parts.append(f"<|{msg['role']}|>") + parts.append(msg["content"]) + + prompt = "\n".join(parts[:-1]) + "\n" + completion = parts[-1] + + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + completion_ids = tokenizer.encode(completion, add_special_tokens=True) + + input_ids = (prompt_ids + completion_ids)[:max_seq_len] + label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len] + + if len(input_ids) < max_seq_len: + pad_len = max_seq_len - len(input_ids) + input_ids += [tokenizer.pad_token_id] * pad_len + label_mask += [False] * pad_len + + assert len(input_ids) == len(label_mask) + + return {"input_ids": input_ids, "label_mask": label_mask} + + +def get_parser() -> ArgumentParser: + parser = ArgumentParser(description="Prepare Tulu V2 dataset") + parser.add_argument("output_dir", type=str, help="""Directory to save the results to.""") + parser.add_argument( + "-t", + "--tokenizer", + type=str, + help="""Tokenizer path or identifier.""", + default="tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json", + ) + parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048) + parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=0) + parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1) + parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8) + return parser + + +if __name__ == "__main__": + prepare_cli_environment() + opts = get_parser().parse_args() + main(opts) diff --git a/tests/data/collator_test.py b/tests/data/collator_test.py index 2279570d5..e94451313 100644 --- a/tests/data/collator_test.py +++ b/tests/data/collator_test.py @@ -92,3 +92,40 @@ def test_collate_with_attention_bias(train_config, pad_direction): ] ) ).all() + + +@pytest.mark.parametrize( + "pad_direction", + [pytest.param(PaddingDirection.right, id="pad-right"), pytest.param(PaddingDirection.left, id="pad-left")], +) +def test_collate_with_label_mask(train_config, pad_direction): + train_config.data.pad_direction = pad_direction + collator = DataCollator.from_train_config(train_config) + + inputs = [ + { + "input_ids": torch.tensor([0, 1, 2, 3]), + "label_mask": torch.tensor([True, False, True, True]), + }, + { + "input_ids": torch.tensor([4, 5, 6]), + "label_mask": torch.tensor([True, True, False]), + }, + ] + batch = collator(inputs) # type: ignore + assert batch["label_mask"] is not None + assert batch["label_mask"].shape == (2, 4) + if pad_direction == "right": + assert ( + batch["label_mask"] + == torch.tensor( + [[True, False, True, True], [True, True, False, False]], + ) + ).all() + else: + assert ( + batch["label_mask"] + == torch.tensor( + [[True, False, True, True], [False, True, True, False]], + ) + ).all() diff --git a/tests/data/memmap_dataset_test.py b/tests/data/memmap_dataset_test.py index 85a3fb0cc..e267043ee 100644 --- a/tests/data/memmap_dataset_test.py +++ b/tests/data/memmap_dataset_test.py @@ -22,6 +22,40 @@ def test_mmap_dataset(tmp_path: Path): assert ds[7]["input_ids"].tolist() == [28, 29, 30, 31] +def test_mmap_dataset_with_label_mask(tmp_path: Path): + mmap1 = np.memmap(tmp_path / "mmap1.npy", mode="w+", dtype=np.uint16, shape=(16,)) + mmap1[:] = np.array(list(range(16)), dtype=np.uint16) + mmap1.flush() + + mask1 = [True] * 16 + mask1[1] = False + mask_mmap1 = np.memmap(tmp_path / "mask_mmap1.npy", mode="w+", dtype=np.bool_, shape=(16,)) + mask_mmap1[:] = np.array(mask1, dtype=np.bool_) + mask_mmap1.flush() + + mmap2 = np.memmap(tmp_path / "mmap2.npy", mode="w+", dtype=np.uint16, shape=(16,)) + mmap2[:] = np.array(list(range(16, 32)), dtype=np.uint16) + mmap2.flush() + + mask2 = [True] * 16 + mask2[-1] = False + mask_mmap2 = np.memmap(tmp_path / "mask_mmap2.npy", mode="w+", dtype=np.bool_, shape=(16,)) + mask_mmap2[:] = np.array(mask2, dtype=np.bool_) + mask_mmap2.flush() + + ds = MemMapDataset( + tmp_path / "mmap1.npy", + tmp_path / "mmap2.npy", + chunk_size=4, + label_mask_paths=[tmp_path / "mask_mmap1.npy", tmp_path / "mask_mmap2.npy"], + ) + assert ds[0]["input_ids"].tolist() == [0, 1, 2, 3] + assert ds[0]["label_mask"].tolist() == [True, False, True, True] + assert ds[1]["input_ids"].tolist() == [4, 5, 6, 7] + assert ds[7]["input_ids"].tolist() == [28, 29, 30, 31] + assert ds[7]["label_mask"].tolist() == [True, True, True, False] + + def test_mmap_dataset_with_metadata(tokenizer: Tokenizer, tmp_path: Path, lorem_ipsum_docs: List[str]): chunk_size = 24 From 5c3aa571d9fba17577a979c3c2bb09df8a70f044 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 17:57:29 -0800 Subject: [PATCH 02/13] Unmask all assistant messages, update format --- olmo/tokenizer.py | 8 ++++++++ scripts/prepare_tulu_data.py | 36 +++++++++++++++++++++++------------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/olmo/tokenizer.py b/olmo/tokenizer.py index b6b934839..a833d3c21 100644 --- a/olmo/tokenizer.py +++ b/olmo/tokenizer.py @@ -44,6 +44,14 @@ def __init__( def vocab_size(self) -> int: return self.base_tokenizer.get_vocab_size() + @property + def eos_token(self) -> str: + return self.decode([self.eos_token_id], skip_special_tokens=False) + + @property + def pad_token(self) -> str: + return self.decode([self.pad_token_id], skip_special_tokens=False) + @classmethod def from_train_config(cls, config: TrainConfig) -> Tokenizer: tokenizer_identifier = config.tokenizer.identifier diff --git a/scripts/prepare_tulu_data.py b/scripts/prepare_tulu_data.py index 65d74de2b..934309772 100644 --- a/scripts/prepare_tulu_data.py +++ b/scripts/prepare_tulu_data.py @@ -64,19 +64,29 @@ def main(opts) -> None: def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): - parts = [] - for msg in example["messages"]: - parts.append(f"<|{msg['role']}|>") - parts.append(msg["content"]) - - prompt = "\n".join(parts[:-1]) + "\n" - completion = parts[-1] + input_ids = [tokenizer.eos_token_id] + label_mask = [False] - prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) - completion_ids = tokenizer.encode(completion, add_special_tokens=True) - - input_ids = (prompt_ids + completion_ids)[:max_seq_len] - label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len] + for msg in example["messages"]: + role_tokens = tokenizer.encode(f"<|{msg['role']}|>\n", add_special_tokens=False) + label_mask += [False] * len(role_tokens) + input_ids += role_tokens + + if msg["role"] == "assistant": + content_tokens = tokenizer.encode( + msg["content"].strip() + tokenizer.eos_token + "\n", add_special_tokens=False + ) + label_mask += [True] * len(content_tokens) + # mask out the last '\n' + assert content_tokens[-2] == tokenizer.eos_token_id + label_mask[-1] = False + else: + content_tokens = tokenizer.encode(msg["content"].strip() + "\n", add_special_tokens=False) + label_mask += [False] * len(content_tokens) + input_ids += content_tokens + + input_ids = input_ids[:max_seq_len] + label_mask = label_mask[:max_seq_len] if len(input_ids) < max_seq_len: pad_len = max_seq_len - len(input_ids) @@ -99,7 +109,7 @@ def get_parser() -> ArgumentParser: default="tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json", ) parser.add_argument("-s", "--seq-len", type=int, help="""Max sequence length.""", default=2048) - parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=0) + parser.add_argument("--eos", type=int, help="""EOS token ID.""", default=50279) parser.add_argument("--pad", type=int, help="""PAD token ID.""", default=1) parser.add_argument("-j", "--num-proc", type=int, help="""Number of workers.""", default=8) return parser From b36f8fbdaeb248e74a1212712475551f7860d5de Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 19:22:13 -0800 Subject: [PATCH 03/13] fix --- configs/mcli/mitchish-instruct.yml | 14 +++++++------- olmo/train.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 1e2bb268c..49a155f78 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -1,10 +1,10 @@ -run_name: olmo-7b-instruct +name: olmo-7b-instruct image: mosaicml/pytorch:2.1.0_cu121-python3.10-ubuntu20.04 -gpu_num: 64 -#gpu_num: 8 -#cluster: r12z3 -cluster: r7z2 -gpu_type: a100_40gb +compute: + #cluster: r12z3 + cluster: r7z2 + gpus: 64 + gpu_type: a100_40gb integrations: - integration_type: git_repo git_repo: allenai/LLM @@ -14,7 +14,7 @@ integrations: command: |- checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded learning_rate=2e-6 - run_name=mitchish-mcli-2.5T-instruct-${learning_rate} + run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-v2 # NOTE: For some reason getting S3 and R2 authentication working both from the command line and # from Python proved to be challenging, maybe because Mosaic's server are in Australia. diff --git a/olmo/train.py b/olmo/train.py index 2dfd48c95..dc81a08ec 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -476,9 +476,9 @@ def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: batch.get("attention_mask"), ) if label_mask is not None: - labels.masked_fill_(~label_mask, -100) + labels = labels.masked_fill(~label_mask, -100) if attention_mask is not None: - labels.masked_fill_(attention_mask == 0.0, -100) + labels = labels.masked_fill(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() def model_forward( From 829f090a2fe699a788017bf18847a252573e47b7 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 19:42:06 -0800 Subject: [PATCH 04/13] Filter out examples that don't have assistant --- scripts/prepare_tulu_data.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/scripts/prepare_tulu_data.py b/scripts/prepare_tulu_data.py index 934309772..b0cc884e5 100644 --- a/scripts/prepare_tulu_data.py +++ b/scripts/prepare_tulu_data.py @@ -26,8 +26,11 @@ def main(opts) -> None: dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") + log.info("Filtering dataset...") + dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore + log.info("Tokenizing dataset...") - preprocessed = dataset.map( + dataset = dataset.map( partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), batched=False, remove_columns=["dataset", "id", "messages"], @@ -36,7 +39,7 @@ def main(opts) -> None: log.info("Counting tokens...") total_tokens = 0 - for ex in track(preprocessed): + for ex in track(dataset): assert len(ex["input_ids"]) == opts.seq_len # type: ignore total_tokens += len(ex["input_ids"]) # type: ignore log.info(f"{total_tokens:,d}") @@ -52,7 +55,7 @@ def main(opts) -> None: str(output_dir / "label_mask.npy"), dtype=np.bool_, mode="w+", shape=(total_tokens,) ) offset = 0 - for ex in track(preprocessed): + for ex in track(dataset): ex_len = len(ex["input_ids"]) # type: ignore input_ids_file[offset : offset + ex_len] = ex["input_ids"] # type: ignore label_mask_file[offset : offset + ex_len] = ex["label_mask"] # type: ignore @@ -63,6 +66,13 @@ def main(opts) -> None: log.info("Done!") +def filter(example): + for msg in example["messages"]: + if msg["role"] == "assistant": + return True + return False + + def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): input_ids = [tokenizer.eos_token_id] label_mask = [False] From 437c838c98755642d1e30db2b129db3d2bebe2d3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 19:42:17 -0800 Subject: [PATCH 05/13] Clone then in place ops --- olmo/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index dc81a08ec..038ece47c 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -471,14 +471,14 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor: # Labels are just input IDs shifted to the left (first item is ignored). labels, label_mask, attention_mask = ( - batch["input_ids"], + batch["input_ids"].clone(), batch.get("label_mask"), batch.get("attention_mask"), ) if label_mask is not None: - labels = labels.masked_fill(~label_mask, -100) + labels.masked_fill_(~label_mask, -100) if attention_mask is not None: - labels = labels.masked_fill(attention_mask == 0.0, -100) + labels.masked_fill_(attention_mask == 0.0, -100) return labels[..., 1:].contiguous() def model_forward( From 8514a82ea63ed7d134f81e1413eab896efeb68f0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 16 Jan 2024 21:37:36 -0800 Subject: [PATCH 06/13] Fix filter --- scripts/prepare_tulu_data.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/prepare_tulu_data.py b/scripts/prepare_tulu_data.py index b0cc884e5..4eba35945 100644 --- a/scripts/prepare_tulu_data.py +++ b/scripts/prepare_tulu_data.py @@ -26,9 +26,6 @@ def main(opts) -> None: dataset = ds.load_dataset("allenai/tulu-v2-sft-mixture", split="train") - log.info("Filtering dataset...") - dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore - log.info("Tokenizing dataset...") dataset = dataset.map( partial(preprocess, tokenizer=tokenizer, max_seq_len=opts.seq_len), @@ -37,12 +34,17 @@ def main(opts) -> None: num_proc=opts.num_proc, # type: ignore ) + log.info("Filtering dataset...") + n = len(dataset) # type: ignore + dataset = dataset.filter(filter, batched=False, num_proc=opts.num_proc) # type: ignore + log.info(f"Filtered out {n - len(dataset):,d} examples") + log.info("Counting tokens...") total_tokens = 0 for ex in track(dataset): assert len(ex["input_ids"]) == opts.seq_len # type: ignore total_tokens += len(ex["input_ids"]) # type: ignore - log.info(f"{total_tokens:,d}") + log.info(f"Total tokens: {total_tokens:,d}") log.info(f"Saving results to '{opts.output_dir}'...") output_dir = Path(opts.output_dir) @@ -67,10 +69,7 @@ def main(opts) -> None: def filter(example): - for msg in example["messages"]: - if msg["role"] == "assistant": - return True - return False + return example["n_labels"] > 0 def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): @@ -104,8 +103,9 @@ def preprocess(example, tokenizer: Tokenizer, max_seq_len: int): label_mask += [False] * pad_len assert len(input_ids) == len(label_mask) + n_labels = sum(label_mask) - return {"input_ids": input_ids, "label_mask": label_mask} + return {"input_ids": input_ids, "label_mask": label_mask, "n_labels": n_labels} def get_parser() -> ArgumentParser: From 2899798955a009e21d7504adb347439a0cf61c87 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Wed, 17 Jan 2024 10:20:30 -0800 Subject: [PATCH 07/13] Update README --- README.md | 14 +++++++++++--- configs/mcli/mitchish-instruct.yml | 2 +- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d17a777eb..48a37ba72 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,17 @@ pip install ai2-olmo ## Fine-tuning -To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing and saving it to a numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. +To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. -Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory. Make sure the model parameters match up with the model your fine-tuning. To be safe you can always start from the config that comes with the model checkpoint. +Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line: -Then launch the training job: +- Update `load_path` to point to the checkpoint you want to start from. +- Set `reset_trainer_state` to `true`. +- Update `data.paths` to point to the `token_ids.npy` file you generated. +- Optionally update `data.label_mask_paths` to point to the `label_mask.npy` file you generated, unless you don't need special masking for the loss. +- Update `evaluators` to add/remove in-loop evaluations. + +Once you're satisfied with your training config, you can launch the training job via `torchrun`. For example: ``` torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ @@ -26,3 +32,5 @@ torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ --load_path={path_to_checkpoint} \ --reset_trainer_state ``` + +Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config. diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 49a155f78..807b9bc7f 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -13,7 +13,7 @@ integrations: ssh_clone: true command: |- checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded - learning_rate=2e-6 + learning_rate=2e-5 run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-v2 # NOTE: For some reason getting S3 and R2 authentication working both from the command line and From b69ea02478b3dd360c8bad1b0e2b15813691e4c6 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Jan 2024 11:56:10 -0800 Subject: [PATCH 08/13] update configs to use optimizer state --- configs/mcli/mitchish-instruct.yml | 14 ++++++---- configs/mitchish-instruct.yaml | 42 +++--------------------------- 2 files changed, 12 insertions(+), 44 deletions(-) diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 807b9bc7f..8596eda59 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -13,8 +13,8 @@ integrations: ssh_clone: true command: |- checkpoint=s3://olmo-checkpoints/ai2-llm/olmo-medium/wd2gxrza/step556000-unsharded - learning_rate=2e-5 - run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-v2 + learning_rate=2e-6 + run_name=mitchish-mcli-2.5T-instruct-${learning_rate}-5ep-v2 # NOTE: For some reason getting S3 and R2 authentication working both from the command line and # from Python proved to be challenging, maybe because Mosaic's server are in Australia. @@ -38,7 +38,6 @@ command: |- # Prepare environment including AWS config files for both S3 and R2 access. mkdir -p /root/.cache/torch mkdir /root/checkpoint-unsharded - mkdir /root/data mkdir /root/.aws touch /root/.aws/credentials /root/.aws/config echo '[s3]' >> /root/.aws/credentials @@ -75,6 +74,11 @@ command: |- --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ "${checkpoint}/model.pt" /root/checkpoint-unsharded/ + # Download optimizer state. + aws s3 cp --profile=r2 --region=auto \ + --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + "${checkpoint}/optim.pt" /root/checkpoint-unsharded/ + # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3. rm -rf /root/.aws @@ -90,8 +94,8 @@ command: |- --save_overwrite \ --save_interval_unsharded=10000 \ --load_path=/root/checkpoint-unsharded \ - --reset_optimizer_state \ --reset_trainer_state \ --compile=null \ --activation_checkpointing=fine_grained \ - --fsdp.wrapping_strategy=size_based + --fsdp.wrapping_strategy=size_based \ + --max_duration=5ep diff --git a/configs/mitchish-instruct.yaml b/configs/mitchish-instruct.yaml index 7b52c2558..ad247e7bc 100644 --- a/configs/mitchish-instruct.yaml +++ b/configs/mitchish-instruct.yaml @@ -43,15 +43,15 @@ compile: optimizer: name: adamw learning_rate: 2e-5 - weight_decay: 0.0 + weight_decay: 0.1 betas: - 0.9 - - 0.999 + - 0.95 metrics_log_interval: 10 scheduler: name: linear_with_warmup - t_warmup: 100 + t_warmup: 200 alpha_f: 0.001 tokenizer: @@ -91,42 +91,6 @@ eval_interval: ${save_interval} eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} evaluators: - - label: all-small-ppl-validation - data: - num_workers: 0 - drop_last: true - # pin_memory: true - # prefetch_factor: 1 - # persistent_workers: false - # timeout: 0 - datasets: - 4chan-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/4chan/val.npy - c4_100_domains-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_100_domains/val.npy - c4_en-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/c4_en/val.npy - gab-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/gab/val.npy - ice-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ice/val.npy - m2d2_s2orc-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_s2orc/val.npy - m2d2_wiki-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/m2d2_wiki/val.npy - manosphere-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/manosphere/val.npy - mc4_en-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/mc4_en/val.npy - pile-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/pile/val.npy - ptb-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/ptb/val.npy - twitterAEE-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/twitterAEE/val.npy - wikitext_103-validation: - - s3://ai2-llm/eval-data/perplexity/v2_small_gptneox20b/wikitext_103/val.npy - ########################## # Downstream evaluations # ########################## From cfbb68f1e6d63882e57a6fd02b7d8584b8ff4d20 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 19 Jan 2024 12:21:25 -0800 Subject: [PATCH 09/13] Update config --- configs/mcli/mitchish-instruct.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 8596eda59..84fa7b57d 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -75,9 +75,9 @@ command: |- "${checkpoint}/model.pt" /root/checkpoint-unsharded/ # Download optimizer state. - aws s3 cp --profile=r2 --region=auto \ - --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ - "${checkpoint}/optim.pt" /root/checkpoint-unsharded/ + #aws s3 cp --profile=r2 --region=auto \ + # --endpoint-url=https://a198dc34621661a1a66a02d6eb7c4dc3.r2.cloudflarestorage.com \ + # "${checkpoint}/optim.pt" /root/checkpoint-unsharded/ # Now remove the aws configs so it doesn't mess with data loading / uploading checkpoints to/from S3. rm -rf /root/.aws @@ -91,11 +91,13 @@ command: |- scripts/train.py configs/mitchish-instruct.yaml \ --run_name=${run_name} \ --optimizer.learning_rate=${learning_rate} \ + --scheduler.grad_clip_warmup_steps=400 \ --save_overwrite \ - --save_interval_unsharded=10000 \ + --save_interval_unsharded=100000 \ --load_path=/root/checkpoint-unsharded \ --reset_trainer_state \ + --reset_optimizer_state \ --compile=null \ - --activation_checkpointing=fine_grained \ + --activation_checkpointing=whole_layer \ --fsdp.wrapping_strategy=size_based \ --max_duration=5ep From f3298c683cbd7c45972e63db09e5063b5670fb97 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 19 Jan 2024 14:26:13 -0800 Subject: [PATCH 10/13] update branch --- configs/mcli/mitchish-instruct.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/mcli/mitchish-instruct.yml b/configs/mcli/mitchish-instruct.yml index 84fa7b57d..59ee0c270 100644 --- a/configs/mcli/mitchish-instruct.yml +++ b/configs/mcli/mitchish-instruct.yml @@ -8,7 +8,7 @@ compute: integrations: - integration_type: git_repo git_repo: allenai/LLM - git_branch: epwalsh/fine-tune-with-label-masking + git_branch: main pip_install: -e . ssh_clone: true command: |- From 3053bfaed647b36cdb5e57ba09ee2d0bf686ef0c Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 19 Jan 2024 15:00:11 -0800 Subject: [PATCH 11/13] Update install instructions in README --- README.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 48a37ba72..3e63e51dc 100644 --- a/README.md +++ b/README.md @@ -7,13 +7,25 @@ ## Installation +First install [PyTorch](https://pytorch.org) according to the instructions specific to your operating system. + +To install from source (recommended for training/fine-tuning) run: + +```bash +git clone https://github.com/allenai/OLMo.git +cd OLMo +pip install -e . ``` + +Otherwise you can install the model code by itself directly from PyPI with: + +```bash pip install ai2-olmo ``` ## Fine-tuning -To fine-tune an OLMo model you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. +To fine-tune an OLMo model using our trainer you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. Next, prepare your training config. There are many examples in the [`configs/`](./configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line: From 9b5155d0a343863e8a3227b74b3b84d6579f6ce0 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 22 Jan 2024 11:45:18 -0800 Subject: [PATCH 12/13] Track S3 upload progress by file --- scripts/storage_cleaner.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index 22a412d93..b63d1932d 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -8,7 +8,6 @@ from argparse import ArgumentParser, _SubParsersAction from dataclasses import dataclass from enum import Enum, auto -from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse @@ -22,7 +21,7 @@ from cached_path.schemes import S3Client from google.api_core.exceptions import NotFound from omegaconf import OmegaConf as om -from rich.progress import Progress, TaskID, track +from rich.progress import track from olmo import util from olmo.aliases import PathOrStr @@ -579,30 +578,27 @@ def download_folder(self, directory_path: str, local_dest_folder: PathOrStr): else: raise ValueError(f"Path {directory_path} is not a valid directory") + def _upload_file(self, local_filepath: str, bucket_name: str, key: str): + self._s3_client.upload_file(local_filepath, bucket_name, key) + def upload(self, local_src: PathOrStr, dest_path: str): if self.local_fs_adapter.is_file(str(local_src)): bucket_name, key = self._get_bucket_name_and_key(dest_path) - self._s3_client.upload_file(str(local_src), bucket_name, key) + self._upload_file(str(local_src), bucket_name, key) elif self.local_fs_adapter.is_dir(str(local_src)): local_src = Path(local_src) - def upload_callback(progress: Progress, upload_task: TaskID, bytes_uploaded: int): - progress.update(upload_task, advance=bytes_uploaded) - - for file_local_path in local_src.rglob("*"): + local_file_paths = list(local_src.rglob("*")) + for file_local_path in track(local_file_paths, description=f"Uploading to {dest_path}"): if file_local_path.is_dir(): continue file_dest_path = str(file_local_path).replace(str(local_src).rstrip("/"), dest_path.rstrip("/")) bucket_name, key = self._get_bucket_name_and_key(file_dest_path) - with Progress(transient=True) as progress: - size_in_bytes = file_local_path.stat().st_size - upload_task = progress.add_task(f"Uploading {key}", total=size_in_bytes) - callback = partial(upload_callback, progress, upload_task) - - self._s3_client.upload_file(str(file_local_path), bucket_name, key, Callback=callback) + if not self._is_file(bucket_name, key): + self._upload_file(str(file_local_path), bucket_name, key) else: raise ValueError(f"Local source {local_src} does not correspond to a valid file or directory") From 5c7d9c69e497e1a7908f4bac58b1cdf42cd31762 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 22 Jan 2024 11:47:04 -0800 Subject: [PATCH 13/13] Reduce number of concurrent S3 uploads to reduce throtlting --- scripts/storage_cleaner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/storage_cleaner.py b/scripts/storage_cleaner.py index b63d1932d..ca6e61a05 100644 --- a/scripts/storage_cleaner.py +++ b/scripts/storage_cleaner.py @@ -17,6 +17,7 @@ import google.cloud.storage as gcs import torch import wandb +from boto3.s3.transfer import TransferConfig from cached_path import add_scheme_client, cached_path, set_cache_dir from cached_path.schemes import S3Client from google.api_core.exceptions import NotFound @@ -579,7 +580,8 @@ def download_folder(self, directory_path: str, local_dest_folder: PathOrStr): raise ValueError(f"Path {directory_path} is not a valid directory") def _upload_file(self, local_filepath: str, bucket_name: str, key: str): - self._s3_client.upload_file(local_filepath, bucket_name, key) + transfer_config = TransferConfig(max_concurrency=4) + self._s3_client.upload_file(local_filepath, bucket_name, key, Config=transfer_config) def upload(self, local_src: PathOrStr, dest_path: str): if self.local_fs_adapter.is_file(str(local_src)):