diff --git a/.pylintrc b/.pylintrc index aff36d65..c0fbd95c 100644 --- a/.pylintrc +++ b/.pylintrc @@ -474,7 +474,8 @@ disable=raw-checker-failed, broad-exception-caught, super-init-not-called, duplicate-code, - too-many-positional-arguments + too-many-positional-arguments, + too-many-lines # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d6f9eeb2..0ad54c5f 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -68,8 +68,8 @@ check_flash_attn_enabled, check_valid_train_args, convert_loss_to_reduce_sum, + create_lora_config, ensure_loadable_dolomite_checkpoint, - get_projection_layer_names, load_latest_full_state, prepare_peft_model, prepare_universal_checkpoint_from_latest, @@ -136,13 +136,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): args.model_name_or_path, args.output_dir ) as path: base_model_args["pretrained_model_name_or_path"] = path + base_model_args["use_padding_free_transformer"] = True model = GPTDolomiteForCausalLM.from_pretrained( **base_model_args, - use_padding_free_transformer=True, ) else: model = AutoModelForCausalLM.from_pretrained(**base_model_args) + # store the base model args so we can recall them later if saving a LoRA model + args.base_model_args = base_model_args + if len(tokenizer) > model.config.vocab_size: print( f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size" @@ -198,46 +201,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled): # - with the exception of granite, which handles it # in the later stanza if args.lora_r > 0: - # if lora - # Third Party - from peft import LoraConfig - - # ensure we select only the modules that exist in the model - proj_layers = get_projection_layer_names(model) - if not args.lora_target_modules: - print( - f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules" - ) - if not proj_layers: - raise RuntimeError("could not find any projection layers in the model") - args.__dict__["lora_target_modules"] = proj_layers - else: - # when the user specifies the module, we should verify that they align with what's in the model - lora_target_modules_set = set(args.lora_target_modules) - diff = lora_target_modules_set - set(proj_layers) - layers_to_target = lora_target_modules_set - diff - if len(diff) == len(args.lora_target_modules): - raise ValueError( - f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically." - ) - if diff: - print( - f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m" - ) - args.__dict__["lora_target_modules"] = list(layers_to_target) - - peft_config = LoraConfig( - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - r=args.lora_r, - bias="none", - task_type="CAUSAL_LM", - target_modules=args.lora_target_modules, - ) + lora_config = create_lora_config(model, args) model = prepare_peft_model( - model, peft_config, gradient_checkpointing=not args.use_dolomite + model, + lora_config, + args.distributed_training_framework, + gradient_checkpointing=not args.use_dolomite, ) - + args.lora_config = lora_config elif not args.use_dolomite: model.gradient_checkpointing_enable() diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py index 33972b59..239367f4 100644 --- a/src/instructlab/training/setup_accelerator.py +++ b/src/instructlab/training/setup_accelerator.py @@ -3,12 +3,10 @@ # Third Party from accelerate import Accelerator -from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP, - BackwardPrefetch, - MixedPrecision, - ShardingStrategy, -) +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import PreTrainedModel import torch # First Party @@ -51,34 +49,52 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption return ds_plugin -def get_fsdp_config(args, model): +def get_fsdp_config(args, model: PreTrainedModel): # Third Party from accelerate.utils import FullyShardedDataParallelPlugin from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + is_lora = args.lora_r > 0 block_name = model._no_split_modules[0] - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=partial( + wrap_policy = None + if is_lora > 0: + wrap_policy = fsdp_auto_wrap_policy(model) + else: + wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ get_module_class_from_name(model, block_name), }, - ), + ) + + # TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA + # We should have this be configurable in the future. + prefetch_policy = ( + BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE + ) + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=wrap_policy, limit_all_gathers=True, mixed_precision_policy=MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + backward_prefetch=prefetch_policy, sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy], cpu_offload=CPUOffload(args.cpu_offload_params_fsdp), ) + + # `use_orig_params` must be disabled when using LoRA and FSDP together + # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts + if args.lora_r > 0: + fsdp_plugin.use_orig_params = False + return fsdp_plugin -def setup_accelerator(args, model, grad_accum): +def setup_accelerator(args, model: PreTrainedModel, grad_accum): if args.distributed_training_framework == "deepspeed": # Third Party from deepspeed import DeepSpeedEngine diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index b6f655bf..6a9d6f84 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from argparse import Namespace from collections import OrderedDict from contextlib import contextmanager from copy import deepcopy from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import importlib import inspect import logging @@ -21,7 +22,7 @@ # Third Party # pylint: disable=no-name-in-module -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from instructlab.dolomite.hf_models import ( GPTDolomiteConfig, export_to_huggingface, @@ -29,19 +30,27 @@ ) from rich.logging import RichHandler from torch import distributed as dist +from torch import nn from torch.distributed import get_rank, is_initialized from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper, ) -from transformers import PreTrainedModel +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType +from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer import numpy as np import torch import torch.nn.functional as F # First Party -from instructlab.training.config import TrainingArgs +from instructlab.training.config import ( + DistributedBackend, + QuantizeDataType, + TrainingArgs, +) def check_valid_train_args(train_args: TrainingArgs): @@ -71,6 +80,25 @@ def check_valid_train_args(train_args: TrainingArgs): "\033[33m WARNING: is_padding_free is being deprecated due to adoption of the default padding-free support in Hugging Face Transformers. As such, this flag is non-functional in 0.6.0 and beyond. If you would like to use the older Dolomite padding-free implementation, please set use_dolomite moving forward.\033[0m" ) + if ( + train_args.accelerate_full_state_at_epoch + and train_args.lora + and train_args.lora.rank > 0 + ): + raise ValueError( + "`accelerate_full_state_at_epoch` is not currently supported when training LoRA models." + ) + + if ( + train_args.lora + and train_args.lora.rank > 0 + and train_args.lora.quantize_data_type != QuantizeDataType.NONE + and train_args.distributed_backend == DistributedBackend.FSDP.value + ): + raise ValueError( + "Quantization is not supported when training LoRA models with FSDP. For quantized LoRA training, please switch to DeepSpeed." + ) + def retrieve_chat_template(chat_tmpl_path): try: @@ -403,9 +431,133 @@ def patch_target_module( setattr(source, obj_name_to_patch, replace_with) +def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: + """Checks if a module or its children are an instance of one of the provided classes. + + Args: + module (nn.Module): A PyTorch module. + wrapped_classes(Tuple): A tuple of potential classes the module could be. + + Returns: + bool: True if the module or any of its children are instances of one of `wrapped_classes`, False otherwise. + """ + if isinstance(module, wrapped_classes): + return True + + for m in module.children(): + if wraps(m, wrapped_classes): + return True + + return False + + +def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig": + # if lora + # Third Party + from peft import LoraConfig + + # ensure we select only the modules that exist in the model + proj_layers = get_projection_layer_names(model) + if not args.lora_target_modules: + print( + f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules" + ) + if not proj_layers: + raise RuntimeError("could not find any projection layers in the model") + args.__dict__["lora_target_modules"] = proj_layers + else: + # when the user specifies the module, we should verify that they align with what's in the model + lora_target_modules_set = set(args.lora_target_modules) + diff = lora_target_modules_set - set(proj_layers) + layers_to_target = lora_target_modules_set - diff + if len(diff) == len(args.lora_target_modules): + raise ValueError( + f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically." + ) + if diff: + print( + f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m" + ) + args.__dict__["lora_target_modules"] = list(layers_to_target) + + return LoraConfig( + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + r=args.lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=args.lora_target_modules, + ) + + +def save_fsdp_lora_model( + args: Namespace, + model: FSDP, + tokenizer: PreTrainedTokenizer, + accelerator: Accelerator, + output_dir: Path, +): + """Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original + model with the trained LoRA adapters merged into the copy. + + This function creates a full copy of the model being trained and stores it in CPU memory. + If encountering OOM errors on CPU, this is likely a culprit. + + Args: + args (Namespace): Args received by the ArgumentParser. + model (FSDP): FSDP model as prepared by `accelerate.Accelerator` + accelerator (Accelerator): The given accelerator object. + """ + # Third Party + from peft import LoraConfig, LoraModel + + if accelerator.distributed_type != DistributedType.FSDP: + raise RuntimeError( + "`save_fsdp_lora_model` was called when FSDP was not being used." + ) + if not wraps(model, FSDP): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not an FSDP model." + ) + if not wraps(model, LoraModel): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not a LoRA model." + ) + + # okay now that validation is out of the way, we are free to implement saving + lora_conf: LoraConfig = args.lora_config + sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config): + state = model.state_dict() + + # When training a LoRA with FSDP and Accelerate, you cannot directly merge the adapters into + # the model wrapped by FSDP. To get around this limitation, we get a copy of the state dict + # create an identical model on CPU, load the state dict into the CPU model, merge the adapters + # and save the model to disk. + if accelerator.is_main_process: + # remove device_map from args list so we can load the model on CPU + old_device_map = args.base_model_args.pop("device_map", None) + model_copy = AutoModelForCausalLM.from_pretrained( + **args.base_model_args, device_map="cpu" + ) + model_copy = LoraModel(model_copy, lora_conf, "default") + model_copy.load_state_dict(state) + model_copy.merge_and_unload(progressbar=True) + model_copy.save_pretrained(output_dir, safe_serialization=True) + model.config.to_json_file(f"{output_dir}/config.json") + tokenizer.save_pretrained(output_dir) + del model_copy + if old_device_map: + # return the previous device_map so it can be used later on if needed + args.base_model_args["device_map"] = old_device_map + + dist.barrier() + + def prepare_peft_model( - model, + model: PreTrainedModel, peft_config, + distributed_backend: str, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": True}, mixed_precision="bf16", @@ -413,6 +565,7 @@ def prepare_peft_model( # will guard this # Third Party from peft import ( + LoraModel, PeftConfig, PeftModel, get_peft_model, @@ -454,7 +607,11 @@ def make_inputs_require_grad(module, input, output): make_inputs_require_grad ) - model = get_peft_model(model, peft_config) + if distributed_backend == DistributedBackend.FSDP.value: + # FSDP doesn't like `get_peft_model` as it leads to dtype mismatches + model = LoraModel(model, peft_config, "default") + else: + model = get_peft_model(model, peft_config) if mixed_precision == "bf16" and getattr(model, "is_loaded_in_4bit", False): peft_module_casting_to_bf16(model) @@ -729,7 +886,7 @@ def _copy_no_lora_dict(state_dict): def save_dict_accelerate( - accelerator, + accelerator: Accelerator, state_to_save, save_directory, max_shard_size="5GB", @@ -784,6 +941,18 @@ def save_hf_format_accelerate( CONFIG_NAME = "config.json" output_config_file = output_dir / CONFIG_NAME + # XXX(osilkin): LoRA + FSDP requires a different saving path than the others + # so we set this variable and use it to avoid those paths further down. + is_fsdp_lora = is_lora and accelerator.distributed_type == DistributedType.FSDP + if is_fsdp_lora: + save_fsdp_lora_model( + args=args, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + output_dir=output_dir, + ) + get_state_dict_unpatched = accelerator.get_state_dict def _get_state_dict_patched(model, unwrap=False): @@ -791,7 +960,7 @@ def _get_state_dict_patched(model, unwrap=False): accelerator.get_state_dict = _get_state_dict_patched - if accelerator.is_main_process: + if not is_fsdp_lora and accelerator.is_main_process: if is_lora: model.module.merge_adapter() model_state = model.module.state_dict() diff --git a/tests/smoketest.sh b/tests/smoketest.sh index 9bdb0df4..6918fb03 100755 --- a/tests/smoketest.sh +++ b/tests/smoketest.sh @@ -19,6 +19,7 @@ NUM_GPUS="${2:-${DEFAULT_GPUS}}" # ############### User-modifiable parameters ############### # Change these as needed MAX_BATCH_LEN=60000 +MAX_SEQ_LEN=4096 NUM_SAMPLES_TRAINED_ON=5000 # upper-bound on training dataset size. # ############### Test Functions ############### @@ -191,6 +192,31 @@ function test_standard_loop_noflashattention_nogranite () { # --is_granite } + +############################################################################## +# Validates the pathing logic for FSDP & LoRA. +# A valid run should result in a model with all adapters merged +# with the base model. +############################################################################## +function test_standard_loop_fsdp_lora() { + torchrun \ + --standalone \ + --nproc_per_node="${NUM_GPUS}" \ + main_ds.py \ + --model_name_or_path="${MODEL_NAME}" \ + --data_path="${COMPUTED_DATA_PATH}" \ + --output_dir="${CHECKPOINTS_DIR}" \ + --num_epochs=1 \ + --effective_batch_size=128 \ + --save_samples=0 \ + --checkpoint_at_epoch \ + --distributed_training_framework="${DISTRIB_FRAMEWORK}" \ + --max_batch_len="${MAX_BATCH_LEN}" \ + --lora_r=4 \ + --lora_alpha=32 \ + --lora_dropout=0.1 +} + function main () { setup_tmpdir @@ -207,6 +233,7 @@ function main () { test_standard_loop_nongranite _cleanup_saved_checkpoints test_standard_loop + test_standard_loop_fsdp_lora } main