Skip to content

Commit

Permalink
Merge pull request #84 from xingyaoww/specify_load_iters_from_ckpt
Browse files Browse the repository at this point in the history
Support specifying load_iters for checkpoint
  • Loading branch information
AleHD authored Nov 6, 2023
2 parents 820f102 + a68be85 commit 01fa877
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 3 deletions.
4 changes: 4 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,10 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--load_iters', type=int, default=None,
help='Specify which checkpoint to load. If not '
'specified, the latest checkpoint (highest iteration '
'number) located in the load directory will be used.')
group.add_argument('--no_load_optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no_load_rng', action='store_true', default=None,
Expand Down
14 changes: 11 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def fix_query_key_value_ordering(model, checkpoint_version):
" checkpoint version {}".format(checkpoint_version))


def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False, specify_iteration=None):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
Expand All @@ -431,6 +431,12 @@ def _load_base_checkpoint(load_dir, use_distributed_optimizer, rank0=False):
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
if specify_iteration is not None:
print_rank_0(
f'overriding iteration {iteration} read from checkpoint with specified iteration {specify_iteration}'
)
iteration = specify_iteration
release = iteration == 0

# Checkpoint.
if rank0:
Expand Down Expand Up @@ -495,7 +501,8 @@ def load_args_from_checkpoint(args, load_arg='load'):
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=True)
rank0=True,
specify_iteration=args.load_iters)

# For args we only care about model state dict
state_dict = model_state_dict
Expand Down Expand Up @@ -572,7 +579,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
model_state_dict, optim_state_dict, release = \
_load_base_checkpoint(load_dir,
use_distributed_optimizer=args.use_distributed_optimizer,
rank0=False)
rank0=False,
specify_iteration=args.load_iters)

if model_state_dict is None:
return 0
Expand Down
2 changes: 2 additions & 0 deletions tools/checkpoint_loader_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def _load_checkpoint(queue, args):

if args.bf16:
sys.argv += ["--bf16"]
if args.load_iters is not None:
sys.argv += ["--load_iters", str(args.load_iters)]

margs = megatron.arguments.parse_args()
margs = load_args_from_checkpoint(margs)
Expand Down
4 changes: 4 additions & 0 deletions tools/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def main():
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--bf16', action='store_true', help='force bfloat16 weights')
parser.add_argument('--load_iters', type=int, default=None,
help='Specify which checkpoint to load. If not '
'specified, the latest checkpoint (highest iteration '
'number) located in the load directory will be used.')

known_args, _ = parser.parse_known_args()
loader = load_plugin('loader', known_args.loader)
Expand Down

0 comments on commit 01fa877

Please sign in to comment.