Skip to content

Commit

Permalink
Merge pull request #295 from RobotSail/lora-4
Browse files Browse the repository at this point in the history
Implement saving FSDP with LoRA
  • Loading branch information
mergify[bot] authored Nov 13, 2024
2 parents ff36e64 + b945f56 commit 8a49747
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 60 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,8 @@ disable=raw-checker-failed,
broad-exception-caught,
super-init-not-called,
duplicate-code,
too-many-positional-arguments
too-many-positional-arguments,
too-many-lines

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
51 changes: 11 additions & 40 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
create_lora_config,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
Expand Down Expand Up @@ -136,13 +136,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
args.model_name_or_path, args.output_dir
) as path:
base_model_args["pretrained_model_name_or_path"] = path
base_model_args["use_padding_free_transformer"] = True
model = GPTDolomiteForCausalLM.from_pretrained(
**base_model_args,
use_padding_free_transformer=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(**base_model_args)

# store the base model args so we can recall them later if saving a LoRA model
args.base_model_args = base_model_args

if len(tokenizer) > model.config.vocab_size:
print(
f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
Expand Down Expand Up @@ -198,46 +201,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum, flash_enabled):
# - with the exception of granite, which handles it
# in the later stanza
if args.lora_r > 0:
# if lora
# Third Party
from peft import LoraConfig

# ensure we select only the modules that exist in the model
proj_layers = get_projection_layer_names(model)
if not args.lora_target_modules:
print(
f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules"
)
if not proj_layers:
raise RuntimeError("could not find any projection layers in the model")
args.__dict__["lora_target_modules"] = proj_layers
else:
# when the user specifies the module, we should verify that they align with what's in the model
lora_target_modules_set = set(args.lora_target_modules)
diff = lora_target_modules_set - set(proj_layers)
layers_to_target = lora_target_modules_set - diff
if len(diff) == len(args.lora_target_modules):
raise ValueError(
f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically."
)
if diff:
print(
f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m"
)
args.__dict__["lora_target_modules"] = list(layers_to_target)

peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
target_modules=args.lora_target_modules,
)
lora_config = create_lora_config(model, args)
model = prepare_peft_model(
model, peft_config, gradient_checkpointing=not args.use_dolomite
model,
lora_config,
args.distributed_training_framework,
gradient_checkpointing=not args.use_dolomite,
)

args.lora_config = lora_config
elif not args.use_dolomite:
model.gradient_checkpointing_enable()

Expand Down
38 changes: 27 additions & 11 deletions src/instructlab/training/setup_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

# Third Party
from accelerate import Accelerator
from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP,
BackwardPrefetch,
MixedPrecision,
ShardingStrategy,
)
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
import torch

# First Party
Expand Down Expand Up @@ -51,34 +49,52 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption
return ds_plugin


def get_fsdp_config(args, model):
def get_fsdp_config(args, model: PreTrainedModel):
# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload

is_lora = args.lora_r > 0
block_name = model._no_split_modules[0]

fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=partial(
wrap_policy = None
if is_lora > 0:
wrap_policy = fsdp_auto_wrap_policy(model)
else:
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, block_name),
},
),
)

# TODO(osilkin): BACKWARD_POST trades memory utilization for processing time, which is important for systems utilizing LoRA
# We should have this be configurable in the future.
prefetch_policy = (
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
mixed_precision_policy=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
backward_prefetch=prefetch_policy,
sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy],
cpu_offload=CPUOffload(args.cpu_offload_params_fsdp),
)

# `use_orig_params` must be disabled when using LoRA and FSDP together
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
if args.lora_r > 0:
fsdp_plugin.use_orig_params = False

return fsdp_plugin


def setup_accelerator(args, model, grad_accum):
def setup_accelerator(args, model: PreTrainedModel, grad_accum):
if args.distributed_training_framework == "deepspeed":
# Third Party
from deepspeed import DeepSpeedEngine
Expand Down
Loading

0 comments on commit 8a49747

Please sign in to comment.