Skip to content

Commit

Permalink
feat: retain only last checkpoint directory
Browse files Browse the repository at this point in the history
Introduced a new command-line argument `--keep_last_checkpoint_only`.
This flag determines whether we should only keep the last checkpoint
directory, with the previous epoch directories always being overwritten.
When this flag is enabled, the epoch directory is named `last_epoch`.

This flag is useful for managing disk space efficiently during model
training. By keeping only the last checkpoint directory and overwriting
the previous ones, it helps to significantly reduce the amount of
storage required. This is particularly beneficial when working with
large models and datasets, where each epoch can consume a substantial
amount of disk space. By enabling the --keep_last_checkpoint_only flag,
 users can ensure that only the most recent model state is saved,
 which is often sufficient for many training and evaluation
 purposes. This approach helps to avoid clutter and maintain a
 cleaner and more manageable file system.

 Given the fact that we always pick epoch 7 during phase 1
 training and do not perform evaluation on each epoch, one might
 decide it is not worth to save all epochs. By keeping only the
 last checkpoint, we can significantly reduce the amount of
 storage required, avoid clutter, and maintain a cleaner and more
 manageable file system.

Signed-off-by: Sébastien Han <[email protected]>
  • Loading branch information
leseb committed Dec 9, 2024
1 parent 84c0f72 commit 07b4201
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ for training jobs. There are a number of options you can specify, such as settin
| fsdp_options | The settings for controlling FSDP when it's selected as the distributed backend. |
| distributed_backend | Specifies which distributed training backend to use. Supported options are "fsdp" and "deepspeed". |
| disable_flash_attn | Disables flash attention when set to true. This allows for training on older devices. |
| keep_last_checkpoint_only | Determines whether we should only keep the last checkpoint directory - the previous checkpoint directory is always overwritten. The checkpoint directory is called `last_epoch`. |

### `DeepSpeedOptions`

Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,8 @@ class TrainingArgs(BaseModel):

# This field defines whether or not data processing will occur inside of `run_training()`
process_data: Optional[bool] = True

# This field specifies whether only the last checkpoint should be retained. When set to true, it
# will overwrite the previous checkpoint directory, keeping only one directory called
# "last_epoch". This works alongside the '--checkpoint_at_epoch' flag.
keep_last_checkpoint_only: Optional[bool] = False
12 changes: 12 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
f"--keep_last_checkpoint_only={train_args.keep_last_checkpoint_only}",
]

if train_args.checkpoint_at_epoch:
Expand Down Expand Up @@ -787,6 +788,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--fsdp_sharding_strategy={train_args.fsdp_options.sharding_strategy.value}"
)

if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m")
process = None
interrupt: KeyboardInterrupt | Exception | None = None
Expand Down Expand Up @@ -962,6 +966,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
),
)
parser.add_argument("--disable_flash_attn", action="store_true")
parser.add_argument(
"--keep_last_checkpoint_only",
action="store_true",
help=(
"Keep only the last checkpoint directory - overwrite the previous ones. Useful for saving disk space."
"The last checkpoint will be saved as 'last_epoch'."
),
)
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
11 changes: 9 additions & 2 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,13 @@ def save_hf_format_accelerate(
samples_seen,
is_lora=False,
):
# Build the subdirectory name
subdir = (
"last_epoch" if args.keep_last_checkpoint_only else f"samples_{samples_seen}"
)

log_rank_0(
f"\033[93mSaving model in huggingface format at samples_seen: {samples_seen}\033[0m",
f"\033[93mSaving model in huggingface format at: {subdir}\033[0m",
to_print=True,
)
start = time.time()
Expand All @@ -936,7 +941,9 @@ def save_hf_format_accelerate(
else:
convert_dolomite = True

final_output_dir = Path(args.output_dir) / "hf_format" / f"samples_{samples_seen}"
# Build the final output directory path
final_output_dir = Path(args.output_dir) / "hf_format" / subdir

if args.use_dolomite and convert_dolomite:
tmpdir = TemporaryDirectory("w") # pylint: disable=consider-using-with
output_dir = Path(tmpdir.name)
Expand Down

0 comments on commit 07b4201

Please sign in to comment.