Skip to content

Commit

Permalink
Merge pull request #235 from JamesKunstle/fullstate_saving
Browse files Browse the repository at this point in the history
adds Accelerate full-state (opt, lr_sched, params)
  • Loading branch information
JamesKunstle authored Oct 1, 2024
2 parents 386c6c0 + 07aa037 commit 8b252d8
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ class TrainingArgs(BaseModel):
warmup_steps: int
is_padding_free: bool
random_seed: int = 42
checkpoint_at_epoch: bool = False
checkpoint_at_epoch: bool = True
accelerate_full_state_at_epoch: bool = True

mock_data: Optional[bool] = False
mock_data_len: int = 0
Expand Down
53 changes: 40 additions & 13 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
apply_gradient_checkpointing,
convert_loss_to_reduce_sum,
ensure_loadable_granite_checkpoint,
load_latest_full_state,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
retrieve_chat_template,
save_checkpoint,
save_hf_format_accelerate,
set_random_seed,
setup_logger,
Expand Down Expand Up @@ -316,6 +318,10 @@ def train(
batch_size = args.effective_batch_size // grad_accum
samples_seen = 0

if hasattr(args, "samples_seen"):
print(f"\033[93mUpdating 'samples_seen' {args.samples_seen}\033[0m")
samples_seen = args.samples_seen

if args.save_samples > 0:
args.save_samples = (args.save_samples // batch_size) * batch_size
(
Expand All @@ -335,7 +341,7 @@ def train(
)

global_grad_norm = None
for epoch in range(args.num_epochs):
for epoch in range(args.current_epoch, args.num_epochs):
if args.sampler in ("multipack"):
train_loader.batch_sampler.set_epoch(epoch)
elif args.sampler in ("distributed"):
Expand All @@ -346,6 +352,7 @@ def train(
if local_rank == 0:
inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}")

# blast through the batches in the train loader up to the last step within the epoch.
for batch in train_loader:
if global_step <= args.last_step:
# in the case of resuming, last_step > 0
Expand Down Expand Up @@ -437,13 +444,14 @@ def train(
if args.save_samples > 0 and (
global_step * batch_size % args.save_samples == 0
):
save_hf_format_accelerate(
args,
model,
tokenizer,
accelerator,
samples_seen,
save_checkpoint(
args=args,
accelerator=accelerator,
model=model,
tokenizer=tokenizer,
samples_seen=samples_seen,
is_lora=bool(args.lora_r),
hf_format=True,
)

# if (
Expand All @@ -461,13 +469,16 @@ def train(
inner_pb.update(1)
torch.cuda.empty_cache()
if args.checkpoint_at_epoch:
save_hf_format_accelerate(
args,
model,
tokenizer,
accelerator,
samples_seen,
save_checkpoint(
args=args,
accelerator=accelerator,
model=model,
tokenizer=tokenizer,
samples_seen=samples_seen,
is_lora=bool(args.lora_r),
full_state=args.accelerate_full_state_at_epoch,
hf_format=True,
epoch=epoch,
)

if args.save_last:
Expand Down Expand Up @@ -588,6 +599,8 @@ def main(args):
args, tokenizer, train_loader, grad_accum
)

load_latest_full_state(args=args, accelerator=accelerator)

train(
args,
model,
Expand Down Expand Up @@ -661,6 +674,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.checkpoint_at_epoch:
command.append("--checkpoint_at_epoch")

if train_args.accelerate_full_state_at_epoch:
command.append("--accelerate_full_state_at_epoch")

if train_args.mock_data:
command.append("--mock_data")
if train_args.mock_len:
Expand Down Expand Up @@ -775,6 +791,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
parser.add_argument("--data_path", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument(
"--current_epoch",
type=int,
default=0,
help="Helpful flag for resuming on a later epoch. Sets dataloader correctly.",
)
parser.add_argument(
"--last_step",
type=int,
Expand Down Expand Up @@ -820,6 +842,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Save a model checkpoint after finishing an epoch.",
)
parser.add_argument(
"--accelerate_full_state_at_epoch",
action="store_true",
help="Save full model state using Accelerate after finishing an epoch.",
)
parser.add_argument("--log_level", type=str, default="INFO")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--mock_data", action="store_true")
Expand Down
111 changes: 111 additions & 0 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,114 @@ def set_random_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def save_checkpoint(
args,
accelerator: Accelerator,
model,
tokenizer,
samples_seen,
is_lora: bool,
epoch: int = None,
hf_format: bool = True,
full_state: bool = False,
) -> None:
if hf_format:
save_hf_format_accelerate(
args=args,
model=model,
accelerator=accelerator,
tokenizer=tokenizer,
samples_seen=samples_seen,
is_lora=is_lora,
)

if full_state:
save_full_state(
args=args,
accelerator=accelerator,
is_lora=is_lora,
epoch=epoch,
samples_seen=samples_seen,
)


def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int):
"""
Saves model, optimizer, and lr_scheduler state.
TODO: save model config - decided not to do this.
TODO: save tokenizer - decided not to do this.
TODO: handle LoRA
TODO: handle granite
"""
if is_lora:
raise NotImplementedError("Can't save full state for LoRA at the moment.")

# if args.is_granite:
# raise NotImplementedError("Can't save full state for Granite models yet.")

output_dir = Path(args.output_dir) / "full_state" / f"epoch_{epoch}"
log_rank_0(f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True)

# patch FSDP state dict method so it works correctly.
def _get_state_dict_patched(model, unwrap=False):
return get_state_dict_unpatched(model, unwrap=unwrap)

if args.distributed_training_framework == "fsdp":
get_state_dict_unpatched = accelerator.get_state_dict
accelerator.get_state_dict = _get_state_dict_patched

accelerator.save_state(
output_dir=output_dir,
# max_shard_size="5GB",
# safe_serialization=True,
)

# save metadata file for current training status
if accelerator.is_main_process:
# TODO: should we set the global_step here rather than calculating global_step
# based on samples_seen?
metadata = {"current_epoch": epoch, "samples_seen": samples_seen}
torch.save(metadata, output_dir / "training_metadata.json")
log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True)

log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True)

# cleanup
if args.distributed_training_framework == "fsdp":
accelerator.get_state_dict = get_state_dict_unpatched


def load_latest_full_state(args, accelerator) -> None:
"""
Loads accelerator state from most recently saved checkpoint
in `output_dir/full_state`.
"""
output_dir = Path(args.output_dir) / "full_state"

if not output_dir.is_dir():
return

# picks checkpoint with the largest number of samples seen, by name.
checkpoint_list = sorted(list(output_dir.iterdir()), reverse=True)

if len(checkpoint_list) == 0:
log_rank_0(
f"\033[93mNo checkpoints to load from: {output_dir}\033[0m", to_print=True
)
return

latest = checkpoint_list[0]

log_rank_0(f"\033[93mLoading state from: {latest}\033[0m", to_print=True)
accelerator.load_state(latest)

training_metadata = torch.load(latest / "training_metadata.json")
log_rank_0(
f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True
)

# previous epoch is basis for current epoch.
args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1
args.__dict__["samples_seen"] = training_metadata["samples_seen"]

0 comments on commit 8b252d8

Please sign in to comment.